ld2daps/Notebooks/Cross Validation Generator.ipynb

134 lines
3.9 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class cvg:\n",
" def __init__(self, attributes, ground_truth, n_test=2, order_dim=0):\n",
" self._order = order_dim\n",
" self._ntests = n_test\n",
" self._actual_ntest = 0\n",
" self._size = attributes.shape[order_dim]\n",
" self._att = attributes\n",
" self._gt = ground_truth\n",
" \n",
" if attributes.shape[0] != ground_truth.shape[0] or \\\n",
" attributes.shape[1] != ground_truth.shape[1] :\n",
" raise ValueError('attributes and ground_truth must have the same 2D shape')\n",
" \n",
" def __iter__(self):\n",
" return self\n",
" \n",
" def __next__(self):\n",
" if self._actual_ntest == self._ntests:\n",
" raise StopIteration\n",
" \n",
" step = self._size / self._ntests\n",
" train_filter = (np.arange(self._size) - step * self._actual_ntest) % self._size < step\n",
" \n",
" if self._order == 0:\n",
" Xtrain = self._att[train_filter].reshape(-1, self._att.shape[2])\n",
" Xtest = self._att[train_filter == False].reshape(-1, self._att.shape[2])\n",
" Ytrain = self._gt[train_filter].reshape(-1, 1)\n",
" Ytest = self._gt[train_filter == False].reshape(-1, 1)\n",
" else:\n",
" Xtrain = self._att[:,train_filter].reshape(-1, self._att.shape[2])\n",
" Xtest = self._att[:,train_filter == False].reshape(-1, self._att.shape[2])\n",
" Ytrain = self._gt[:,train_filter].reshape(-1, 1)\n",
" Ytest = self._gt[:,train_filter == False].reshape(-1, 1)\n",
"\n",
" \n",
" self._actual_ntest += 1\n",
" \n",
" return (Xtrain, Xtest, Ytrain, Ytest, train_filter)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X = np.arange(100*200*10).reshape(100, 200, 10)\n",
"Y = np.arange(100 * 200).reshape(100, 200)\n",
"\n",
"for xn, xt, yn, yt, t in cvg(X, Y, 10, 1):\n",
" disp = np.zeros(Y.shape)\n",
" disp[:,t] = 1.\n",
" plt.imshow(disp)\n",
" plt.show()\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"from CrossValidationGenerator import APsCVG\n",
"sys.path.append('../triskele/python')\n",
"import triskele"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gt = triskele.read('../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif')\n",
"att = np.array((gt, gt))\n",
"att = np.moveaxis(att, 0, 2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for xt, xv, yt, yv, ti in APsCVG(gt, att, 1):\n",
" print(xt.shape, yt.shape, xv.shape, yv.shape)\n",
" plt.imshow(ti * 1.)\n",
" plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}