import logging import matplotlib as mpl from matplotlib import path, patches, pyplot as plt import numpy as np from skimage import measure import ipywidgets as ipw import higra as hg import sap from sap.spectra import get_bins log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) log.addHandler(ch) def compute_contours(image, threshold): contours = measure.find_contours(image > threshold, 0, positive_orientation='high', fully_connected='high') return contours def convert_contours_to_verts(contours): if len(contours) == 0: return np.zeros((1, 2)), None codes = [] for c in contours: code = np.repeat(path.Path.LINETO, len(c)) code[0] = path.Path.MOVETO codes += [code] verts = np.concatenate(contours)[:, ::-1] codes = np.concatenate(codes) return verts, codes def get_contours_verts(image, threshold): return convert_contours_to_verts(compute_contours(image, threshold)) def ui8_clip(x, vmin=None, vmax=None): vmin = vmin if vmin else np.nanmin(x) vmax = vmax if vmax else np.nanmax(x) ui8 = ((np.clip(x, vmin, vmax) - vmin) / (vmax - vmin) * 255).astype(np.uint8) return np.ma.array(ui8, mask=np.isnan(x)) def hillshades(x): hs = np.zeros_like(x) hs[:-1,:-1] = x[:-1,:-1] - x[1:,1:] return hs def pixel_to_node(tree, mask): """Compute the node mask from the pixel mask.""" node_mask = hg.accumulate_and_min_sequential(tree._tree, np.ones(tree._tree.num_vertices(), dtype=np.uint8), mask.ravel(), hg.Accumulators.min).astype(np.bool) return node_mask def load_image(image): model['im'] = image model['im_hs'] = ui8_clip(hillshades(model['im']), -1, 1) model['im_hs_alpha'] = ((model['im_hs'].astype(float) - 127) / 127) ** 2 * .4 kwargs = {} kwargs['vmin'], kwargs['vmax'] = np.quantile(model['im'], (.01, .99)) model['im_kwargs'] = kwargs model['im_mesh'] = np.meshgrid(np.arange(model['im'].shape[1]), np.arange(model['im'].shape[0])) model['tree_cache'] = sap.MaxTree(model['im']) def line_select_callback(eclick, erelease): log.info('spectrum mouse release callback triggered') model['x1'], model['y1'] = eclick.xdata, eclick.ydata model['x2'], model['y2'] = erelease.xdata, erelease.ydata refresh_lidar_axis() def dtm_select_callback(eclick, erelease): log.info('lidar mouse release callback triggered') model['dtm_y1'], model['dtm_x1'] = np.int(eclick.xdata), np.int(eclick.ydata) model['dtm_y2'], model['dtm_x2'] = np.int(erelease.xdata), np.int(erelease.ydata) filter_nodes() def filter_nodes(): log.info('filter selected nodes') x1, x2 = model['dtm_x1'], model['dtm_x2'] y1, y2 = model['dtm_y1'], model['dtm_y2'] pmask = np.zeros_like(model['im'], dtype=np.bool) pmask[x1:x2,y1:y2] = True model['pixel_mask'] = pmask compute_spectrum_highlight() if model['highlight'].value: refresh_spectrum_axis() else: # This trigger spectrum refresh model['highlight'].value = True def init_spectrum_axis(): log.info('init spectrum axis...') compute_spectrum() fig = model['fig'] ax = model['ax_spectrum'] x, y = model['x'].value, model['y'].value spectrum, xedges, yedges, x_log, y_log = model['spectrum'] #x_log, y_log = model['x_log'].value, model['y_log'].value #void = np.ma.zeros((100, 200)) #void.mask = True #reset_ax(ax) ax.clear() pc = ax.pcolor(xedges, yedges, spectrum.T, norm=mpl.colors.LogNorm(), animated=False) if x_log: ax.set_xscale('log') if y_log: ax.set_yscale('log') ax.set_xlabel(x) ax.set_ylabel(y) ax.set_title('Pattern Spectrum') ax.grid(True) ax.grid(which='minor', color='#BBBBBB', linestyle=':') model['spectrum_pcolor'] = pc fig.canvas.draw() log.info('init spectrum axis done') def _compute_spectrum(node_mask=None): tree = model['tree_cache'] x, y = model['x'].value, model['y'].value x_log, y_log = model['x_log'].value, model['y_log'].value ps = sap.spectrum2d(tree, x, y, 200, 100, x_log, y_log, node_mask=node_mask) ps = (np.ma.array(ps[0], mask=ps[0]==0),) + ps[1:] return ps def compute_spectrum(): log.info('compute spectrum...') model['spectrum'] = _compute_spectrum() log.info('compute spectrum done') def compute_spectrum_highlight(): log.info('compute spectrum highlight...') tree = model['tree_cache'] mask = model['pixel_mask'] spectrum = model['spectrum'] node_mask = pixel_to_node(tree, mask)# if mask is not None else None #model['spectrum_highlight'] = _compute_spectrum(node_mask) spectrum_mask = _compute_spectrum(node_mask) spectrum_highlight = np.zeros_like(spectrum[0]) #spectrum_highlight = np.tile(np.nan, spectrum[0].shape) spectrum_highlight = np.divide(spectrum_mask[0], spectrum[0], where=spectrum[0]!=0, out=spectrum_highlight) spectrum_highlight.mask = spectrum[0].mask model['spectrum_highlight'] = spectrum_highlight log.info('compute spectrum highlight done') def refresh_spectrum_axis(): log.info('refresh spectrum axis...') fig = model['fig'] pc = model['spectrum_pcolor'] highlight = model['highlight'].value spectrum = model['spectrum_highlight'] if highlight else model['spectrum'][0] pc.set_array(spectrum.T.compressed()) pc.set_norm(mpl.colors.Normalize(0, 1) if highlight else mpl.colors.LogNorm()) # TODO: if highlight remove norm log and set to [0, 1] #norm=mpl.colors.LogNorm() if log_scale else None) fig.draw_artist(pc) fig.canvas.blit(pc.clipbox) #fig.canvas.flush_events() model['widget_spectrum'].update_background(None) log.info('refresh spectrum axis done') def show_spectrum(): assert False, 'LEGACY FUNCTION NEVER CALLED DAMNIT' x, y = model['x'].value, model['y'].value x_log, y_log = model['x_log'].value, model['y_log'].value highlight = model['highlight'].value # Compute spectrum log.info('compute spectrum...') ps = sap.spectrum2d(model['tree_cache'], x, y, 200, 100, x_log, y_log) log.info('compute done') if highlight: node_mask = pixel_to_node(model['tree_cache'], model['pixel_mask']) if model['pixel_mask'] is not None else None log.info('compute masked spectrum...') ps_mask = sap.spectrum2d(model['tree_cache'], x, y, 200, 100, x_log, y_log, node_mask=node_mask) ps[0][:] = np.divide(ps_mask[0], ps[0], where=ps[0]!=0, out=np.tile(np.nan, ps_mask[0].shape)) log.info('compute done') # Display spectrum plt.sca(model['ax_spectrum']) plt.cla() log.info('draw spectrum...') sap.show_spectrum(*ps, log_scale=not highlight) plt.xlabel(x) plt.ylabel(y) plt.grid(which='minor', color='#BBBBBB', linestyle=':') plt.grid(True) log.info('draw complete') #plt.gcf().canvas.draw_idle() plt.gcf().canvas.blit(model['ax_spectrum'].bbox) log.info('matplotlib redraw complete') # Refactor selector? #init_selector() def get_attribute_extends(): x1, x2 = model['x1'], model['x2'] y1, y2 = model['y1'], model['y2'] return x1, x2, y1, y2 def filter_tree(): log.info('filtering the tree...') x1, x2, y1, y2 = get_attribute_extends() t = model['tree_cache'] X = t.get_attribute(model['x'].value) Y = t.get_attribute(model['y'].value) im_f = model['tree_cache'].reconstruct((X < x1) | (X > x2) | (Y < y1) | (Y > y2), filtering='subtractive') model['im_f'] = im_f log.info('tree filtering complete') def refresh_lidar_axis(): log.info('refresh lidar view...') x1, x2, y1, y2 = get_attribute_extends() # No filter if x1 == x2 == y1 == y2 == 0: log.info('no filtering applied') return # Filter filter_tree() draw_lidar_contours() #model['ax_dtm'].contour(*model['im_mesh'], img, [model['tree_cache']._alt[-1]], colors='lime') log.info('refresh complete') #plt.gcf().canvas.draw_idle() #log.info('matplotlib redraw complete') def reset_ax(ax): ax.clear() ax.collections.clear() ax.patches.clear() def draw_lidar_contours(): log.info('draw lidar contours...') path_patch = model['dtm_contours'] contour_path = path_patch.get_path() im_f = model['im_f'] tree = model['tree_cache'] root_val = tree._alt[-1] contour_path.vertices, contour_path.codes = get_contours_verts(im_f, root_val) fig = model['fig'] fig.canvas.restore_region(model['dtm_bg']) fig.draw_artist(path_patch) fig.canvas.blit(path_patch.clipbox) model['widget_dtm'].update_background(None) #fig.canvas.flush_events() def init_dtm_contours(): contour_path = path.Path([(0,0)]) path_patch = patches.PathPatch(contour_path, color='lime', linewidth=2, fill=False, animated=True) model['dtm_contours'] = path_patch model['ax_dtm'].add_artist(path_patch) def init_lidar_axis(): #reset_ax(model['ax_dtm'] _init_lidar_axis() init_dtm_contours() def reset_lidar_axis(): log.info('reset lidar axis') path_patch = model['dtm_contours'] fig = model['fig'] ax = model['ax_dtm'] path_patch.set_visible(False) fig.canvas.draw() model['dtm_bg'] = fig.canvas.copy_from_bbox(ax.bbox) path_patch.set_visible(True) def _init_lidar_axis(): log.info('init lidar view') ax = model['ax_dtm'] fig = model['fig'] # xlim = model['ax_dtm_xlim'] # ylim = model['ax_dtm_ylim'] #reset_ax(ax) ax.imshow(model['im'], **model['im_kwargs']) ax.imshow(model['im_hs'], cmap=plt.cm.Greys, alpha=model['im_hs_alpha']) ax.set_title('LiDAR') # ax.set_xlim(xlim) # ax.set_ylim(ylim) fig.canvas.draw() model['dtm_bg'] = fig.canvas.copy_from_bbox(ax.bbox) def load_tile(): model['im'] = rio.open(get_tile(model['lidar_feature'].value, model['tile_name'].value)).read()[0] model['im_hs'] = ui8_clip(hillshades(model['im']), -1, 1) model['im_hs_alpha'] = ((model['im_hs'].astype(float) - 127) / 127) ** 2 * .4 kwargs = {} kwargs['vmin'], kwargs['vmax'] = np.quantile(model['im'], (.01, .99)) model['im_kwargs'] = kwargs model['im_mesh'] = np.meshgrid(np.arange(model['im'].shape[1]), np.arange(model['im'].shape[0])) def load_tree(): model['tree_cache'] = model['tree_type'].value(model['im']) def reset_selections(): model.update({ 'x1': 0, 'x2': 0, 'y1': 0, 'y2': 0, 'dtm_x1': 0, 'dtm_x2': 0, 'dtm_y1': 0, 'dtm_y2': 0, }) model['overlay'].value = True model['highlight'].value = False def init_selector(): """Call after creating plot axes""" # Selectors # Setup selector spectrum model['widget_spectrum'] = mpl.widgets.RectangleSelector(model['ax_spectrum'], line_select_callback, useblit=True, button=[1, 3], # don't use middle button minspanx=5, minspany=5, spancoords='pixels', props = dict(facecolor=(0,1,0,.2), edgecolor=(0,1,0), fill=True, ls='--', lw=2), interactive=True) # Setup selector DTM model['widget_dtm'] = mpl.widgets.RectangleSelector(model['ax_dtm'], dtm_select_callback, useblit=True, button=[1, 3], # don't use middle button minspanx=5, minspany=5, spancoords='pixels', props = dict(facecolor=(1,1,0,.2), edgecolor=(1,1,0), fill=True, ls='--', lw=2), interactive=True) def load_change(change): #print('Loading tile...') #load_tile() #print('Loading tree...') #load_tree() #print('Loading complete') log.info('change triggered redraw') reset_selections() #show_spectrum() init_spectrum_axis() init_lidar_axis() #refresh_lidar_axis() #plt.show() def resize_callback(resize_event): log.info('resize callback triggered') reset_lidar_axis() # Try to update only when resizing finished #lw = model['last_width'] if 'last_width' in model else 0 #lh = model['last_height'] if 'last_height' in model else 0 #if resize_event.width == lw and resize_event.height == lh: # init_lidar_axis() # #model['last_width'] = resize_event.width #model['last_height'] = resize_event.height def spectrum_change(change): if change['name'] != 'value': return #show_spectrum() log.info('spetrum change triggered') #compute_spectrum() reset_selections() init_spectrum_axis() #compute_spectrum_highlight() #refresh_spectrum_axis() def spectrum_change_hl(change): log.info('spetrum highlight toggle change triggered') refresh_spectrum_axis() def init_model(): model.update({ 'ax_spectrum': None, 'ax_dtm': None, 'x': ipw.Dropdown(description='Attribute X', options=sap.available_attributes().keys(), value='area'), 'y': ipw.Dropdown(description='Attribute Y', options=sap.available_attributes().keys(), value='compactness'), 'x_log': ipw.Checkbox(description='X log', value=True), 'y_log': ipw.Checkbox(description='Y log', value=False), 'highlight': ipw.Checkbox(description='Spectrum higlight'), 'overlay': ipw.Checkbox(description='Filtering overlay', value=True), 'x1': 0, 'x2': 0, 'y1': 0, 'y2': 0, 'dtm_x1': 0, 'dtm_x2': 0, 'dtm_y1': 0, 'dtm_y2': 0, 'pixel_mask': None, 'ax_dtm_xlim': None, 'ax_dtm_ylim': None }) def init_observer(): model['x_log'].observe(spectrum_change) model['y_log'].observe(spectrum_change) model['highlight'].observe(spectrum_change_hl) model['x'].observe(spectrum_change) model['y'].observe(spectrum_change) model = { } def init(): log.info('Initialize spectra app') init_model() init_observer()