Add GroupStratifiedShuffleSplit
This commit is contained in:
parent
6f92c575f6
commit
a2bca010f5
@ -0,0 +1 @@
|
||||
from .generic import *
|
||||
25
minigrida/cvgenerators/generic.py
Normal file
25
minigrida/cvgenerators/generic.py
Normal file
@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
# file generic.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 08 juil. 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
|
||||
from sklearn.model_selection import GroupKFold, StratifiedShuffleSplit
|
||||
|
||||
|
||||
class GroupStratifiedShuffleSplit(GroupKFold):
|
||||
def __init__(self, n_splits=5, train_size=None, random_state=None):
|
||||
super().__init__(n_splits)
|
||||
self.train_size = train_size
|
||||
self.random_state = random_state
|
||||
|
||||
def split(self, X, y=None, groups=None):
|
||||
splits = [s for s in super().split(X, y, groups)]
|
||||
for train, test in splits:
|
||||
sss_train, sss_test = next(StratifiedShuffleSplit(1, train_size=self.train_size, random_state=self.random_state).split(X[train], y[train], groups[train]))
|
||||
|
||||
yield sss_train, test
|
||||
Loading…
Reference in New Issue
Block a user