135 lines
3.9 KiB
Python
135 lines
3.9 KiB
Python
#!/usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
# \file jurse.py
|
|
# \brief TODO
|
|
# \author Florent Guiotte <florent.guiotte@gmail.com>
|
|
# \version 0.1
|
|
# \date 07 sept. 2018
|
|
#
|
|
# TODO details
|
|
|
|
import hashlib
|
|
import importlib
|
|
from collections import OrderedDict
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn import metrics
|
|
import triskele
|
|
from .protocol import Protocol, TestError
|
|
|
|
|
|
class Jurse(Protocol):
|
|
"""First JURSE test protocol for LiDAR classification with 2D maps.
|
|
|
|
This first protocol compute attribute profiles on the whole scene then
|
|
split in train and test for a random forest classifier.
|
|
|
|
"""
|
|
|
|
def __init__(self, expe):
|
|
super().__init__(expe, self.__class__.__name__)
|
|
|
|
def _get_hashes(self):
|
|
hashes = OrderedDict()
|
|
glob = hashlib.sha1()
|
|
|
|
for k in ['ground_truth', 'descriptors_script',
|
|
'cross_validation', 'classifier']:
|
|
val = str(self._expe[k]).encode('utf-8')
|
|
hashes[k] = hashlib.sha1(val).hexdigest()
|
|
glob.update(val)
|
|
hashes['global'] = glob.hexdigest()
|
|
|
|
return hashes
|
|
|
|
def _run(self):
|
|
self._log.info('Compute descriptors')
|
|
try:
|
|
descriptors = self._compute_descriptors()
|
|
except Exception:
|
|
raise TestError('Error occured during description')
|
|
self._time('description')
|
|
|
|
self._log.info('Classify data')
|
|
try:
|
|
classification = self._compute_classification(descriptors)
|
|
except Exception:
|
|
raise TestError('Error occured during classification')
|
|
self._time('classification')
|
|
|
|
self._log.info('Run metrics')
|
|
metrics = self._run_metrics(classification, descriptors)
|
|
self._time('metrics')
|
|
|
|
cmap = str(self._results_base_name) + '.tif'
|
|
self._log.info('Saving classification map {}'.format(cmap))
|
|
triskele.write(cmap, classification)
|
|
|
|
results = OrderedDict()
|
|
results['classification'] = cmap
|
|
results['metrics'] = metrics
|
|
self._results = results
|
|
|
|
def _compute_descriptors(self):
|
|
script = self._expe['descriptors_script']
|
|
|
|
desc = importlib.import_module(script['name'])
|
|
att = desc.run(**script['parameters'])
|
|
|
|
return att
|
|
|
|
def _compute_classification(self, descriptors):
|
|
# Ground truth
|
|
gt = self._get_ground_truth()
|
|
|
|
# CrossVal and ML
|
|
cv = self._expe['cross_validation']
|
|
cl = self._expe['classifier']
|
|
|
|
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
|
|
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
|
|
|
|
prediction = np.zeros_like(gt, dtype=np.uint8)
|
|
|
|
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
|
|
rfc = classifier(**cl['parameters'])
|
|
rfc.fit(xt, yt)
|
|
|
|
ypred = rfc.predict(xv)
|
|
|
|
prediction[ti] = ypred
|
|
|
|
return prediction
|
|
|
|
def _get_ground_truth(self):
|
|
gt_expe = self._expe['ground_truth']
|
|
gt = triskele.read(gt_expe['raster'])
|
|
|
|
# Meta labeling
|
|
idx_map = np.arange(gt.max() + 1)
|
|
|
|
if 'meta_labels' in gt_expe:
|
|
meta_idx = pd.read_csv(gt_expe['meta_labels'])
|
|
idx = np.array(meta_idx['index'])
|
|
midx = np.array(meta_idx['metaclass_index'])
|
|
idx_map[idx] = midx
|
|
|
|
return idx_map[gt]
|
|
|
|
def _get_results(self):
|
|
return self._results
|
|
|
|
def _run_metrics(self, classification, descriptors):
|
|
gt = self._get_ground_truth()
|
|
|
|
f = np.nonzero(classification)
|
|
pred = classification[f].ravel()
|
|
gt = gt[f].ravel()
|
|
|
|
results = OrderedDict()
|
|
results['dimensions'] = descriptors.shape[-1]
|
|
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
|
|
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
|
|
|
|
return results
|