Add GroupStratifiedShuffleSplit
This commit is contained in:
parent
6f92c575f6
commit
a2bca010f5
@ -0,0 +1 @@
|
|||||||
|
from .generic import *
|
||||||
25
minigrida/cvgenerators/generic.py
Normal file
25
minigrida/cvgenerators/generic.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# file generic.py
|
||||||
|
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||||
|
# version 0.0
|
||||||
|
# date 08 juil. 2020
|
||||||
|
"""Abstract
|
||||||
|
|
||||||
|
doc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sklearn.model_selection import GroupKFold, StratifiedShuffleSplit
|
||||||
|
|
||||||
|
|
||||||
|
class GroupStratifiedShuffleSplit(GroupKFold):
|
||||||
|
def __init__(self, n_splits=5, train_size=None, random_state=None):
|
||||||
|
super().__init__(n_splits)
|
||||||
|
self.train_size = train_size
|
||||||
|
self.random_state = random_state
|
||||||
|
|
||||||
|
def split(self, X, y=None, groups=None):
|
||||||
|
splits = [s for s in super().split(X, y, groups)]
|
||||||
|
for train, test in splits:
|
||||||
|
sss_train, sss_test = next(StratifiedShuffleSplit(1, train_size=self.train_size, random_state=self.random_state).split(X[train], y[train], groups[train]))
|
||||||
|
|
||||||
|
yield sss_train, test
|
||||||
Loading…
Reference in New Issue
Block a user