#!/usr/bin/python # -*- coding: utf-8 -*- # \file CrossValidationGenerator.py # \brief TODO # \author Florent Guiotte # \version 0.1 # \date 28 Mar 2018 # # TODO details import numpy as np class CVG_legacy: def __init__(self, attributes, ground_truth, n_test=2, order_dim=0): self._order = order_dim self._ntests = n_test self._actual_ntest = 0 self._size = attributes.shape[order_dim] self._att = attributes self._gt = ground_truth if attributes.shape[0] != ground_truth.shape[0] or \ attributes.shape[1] != ground_truth.shape[1] : raise ValueError('attributes and ground_truth must have the same 2D shape') def __iter__(self): return self def __next__(self): if self._actual_ntest == self._ntests: raise StopIteration step = self._size / self._ntests train_filter = (np.arange(self._size) - step * self._actual_ntest) % self._size < step if self._order == 0: Xtrain = self._att[train_filter].reshape(-1, self._att.shape[2]) Xtest = self._att[train_filter == False].reshape(-1, self._att.shape[2]) Ytrain = self._gt[train_filter].reshape(-1) Ytest = self._gt[train_filter == False].reshape(-1) else: Xtrain = self._att[:,train_filter].reshape(-1, self._att.shape[2]) Xtest = self._att[:,train_filter == False].reshape(-1, self._att.shape[2]) Ytrain = self._gt[:,train_filter].reshape(-1) Ytest = self._gt[:,train_filter == False].reshape(-1) self._actual_ntest += 1 return (Xtrain, Xtest, Ytrain, Ytest, train_filter) class APsCVG: """Cross Validation Generator for Attribute Profiles Descriptors""" def __init__(self, ground_truth, attributes, cv_count=5, label_ignore=None): d return xtrain, xtest, ytrain, ytest, test_index