134 lines
3.9 KiB
Plaintext
134 lines
3.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class cvg:\n",
|
|
" def __init__(self, attributes, ground_truth, n_test=2, order_dim=0):\n",
|
|
" self._order = order_dim\n",
|
|
" self._ntests = n_test\n",
|
|
" self._actual_ntest = 0\n",
|
|
" self._size = attributes.shape[order_dim]\n",
|
|
" self._att = attributes\n",
|
|
" self._gt = ground_truth\n",
|
|
" \n",
|
|
" if attributes.shape[0] != ground_truth.shape[0] or \\\n",
|
|
" attributes.shape[1] != ground_truth.shape[1] :\n",
|
|
" raise ValueError('attributes and ground_truth must have the same 2D shape')\n",
|
|
" \n",
|
|
" def __iter__(self):\n",
|
|
" return self\n",
|
|
" \n",
|
|
" def __next__(self):\n",
|
|
" if self._actual_ntest == self._ntests:\n",
|
|
" raise StopIteration\n",
|
|
" \n",
|
|
" step = self._size / self._ntests\n",
|
|
" train_filter = (np.arange(self._size) - step * self._actual_ntest) % self._size < step\n",
|
|
" \n",
|
|
" if self._order == 0:\n",
|
|
" Xtrain = self._att[train_filter].reshape(-1, self._att.shape[2])\n",
|
|
" Xtest = self._att[train_filter == False].reshape(-1, self._att.shape[2])\n",
|
|
" Ytrain = self._gt[train_filter].reshape(-1, 1)\n",
|
|
" Ytest = self._gt[train_filter == False].reshape(-1, 1)\n",
|
|
" else:\n",
|
|
" Xtrain = self._att[:,train_filter].reshape(-1, self._att.shape[2])\n",
|
|
" Xtest = self._att[:,train_filter == False].reshape(-1, self._att.shape[2])\n",
|
|
" Ytrain = self._gt[:,train_filter].reshape(-1, 1)\n",
|
|
" Ytest = self._gt[:,train_filter == False].reshape(-1, 1)\n",
|
|
"\n",
|
|
" \n",
|
|
" self._actual_ntest += 1\n",
|
|
" \n",
|
|
" return (Xtrain, Xtest, Ytrain, Ytest, train_filter)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X = np.arange(100*200*10).reshape(100, 200, 10)\n",
|
|
"Y = np.arange(100 * 200).reshape(100, 200)\n",
|
|
"\n",
|
|
"for xn, xt, yn, yt, t in cvg(X, Y, 10, 1):\n",
|
|
" disp = np.zeros(Y.shape)\n",
|
|
" disp[:,t] = 1.\n",
|
|
" plt.imshow(disp)\n",
|
|
" plt.show()\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"sys.path.append('..')\n",
|
|
"from CrossValidationGenerator import APsCVG\n",
|
|
"sys.path.append('../triskele/python')\n",
|
|
"import triskele"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"gt = triskele.read('../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif')\n",
|
|
"att = np.array((gt, gt))\n",
|
|
"att = np.moveaxis(att, 0, 2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for xt, xv, yt, yv, ti in APsCVG(gt, att, 1):\n",
|
|
" print(xt.shape, yt.shape, xv.shape, yv.shape)\n",
|
|
" plt.imshow(ti * 1.)\n",
|
|
" plt.show()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|