Add coords in pipeline
This commit is contained in:
parent
06649497ae
commit
7bff8e20d2
@ -11,12 +11,12 @@ doc.
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def run(gt, rasters, remove=None):
|
def run(gt, rasters, coords, remove=None):
|
||||||
X = []
|
X = []
|
||||||
y = []
|
y = []
|
||||||
groups = []
|
groups = []
|
||||||
|
|
||||||
for i, (gti, rastersi) in enumerate(zip(gt, rasters)):
|
for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
|
||||||
# Create vectors
|
# Create vectors
|
||||||
X_raw = np.moveaxis(np.array(list(rastersi.values())), 0, -1)
|
X_raw = np.moveaxis(np.array(list(rastersi.values())), 0, -1)
|
||||||
y_raw = gti
|
y_raw = gti
|
||||||
@ -28,7 +28,7 @@ def run(gt, rasters, remove=None):
|
|||||||
|
|
||||||
X += [X_raw[lbl]]
|
X += [X_raw[lbl]]
|
||||||
y += [y_raw[lbl]]
|
y += [y_raw[lbl]]
|
||||||
groups += [np.repeat(i, lbl.sum())]
|
groups += [np.repeat(coordsi, lbl.sum())]
|
||||||
|
|
||||||
X = np.concatenate(X)
|
X = np.concatenate(X)
|
||||||
y = np.concatenate(y)
|
y = np.concatenate(y)
|
||||||
|
|||||||
@ -9,8 +9,8 @@ doc.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import rasterio as rio
|
|
||||||
import logging
|
import logging
|
||||||
|
import rasterio as rio
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -21,6 +21,7 @@ def run(rasters_path, rasters_name, gt_suffix='gt.tif'):
|
|||||||
|
|
||||||
gt = []
|
gt = []
|
||||||
rasters = []
|
rasters = []
|
||||||
|
coords = []
|
||||||
|
|
||||||
for gtn in gt_names:
|
for gtn in gt_names:
|
||||||
gt += [load_tif(gtn)]
|
gt += [load_tif(gtn)]
|
||||||
@ -29,7 +30,9 @@ def run(rasters_path, rasters_name, gt_suffix='gt.tif'):
|
|||||||
load_tif(gtn.as_posix().replace(gt_suffix, '') + n)
|
load_tif(gtn.as_posix().replace(gt_suffix, '') + n)
|
||||||
for n in rasters_name}]
|
for n in rasters_name}]
|
||||||
|
|
||||||
return gt, rasters
|
coords += ['_'.join(gtn.stem.split('_')[:2])]
|
||||||
|
|
||||||
|
return gt, rasters, coords
|
||||||
|
|
||||||
|
|
||||||
def load_tif(path):
|
def load_tif(path):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user