minigrida/minigrida/descriptors/daps.py
2020-07-07 11:07:13 +02:00

68 lines
1.8 KiB
Python

#!/usr/bin/env python
# file daps.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 07 juil. 2020
#!/usr/bin/env python
# file aps.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 30 mai 2020
"""Abstract
doc.
"""
import numpy as np
import sap
from multiprocessing import Pool
def _diff_attribute_profiles(*kwargs):
return sap.attribute_profiles(*kwargs).diff()
def run(gt, rasters, coords, remove,
attributes, adjacency='4', filtering='direct', dtype=np.float32):
X = []
y = []
groups = []
Xn = None
for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
# Compute EAP
attributes = [attributes] * len(rastersi) if isinstance(attributes, dict) else attributes
pool = Pool()
eap = pool.starmap(_diff_attribute_profiles, [
(raster, attribute, adjacency, name, filtering)
for (name, raster), attribute
in zip(rastersi.items(), attributes)])
pool.close()
pool.join()
eap = sap.concatenate(eap)
Xn = [' '.join((a.description['tree']['image_name'],
a.description['attribute'],
*[sap.profiles._title(p)]))
for a in eap for p in a.description['profiles']] if not Xn else Xn
# Create vectors
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1).astype(dtype)
y_raw = gti
# Remove unwanted label X, y
lbl = np.ones_like(y_raw, dtype=np.bool)
for l in remove if remove else []:
lbl &= y_raw != l
X += [X_raw[lbl]]
y += [y_raw[lbl]]
groups += [np.repeat(coordsi, lbl.sum())]
X = np.concatenate(X)
y = np.concatenate(y)
groups = np.concatenate(groups)
return X, y, groups, Xn