From a2214ee5fe71b9387afbeeed643b798902d33dc4 Mon Sep 17 00:00:00 2001 From: Karamaz0V1 Date: Thu, 11 Oct 2018 19:01:45 +0200 Subject: [PATCH] Add sub sampling parameters in Split --- cvgenerators/jurse.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/cvgenerators/jurse.py b/cvgenerators/jurse.py index c68eb20..24e1b46 100644 --- a/cvgenerators/jurse.py +++ b/cvgenerators/jurse.py @@ -19,15 +19,22 @@ class Split: If used with a split first description, make sure you use compatible settings. + Use `sub_sample (and `random_state`) parameters to sub sample the training + set. + + return xtrain, xtest, ytrain, ytest, test_index + """ - def __init__(self, ground_truth, attributes, n_test=2, order_dim=0, remove_unclassified=True): + def __init__(self, ground_truth, attributes, n_test=2, order_dim=0, sub_sample=1.0, random_state=0, remove_unclassified=True): self._att = attributes self._gt = ground_truth self._n = n_test self._d = order_dim self._s = 0 self._r = remove_unclassified + self._ssp = sub_sample + self._rs = random_state self._size = ground_truth.shape[order_dim] self._step = int(ground_truth.shape[order_dim] / n_test) @@ -48,10 +55,14 @@ class Split: unclassified = self._gt == 0 train_index = ~test_index & ~unclassified + # Sub sample training + np.random.seed(self._rs) + train_index &= np.random.random(train_index.shape) < self._ssp + + # Remove unclassified if self._r: test_index &= ~unclassified - #ipdb.set_trace() xtrain = self._att[train_index] xtest = self._att[test_index] ytrain = self._gt[train_index]