Merge branch 'develop'

This commit is contained in:
Florent Guiotte 2018-10-12 10:55:53 +02:00
commit 851617c7f6

View File

@ -19,15 +19,22 @@ class Split:
If used with a split first description, make sure you use compatible
settings.
Use `sub_sample (and `random_state`) parameters to sub sample the training
set.
return xtrain, xtest, ytrain, ytest, test_index
"""
def __init__(self, ground_truth, attributes, n_test=2, order_dim=0, remove_unclassified=True):
def __init__(self, ground_truth, attributes, n_test=2, order_dim=0, sub_sample=1.0, random_state=0, remove_unclassified=True):
self._att = attributes
self._gt = ground_truth
self._n = n_test
self._d = order_dim
self._s = 0
self._r = remove_unclassified
self._ssp = sub_sample
self._rs = random_state
self._size = ground_truth.shape[order_dim]
self._step = int(ground_truth.shape[order_dim] / n_test)
@ -48,10 +55,14 @@ class Split:
unclassified = self._gt == 0
train_index = ~test_index & ~unclassified
# Sub sample training
np.random.seed(self._rs)
train_index &= np.random.random(train_index.shape) < self._ssp
# Remove unclassified
if self._r:
test_index &= ~unclassified
#ipdb.set_trace()
xtrain = self._att[train_index]
xtest = self._att[test_index]
ytrain = self._gt[train_index]