From c90ee468b2a662d2fcfb6e5270d6962506cf49bb Mon Sep 17 00:00:00 2001 From: Karamaz0V1 Date: Mon, 9 Jul 2018 18:02:40 +0200 Subject: [PATCH] Classification and Scores --- Notebooks/Attribute Profiles Classifier.ipynb | 78 ++++- Notebooks/Classification Scores.ipynb | 298 ++++++++++++++++++ Notebooks/Cross Validation Generator.ipynb | 2 +- 3 files changed, 374 insertions(+), 4 deletions(-) create mode 100644 Notebooks/Classification Scores.ipynb diff --git a/Notebooks/Attribute Profiles Classifier.ipynb b/Notebooks/Attribute Profiles Classifier.ipynb index 8af1ee8..b4ec61e 100644 --- a/Notebooks/Attribute Profiles Classifier.ipynb +++ b/Notebooks/Attribute Profiles Classifier.ipynb @@ -17,7 +17,9 @@ "source": [ "## Setup\n", "\n", - "### Packages" + "### Packages\n", + "\n", + "#### Attributes Profile" ] }, { @@ -38,6 +40,26 @@ "import triskele" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Classifier" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import metrics\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "import pandas as pd\n", + "import pickle\n", + "from CrossValidationGenerator import APsCVG" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -148,7 +170,7 @@ "outputs": [], "source": [ "areas = [10., 100.]\n", - "areas.extend([x * 1e3 for x in range(1,100,2)])\n", + "areas.extend([x * 1e3 for x in range(1,100,1)])\n", "plt.plot(areas, '.')\n", "plt.show()" ] @@ -206,7 +228,8 @@ "metadata": {}, "outputs": [], "source": [ - "out_vectors.data.shape" + "att = out_vectors.data\n", + "att.shape, att.dtype" ] }, { @@ -229,6 +252,55 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cross Valid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction = np.zeros_like(gt)\n", + "\n", + "for xt, xv, yt, yv, ti in APsCVG(gt, att, 5):\n", + " plt.imshow(ti * 1.)\n", + " plt.show()\n", + " \n", + " rfc = RandomForestClassifier(n_jobs=-1, random_state=0, n_estimators=100, verbose=True)\n", + " rfc.fit(xt, yt)\n", + " \n", + " ypred = rfc.predict(xv)\n", + " \n", + " prediction[ti] = ypred" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=figsize)\n", + "plt.imshow(prediction)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imsave('../Res/tmppred.png', prediction)\n", + "plt.imsave('../Res/gt.png', gt)\n", + "triskele.write('../Res/tmppred_8.tif', prediction)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/Notebooks/Classification Scores.ipynb b/Notebooks/Classification Scores.ipynb new file mode 100644 index 0000000..376f74f --- /dev/null +++ b/Notebooks/Classification Scores.ipynb @@ -0,0 +1,298 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generic Classification Scores for DFC 2018" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn import metrics\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "\n", + "# Triskele\n", + "import sys\n", + "from pathlib import Path\n", + "triskele_path = Path('../triskele/python')\n", + "sys.path.append(str(triskele_path.resolve()))\n", + "import triskele\n", + "\n", + "figsize = np.array((16, 9))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Classes Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_dfc_lbl = pd.read_csv('../labels.csv')\n", + "df_meta_idx = pd.read_csv('../metaclass_indexes.csv')\n", + "df_meta_lbl = pd.read_csv('../metaclass_labels.csv')\n", + "\n", + "df_dfc_lbl.merge(df_meta_idx).merge(df_meta_lbl)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "meta_idx = np.array(df_meta_idx['metaclass_index'], dtype=np.uint8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Ground Truth and Prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gt = triskele.read('../Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif')\n", + "pred = triskele.read('../Res/tmppred.tif')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Display Classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax_gt, ax_pred) = plt.subplots(2, figsize=figsize * 2)\n", + "ax_gt.imshow(gt)\n", + "ax_gt.set_title('Ground Truth')\n", + "ax_pred.imshow(pred)\n", + "ax_pred.set_title('Prediction')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Display Meta Classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax_gt, ax_pred) = plt.subplots(2, figsize=figsize * 2)\n", + "ax_gt.imshow(meta_idx[gt])\n", + "ax_gt.set_title('Ground Truth')\n", + "ax_pred.imshow(meta_idx[pred])\n", + "ax_pred.set_title('Prediction')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Metrics\n", + "\n", + "### Classes\n", + "\n", + "#### Confusion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f = np.nonzero(pred)\n", + "pred_s = pred[f].flatten()\n", + "gt_s = gt[f].flatten()\n", + "\n", + "ct = pd.crosstab(gt_s, pred_s,\n", + " rownames=['Prediction'], colnames=['Reference'],\n", + " margins=True, margins_name='Total',\n", + " normalize=False # all, index, columns\n", + " )\n", + "ct" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Scores\n", + "\n", + "##### Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metrics.accuracy_score(gt_s, pred_s)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Kappa" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metrics.cohen_kappa_score(gt_s, pred_s)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Precision, Recall, f1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metrics.precision_recall_fscore_support(gt_s, pred_s)\n", + "print(metrics.classification_report(gt_s, pred_s))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Meta Classes\n", + "\n", + "#### Confusion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f = np.nonzero(pred)\n", + "m_pred_s = meta_idx[pred_s]\n", + "m_gt_s = meta_idx[gt_s]\n", + "\n", + "ct = pd.crosstab(m_gt_s, m_pred_s,\n", + " rownames=['Prediction'], colnames=['Reference'],\n", + " margins=True, margins_name='Total',\n", + " normalize=False # all, index, columns\n", + " )\n", + "ct" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Scores\n", + "\n", + "##### Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metrics.accuracy_score(m_gt_s, m_pred_s)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Kappa" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metrics.cohen_kappa_score(m_gt_s, m_pred_s)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Precision, Recall, f1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metrics.precision_recall_fscore_support(m_gt_s, m_pred_s)\n", + "print(metrics.classification_report(m_gt_s, m_pred_s))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Notebooks/Cross Validation Generator.ipynb b/Notebooks/Cross Validation Generator.ipynb index a1296c9..d95afb2 100644 --- a/Notebooks/Cross Validation Generator.ipynb +++ b/Notebooks/Cross Validation Generator.ipynb @@ -103,7 +103,7 @@ "metadata": {}, "outputs": [], "source": [ - "for xt, xv, yt, yv, ti in APsCVG(gt, att, 2):\n", + "for xt, xv, yt, yv, ti in APsCVG(gt, att, 1):\n", " print(xt.shape, yt.shape, xv.shape, yv.shape)\n", " plt.imshow(ti * 1.)\n", " plt.show()"