Add GroupStratifiedShuffleSplit

This commit is contained in:
Florent Guiotte 2020-07-08 17:23:50 +02:00
parent 6f92c575f6
commit a2bca010f5
2 changed files with 26 additions and 0 deletions

View File

@ -0,0 +1 @@
from .generic import *

View File

@ -0,0 +1,25 @@
#!/usr/bin/env python
# file generic.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 08 juil. 2020
"""Abstract
doc.
"""
from sklearn.model_selection import GroupKFold, StratifiedShuffleSplit
class GroupStratifiedShuffleSplit(GroupKFold):
def __init__(self, n_splits=5, train_size=None, random_state=None):
super().__init__(n_splits)
self.train_size = train_size
self.random_state = random_state
def split(self, X, y=None, groups=None):
splits = [s for s in super().split(X, y, groups)]
for train, test in splits:
sss_train, sss_test = next(StratifiedShuffleSplit(1, train_size=self.train_size, random_state=self.random_state).split(X[train], y[train], groups[train]))
yield sss_train, test