Quick fix on cvgenerator

This commit is contained in:
Florent Guiotte 2020-07-08 17:46:22 +02:00
parent a2bca010f5
commit 2af2656837

View File

@ -22,4 +22,4 @@ class GroupStratifiedShuffleSplit(GroupKFold):
for train, test in splits: for train, test in splits:
sss_train, sss_test = next(StratifiedShuffleSplit(1, train_size=self.train_size, random_state=self.random_state).split(X[train], y[train], groups[train])) sss_train, sss_test = next(StratifiedShuffleSplit(1, train_size=self.train_size, random_state=self.random_state).split(X[train], y[train], groups[train]))
yield sss_train, test yield train[sss_train], test