CVG ready
This commit is contained in:
parent
1badfb9710
commit
4c04b1227f
@ -51,8 +51,47 @@ class CVG_legacy:
|
|||||||
|
|
||||||
class APsCVG:
|
class APsCVG:
|
||||||
"""Cross Validation Generator for Attribute Profiles Descriptors"""
|
"""Cross Validation Generator for Attribute Profiles Descriptors"""
|
||||||
def __init__(self, ground_truth, attributes, cv_count=5, label_ignore=None):
|
def __init__(self, ground_truth, attributes, n_test=5, label_ignore=None):
|
||||||
d
|
self._gt = ground_truth
|
||||||
|
self._att = attributes
|
||||||
|
self._cv_count = n_test
|
||||||
|
self._actual_count = 0
|
||||||
|
|
||||||
|
if attributes.shape[0] != ground_truth.shape[0] or \
|
||||||
|
attributes.shape[1] != ground_truth.shape[1] :
|
||||||
|
raise ValueError('attributes and ground_truth must have the same 2D shape')
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._cv_count == self._actual_count:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
split_map = semantic_cvg(self._gt, self._cv_count, self._actual_count)
|
||||||
|
xtrain = self._att[split_map == 1].reshape(-1, self._att.shape[2])
|
||||||
|
xtest = self._att[split_map == 2].reshape(-1, self._att.shape[2])
|
||||||
|
ytrain = self._gt[split_map == 1].reshape(-1)
|
||||||
|
ytest = self._gt[split_map == 2].reshape(-1)
|
||||||
|
test_index = split_map == 2
|
||||||
|
|
||||||
|
self._actual_count += 1
|
||||||
|
|
||||||
return xtrain, xtest, ytrain, ytest, test_index
|
return xtrain, xtest, ytrain, ytest, test_index
|
||||||
|
|
||||||
|
def semantic_cvg(gt, nb_split, step=0):
|
||||||
|
count = np.unique(gt, return_counts=True)
|
||||||
|
|
||||||
|
test_part = 1 / nb_split
|
||||||
|
|
||||||
|
split = np.zeros_like(gt)
|
||||||
|
|
||||||
|
for lbli, lblc in zip(count[0][1:], count[1][1:]):
|
||||||
|
treshold = int(lblc * test_part)
|
||||||
|
#print('lbli:{}, count:{}, train:{}'.format(lbli, lblc, treshold))
|
||||||
|
f = np.nonzero(gt == lbli)
|
||||||
|
t_int, t_ext = treshold * step, treshold * (step + 1)
|
||||||
|
split[f[0], f[1]] = 1
|
||||||
|
split[f[0][t_int:t_ext], f[1][t_int:t_ext]] = 2
|
||||||
|
|
||||||
|
return split
|
||||||
|
@ -77,7 +77,7 @@
|
|||||||
" '../Data/phase1_rasters/Intensity_C3/UH17_GI3F051_TR.tif',\n",
|
" '../Data/phase1_rasters/Intensity_C3/UH17_GI3F051_TR.tif',\n",
|
||||||
" #'../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif',\n",
|
" #'../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif',\n",
|
||||||
" #'../Res/HVR/C123_num_returns_0_5_nearest.tif',\n",
|
" #'../Res/HVR/C123_num_returns_0_5_nearest.tif',\n",
|
||||||
" #'../Res/HVR noisy/C123_num_returns_0_5_nearest.tif'\n",
|
" '../Res/HVR noisy/C123_num_returns_0_5_nearest.tif'\n",
|
||||||
"\n",
|
"\n",
|
||||||
"]"
|
"]"
|
||||||
]
|
]
|
||||||
|
@ -79,7 +79,34 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from CrossValidationGenerator import CVG"
|
"import sys\n",
|
||||||
|
"sys.path.append('..')\n",
|
||||||
|
"from CrossValidationGenerator import APsCVG\n",
|
||||||
|
"sys.path.append('../triskele/python')\n",
|
||||||
|
"import triskele"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"gt = triskele.read('../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif')\n",
|
||||||
|
"att = np.array((gt, gt))\n",
|
||||||
|
"att = np.moveaxis(att, 0, 2)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for xt, xv, yt, yv, ti in APsCVG(gt, att, 2):\n",
|
||||||
|
" print(xt.shape, yt.shape, xv.shape, yv.shape)\n",
|
||||||
|
" plt.imshow(ti * 1.)\n",
|
||||||
|
" plt.show()"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
"# Triskele\n",
|
"# Triskele\n",
|
||||||
"import sys\n",
|
"import sys\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"triskele_path = Path('../triskele')\n",
|
"triskele_path = Path('../triskele/python')\n",
|
||||||
"sys.path.append(str(triskele_path.resolve()))\n",
|
"sys.path.append(str(triskele_path.resolve()))\n",
|
||||||
"import triskele"
|
"import triskele"
|
||||||
]
|
]
|
||||||
@ -133,7 +133,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"nb_split = 10\n",
|
"nb_split = 5\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def semantic_cvg(gt, nb_split, step=0):\n",
|
"def semantic_cvg(gt, nb_split, step=0):\n",
|
||||||
" test_part = 1 / nb_split\n",
|
" test_part = 1 / nb_split\n",
|
||||||
@ -162,15 +162,6 @@
|
|||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"gt == lbli"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
Loading…
Reference in New Issue
Block a user