Add coords in pipeline

This commit is contained in:
Florent Guiotte 2020-05-30 11:06:20 +02:00
parent 06649497ae
commit 7bff8e20d2
2 changed files with 8 additions and 5 deletions

View File

@ -11,12 +11,12 @@ doc.
import numpy as np import numpy as np
def run(gt, rasters, remove=None): def run(gt, rasters, coords, remove=None):
X = [] X = []
y = [] y = []
groups = [] groups = []
for i, (gti, rastersi) in enumerate(zip(gt, rasters)): for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
# Create vectors # Create vectors
X_raw = np.moveaxis(np.array(list(rastersi.values())), 0, -1) X_raw = np.moveaxis(np.array(list(rastersi.values())), 0, -1)
y_raw = gti y_raw = gti
@ -28,7 +28,7 @@ def run(gt, rasters, remove=None):
X += [X_raw[lbl]] X += [X_raw[lbl]]
y += [y_raw[lbl]] y += [y_raw[lbl]]
groups += [np.repeat(i, lbl.sum())] groups += [np.repeat(coordsi, lbl.sum())]
X = np.concatenate(X) X = np.concatenate(X)
y = np.concatenate(y) y = np.concatenate(y)

View File

@ -9,8 +9,8 @@ doc.
""" """
from pathlib import Path from pathlib import Path
import rasterio as rio
import logging import logging
import rasterio as rio
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -21,6 +21,7 @@ def run(rasters_path, rasters_name, gt_suffix='gt.tif'):
gt = [] gt = []
rasters = [] rasters = []
coords = []
for gtn in gt_names: for gtn in gt_names:
gt += [load_tif(gtn)] gt += [load_tif(gtn)]
@ -29,7 +30,9 @@ def run(rasters_path, rasters_name, gt_suffix='gt.tif'):
load_tif(gtn.as_posix().replace(gt_suffix, '') + n) load_tif(gtn.as_posix().replace(gt_suffix, '') + n)
for n in rasters_name}] for n in rasters_name}]
return gt, rasters coords += ['_'.join(gtn.stem.split('_')[:2])]
return gt, rasters, coords
def load_tif(path): def load_tif(path):