From 4c04b1227f6ffdb955606f56e1157ff6b74cf8cc Mon Sep 17 00:00:00 2001 From: Karamaz0V1 Date: Mon, 9 Jul 2018 12:56:39 +0200 Subject: [PATCH] CVG ready --- CrossValidationGenerator.py | 43 ++++++++++++++++++- Notebooks/Attribute Profiles Classifier.ipynb | 2 +- Notebooks/Cross Validation Generator.ipynb | 29 ++++++++++++- ...und truth split train and validation.ipynb | 13 +----- 4 files changed, 72 insertions(+), 15 deletions(-) diff --git a/CrossValidationGenerator.py b/CrossValidationGenerator.py index 1875912..a4038bb 100644 --- a/CrossValidationGenerator.py +++ b/CrossValidationGenerator.py @@ -51,8 +51,47 @@ class CVG_legacy: class APsCVG: """Cross Validation Generator for Attribute Profiles Descriptors""" - def __init__(self, ground_truth, attributes, cv_count=5, label_ignore=None): - d + def __init__(self, ground_truth, attributes, n_test=5, label_ignore=None): + self._gt = ground_truth + self._att = attributes + self._cv_count = n_test + self._actual_count = 0 + if attributes.shape[0] != ground_truth.shape[0] or \ + attributes.shape[1] != ground_truth.shape[1] : + raise ValueError('attributes and ground_truth must have the same 2D shape') + def __iter__(self): + return self + + def __next__(self): + if self._cv_count == self._actual_count: + raise StopIteration + + split_map = semantic_cvg(self._gt, self._cv_count, self._actual_count) + xtrain = self._att[split_map == 1].reshape(-1, self._att.shape[2]) + xtest = self._att[split_map == 2].reshape(-1, self._att.shape[2]) + ytrain = self._gt[split_map == 1].reshape(-1) + ytest = self._gt[split_map == 2].reshape(-1) + test_index = split_map == 2 + + self._actual_count += 1 + return xtrain, xtest, ytrain, ytest, test_index + +def semantic_cvg(gt, nb_split, step=0): + count = np.unique(gt, return_counts=True) + + test_part = 1 / nb_split + + split = np.zeros_like(gt) + + for lbli, lblc in zip(count[0][1:], count[1][1:]): + treshold = int(lblc * test_part) + #print('lbli:{}, count:{}, train:{}'.format(lbli, lblc, treshold)) + f = np.nonzero(gt == lbli) + t_int, t_ext = treshold * step, treshold * (step + 1) + split[f[0], f[1]] = 1 + split[f[0][t_int:t_ext], f[1][t_int:t_ext]] = 2 + + return split diff --git a/Notebooks/Attribute Profiles Classifier.ipynb b/Notebooks/Attribute Profiles Classifier.ipynb index cf68f55..8af1ee8 100644 --- a/Notebooks/Attribute Profiles Classifier.ipynb +++ b/Notebooks/Attribute Profiles Classifier.ipynb @@ -77,7 +77,7 @@ " '../Data/phase1_rasters/Intensity_C3/UH17_GI3F051_TR.tif',\n", " #'../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif',\n", " #'../Res/HVR/C123_num_returns_0_5_nearest.tif',\n", - " #'../Res/HVR noisy/C123_num_returns_0_5_nearest.tif'\n", + " '../Res/HVR noisy/C123_num_returns_0_5_nearest.tif'\n", "\n", "]" ] diff --git a/Notebooks/Cross Validation Generator.ipynb b/Notebooks/Cross Validation Generator.ipynb index 6287810..a1296c9 100644 --- a/Notebooks/Cross Validation Generator.ipynb +++ b/Notebooks/Cross Validation Generator.ipynb @@ -79,7 +79,34 @@ "metadata": {}, "outputs": [], "source": [ - "from CrossValidationGenerator import CVG" + "import sys\n", + "sys.path.append('..')\n", + "from CrossValidationGenerator import APsCVG\n", + "sys.path.append('../triskele/python')\n", + "import triskele" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gt = triskele.read('../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif')\n", + "att = np.array((gt, gt))\n", + "att = np.moveaxis(att, 0, 2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for xt, xv, yt, yv, ti in APsCVG(gt, att, 2):\n", + " print(xt.shape, yt.shape, xv.shape, yv.shape)\n", + " plt.imshow(ti * 1.)\n", + " plt.show()" ] } ], diff --git a/Notebooks/DFC2018 Ground truth split train and validation.ipynb b/Notebooks/DFC2018 Ground truth split train and validation.ipynb index db41e9e..c2caead 100644 --- a/Notebooks/DFC2018 Ground truth split train and validation.ipynb +++ b/Notebooks/DFC2018 Ground truth split train and validation.ipynb @@ -13,7 +13,7 @@ "# Triskele\n", "import sys\n", "from pathlib import Path\n", - "triskele_path = Path('../triskele')\n", + "triskele_path = Path('../triskele/python')\n", "sys.path.append(str(triskele_path.resolve()))\n", "import triskele" ] @@ -133,7 +133,7 @@ "metadata": {}, "outputs": [], "source": [ - "nb_split = 10\n", + "nb_split = 5\n", "\n", "def semantic_cvg(gt, nb_split, step=0):\n", " test_part = 1 / nb_split\n", @@ -162,15 +162,6 @@ "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gt == lbli" - ] - }, { "cell_type": "markdown", "metadata": {},