Merge branch 'develop'
This commit is contained in:
commit
f3076caa1e
@ -19,15 +19,22 @@ class Split:
|
|||||||
If used with a split first description, make sure you use compatible
|
If used with a split first description, make sure you use compatible
|
||||||
settings.
|
settings.
|
||||||
|
|
||||||
|
Use `sub_sample (and `random_state`) parameters to sub sample the training
|
||||||
|
set.
|
||||||
|
|
||||||
|
return xtrain, xtest, ytrain, ytest, test_index
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ground_truth, attributes, n_test=2, order_dim=0, remove_unclassified=True):
|
def __init__(self, ground_truth, attributes, n_test=2, order_dim=0, sub_sample=1.0, random_state=0, remove_unclassified=True):
|
||||||
self._att = attributes
|
self._att = attributes
|
||||||
self._gt = ground_truth
|
self._gt = ground_truth
|
||||||
self._n = n_test
|
self._n = n_test
|
||||||
self._d = order_dim
|
self._d = order_dim
|
||||||
self._s = 0
|
self._s = 0
|
||||||
self._r = remove_unclassified
|
self._r = remove_unclassified
|
||||||
|
self._ssp = sub_sample
|
||||||
|
self._rs = random_state
|
||||||
|
|
||||||
self._size = ground_truth.shape[order_dim]
|
self._size = ground_truth.shape[order_dim]
|
||||||
self._step = int(ground_truth.shape[order_dim] / n_test)
|
self._step = int(ground_truth.shape[order_dim] / n_test)
|
||||||
@ -48,10 +55,14 @@ class Split:
|
|||||||
unclassified = self._gt == 0
|
unclassified = self._gt == 0
|
||||||
train_index = ~test_index & ~unclassified
|
train_index = ~test_index & ~unclassified
|
||||||
|
|
||||||
|
# Sub sample training
|
||||||
|
np.random.seed(self._rs)
|
||||||
|
train_index &= np.random.random(train_index.shape) < self._ssp
|
||||||
|
|
||||||
|
# Remove unclassified
|
||||||
if self._r:
|
if self._r:
|
||||||
test_index &= ~unclassified
|
test_index &= ~unclassified
|
||||||
|
|
||||||
#ipdb.set_trace()
|
|
||||||
xtrain = self._att[train_index]
|
xtrain = self._att[train_index]
|
||||||
xtest = self._att[test_index]
|
xtest = self._att[test_index]
|
||||||
ytrain = self._gt[train_index]
|
ytrain = self._gt[train_index]
|
||||||
|
|||||||
@ -68,7 +68,20 @@ def run(rasters, treshold=1e4, areas=None, sd=None, moi=None, split=1, split_dim
|
|||||||
for i, cut in enumerate(dcuts):
|
for i, cut in enumerate(dcuts):
|
||||||
view[i*step:(i+1)*step+1] = np.moveaxis(cut, 0, d)
|
view[i*step:(i+1)*step+1] = np.moveaxis(cut, 0, d)
|
||||||
|
|
||||||
return descriptors
|
# Merge with original
|
||||||
|
loader = ld2dap.LoadTIFF(rasters)
|
||||||
|
raw_in = ld2dap.RawInput(descriptors, vout.metadata)
|
||||||
|
merger = ld2dap.Merger()
|
||||||
|
final_out = ld2dap.RawOutput()
|
||||||
|
|
||||||
|
final_out.input = merger
|
||||||
|
merger.input = loader
|
||||||
|
merger.second.input = raw_in
|
||||||
|
|
||||||
|
loader.run()
|
||||||
|
raw_in.run()
|
||||||
|
|
||||||
|
return final_out.data
|
||||||
|
|
||||||
def version():
|
def version():
|
||||||
return 'v0.0'
|
return 'v0.0'
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user