{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class cvg:\n", " def __init__(self, attributes, ground_truth, n_test=2, order_dim=0):\n", " self._order = order_dim\n", " self._ntests = n_test\n", " self._actual_ntest = 0\n", " self._size = attributes.shape[order_dim]\n", " self._att = attributes\n", " self._gt = ground_truth\n", " \n", " if attributes.shape[0] != ground_truth.shape[0] or \\\n", " attributes.shape[1] != ground_truth.shape[1] :\n", " raise ValueError('attributes and ground_truth must have the same 2D shape')\n", " \n", " def __iter__(self):\n", " return self\n", " \n", " def __next__(self):\n", " if self._actual_ntest == self._ntests:\n", " raise StopIteration\n", " \n", " step = self._size / self._ntests\n", " train_filter = (np.arange(self._size) - step * self._actual_ntest) % self._size < step\n", " \n", " if self._order == 0:\n", " Xtrain = self._att[train_filter].reshape(-1, self._att.shape[2])\n", " Xtest = self._att[train_filter == False].reshape(-1, self._att.shape[2])\n", " Ytrain = self._gt[train_filter].reshape(-1, 1)\n", " Ytest = self._gt[train_filter == False].reshape(-1, 1)\n", " else:\n", " Xtrain = self._att[:,train_filter].reshape(-1, self._att.shape[2])\n", " Xtest = self._att[:,train_filter == False].reshape(-1, self._att.shape[2])\n", " Ytrain = self._gt[:,train_filter].reshape(-1, 1)\n", " Ytest = self._gt[:,train_filter == False].reshape(-1, 1)\n", "\n", " \n", " self._actual_ntest += 1\n", " \n", " return (Xtrain, Xtest, Ytrain, Ytest, train_filter)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = np.arange(100*200*10).reshape(100, 200, 10)\n", "Y = np.arange(100 * 200).reshape(100, 200)\n", "\n", "for xn, xt, yn, yt, t in cvg(X, Y, 10, 1):\n", " disp = np.zeros(Y.shape)\n", " disp[:,t] = 1.\n", " plt.imshow(disp)\n", " plt.show()\n", " \n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", "from CrossValidationGenerator import APsCVG\n", "sys.path.append('../triskele/python')\n", "import triskele" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gt = triskele.read('../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif')\n", "att = np.array((gt, gt))\n", "att = np.moveaxis(att, 0, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for xt, xv, yt, yv, ti in APsCVG(gt, att, 1):\n", " print(xt.shape, yt.shape, xv.shape, yv.shape)\n", " plt.imshow(ti * 1.)\n", " plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 2 }