Merge branch 'develop'
This commit is contained in:
commit
f3076caa1e
@ -19,15 +19,22 @@ class Split:
|
||||
If used with a split first description, make sure you use compatible
|
||||
settings.
|
||||
|
||||
Use `sub_sample (and `random_state`) parameters to sub sample the training
|
||||
set.
|
||||
|
||||
return xtrain, xtest, ytrain, ytest, test_index
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, ground_truth, attributes, n_test=2, order_dim=0, remove_unclassified=True):
|
||||
def __init__(self, ground_truth, attributes, n_test=2, order_dim=0, sub_sample=1.0, random_state=0, remove_unclassified=True):
|
||||
self._att = attributes
|
||||
self._gt = ground_truth
|
||||
self._n = n_test
|
||||
self._d = order_dim
|
||||
self._s = 0
|
||||
self._r = remove_unclassified
|
||||
self._ssp = sub_sample
|
||||
self._rs = random_state
|
||||
|
||||
self._size = ground_truth.shape[order_dim]
|
||||
self._step = int(ground_truth.shape[order_dim] / n_test)
|
||||
@ -48,10 +55,14 @@ class Split:
|
||||
unclassified = self._gt == 0
|
||||
train_index = ~test_index & ~unclassified
|
||||
|
||||
# Sub sample training
|
||||
np.random.seed(self._rs)
|
||||
train_index &= np.random.random(train_index.shape) < self._ssp
|
||||
|
||||
# Remove unclassified
|
||||
if self._r:
|
||||
test_index &= ~unclassified
|
||||
|
||||
#ipdb.set_trace()
|
||||
xtrain = self._att[train_index]
|
||||
xtest = self._att[test_index]
|
||||
ytrain = self._gt[train_index]
|
||||
|
||||
@ -68,7 +68,20 @@ def run(rasters, treshold=1e4, areas=None, sd=None, moi=None, split=1, split_dim
|
||||
for i, cut in enumerate(dcuts):
|
||||
view[i*step:(i+1)*step+1] = np.moveaxis(cut, 0, d)
|
||||
|
||||
return descriptors
|
||||
# Merge with original
|
||||
loader = ld2dap.LoadTIFF(rasters)
|
||||
raw_in = ld2dap.RawInput(descriptors, vout.metadata)
|
||||
merger = ld2dap.Merger()
|
||||
final_out = ld2dap.RawOutput()
|
||||
|
||||
final_out.input = merger
|
||||
merger.input = loader
|
||||
merger.second.input = raw_in
|
||||
|
||||
loader.run()
|
||||
raw_in.run()
|
||||
|
||||
return final_out.data
|
||||
|
||||
def version():
|
||||
return 'v0.0'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user