51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
#!/usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
# \file CrossValidationGenerator.py
|
|
# \brief TODO
|
|
# \author Florent Guiotte <florent.guiotte@gmail.com>
|
|
# \version 0.1
|
|
# \date 28 Mar 2018
|
|
#
|
|
# TODO details
|
|
|
|
import numpy as np
|
|
|
|
class CVG:
|
|
def __init__(self, attributes, ground_truth, n_test=2, order_dim=0):
|
|
self._order = order_dim
|
|
self._ntests = n_test
|
|
self._actual_ntest = 0
|
|
self._size = attributes.shape[order_dim]
|
|
self._att = attributes
|
|
self._gt = ground_truth
|
|
|
|
if attributes.shape[0] != ground_truth.shape[0] or \
|
|
attributes.shape[1] != ground_truth.shape[1] :
|
|
raise ValueError('attributes and ground_truth must have the same 2D shape')
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self._actual_ntest == self._ntests:
|
|
raise StopIteration
|
|
|
|
step = self._size / self._ntests
|
|
train_filter = (np.arange(self._size) - step * self._actual_ntest) % self._size < step
|
|
|
|
if self._order == 0:
|
|
Xtrain = self._att[train_filter].reshape(-1, self._att.shape[2])
|
|
Xtest = self._att[train_filter == False].reshape(-1, self._att.shape[2])
|
|
Ytrain = self._gt[train_filter].reshape(-1)
|
|
Ytest = self._gt[train_filter == False].reshape(-1)
|
|
else:
|
|
Xtrain = self._att[:,train_filter].reshape(-1, self._att.shape[2])
|
|
Xtest = self._att[:,train_filter == False].reshape(-1, self._att.shape[2])
|
|
Ytrain = self._gt[:,train_filter].reshape(-1)
|
|
Ytest = self._gt[:,train_filter == False].reshape(-1)
|
|
|
|
|
|
self._actual_ntest += 1
|
|
|
|
return (Xtrain, Xtest, Ytrain, Ytest, train_filter)
|