diff --git a/minigrida/database/helpers.py b/minigrida/database/helpers.py index 6413c34..bfe382b 100644 --- a/minigrida/database/helpers.py +++ b/minigrida/database/helpers.py @@ -50,11 +50,11 @@ def create_session(name, desc, project_name, urgency=1): else: print('Session "{}" already exists.'.format(name)) -def bind_testing(): +def connect_testing(): db.bind('sqlite', ':memory:') db.generate_mapping(create_tables=True) -def bind(credentials_file): +def connect(credentials_file): with open(credentials_file) as f: credentials = json.load(f) db.bind(**credentials) diff --git a/logger.py b/minigrida/logger.py similarity index 100% rename from logger.py rename to minigrida/logger.py diff --git a/minigrida/logging.yaml b/minigrida/logging.yaml new file mode 100644 index 0000000..b18e6ed --- /dev/null +++ b/minigrida/logging.yaml @@ -0,0 +1,40 @@ +version: 1 +disable_existing_loggers: False +formatters: + simple: + format: "%(asctime)s %(levelname)s %(name)s: %(message)s" + +handlers: + console: + class: logging.StreamHandler + level: INFO + formatter: simple + stream: ext://sys.stdout + + info_file_handler: + class: logging.handlers.RotatingFileHandler + level: INFO + formatter: simple + filename: Logs/info.log + maxBytes: 10485760 # 10MB + backupCount: 20 + encoding: utf8 + + error_file_handler: + class: logging.handlers.RotatingFileHandler + level: ERROR + formatter: simple + filename: Logs/errors.log + maxBytes: 10485760 # 10MB + backupCount: 20 + encoding: utf8 + +loggers: + my_module: + level: DEBUG + handlers: [console] + propagate: no + +root: + level: DEBUG + handlers: [console, info_file_handler, error_file_handler] diff --git a/minigrida/protocols/jurse2.py b/minigrida/protocols/jurse2.py new file mode 100644 index 0000000..aabf715 --- /dev/null +++ b/minigrida/protocols/jurse2.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# file jurse2.py +# author Florent Guiotte +# version 0.0 +# date 26 mai 2020 +"""Abstract + +doc. +""" + +import importlib +import numpy as np +from sklearn import metrics +import rasterio +from .protocol import Protocol, TestError + + +class Jurse2(Protocol): + """Second JURSE test protocol for LiDAR classification with 2D maps. + + """ + + def __init__(self, expe): + super().__init__(expe, self.__class__.__name__) + + def _get_hashes(self): + hashes = OrderedDict() + glob = hashlib.sha1() + + for k in ['ground_truth', 'descriptors_script', + 'cross_validation', 'classifier']: + val = str(self._expe[k]).encode('utf-8') + hashes[k] = hashlib.sha1(val).hexdigest() + glob.update(val) + hashes['global'] = glob.hexdigest() + + return hashes + + def _run(self): + self._log.info('Compute descriptors') + try: + descriptors = self._compute_descriptors() + except Exception: + raise TestError('Error occured during description') + self._time('description') + + self._log.info('Classify data') + try: + classification = self._compute_classification(descriptors) + except Exception: + raise TestError('Error occured during classification') + self._time('classification') + + self._log.info('Run metrics') + metrics = self._run_metrics(classification, descriptors) + self._time('metrics') + + cmap = str(self._results_base_name) + '.tif' + self._log.info('Saving classification map {}'.format(cmap)) + triskele.write(cmap, classification) + + results = OrderedDict() + results['classification'] = cmap + results['metrics'] = metrics + self._results = results + + def _compute_descriptors(self): + script = self._expe['descriptors_script'] + + desc = importlib.import_module(script['name']) + att = desc.run(**script['parameters']) + + return att + + def _compute_classification(self, descriptors): + # Ground truth + gt = self._get_ground_truth() + + # CrossVal and ML + cv = self._expe['cross_validation'] + cl = self._expe['classifier'] + + cross_val = getattr(importlib.import_module(cv['package']), cv['name']) + classifier = getattr(importlib.import_module(cl['package']), cl['name']) + + prediction = np.zeros_like(gt, dtype=np.uint8) + + for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']): + rfc = classifier(**cl['parameters']) + rfc.fit(xt, yt) + + ypred = rfc.predict(xv) + + prediction[ti] = ypred + + return prediction + + def _get_ground_truth(self): + gt_expe = self._expe['ground_truth'] + gt = triskele.read(gt_expe['raster']) + + # Meta labeling + idx_map = np.arange(gt.max() + 1) + + if 'meta_labels' in gt_expe: + meta_idx = pd.read_csv(gt_expe['meta_labels']) + idx = np.array(meta_idx['index']) + midx = np.array(meta_idx['metaclass_index']) + idx_map[idx] = midx + + return idx_map[gt] + + def _get_results(self): + return self._results + + def _run_metrics(self, classification, descriptors): + gt = self._get_ground_truth() + + f = np.nonzero(classification) + pred = classification[f].ravel() + gt = gt[f].ravel() + + results = OrderedDict() + results['dimensions'] = descriptors.shape[-1] + results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred)) + results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred)) + + return results diff --git a/minigrida/protocols/protocol.py b/minigrida/protocols/protocol.py index b53efc8..73cbbab 100644 --- a/minigrida/protocols/protocol.py +++ b/minigrida/protocols/protocol.py @@ -10,16 +10,12 @@ import logging import time -from collections import OrderedDict - class Protocol: def __init__(self, expe, name=None): self._log = logging.getLogger(name) self._expe = expe self._name = name - self._times = OrderedDict() - self._results_base_name = None self._log.debug('expe loaded: {}'.format(self._expe)) def get_hashes(self): diff --git a/supervisor.py b/minigrida/supervisor.py similarity index 77% rename from supervisor.py rename to minigrida/supervisor.py index 63cacb8..61c7907 100644 --- a/supervisor.py +++ b/minigrida/supervisor.py @@ -3,15 +3,13 @@ # \file supervisor.py # \brief TODO # \author Florent Guiotte -# \version 0.1 +# \version 2.1 # \date 07 sept. 2018 # # TODO details -import yaml import importlib -import hashlib -from collections import OrderedDict +import json import time import os import datetime @@ -21,34 +19,10 @@ import traceback import logging import logger from protocols.protocol import TestError - -ENRICHMENT_DIR = Path('./Enrichment/') -TEST_DIR = ENRICHMENT_DIR / 'Tests' -STAGING_DIR = ENRICHMENT_DIR / 'Staging' -RESULT_DIR = ENRICHMENT_DIR / 'Results' -FAILED_DIR = ENRICHMENT_DIR / 'Failed' +import database log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1])) -def setup_yaml(): - """ Keep yaml ordered, newline string - from https://stackoverflow.com/a/8661021 - """ - represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items()) - yaml.add_representer(OrderedDict, represent_dict_order) - - """ https://stackoverflow.com/a/24291536 """ - yaml.Dumper.org_represent_str = yaml.Dumper.represent_str - yaml.add_representer(str, repr_str, Dumper=yaml.Dumper) - - -def repr_str(dumper, data): - if '\n' in data: - return dumper.represent_scalar(u'tag:yaml.org,2002:str', - data, style='|') - return dumper.org_represent_str(data) - - def update_queue(): tmp_queue = list() for child in TEST_DIR.iterdir(): @@ -188,28 +162,12 @@ class ExpePath: def get_result_path(self): return Path(RESULT_DIR) / self._get_hash_name() - def _check_hash(self, expe): - if self._hash is None: - if 'hashes' in expe: - self._hash = expe['hashes']['global'] - - def _write(self, path, expe): - new_path = Path(path) / self._get_complete_name() - with open(new_path, 'w') as of: - yaml.dump(expe, of, - default_flow_style=False, - encoding=None, - allow_unicode=True) - self._actual.unlink() - self._actual = new_path - - -def watch_folder(): - log.info('Waiting for test') - while not list(TEST_DIR.glob('*.yml')): - time.sleep(3) +def get_database(): def main(): + print('Hello again') + database + return while(True): try: queue = update_queue() @@ -239,5 +197,4 @@ if __name__ == '__main__': logger.setup_logging() log.info('Starting supervisor') - setup_yaml() main()