diff --git a/minigrida/cvgenerators/__init__.py b/minigrida/cvgenerators/__init__.py index e69de29..c9840bc 100644 --- a/minigrida/cvgenerators/__init__.py +++ b/minigrida/cvgenerators/__init__.py @@ -0,0 +1 @@ +from .generic import * diff --git a/minigrida/cvgenerators/generic.py b/minigrida/cvgenerators/generic.py new file mode 100644 index 0000000..43f8403 --- /dev/null +++ b/minigrida/cvgenerators/generic.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# file generic.py +# author Florent Guiotte +# version 0.0 +# date 08 juil. 2020 +"""Abstract + +doc. +""" + +from sklearn.model_selection import GroupKFold, StratifiedShuffleSplit + + +class GroupStratifiedShuffleSplit(GroupKFold): + def __init__(self, n_splits=5, train_size=None, random_state=None): + super().__init__(n_splits) + self.train_size = train_size + self.random_state = random_state + + def split(self, X, y=None, groups=None): + splits = [s for s in super().split(X, y, groups)] + for train, test in splits: + sss_train, sss_test = next(StratifiedShuffleSplit(1, train_size=self.train_size, random_state=self.random_state).split(X[train], y[train], groups[train])) + + yield sss_train, test