Pool for descriptors

This commit is contained in:
Florent Guiotte 2020-06-11 11:45:41 +02:00
parent 4cffb2fc73
commit a70c782931
2 changed files with 25 additions and 21 deletions

View File

@ -10,17 +10,15 @@ doc.
import numpy as np
import sap
from joblib import Memory
memory = Memory(location='cache/', verbose=0)
from multiprocessing import Pool
@memory.cache
def _attribute_profiles(*kwargs):
return sap.attribute_profiles(*kwargs)
def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direct'):
def run(gt, rasters, coords, remove,
attributes, adjacency='4', filtering='direct', dtype=np.float32):
X = []
y = []
groups = []
@ -29,9 +27,13 @@ def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direc
for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
# Compute EAP
attributes = [attributes] * len(rastersi) if isinstance(attributes, dict) else attributes
eap = []
for (name, raster), attribute in zip(rastersi.items(), attributes):
eap += [_attribute_profiles(raster, attribute, adjacency, name, filtering)]
pool = Pool()
eap = pool.starmap(_attribute_profiles, [
(raster, attribute, adjacency, name, filtering)
for (name, raster), attribute
in zip(rastersi.items(), attributes)])
pool.close()
pool.join()
eap = sap.concatenate(eap)
Xn = [' '.join((a['tree']['image_name'],
@ -40,7 +42,7 @@ def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direc
for a in eap.description for p in a['profiles']] if not Xn else Xn
# Create vectors
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1)
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1).astype(dtype)
y_raw = gti
# Remove unwanted label X, y

View File

@ -10,17 +10,15 @@ doc.
import numpy as np
import sap
from joblib import Memory
memory = Memory(location='cache/', verbose=0)
from multiprocessing import Pool
@memory.cache
def _self_dual_attribute_profiles(*kwargs):
return sap.self_dual_attribute_profiles(*kwargs)
def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direct'):
def run(gt, rasters, coords, remove,
attributes, adjacency='4', filtering='direct', dtype=np.float32):
X = []
y = []
groups = []
@ -29,9 +27,13 @@ def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direc
for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
# Compute EAP
attributes = [attributes] * len(rastersi) if isinstance(attributes, dict) else attributes
eap = []
for (name, raster), attribute in zip(rastersi.items(), attributes):
eap += [_self_dual_attribute_profiles(raster, attribute, adjacency, name, filtering)]
pool = Pool()
eap = pool.starmap(_self_dual_attribute_profiles, [
(raster, attribute, adjacency, name, filtering)
for (name, raster), attribute
in zip(rastersi.items(), attributes)])
pool.close()
pool.join()
eap = sap.concatenate(eap)
Xn = [' '.join((a['tree']['image_name'],
@ -40,7 +42,7 @@ def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direc
for a in eap.description for p in a['profiles']] if not Xn else Xn
# Create vectors
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1)
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1).astype(dtype)
y_raw = gti
# Remove unwanted label X, y