CVG ready

This commit is contained in:
Florent Guiotte 2018-07-09 12:56:39 +02:00
parent 1badfb9710
commit 4c04b1227f
4 changed files with 72 additions and 15 deletions

View File

@ -51,8 +51,47 @@ class CVG_legacy:
class APsCVG:
"""Cross Validation Generator for Attribute Profiles Descriptors"""
def __init__(self, ground_truth, attributes, cv_count=5, label_ignore=None):
d
def __init__(self, ground_truth, attributes, n_test=5, label_ignore=None):
self._gt = ground_truth
self._att = attributes
self._cv_count = n_test
self._actual_count = 0
if attributes.shape[0] != ground_truth.shape[0] or \
attributes.shape[1] != ground_truth.shape[1] :
raise ValueError('attributes and ground_truth must have the same 2D shape')
def __iter__(self):
return self
def __next__(self):
if self._cv_count == self._actual_count:
raise StopIteration
split_map = semantic_cvg(self._gt, self._cv_count, self._actual_count)
xtrain = self._att[split_map == 1].reshape(-1, self._att.shape[2])
xtest = self._att[split_map == 2].reshape(-1, self._att.shape[2])
ytrain = self._gt[split_map == 1].reshape(-1)
ytest = self._gt[split_map == 2].reshape(-1)
test_index = split_map == 2
self._actual_count += 1
return xtrain, xtest, ytrain, ytest, test_index
def semantic_cvg(gt, nb_split, step=0):
count = np.unique(gt, return_counts=True)
test_part = 1 / nb_split
split = np.zeros_like(gt)
for lbli, lblc in zip(count[0][1:], count[1][1:]):
treshold = int(lblc * test_part)
#print('lbli:{}, count:{}, train:{}'.format(lbli, lblc, treshold))
f = np.nonzero(gt == lbli)
t_int, t_ext = treshold * step, treshold * (step + 1)
split[f[0], f[1]] = 1
split[f[0][t_int:t_ext], f[1][t_int:t_ext]] = 2
return split

View File

@ -77,7 +77,7 @@
" '../Data/phase1_rasters/Intensity_C3/UH17_GI3F051_TR.tif',\n",
" #'../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif',\n",
" #'../Res/HVR/C123_num_returns_0_5_nearest.tif',\n",
" #'../Res/HVR noisy/C123_num_returns_0_5_nearest.tif'\n",
" '../Res/HVR noisy/C123_num_returns_0_5_nearest.tif'\n",
"\n",
"]"
]

View File

@ -79,7 +79,34 @@
"metadata": {},
"outputs": [],
"source": [
"from CrossValidationGenerator import CVG"
"import sys\n",
"sys.path.append('..')\n",
"from CrossValidationGenerator import APsCVG\n",
"sys.path.append('../triskele/python')\n",
"import triskele"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gt = triskele.read('../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif')\n",
"att = np.array((gt, gt))\n",
"att = np.moveaxis(att, 0, 2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for xt, xv, yt, yv, ti in APsCVG(gt, att, 2):\n",
" print(xt.shape, yt.shape, xv.shape, yv.shape)\n",
" plt.imshow(ti * 1.)\n",
" plt.show()"
]
}
],

View File

@ -13,7 +13,7 @@
"# Triskele\n",
"import sys\n",
"from pathlib import Path\n",
"triskele_path = Path('../triskele')\n",
"triskele_path = Path('../triskele/python')\n",
"sys.path.append(str(triskele_path.resolve()))\n",
"import triskele"
]
@ -133,7 +133,7 @@
"metadata": {},
"outputs": [],
"source": [
"nb_split = 10\n",
"nb_split = 5\n",
"\n",
"def semantic_cvg(gt, nb_split, step=0):\n",
" test_part = 1 / nb_split\n",
@ -162,15 +162,6 @@
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gt == lbli"
]
},
{
"cell_type": "markdown",
"metadata": {},