Pool for descriptors
This commit is contained in:
parent
4cffb2fc73
commit
a70c782931
@ -10,17 +10,15 @@ doc.
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sap
|
import sap
|
||||||
from joblib import Memory
|
from multiprocessing import Pool
|
||||||
|
|
||||||
memory = Memory(location='cache/', verbose=0)
|
|
||||||
|
|
||||||
|
|
||||||
@memory.cache
|
|
||||||
def _attribute_profiles(*kwargs):
|
def _attribute_profiles(*kwargs):
|
||||||
return sap.attribute_profiles(*kwargs)
|
return sap.attribute_profiles(*kwargs)
|
||||||
|
|
||||||
|
|
||||||
def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direct'):
|
def run(gt, rasters, coords, remove,
|
||||||
|
attributes, adjacency='4', filtering='direct', dtype=np.float32):
|
||||||
X = []
|
X = []
|
||||||
y = []
|
y = []
|
||||||
groups = []
|
groups = []
|
||||||
@ -29,9 +27,13 @@ def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direc
|
|||||||
for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
|
for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
|
||||||
# Compute EAP
|
# Compute EAP
|
||||||
attributes = [attributes] * len(rastersi) if isinstance(attributes, dict) else attributes
|
attributes = [attributes] * len(rastersi) if isinstance(attributes, dict) else attributes
|
||||||
eap = []
|
pool = Pool()
|
||||||
for (name, raster), attribute in zip(rastersi.items(), attributes):
|
eap = pool.starmap(_attribute_profiles, [
|
||||||
eap += [_attribute_profiles(raster, attribute, adjacency, name, filtering)]
|
(raster, attribute, adjacency, name, filtering)
|
||||||
|
for (name, raster), attribute
|
||||||
|
in zip(rastersi.items(), attributes)])
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
eap = sap.concatenate(eap)
|
eap = sap.concatenate(eap)
|
||||||
|
|
||||||
Xn = [' '.join((a['tree']['image_name'],
|
Xn = [' '.join((a['tree']['image_name'],
|
||||||
@ -40,7 +42,7 @@ def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direc
|
|||||||
for a in eap.description for p in a['profiles']] if not Xn else Xn
|
for a in eap.description for p in a['profiles']] if not Xn else Xn
|
||||||
|
|
||||||
# Create vectors
|
# Create vectors
|
||||||
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1)
|
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1).astype(dtype)
|
||||||
y_raw = gti
|
y_raw = gti
|
||||||
|
|
||||||
# Remove unwanted label X, y
|
# Remove unwanted label X, y
|
||||||
|
|||||||
@ -10,17 +10,15 @@ doc.
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sap
|
import sap
|
||||||
from joblib import Memory
|
from multiprocessing import Pool
|
||||||
|
|
||||||
memory = Memory(location='cache/', verbose=0)
|
|
||||||
|
|
||||||
|
|
||||||
@memory.cache
|
|
||||||
def _self_dual_attribute_profiles(*kwargs):
|
def _self_dual_attribute_profiles(*kwargs):
|
||||||
return sap.self_dual_attribute_profiles(*kwargs)
|
return sap.self_dual_attribute_profiles(*kwargs)
|
||||||
|
|
||||||
|
|
||||||
def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direct'):
|
def run(gt, rasters, coords, remove,
|
||||||
|
attributes, adjacency='4', filtering='direct', dtype=np.float32):
|
||||||
X = []
|
X = []
|
||||||
y = []
|
y = []
|
||||||
groups = []
|
groups = []
|
||||||
@ -29,18 +27,22 @@ def run(gt, rasters, coords, remove, attributes, adjacency='4', filtering='direc
|
|||||||
for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
|
for i, (gti, rastersi, coordsi) in enumerate(zip(gt, rasters, coords)):
|
||||||
# Compute EAP
|
# Compute EAP
|
||||||
attributes = [attributes] * len(rastersi) if isinstance(attributes, dict) else attributes
|
attributes = [attributes] * len(rastersi) if isinstance(attributes, dict) else attributes
|
||||||
eap = []
|
pool = Pool()
|
||||||
for (name, raster), attribute in zip(rastersi.items(), attributes):
|
eap = pool.starmap(_self_dual_attribute_profiles, [
|
||||||
eap += [_self_dual_attribute_profiles(raster, attribute, adjacency, name, filtering)]
|
(raster, attribute, adjacency, name, filtering)
|
||||||
|
for (name, raster), attribute
|
||||||
|
in zip(rastersi.items(), attributes)])
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
eap = sap.concatenate(eap)
|
eap = sap.concatenate(eap)
|
||||||
|
|
||||||
Xn = [' '.join((a['tree']['image_name'],
|
Xn = [' '.join((a['tree']['image_name'],
|
||||||
a['attribute'],
|
a['attribute'],
|
||||||
*[str(v) for v in p.values()]))
|
*[str(v) for v in p.values()]))
|
||||||
for a in eap.description for p in a['profiles']] if not Xn else Xn
|
for a in eap.description for p in a['profiles']] if not Xn else Xn
|
||||||
|
|
||||||
# Create vectors
|
# Create vectors
|
||||||
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1)
|
X_raw = np.moveaxis(np.array(list(eap.vectorize())), 0, -1).astype(dtype)
|
||||||
y_raw = gti
|
y_raw = gti
|
||||||
|
|
||||||
# Remove unwanted label X, y
|
# Remove unwanted label X, y
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user