ld2daps/CVGenerators/CrossValidationGenerator.py

98 lines
3.4 KiB
Python

#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file CrossValidationGenerator.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 28 Mar 2018
#
# TODO details
import numpy as np
class CVG_legacy:
def __init__(self, attributes, ground_truth, n_test=2, order_dim=0):
self._order = order_dim
self._ntests = n_test
self._actual_ntest = 0
self._size = attributes.shape[order_dim]
self._att = attributes
self._gt = ground_truth
if attributes.shape[0] != ground_truth.shape[0] or \
attributes.shape[1] != ground_truth.shape[1] :
raise ValueError('attributes and ground_truth must have the same 2D shape')
def __iter__(self):
return self
def __next__(self):
if self._actual_ntest == self._ntests:
raise StopIteration
step = self._size / self._ntests
train_filter = (np.arange(self._size) - step * self._actual_ntest) % self._size < step
if self._order == 0:
Xtrain = self._att[train_filter].reshape(-1, self._att.shape[2])
Xtest = self._att[train_filter == False].reshape(-1, self._att.shape[2])
Ytrain = self._gt[train_filter].reshape(-1)
Ytest = self._gt[train_filter == False].reshape(-1)
else:
Xtrain = self._att[:,train_filter].reshape(-1, self._att.shape[2])
Xtest = self._att[:,train_filter == False].reshape(-1, self._att.shape[2])
Ytrain = self._gt[:,train_filter].reshape(-1)
Ytest = self._gt[:,train_filter == False].reshape(-1)
self._actual_ntest += 1
return (Xtrain, Xtest, Ytrain, Ytest, train_filter)
class APsCVG:
"""Cross Validation Generator for Attribute Profiles Descriptors"""
def __init__(self, ground_truth, attributes, n_test=5, label_ignore=None):
self._gt = ground_truth
self._att = attributes
self._cv_count = n_test
self._actual_count = 0
if attributes.shape[0] != ground_truth.shape[0] or \
attributes.shape[1] != ground_truth.shape[1] :
raise ValueError('attributes and ground_truth must have the same 2D shape')
def __iter__(self):
return self
def __next__(self):
if self._cv_count == self._actual_count:
raise StopIteration
split_map = semantic_cvg(self._gt, self._cv_count, self._actual_count)
xtrain = self._att[split_map == 1].reshape(-1, self._att.shape[2])
xtest = self._att[split_map == 2].reshape(-1, self._att.shape[2])
ytrain = self._gt[split_map == 1].reshape(-1)
ytest = self._gt[split_map == 2].reshape(-1)
test_index = split_map == 2
self._actual_count += 1
return xtrain, xtest, ytrain, ytest, test_index
def semantic_cvg(gt, nb_split, step=0):
count = np.unique(gt, return_counts=True)
test_part = 1 / nb_split
split = np.zeros_like(gt)
for lbli, lblc in zip(count[0][1:], count[1][1:]):
treshold = int(lblc * test_part)
#print('lbli:{}, count:{}, train:{}'.format(lbli, lblc, treshold))
f = np.nonzero(gt == lbli)
t_int, t_ext = treshold * step, treshold * (step + 1)
split[f[0], f[1]] = 1
split[f[0][t_int:t_ext], f[1][t_int:t_ext]] = 2
return split