CVG ready
This commit is contained in:
parent
1badfb9710
commit
4c04b1227f
@ -51,8 +51,47 @@ class CVG_legacy:
|
||||
|
||||
class APsCVG:
|
||||
"""Cross Validation Generator for Attribute Profiles Descriptors"""
|
||||
def __init__(self, ground_truth, attributes, cv_count=5, label_ignore=None):
|
||||
d
|
||||
def __init__(self, ground_truth, attributes, n_test=5, label_ignore=None):
|
||||
self._gt = ground_truth
|
||||
self._att = attributes
|
||||
self._cv_count = n_test
|
||||
self._actual_count = 0
|
||||
|
||||
if attributes.shape[0] != ground_truth.shape[0] or \
|
||||
attributes.shape[1] != ground_truth.shape[1] :
|
||||
raise ValueError('attributes and ground_truth must have the same 2D shape')
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._cv_count == self._actual_count:
|
||||
raise StopIteration
|
||||
|
||||
split_map = semantic_cvg(self._gt, self._cv_count, self._actual_count)
|
||||
xtrain = self._att[split_map == 1].reshape(-1, self._att.shape[2])
|
||||
xtest = self._att[split_map == 2].reshape(-1, self._att.shape[2])
|
||||
ytrain = self._gt[split_map == 1].reshape(-1)
|
||||
ytest = self._gt[split_map == 2].reshape(-1)
|
||||
test_index = split_map == 2
|
||||
|
||||
self._actual_count += 1
|
||||
|
||||
return xtrain, xtest, ytrain, ytest, test_index
|
||||
|
||||
def semantic_cvg(gt, nb_split, step=0):
|
||||
count = np.unique(gt, return_counts=True)
|
||||
|
||||
test_part = 1 / nb_split
|
||||
|
||||
split = np.zeros_like(gt)
|
||||
|
||||
for lbli, lblc in zip(count[0][1:], count[1][1:]):
|
||||
treshold = int(lblc * test_part)
|
||||
#print('lbli:{}, count:{}, train:{}'.format(lbli, lblc, treshold))
|
||||
f = np.nonzero(gt == lbli)
|
||||
t_int, t_ext = treshold * step, treshold * (step + 1)
|
||||
split[f[0], f[1]] = 1
|
||||
split[f[0][t_int:t_ext], f[1][t_int:t_ext]] = 2
|
||||
|
||||
return split
|
||||
|
@ -77,7 +77,7 @@
|
||||
" '../Data/phase1_rasters/Intensity_C3/UH17_GI3F051_TR.tif',\n",
|
||||
" #'../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif',\n",
|
||||
" #'../Res/HVR/C123_num_returns_0_5_nearest.tif',\n",
|
||||
" #'../Res/HVR noisy/C123_num_returns_0_5_nearest.tif'\n",
|
||||
" '../Res/HVR noisy/C123_num_returns_0_5_nearest.tif'\n",
|
||||
"\n",
|
||||
"]"
|
||||
]
|
||||
|
@ -79,7 +79,34 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from CrossValidationGenerator import CVG"
|
||||
"import sys\n",
|
||||
"sys.path.append('..')\n",
|
||||
"from CrossValidationGenerator import APsCVG\n",
|
||||
"sys.path.append('../triskele/python')\n",
|
||||
"import triskele"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gt = triskele.read('../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif')\n",
|
||||
"att = np.array((gt, gt))\n",
|
||||
"att = np.moveaxis(att, 0, 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for xt, xv, yt, yv, ti in APsCVG(gt, att, 2):\n",
|
||||
" print(xt.shape, yt.shape, xv.shape, yv.shape)\n",
|
||||
" plt.imshow(ti * 1.)\n",
|
||||
" plt.show()"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -13,7 +13,7 @@
|
||||
"# Triskele\n",
|
||||
"import sys\n",
|
||||
"from pathlib import Path\n",
|
||||
"triskele_path = Path('../triskele')\n",
|
||||
"triskele_path = Path('../triskele/python')\n",
|
||||
"sys.path.append(str(triskele_path.resolve()))\n",
|
||||
"import triskele"
|
||||
]
|
||||
@ -133,7 +133,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nb_split = 10\n",
|
||||
"nb_split = 5\n",
|
||||
"\n",
|
||||
"def semantic_cvg(gt, nb_split, step=0):\n",
|
||||
" test_part = 1 / nb_split\n",
|
||||
@ -162,15 +162,6 @@
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gt == lbli"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
Loading…
Reference in New Issue
Block a user