spectra-gui/demo.ipynb
2021-03-10 15:13:24 +01:00

543 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Interactive pixel to Spectrum x DTM Analysis in Attribute Space\n",
"\n",
"Fonctionnel, mériterais une réécriture :p"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import rasterio as rio\n",
"import sap\n",
"import matplotlib as mpl\n",
"from matplotlib import pyplot as plt\n",
"import inspect\n",
"import ipywidgets as ipw\n",
"import higra as hg\n",
"import numpy as np\n",
"from pathlib import Path\n",
"\n",
"plt.style.use('dark_background')\n",
"plt.set_cmap('plasma')\n",
"plt.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"dtm = rio.open('data/dsm_vox_50cm_tile_-12_0.tif').read()[0]#[-500:,-500:].copy()\n",
"dtm.shape"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"918 tiles loaded.\n"
]
}
],
"source": [
"data_dir = Path('data')\n",
"tiles = list(data_dir.glob('**/*.tif'))\n",
"\n",
"print('{} tiles loaded.'.format(len(tiles)))"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"lidar_features = list(set([str(x.stem).split('_vox_')[0] for x in tiles]))\n",
"tile_names = list(set(['_'.join(str(x.stem).split('_')[-2:]) for x in tiles]))\n",
"\n",
"def get_tile(feature, name):\n",
" return list(data_dir.glob('**/{}_vox_50cm_tile_{}.tif'.format(feature, name)))[0]\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup utils"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def ui8_clip(x, vmin=None, vmax=None):\n",
" vmin = vmin if vmin else np.nanmin(x)\n",
" vmax = vmax if vmax else np.nanmax(x)\n",
" \n",
" ui8 = ((np.clip(x, vmin, vmax) - vmin) / (vmax - vmin) * 255).astype(np.uint8)\n",
" \n",
" return np.ma.array(ui8, mask=np.isnan(x))\n",
"\n",
"def hillshades(x):\n",
" hs = np.zeros_like(x) \n",
" hs[:-1,:-1] = x[:-1,:-1] - x[1:,1:]\n",
" return hs\n",
"\n",
"from sap.spectra import get_bins\n",
"\n",
"def spectrum2d(tree, x_attribute, y_attribute, x_count=100, y_count=100, \n",
" x_log=False, y_log=False, weighted=True, normalized=True,\n",
" node_mask=None):\n",
" x = tree.get_attribute(x_attribute)\n",
" y = tree.get_attribute(y_attribute)\n",
"\n",
" bins = (get_bins(x, x_count, 'geo' if x_log else 'lin'),\n",
" get_bins(y, y_count, 'geo' if y_log else 'lin'))\n",
"\n",
" weights = tree.get_attribute('area') if weighted else None\n",
" weights = weights / tree._image.size if normalized and weighted else weights\n",
"\n",
" s, xedges, yedges = np.histogram2d(x[node_mask].ravel(), y[node_mask].ravel(),\n",
" bins=bins, density=None, weights=weights[node_mask].ravel())\n",
"\n",
" return s, xedges, yedges, x_log, y_log\n",
"\n",
"def pixel_to_node(tree, mask):\n",
" \"\"\"Compute the node mask from the pixel mask.\"\"\"\n",
" node_mask = hg.accumulate_and_min_sequential(tree._tree, \n",
" np.ones(tree._tree.num_vertices(), dtype=np.uint8), \n",
" mask.ravel(), \n",
" hg.Accumulators.min).astype(np.bool)\n",
" return node_mask\n",
"\n",
"#dtm_hs = ui8_clip(hillshades(dtm), -1, 1)\n",
"#cm_hud = mpl.colors.LinearSegmentedColormap.from_list('GreenHUD', [(0.,0.,0.,0.), (0.,1.,0.,1.)], 256)\n",
"#alpha = ((dtm_hs.astype(float) - 127) / 127) ** 2 * .4\n",
"#cX, cY = np.meshgrid(np.arange(dtm.shape[1]), np.arange(dtm.shape[0]))\n",
"#kwargs = {}\n",
"#kwargs['vmin'], kwargs['vmax'] = np.quantile(dtm, (.01, .99))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/cbook/__init__.py\", line 224, in process\n",
" func(*args, **kwargs)\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/widgets.py\", line 1555, in release\n",
" self._release(event)\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/widgets.py\", line 2112, in _release\n",
" self.onselect(self.eventpress, self.eventrelease)\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 27, in line_select_callback\n",
" filter_dtm()\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 84, in filter_dtm\n",
" t = model['tree_cache']\n",
"KeyError: 'tree_cache'\n",
"<ipython-input-6-6d10fd3db2e6>:30: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" model['dtm_y1'], model['dtm_x1'] = np.int(eclick.xdata), np.int(eclick.ydata)\n",
"<ipython-input-6-6d10fd3db2e6>:31: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" model['dtm_y2'], model['dtm_x2'] = np.int(erelease.xdata), np.int(erelease.ydata)\n",
"<ipython-input-6-6d10fd3db2e6>:38: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" pmask = np.zeros_like(model['im'], dtype=np.bool)\n",
"Traceback (most recent call last):\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/cbook/__init__.py\", line 224, in process\n",
" func(*args, **kwargs)\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/widgets.py\", line 1555, in release\n",
" self._release(event)\n",
" File \"/usr/lib/python3.9/site-packages/matplotlib/widgets.py\", line 2112, in _release\n",
" self.onselect(self.eventpress, self.eventrelease)\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 32, in dtm_select_callback\n",
" filter_nodes()\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 42, in filter_nodes\n",
" show_spectrum()\n",
" File \"<ipython-input-6-6d10fd3db2e6>\", line 45, in show_spectrum\n",
" x, y = model['x'], model['y']\n",
"KeyError: 'x'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"HI!\n"
]
}
],
"source": [
"model = {}\n",
"load_tile(lidar_features[-1], tile_names[-1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interactive GUI"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bb4daff9ec66456ab58f00021b778efd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(Dropdown(description='tree', index=1, options=(('AlphaTree', <class 'sap.trees.AlphaTree…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.spectrum_widget(tree=<class 'sap.trees.MaxTree'>, x='area', y='compactness', x_log=True, y_log=False, hillshade_overlay=True, highlight=False)>"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \"widget\" or \"qt\", qt provide a snappier GUI\n",
"#%matplotlib widget\n",
"%matplotlib qt\n",
"plt.style.use('dark_background')\n",
"\n",
"model.update({\n",
" 'ax_spectrum': None,\n",
" 'ax_dtm': None,\n",
" 'x': None,\n",
" 'y': None,\n",
" 'x1': 0,\n",
" 'x2': 0,\n",
" 'y1': 0,\n",
" 'y2': 0,\n",
" 'dtm_x1': 0,\n",
" 'dtm_x2': 0,\n",
" 'dtm_y1': 0,\n",
" 'dtm_y2': 0,\n",
" 'overlay': True,\n",
" 'pixel_mask': None,\n",
"})\n",
"\n",
"\n",
"def line_select_callback(eclick, erelease):\n",
" model['x1'], model['y1'] = eclick.xdata, eclick.ydata\n",
" model['x2'], model['y2'] = erelease.xdata, erelease.ydata\n",
" filter_dtm()\n",
" \n",
"def dtm_select_callback(eclick, erelease):\n",
" model['dtm_y1'], model['dtm_x1'] = np.int(eclick.xdata), np.int(eclick.ydata)\n",
" model['dtm_y2'], model['dtm_x2'] = np.int(erelease.xdata), np.int(erelease.ydata)\n",
" filter_nodes()\n",
" \n",
"def filter_nodes():\n",
" x1, x2 = model['dtm_x1'], model['dtm_x2']\n",
" y1, y2 = model['dtm_y1'], model['dtm_y2']\n",
" pmask = np.zeros_like(model['im'], dtype=np.bool)\n",
" pmask[x1:x2,y1:y2] = True\n",
" model['pixel_mask'] = pmask\n",
" model['highlight'] = True\n",
" show_spectrum()\n",
" \n",
"def show_spectrum():\n",
" x, y = model['x'], model['y']\n",
" x_log, y_log = model['x_log'], model['y_log']\n",
" highlight = model['highlight']\n",
"\n",
" # Compute spectrum\n",
" ps = sap.spectrum2d(model['tree_cache'], x, y, 200, 100, x_log, y_log)\n",
" \n",
" if highlight:\n",
" node_mask = pixel_to_node(model['tree_cache'], model['pixel_mask']) if model['pixel_mask'] is not None else None\n",
" ps_mask = spectrum2d(model['tree_cache'], x, y, 200, 100, x_log, y_log, node_mask=node_mask)\n",
" ps[0][:] = ps_mask[0] / ps[0]\n",
"\n",
" # Display spectrum\n",
" plt.sca(model['ax_spectrum'])\n",
" plt.cla()\n",
" sap.show_spectrum(*ps, log_scale=not highlight)\n",
" plt.xlabel(x)\n",
" plt.ylabel(y)\n",
" plt.grid(which='minor', color='#BBBBBB', linestyle=':')\n",
" plt.grid(True)\n",
" \n",
" plt.gcf().canvas.draw()\n",
"\n",
"\n",
"def filter_dtm():\n",
" x1, x2 = model['x1'], model['x2']\n",
" y1, y2 = model['y1'], model['y2']\n",
" \n",
" model['ax_dtm'].clear()\n",
"\n",
" # No filter\n",
" if x1 == x2 == y1 == y2 == 0:\n",
" if not model['overlay']:\n",
" model['ax_dtm'].imshow(model['im'], **model['im_kwargs'])\n",
" else:\n",
" model['ax_dtm'].imshow(model['im'], **model['im_kwargs'])\n",
" model['ax_dtm'].imshow(model['im_hs'], cmap=plt.cm.Greys, alpha=model['im_hs_alpha'])\n",
"\n",
" return\n",
" \n",
" # Filter\n",
" t = model['tree_cache']\n",
" X = t.get_attribute(model['x'])\n",
" Y = t.get_attribute(model['y'])\n",
"\n",
" img = model['tree_cache'].reconstruct((X < x1) | (X > x2) | (Y < y1) | (Y > y2), filtering='subtractive')\n",
" \n",
" if not model['overlay']:\n",
" model['ax_dtm'].imshow(img)\n",
" else:\n",
" model['ax_dtm'].imshow(model['im'], **model['im_kwargs'])\n",
" model['ax_dtm'].imshow(model['im_hs'], cmap=plt.cm.Greys, alpha=model['im_hs_alpha'])\n",
" model['ax_dtm'].contour(*model['im_mesh'], img, [model['tree_cache']._alt[-1]], colors='lime')\n",
"\n",
" plt.gcf().canvas.draw()\n",
"\n",
"def spectrum_widget(tree=sap.MaxTree, x='area', y='compactness', x_log=True, y_log=False,\n",
" hillshade_overlay=True, highlight=False):\n",
" # Update only toggle of hillshade_overlay\n",
" if hillshade_overlay != model['overlay']:\n",
" model['overlay'] = hillshade_overlay\n",
" filter_dtm()\n",
" return\n",
" \n",
" # Update model and tree computation if needed\n",
" model.update({'x': x, 'y': y, 'x_log': x_log, 'y_log': y_log, 'highlight': highlight})\n",
" if not 'tree_cache' in model or not isinstance(model['tree_cache'], tree):\n",
" print('Computing the {} of the DTM...'.format(tree))\n",
" model['tree_cache'] = tree(model['im'])\n",
" \n",
" show_spectrum()\n",
" \n",
" # Setup selector spectrum\n",
" selector.rs = mpl.widgets.RectangleSelector(model['ax_spectrum'], line_select_callback,\n",
" drawtype='box', useblit=True,\n",
" button=[1, 3], # don't use middle button\n",
" minspanx=5, minspany=5,\n",
" spancoords='pixels',\n",
" rectprops = dict(facecolor=(0,1,0,.2), edgecolor=(0,1,0), fill=True, ls='--', lw=2),\n",
" interactive=True)\n",
" \n",
" # Setup selector DTM\n",
" selector_dtm.rs = mpl.widgets.RectangleSelector(model['ax_dtm'], dtm_select_callback,\n",
" drawtype='box', useblit=True,\n",
" button=[1, 3], # don't use middle button\n",
" minspanx=5, minspany=5,\n",
" spancoords='pixels',\n",
" rectprops = dict(facecolor=(1,1,0,.2), edgecolor=(1,1,0), fill=True, ls='--', lw=2),\n",
" interactive=True)\n",
"\n",
"\n",
"plt.close()\n",
"fig = plt.figure('Spectrum', figsize=(18, 5), constrained_layout=True)\n",
"grid = mpl.gridspec.GridSpec(1, 3, fig)\n",
"\n",
"model['ax_spectrum'] = fig.add_subplot(grid[:2])\n",
"model['ax_dtm'] = fig.add_subplot(grid[-1])\n",
"\n",
"# Selectors\n",
"def selector(event):\n",
" pass\n",
"\n",
"def selector_dtm(event):\n",
" pass\n",
"\n",
"\n",
"#plt.connect('key_press_event', selector)\n",
"\n",
"\n",
"filter_dtm()\n",
"\n",
"ipw.interact(spectrum_widget, \n",
" tree=inspect.getmembers(sap.trees, lambda t: inspect.isclass(t) and issubclass(t, sap.Tree) and t != sap.Tree),\n",
" x=sap.available_attributes().keys(),\n",
" y=sap.available_attributes().keys())"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('AlphaTree', sap.trees.AlphaTree),\n",
" ('MaxTree', sap.trees.MaxTree),\n",
" ('MinTree', sap.trees.MinTree),\n",
" ('OmegaTree', sap.trees.OmegaTree),\n",
" ('TosTree', sap.trees.TosTree)]"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inspect.getmembers(sap.trees, lambda t: inspect.isclass(t) and issubclass(t, sap.Tree) and t != sap.Tree)"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
"def load_tile():\n",
" model['im'] = rio.open(get_tile(model['lidar_feature'].value, model['tile_name'].value)).read()[0]\n",
" model['im_hs'] = ui8_clip(hillshades(model['im']), -1, 1)\n",
" model['im_hs_alpha'] = ((model['im_hs'].astype(float) - 127) / 127) ** 2 * .4\n",
" kwargs = {}\n",
" kwargs['vmin'], kwargs['vmax'] = np.quantile(model['im'], (.01, .99))\n",
" model['im_kwargs'] = kwargs\n",
" model['im_mesh'] = np.meshgrid(np.arange(model['im'].shape[1]), np.arange(model['im'].shape[0]))\n",
"\n",
"def load_tree():\n",
" model['tree_cache'] = model['tree_type'].value(model['im'])\n",
"\n",
"def reset_selections():\n",
" model.update({\n",
" 'x1': 0,\n",
" 'x2': 0,\n",
" 'y1': 0,\n",
" 'y2': 0,\n",
" 'dtm_x1': 0,\n",
" 'dtm_x2': 0,\n",
" 'dtm_y1': 0,\n",
" 'dtm_y2': 0,\n",
" 'overlay': True,\n",
" 'highlight': False\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "66018424ba894f8d89736b36481ef5f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HBox(children=(Dropdown(description='Tile ID', options=('-11_7', '-12_-1', '-13_-3', '-10_-1', …"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model['lidar_feature'] = ipw.Dropdown(description='LiDAR feature', options=lidar_features)\n",
"model['tile_name'] = ipw.Dropdown(description='Tile ID', options=tile_names)\n",
"model['tree_type'] = ipw.Dropdown(description='Hierarchy', value=sap.trees.MaxTree, options=inspect.getmembers(sap.trees, lambda t: inspect.isclass(t) and issubclass(t, sap.Tree) and t != sap.Tree))\n",
"load_button = ipw.Button(description='Analyse', icon='rocket')\n",
"\n",
"model['x'] = ipw.Dropdown(description='Attribute X', options=sap.available_attributes().keys())\n",
"model['y'] = ipw.Dropdown(description='Attribute Y', options=sap.available_attributes().keys())\n",
"model['x_log'] = ipw.Checkbox(description='X log')\n",
"model['y_log'] = ipw.Checkbox(description='Y log')\n",
"model['overlay'] = ipw.Checkbox(description='Filtering overlay')\n",
"model['highlight'] = ipw.Checkbox(description='Spectrum higlight')\n",
"\n",
"\n",
"def load_change(change):\n",
" print('Loading tile...')\n",
" load_tile()\n",
" print('Loading tree...')\n",
" load_tree()\n",
" print('Loading complete')\n",
" reset_selections()\n",
" show_spectrum()\n",
" filter_dtm()\n",
" plt.show()\n",
"\n",
"lidar_feature.observe(feature_change, names='value')\n",
"tile_name.observe(name_change, names='value')\n",
"tree_type.observe(tree_change, names='value')\n",
"load_button.on_click(load_change)\n",
"\n",
"ipw.VBox([\n",
" ipw.HBox([model['tile_name'], model['lidar_feature'], model['tree_type'] , load_button]),\n",
" ipw.HBox([model['x'], model['x_log']]),\n",
" ipw.HBox([model['y'], model['y_log']]),\n",
" model['highlight'],\n",
" model['overlay']\n",
"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}