spectra-gui/spectra_app.py
2024-02-14 17:47:24 +02:00

506 lines
15 KiB
Python

import logging
import matplotlib as mpl
from matplotlib import path, patches, pyplot as plt
import numpy as np
from skimage import measure
import ipywidgets as ipw
import higra as hg
import sap
from sap.spectra import get_bins
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
log.addHandler(ch)
def compute_contours(image, threshold):
contours = measure.find_contours(image > threshold, 0, positive_orientation='high', fully_connected='high')
return contours
def convert_contours_to_verts(contours):
if len(contours) == 0:
return np.zeros((1, 2)), None
codes = []
for c in contours:
code = np.repeat(path.Path.LINETO, len(c))
code[0] = path.Path.MOVETO
codes += [code]
verts = np.concatenate(contours)[:, ::-1]
codes = np.concatenate(codes)
return verts, codes
def get_contours_verts(image, threshold):
return convert_contours_to_verts(compute_contours(image, threshold))
def ui8_clip(x, vmin=None, vmax=None):
vmin = vmin if vmin else np.nanmin(x)
vmax = vmax if vmax else np.nanmax(x)
ui8 = ((np.clip(x, vmin, vmax) - vmin) / (vmax - vmin) * 255).astype(np.uint8)
return np.ma.array(ui8, mask=np.isnan(x))
def hillshades(x):
hs = np.zeros_like(x)
hs[:-1,:-1] = x[:-1,:-1] - x[1:,1:]
return hs
def pixel_to_node(tree, mask):
"""Compute the node mask from the pixel mask."""
node_mask = hg.accumulate_and_min_sequential(tree._tree,
np.ones(tree._tree.num_vertices(), dtype=np.uint8),
mask.ravel(),
hg.Accumulators.min).astype(np.bool)
return node_mask
def load_image(image):
model['im'] = image
model['im_hs'] = ui8_clip(hillshades(model['im']), -1, 1)
model['im_hs_alpha'] = ((model['im_hs'].astype(float) - 127) / 127) ** 2 * .4
kwargs = {}
kwargs['vmin'], kwargs['vmax'] = np.quantile(model['im'], (.01, .99))
model['im_kwargs'] = kwargs
model['im_mesh'] = np.meshgrid(np.arange(model['im'].shape[1]), np.arange(model['im'].shape[0]))
model['tree_cache'] = sap.MaxTree(model['im'])
def line_select_callback(eclick, erelease):
log.info('spectrum mouse release callback triggered')
model['x1'], model['y1'] = eclick.xdata, eclick.ydata
model['x2'], model['y2'] = erelease.xdata, erelease.ydata
refresh_lidar_axis()
def dtm_select_callback(eclick, erelease):
log.info('lidar mouse release callback triggered')
model['dtm_y1'], model['dtm_x1'] = np.int(eclick.xdata), np.int(eclick.ydata)
model['dtm_y2'], model['dtm_x2'] = np.int(erelease.xdata), np.int(erelease.ydata)
filter_nodes()
def filter_nodes():
log.info('filter selected nodes')
x1, x2 = model['dtm_x1'], model['dtm_x2']
y1, y2 = model['dtm_y1'], model['dtm_y2']
pmask = np.zeros_like(model['im'], dtype=np.bool)
pmask[x1:x2,y1:y2] = True
model['pixel_mask'] = pmask
compute_spectrum_highlight()
if model['highlight'].value:
refresh_spectrum_axis()
else:
# This trigger spectrum refresh
model['highlight'].value = True
def init_spectrum_axis():
log.info('init spectrum axis...')
compute_spectrum()
fig = model['fig']
ax = model['ax_spectrum']
x, y = model['x'].value, model['y'].value
spectrum, xedges, yedges, x_log, y_log = model['spectrum']
#x_log, y_log = model['x_log'].value, model['y_log'].value
#void = np.ma.zeros((100, 200))
#void.mask = True
#reset_ax(ax)
ax.clear()
pc = ax.pcolor(xedges, yedges, spectrum.T, norm=mpl.colors.LogNorm(), animated=False)
if x_log:
ax.set_xscale('log')
if y_log:
ax.set_yscale('log')
ax.set_xlabel(x)
ax.set_ylabel(y)
ax.set_title('Pattern Spectrum')
ax.grid(True)
ax.grid(which='minor', color='#BBBBBB', linestyle=':')
model['spectrum_pcolor'] = pc
fig.canvas.draw()
log.info('init spectrum axis done')
def _compute_spectrum(node_mask=None):
tree = model['tree_cache']
x, y = model['x'].value, model['y'].value
x_log, y_log = model['x_log'].value, model['y_log'].value
ps = sap.spectrum2d(tree, x, y, 200, 100, x_log, y_log, node_mask=node_mask)
ps = (np.ma.array(ps[0], mask=ps[0]==0),) + ps[1:]
return ps
def compute_spectrum():
log.info('compute spectrum...')
model['spectrum'] = _compute_spectrum()
log.info('compute spectrum done')
def compute_spectrum_highlight():
log.info('compute spectrum highlight...')
tree = model['tree_cache']
mask = model['pixel_mask']
spectrum = model['spectrum']
node_mask = pixel_to_node(tree, mask)# if mask is not None else None
#model['spectrum_highlight'] = _compute_spectrum(node_mask)
spectrum_mask = _compute_spectrum(node_mask)
spectrum_highlight = np.zeros_like(spectrum[0])
#spectrum_highlight = np.tile(np.nan, spectrum[0].shape)
spectrum_highlight = np.divide(spectrum_mask[0],
spectrum[0],
where=spectrum[0]!=0,
out=spectrum_highlight)
spectrum_highlight.mask = spectrum[0].mask
model['spectrum_highlight'] = spectrum_highlight
log.info('compute spectrum highlight done')
def refresh_spectrum_axis():
log.info('refresh spectrum axis...')
fig = model['fig']
pc = model['spectrum_pcolor']
highlight = model['highlight'].value
spectrum = model['spectrum_highlight'] if highlight else model['spectrum'][0]
pc.set_array(spectrum.T.compressed())
pc.set_norm(mpl.colors.Normalize(0, 1) if highlight else mpl.colors.LogNorm())
# TODO: if highlight remove norm log and set to [0, 1]
#norm=mpl.colors.LogNorm() if log_scale else None)
fig.draw_artist(pc)
fig.canvas.blit(pc.clipbox)
#fig.canvas.flush_events()
model['widget_spectrum'].update_background(None)
log.info('refresh spectrum axis done')
def show_spectrum():
assert False, 'LEGACY FUNCTION NEVER CALLED DAMNIT'
x, y = model['x'].value, model['y'].value
x_log, y_log = model['x_log'].value, model['y_log'].value
highlight = model['highlight'].value
# Compute spectrum
log.info('compute spectrum...')
ps = sap.spectrum2d(model['tree_cache'], x, y, 200, 100, x_log, y_log)
log.info('compute done')
if highlight:
node_mask = pixel_to_node(model['tree_cache'], model['pixel_mask']) if model['pixel_mask'] is not None else None
log.info('compute masked spectrum...')
ps_mask = sap.spectrum2d(model['tree_cache'], x, y, 200, 100, x_log, y_log, node_mask=node_mask)
ps[0][:] = np.divide(ps_mask[0], ps[0], where=ps[0]!=0, out=np.tile(np.nan, ps_mask[0].shape))
log.info('compute done')
# Display spectrum
plt.sca(model['ax_spectrum'])
plt.cla()
log.info('draw spectrum...')
sap.show_spectrum(*ps, log_scale=not highlight)
plt.xlabel(x)
plt.ylabel(y)
plt.grid(which='minor', color='#BBBBBB', linestyle=':')
plt.grid(True)
log.info('draw complete')
#plt.gcf().canvas.draw_idle()
plt.gcf().canvas.blit(model['ax_spectrum'].bbox)
log.info('matplotlib redraw complete')
# Refactor selector?
#init_selector()
def get_attribute_extends():
x1, x2 = model['x1'], model['x2']
y1, y2 = model['y1'], model['y2']
return x1, x2, y1, y2
def filter_tree():
log.info('filtering the tree...')
x1, x2, y1, y2 = get_attribute_extends()
t = model['tree_cache']
X = t.get_attribute(model['x'].value)
Y = t.get_attribute(model['y'].value)
im_f = model['tree_cache'].reconstruct((X < x1) | (X > x2) |
(Y < y1) | (Y > y2),
filtering='subtractive')
model['im_f'] = im_f
log.info('tree filtering complete')
def refresh_lidar_axis():
log.info('refresh lidar view...')
x1, x2, y1, y2 = get_attribute_extends()
# No filter
if x1 == x2 == y1 == y2 == 0:
log.info('no filtering applied')
return
# Filter
filter_tree()
draw_lidar_contours()
#model['ax_dtm'].contour(*model['im_mesh'], img, [model['tree_cache']._alt[-1]], colors='lime')
log.info('refresh complete')
#plt.gcf().canvas.draw_idle()
#log.info('matplotlib redraw complete')
def reset_ax(ax):
ax.clear()
ax.collections.clear()
ax.patches.clear()
def draw_lidar_contours():
log.info('draw lidar contours...')
path_patch = model['dtm_contours']
contour_path = path_patch.get_path()
im_f = model['im_f']
tree = model['tree_cache']
root_val = tree._alt[-1]
contour_path.vertices, contour_path.codes = get_contours_verts(im_f, root_val)
fig = model['fig']
fig.canvas.restore_region(model['dtm_bg'])
fig.draw_artist(path_patch)
fig.canvas.blit(path_patch.clipbox)
model['widget_dtm'].update_background(None)
#fig.canvas.flush_events()
def init_dtm_contours():
contour_path = path.Path([(0,0)])
path_patch = patches.PathPatch(contour_path,
color='lime',
linewidth=2,
fill=False,
animated=True)
model['dtm_contours'] = path_patch
model['ax_dtm'].add_artist(path_patch)
def init_lidar_axis():
#reset_ax(model['ax_dtm']
_init_lidar_axis()
init_dtm_contours()
def reset_lidar_axis():
log.info('reset lidar axis')
path_patch = model['dtm_contours']
fig = model['fig']
ax = model['ax_dtm']
path_patch.set_visible(False)
fig.canvas.draw()
model['dtm_bg'] = fig.canvas.copy_from_bbox(ax.bbox)
path_patch.set_visible(True)
def _init_lidar_axis():
log.info('init lidar view')
ax = model['ax_dtm']
fig = model['fig']
# xlim = model['ax_dtm_xlim']
# ylim = model['ax_dtm_ylim']
#reset_ax(ax)
ax.imshow(model['im'], **model['im_kwargs'])
ax.imshow(model['im_hs'], cmap=plt.cm.Greys, alpha=model['im_hs_alpha'])
ax.set_title('LiDAR')
# ax.set_xlim(xlim)
# ax.set_ylim(ylim)
fig.canvas.draw()
model['dtm_bg'] = fig.canvas.copy_from_bbox(ax.bbox)
def load_tile():
model['im'] = rio.open(get_tile(model['lidar_feature'].value, model['tile_name'].value)).read()[0]
model['im_hs'] = ui8_clip(hillshades(model['im']), -1, 1)
model['im_hs_alpha'] = ((model['im_hs'].astype(float) - 127) / 127) ** 2 * .4
kwargs = {}
kwargs['vmin'], kwargs['vmax'] = np.quantile(model['im'], (.01, .99))
model['im_kwargs'] = kwargs
model['im_mesh'] = np.meshgrid(np.arange(model['im'].shape[1]), np.arange(model['im'].shape[0]))
def load_tree():
model['tree_cache'] = model['tree_type'].value(model['im'])
def reset_selections():
model.update({
'x1': 0,
'x2': 0,
'y1': 0,
'y2': 0,
'dtm_x1': 0,
'dtm_x2': 0,
'dtm_y1': 0,
'dtm_y2': 0,
})
model['overlay'].value = True
model['highlight'].value = False
def init_selector():
"""Call after creating plot axes"""
# Selectors
# Setup selector spectrum
model['widget_spectrum'] = mpl.widgets.RectangleSelector(model['ax_spectrum'], line_select_callback,
useblit=True,
button=[1, 3], # don't use middle button
minspanx=5, minspany=5,
spancoords='pixels',
props = dict(facecolor=(0,1,0,.2), edgecolor=(0,1,0), fill=True, ls='--', lw=2),
interactive=True)
# Setup selector DTM
model['widget_dtm'] = mpl.widgets.RectangleSelector(model['ax_dtm'], dtm_select_callback,
useblit=True,
button=[1, 3], # don't use middle button
minspanx=5, minspany=5,
spancoords='pixels',
props = dict(facecolor=(1,1,0,.2), edgecolor=(1,1,0), fill=True, ls='--', lw=2),
interactive=True)
def load_change(change):
#print('Loading tile...')
#load_tile()
#print('Loading tree...')
#load_tree()
#print('Loading complete')
log.info('change triggered redraw')
reset_selections()
#show_spectrum()
init_spectrum_axis()
init_lidar_axis()
#refresh_lidar_axis()
#plt.show()
def resize_callback(resize_event):
log.info('resize callback triggered')
reset_lidar_axis()
# Try to update only when resizing finished
#lw = model['last_width'] if 'last_width' in model else 0
#lh = model['last_height'] if 'last_height' in model else 0
#if resize_event.width == lw and resize_event.height == lh:
# init_lidar_axis()
#
#model['last_width'] = resize_event.width
#model['last_height'] = resize_event.height
def spectrum_change(change):
if change['name'] != 'value':
return
#show_spectrum()
log.info('spetrum change triggered')
#compute_spectrum()
reset_selections()
init_spectrum_axis()
#compute_spectrum_highlight()
#refresh_spectrum_axis()
def spectrum_change_hl(change):
log.info('spetrum highlight toggle change triggered')
refresh_spectrum_axis()
def init_model():
model.update({
'ax_spectrum': None,
'ax_dtm': None,
'x': ipw.Dropdown(description='Attribute X', options=sap.available_attributes().keys(), value='area'),
'y': ipw.Dropdown(description='Attribute Y', options=sap.available_attributes().keys(), value='compactness'),
'x_log': ipw.Checkbox(description='X log', value=True),
'y_log': ipw.Checkbox(description='Y log', value=False),
'highlight': ipw.Checkbox(description='Spectrum higlight'),
'overlay': ipw.Checkbox(description='Filtering overlay', value=True),
'x1': 0,
'x2': 0,
'y1': 0,
'y2': 0,
'dtm_x1': 0,
'dtm_x2': 0,
'dtm_y1': 0,
'dtm_y2': 0,
'pixel_mask': None,
'ax_dtm_xlim': None,
'ax_dtm_ylim': None
})
def init_observer():
model['x_log'].observe(spectrum_change)
model['y_log'].observe(spectrum_change)
model['highlight'].observe(spectrum_change_hl)
model['x'].observe(spectrum_change)
model['y'].observe(spectrum_change)
model = {
}
def init():
log.info('Initialize spectra app')
init_model()
init_observer()