Refactor Model

This commit is contained in:
Florent Guiotte 2021-03-10 13:17:55 +01:00
parent f37c9da47a
commit 8fd8dd3374

View File

@ -13,17 +13,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"<Figure size 432x288 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [ "source": [
"import rasterio as rio\n", "import rasterio as rio\n",
"import sap\n", "import sap\n",
@ -33,16 +23,26 @@
"import ipywidgets as ipw\n", "import ipywidgets as ipw\n",
"import higra as hg\n", "import higra as hg\n",
"import numpy as np\n", "import numpy as np\n",
"from pathlib import Path\n",
"\n", "\n",
"plt.style.use('dark_background')\n", "plt.style.use('dark_background')\n",
"plt.set_cmap('plasma')" "plt.set_cmap('plasma')\n",
"plt.close()"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Load DTM" "## Load dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"dtm = rio.open('data/dsm_vox_50cm_tile_-12_0.tif').read()[0]#[-500:,-500:].copy()\n",
"dtm.shape"
] ]
}, },
{ {
@ -51,19 +51,40 @@
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "name": "stdout",
"text/plain": [ "output_type": "stream",
"(2001, 2001)" "text": [
"918 tiles loaded.\n"
] ]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"dtm = rio.open('data/dsm_vox_50cm_tile_-12_0.tif').read()[0]#[-500:,-500:].copy()\n", "data_dir = Path('data')\n",
"dtm.shape" "tiles = list(data_dir.glob('**/*.tif'))\n",
"\n",
"print('{} tiles loaded.'.format(len(tiles)))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"lidar_features = list(set([str(x.stem).split('_vox_')[0] for x in tiles]))\n",
"tile_names = list(set(['_'.join(str(x.stem).split('_')[-2:]) for x in tiles]))\n",
"\n",
"def get_tile(feature, name):\n",
" return list(data_dir.glob('**/{}_vox_50cm_tile_{}.tif'.format(feature, name)))[0]\n",
"\n",
"def load_tile(feature, name):\n",
" model['im'] = rio.open(get_tile(feature, name)).read()[0]\n",
" model['im_hs'] = ui8_clip(hillshades(model['im']), -1, 1)\n",
" model['im_hs_alpha'] = ((model['im_hs'].astype(float) - 127) / 127) ** 2 * .4\n",
" kwargs = {}\n",
" kwargs['vmin'], kwargs['vmax'] = np.quantile(model['im'], (.01, .99))\n",
" model['im_kwargs'] = kwargs\n",
" model['im_mesh'] = np.meshgrid(np.arange(model['im'].shape[1]), np.arange(model['im'].shape[0]))"
] ]
}, },
{ {
@ -75,11 +96,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"metadata": { "metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [] "tags": []
}, },
"outputs": [], "outputs": [],
@ -124,14 +142,71 @@
" hg.Accumulators.min).astype(np.bool)\n", " hg.Accumulators.min).astype(np.bool)\n",
" return node_mask\n", " return node_mask\n",
"\n", "\n",
"dtm_hs = ui8_clip(hillshades(dtm), -1, 1)\n", "#dtm_hs = ui8_clip(hillshades(dtm), -1, 1)\n",
"cm_hud = mpl.colors.LinearSegmentedColormap.from_list('GreenHUD', [(0.,0.,0.,0.), (0.,1.,0.,1.)], 256)\n", "#cm_hud = mpl.colors.LinearSegmentedColormap.from_list('GreenHUD', [(0.,0.,0.,0.), (0.,1.,0.,1.)], 256)\n",
"alpha = ((dtm_hs.astype(float) - 127) / 127) ** 2 * .4\n", "#alpha = ((dtm_hs.astype(float) - 127) / 127) ** 2 * .4\n",
"cX, cY = np.meshgrid(np.arange(dtm.shape[1]), np.arange(dtm.shape[0]))\n", "#cX, cY = np.meshgrid(np.arange(dtm.shape[1]), np.arange(dtm.shape[0]))\n",
"kwargs = {}\n", "#kwargs = {}\n",
"kwargs['vmin'], kwargs['vmax'] = np.quantile(dtm, (.01, .99))\n", "#kwargs['vmin'], kwargs['vmax'] = np.quantile(dtm, (.01, .99))"
"\n", ]
"model = {}" },
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/cbook/__init__.py\", line 224, in process\n",
" func(*args, **kwargs)\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/widgets.py\", line 1555, in release\n",
" self._release(event)\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/widgets.py\", line 2112, in _release\n",
" self.onselect(self.eventpress, self.eventrelease)\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 27, in line_select_callback\n",
" filter_dtm()\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 84, in filter_dtm\n",
" t = model['tree_cache']\n",
"KeyError: 'tree_cache'\n",
"<ipython-input-6-6d10fd3db2e6>:30: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" model['dtm_y1'], model['dtm_x1'] = np.int(eclick.xdata), np.int(eclick.ydata)\n",
"<ipython-input-6-6d10fd3db2e6>:31: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" model['dtm_y2'], model['dtm_x2'] = np.int(erelease.xdata), np.int(erelease.ydata)\n",
"<ipython-input-6-6d10fd3db2e6>:38: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" pmask = np.zeros_like(model['im'], dtype=np.bool)\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/cbook/__init__.py\", line 224, in process\n",
" func(*args, **kwargs)\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/widgets.py\", line 1555, in release\n",
" self._release(event)\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/widgets.py\", line 2112, in _release\n",
" self.onselect(self.eventpress, self.eventrelease)\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 32, in dtm_select_callback\n",
" filter_nodes()\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 42, in filter_nodes\n",
" show_spectrum()\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 45, in show_spectrum\n",
" x, y = model['x'], model['y']\n",
"KeyError: 'x'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"HI!\n"
]
}
],
"source": [
"model = {}\n",
"load_tile(lidar_features[-1], tile_names[-1])"
] ]
}, },
{ {
@ -143,18 +218,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 8,
"metadata": { "metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "81086011013945a497e4ffb7163998f9", "model_id": "baf2371d25d543739572153d8e331efa",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -171,7 +243,7 @@
"<function __main__.spectrum_widget(tree=<class 'sap.trees.MaxTree'>, x='area', y='compactness', x_log=True, y_log=False, hillshade_overlay=True, highlight=False)>" "<function __main__.spectrum_widget(tree=<class 'sap.trees.MaxTree'>, x='area', y='compactness', x_log=True, y_log=False, hillshade_overlay=True, highlight=False)>"
] ]
}, },
"execution_count": 4, "execution_count": 8,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -211,10 +283,9 @@
" filter_nodes()\n", " filter_nodes()\n",
" \n", " \n",
"def filter_nodes():\n", "def filter_nodes():\n",
" print('HI!')\n",
" x1, x2 = model['dtm_x1'], model['dtm_x2']\n", " x1, x2 = model['dtm_x1'], model['dtm_x2']\n",
" y1, y2 = model['dtm_y1'], model['dtm_y2']\n", " y1, y2 = model['dtm_y1'], model['dtm_y2']\n",
" pmask = np.zeros_like(dtm, dtype=np.bool)\n", " pmask = np.zeros_like(model['im'], dtype=np.bool)\n",
" pmask[x1:x2,y1:y2] = True\n", " pmask[x1:x2,y1:y2] = True\n",
" model['pixel_mask'] = pmask\n", " model['pixel_mask'] = pmask\n",
" model['highlight'] = True\n", " model['highlight'] = True\n",
@ -252,10 +323,10 @@
" # No filter\n", " # No filter\n",
" if x1 == x2 == y1 == y2 == 0:\n", " if x1 == x2 == y1 == y2 == 0:\n",
" if not model['overlay']:\n", " if not model['overlay']:\n",
" model['ax_dtm'].imshow(dtm, **kwargs)\n", " model['ax_dtm'].imshow(model['im'], **model['im_kwargs'])\n",
" else:\n", " else:\n",
" model['ax_dtm'].imshow(dtm, **kwargs)\n", " model['ax_dtm'].imshow(model['im'], **model['im_kwargs'])\n",
" model['ax_dtm'].imshow(dtm_hs, cmap=plt.cm.Greys, alpha=alpha)\n", " model['ax_dtm'].imshow(model['im_hs'], cmap=plt.cm.Greys, alpha=model['im_hs_alpha'])\n",
"\n", "\n",
" return\n", " return\n",
" \n", " \n",
@ -270,9 +341,9 @@
" if not model['overlay']:\n", " if not model['overlay']:\n",
" model['ax_dtm'].imshow(img)\n", " model['ax_dtm'].imshow(img)\n",
" else:\n", " else:\n",
" model['ax_dtm'].imshow(dtm, **kwargs)\n", " model['ax_dtm'].imshow(model['im'], **model['im_kwargs'])\n",
" model['ax_dtm'].imshow(dtm_hs, cmap=plt.cm.Greys, alpha=alpha)\n", " model['ax_dtm'].imshow(model['im_hs'], cmap=plt.cm.Greys, alpha=model['im_hs_alpha'])\n",
" model['ax_dtm'].contour(cX, cY, img, [model['tree_cache']._alt[-1]], colors='lime')\n", " model['ax_dtm'].contour(*model['im_mesh'], img, [model['tree_cache']._alt[-1]], colors='lime')\n",
"\n", "\n",
" plt.gcf().canvas.draw()\n", " plt.gcf().canvas.draw()\n",
"\n", "\n",
@ -288,7 +359,7 @@
" model.update({'x': x, 'y': y, 'x_log': x_log, 'y_log': y_log, 'highlight': highlight})\n", " model.update({'x': x, 'y': y, 'x_log': x_log, 'y_log': y_log, 'highlight': highlight})\n",
" if not 'tree_cache' in model or not isinstance(model['tree_cache'], tree):\n", " if not 'tree_cache' in model or not isinstance(model['tree_cache'], tree):\n",
" print('Computing the {} of the DTM...'.format(tree))\n", " print('Computing the {} of the DTM...'.format(tree))\n",
" model['tree_cache'] = tree(dtm)\n", " model['tree_cache'] = tree(model['im'])\n",
" \n", " \n",
" show_spectrum()\n", " show_spectrum()\n",
" \n", " \n",
@ -336,6 +407,13 @@
" x=sap.available_attributes().keys(),\n", " x=sap.available_attributes().keys(),\n",
" y=sap.available_attributes().keys())" " y=sap.available_attributes().keys())"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {