WIP on refactor

This commit is contained in:
Florent Guiotte 2020-05-26 10:02:32 +02:00
parent 131d687d42
commit d356ff1dd5
6 changed files with 177 additions and 56 deletions

View File

@ -50,11 +50,11 @@ def create_session(name, desc, project_name, urgency=1):
else: else:
print('Session "{}" already exists.'.format(name)) print('Session "{}" already exists.'.format(name))
def bind_testing(): def connect_testing():
db.bind('sqlite', ':memory:') db.bind('sqlite', ':memory:')
db.generate_mapping(create_tables=True) db.generate_mapping(create_tables=True)
def bind(credentials_file): def connect(credentials_file):
with open(credentials_file) as f: with open(credentials_file) as f:
credentials = json.load(f) credentials = json.load(f)
db.bind(**credentials) db.bind(**credentials)

40
minigrida/logging.yaml Normal file
View File

@ -0,0 +1,40 @@
version: 1
disable_existing_loggers: False
formatters:
simple:
format: "%(asctime)s %(levelname)s %(name)s: %(message)s"
handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: simple
stream: ext://sys.stdout
info_file_handler:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: simple
filename: Logs/info.log
maxBytes: 10485760 # 10MB
backupCount: 20
encoding: utf8
error_file_handler:
class: logging.handlers.RotatingFileHandler
level: ERROR
formatter: simple
filename: Logs/errors.log
maxBytes: 10485760 # 10MB
backupCount: 20
encoding: utf8
loggers:
my_module:
level: DEBUG
handlers: [console]
propagate: no
root:
level: DEBUG
handlers: [console, info_file_handler, error_file_handler]

View File

@ -0,0 +1,128 @@
#!/usr/bin/env python
# file jurse2.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 26 mai 2020
"""Abstract
doc.
"""
import importlib
import numpy as np
from sklearn import metrics
import rasterio
from .protocol import Protocol, TestError
class Jurse2(Protocol):
"""Second JURSE test protocol for LiDAR classification with 2D maps.
"""
def __init__(self, expe):
super().__init__(expe, self.__class__.__name__)
def _get_hashes(self):
hashes = OrderedDict()
glob = hashlib.sha1()
for k in ['ground_truth', 'descriptors_script',
'cross_validation', 'classifier']:
val = str(self._expe[k]).encode('utf-8')
hashes[k] = hashlib.sha1(val).hexdigest()
glob.update(val)
hashes['global'] = glob.hexdigest()
return hashes
def _run(self):
self._log.info('Compute descriptors')
try:
descriptors = self._compute_descriptors()
except Exception:
raise TestError('Error occured during description')
self._time('description')
self._log.info('Classify data')
try:
classification = self._compute_classification(descriptors)
except Exception:
raise TestError('Error occured during classification')
self._time('classification')
self._log.info('Run metrics')
metrics = self._run_metrics(classification, descriptors)
self._time('metrics')
cmap = str(self._results_base_name) + '.tif'
self._log.info('Saving classification map {}'.format(cmap))
triskele.write(cmap, classification)
results = OrderedDict()
results['classification'] = cmap
results['metrics'] = metrics
self._results = results
def _compute_descriptors(self):
script = self._expe['descriptors_script']
desc = importlib.import_module(script['name'])
att = desc.run(**script['parameters'])
return att
def _compute_classification(self, descriptors):
# Ground truth
gt = self._get_ground_truth()
# CrossVal and ML
cv = self._expe['cross_validation']
cl = self._expe['classifier']
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
prediction = np.zeros_like(gt, dtype=np.uint8)
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
rfc = classifier(**cl['parameters'])
rfc.fit(xt, yt)
ypred = rfc.predict(xv)
prediction[ti] = ypred
return prediction
def _get_ground_truth(self):
gt_expe = self._expe['ground_truth']
gt = triskele.read(gt_expe['raster'])
# Meta labeling
idx_map = np.arange(gt.max() + 1)
if 'meta_labels' in gt_expe:
meta_idx = pd.read_csv(gt_expe['meta_labels'])
idx = np.array(meta_idx['index'])
midx = np.array(meta_idx['metaclass_index'])
idx_map[idx] = midx
return idx_map[gt]
def _get_results(self):
return self._results
def _run_metrics(self, classification, descriptors):
gt = self._get_ground_truth()
f = np.nonzero(classification)
pred = classification[f].ravel()
gt = gt[f].ravel()
results = OrderedDict()
results['dimensions'] = descriptors.shape[-1]
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
return results

View File

@ -10,16 +10,12 @@
import logging import logging
import time import time
from collections import OrderedDict
class Protocol: class Protocol:
def __init__(self, expe, name=None): def __init__(self, expe, name=None):
self._log = logging.getLogger(name) self._log = logging.getLogger(name)
self._expe = expe self._expe = expe
self._name = name self._name = name
self._times = OrderedDict()
self._results_base_name = None
self._log.debug('expe loaded: {}'.format(self._expe)) self._log.debug('expe loaded: {}'.format(self._expe))
def get_hashes(self): def get_hashes(self):

View File

@ -3,15 +3,13 @@
# \file supervisor.py # \file supervisor.py
# \brief TODO # \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com> # \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1 # \version 2.1
# \date 07 sept. 2018 # \date 07 sept. 2018
# #
# TODO details # TODO details
import yaml
import importlib import importlib
import hashlib import json
from collections import OrderedDict
import time import time
import os import os
import datetime import datetime
@ -21,34 +19,10 @@ import traceback
import logging import logging
import logger import logger
from protocols.protocol import TestError from protocols.protocol import TestError
import database
ENRICHMENT_DIR = Path('./Enrichment/')
TEST_DIR = ENRICHMENT_DIR / 'Tests'
STAGING_DIR = ENRICHMENT_DIR / 'Staging'
RESULT_DIR = ENRICHMENT_DIR / 'Results'
FAILED_DIR = ENRICHMENT_DIR / 'Failed'
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1])) log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
def setup_yaml():
""" Keep yaml ordered, newline string
from https://stackoverflow.com/a/8661021
"""
represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items())
yaml.add_representer(OrderedDict, represent_dict_order)
""" https://stackoverflow.com/a/24291536 """
yaml.Dumper.org_represent_str = yaml.Dumper.represent_str
yaml.add_representer(str, repr_str, Dumper=yaml.Dumper)
def repr_str(dumper, data):
if '\n' in data:
return dumper.represent_scalar(u'tag:yaml.org,2002:str',
data, style='|')
return dumper.org_represent_str(data)
def update_queue(): def update_queue():
tmp_queue = list() tmp_queue = list()
for child in TEST_DIR.iterdir(): for child in TEST_DIR.iterdir():
@ -188,28 +162,12 @@ class ExpePath:
def get_result_path(self): def get_result_path(self):
return Path(RESULT_DIR) / self._get_hash_name() return Path(RESULT_DIR) / self._get_hash_name()
def _check_hash(self, expe): def get_database():
if self._hash is None:
if 'hashes' in expe:
self._hash = expe['hashes']['global']
def _write(self, path, expe):
new_path = Path(path) / self._get_complete_name()
with open(new_path, 'w') as of:
yaml.dump(expe, of,
default_flow_style=False,
encoding=None,
allow_unicode=True)
self._actual.unlink()
self._actual = new_path
def watch_folder():
log.info('Waiting for test')
while not list(TEST_DIR.glob('*.yml')):
time.sleep(3)
def main(): def main():
print('Hello again')
database
return
while(True): while(True):
try: try:
queue = update_queue() queue = update_queue()
@ -239,5 +197,4 @@ if __name__ == '__main__':
logger.setup_logging() logger.setup_logging()
log.info('Starting supervisor') log.info('Starting supervisor')
setup_yaml()
main() main()