WIP on refactor
This commit is contained in:
parent
131d687d42
commit
d356ff1dd5
@ -50,11 +50,11 @@ def create_session(name, desc, project_name, urgency=1):
|
|||||||
else:
|
else:
|
||||||
print('Session "{}" already exists.'.format(name))
|
print('Session "{}" already exists.'.format(name))
|
||||||
|
|
||||||
def bind_testing():
|
def connect_testing():
|
||||||
db.bind('sqlite', ':memory:')
|
db.bind('sqlite', ':memory:')
|
||||||
db.generate_mapping(create_tables=True)
|
db.generate_mapping(create_tables=True)
|
||||||
|
|
||||||
def bind(credentials_file):
|
def connect(credentials_file):
|
||||||
with open(credentials_file) as f:
|
with open(credentials_file) as f:
|
||||||
credentials = json.load(f)
|
credentials = json.load(f)
|
||||||
db.bind(**credentials)
|
db.bind(**credentials)
|
||||||
|
|||||||
40
minigrida/logging.yaml
Normal file
40
minigrida/logging.yaml
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
version: 1
|
||||||
|
disable_existing_loggers: False
|
||||||
|
formatters:
|
||||||
|
simple:
|
||||||
|
format: "%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||||
|
|
||||||
|
handlers:
|
||||||
|
console:
|
||||||
|
class: logging.StreamHandler
|
||||||
|
level: INFO
|
||||||
|
formatter: simple
|
||||||
|
stream: ext://sys.stdout
|
||||||
|
|
||||||
|
info_file_handler:
|
||||||
|
class: logging.handlers.RotatingFileHandler
|
||||||
|
level: INFO
|
||||||
|
formatter: simple
|
||||||
|
filename: Logs/info.log
|
||||||
|
maxBytes: 10485760 # 10MB
|
||||||
|
backupCount: 20
|
||||||
|
encoding: utf8
|
||||||
|
|
||||||
|
error_file_handler:
|
||||||
|
class: logging.handlers.RotatingFileHandler
|
||||||
|
level: ERROR
|
||||||
|
formatter: simple
|
||||||
|
filename: Logs/errors.log
|
||||||
|
maxBytes: 10485760 # 10MB
|
||||||
|
backupCount: 20
|
||||||
|
encoding: utf8
|
||||||
|
|
||||||
|
loggers:
|
||||||
|
my_module:
|
||||||
|
level: DEBUG
|
||||||
|
handlers: [console]
|
||||||
|
propagate: no
|
||||||
|
|
||||||
|
root:
|
||||||
|
level: DEBUG
|
||||||
|
handlers: [console, info_file_handler, error_file_handler]
|
||||||
128
minigrida/protocols/jurse2.py
Normal file
128
minigrida/protocols/jurse2.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# file jurse2.py
|
||||||
|
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||||
|
# version 0.0
|
||||||
|
# date 26 mai 2020
|
||||||
|
"""Abstract
|
||||||
|
|
||||||
|
doc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import numpy as np
|
||||||
|
from sklearn import metrics
|
||||||
|
import rasterio
|
||||||
|
from .protocol import Protocol, TestError
|
||||||
|
|
||||||
|
|
||||||
|
class Jurse2(Protocol):
|
||||||
|
"""Second JURSE test protocol for LiDAR classification with 2D maps.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, expe):
|
||||||
|
super().__init__(expe, self.__class__.__name__)
|
||||||
|
|
||||||
|
def _get_hashes(self):
|
||||||
|
hashes = OrderedDict()
|
||||||
|
glob = hashlib.sha1()
|
||||||
|
|
||||||
|
for k in ['ground_truth', 'descriptors_script',
|
||||||
|
'cross_validation', 'classifier']:
|
||||||
|
val = str(self._expe[k]).encode('utf-8')
|
||||||
|
hashes[k] = hashlib.sha1(val).hexdigest()
|
||||||
|
glob.update(val)
|
||||||
|
hashes['global'] = glob.hexdigest()
|
||||||
|
|
||||||
|
return hashes
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
self._log.info('Compute descriptors')
|
||||||
|
try:
|
||||||
|
descriptors = self._compute_descriptors()
|
||||||
|
except Exception:
|
||||||
|
raise TestError('Error occured during description')
|
||||||
|
self._time('description')
|
||||||
|
|
||||||
|
self._log.info('Classify data')
|
||||||
|
try:
|
||||||
|
classification = self._compute_classification(descriptors)
|
||||||
|
except Exception:
|
||||||
|
raise TestError('Error occured during classification')
|
||||||
|
self._time('classification')
|
||||||
|
|
||||||
|
self._log.info('Run metrics')
|
||||||
|
metrics = self._run_metrics(classification, descriptors)
|
||||||
|
self._time('metrics')
|
||||||
|
|
||||||
|
cmap = str(self._results_base_name) + '.tif'
|
||||||
|
self._log.info('Saving classification map {}'.format(cmap))
|
||||||
|
triskele.write(cmap, classification)
|
||||||
|
|
||||||
|
results = OrderedDict()
|
||||||
|
results['classification'] = cmap
|
||||||
|
results['metrics'] = metrics
|
||||||
|
self._results = results
|
||||||
|
|
||||||
|
def _compute_descriptors(self):
|
||||||
|
script = self._expe['descriptors_script']
|
||||||
|
|
||||||
|
desc = importlib.import_module(script['name'])
|
||||||
|
att = desc.run(**script['parameters'])
|
||||||
|
|
||||||
|
return att
|
||||||
|
|
||||||
|
def _compute_classification(self, descriptors):
|
||||||
|
# Ground truth
|
||||||
|
gt = self._get_ground_truth()
|
||||||
|
|
||||||
|
# CrossVal and ML
|
||||||
|
cv = self._expe['cross_validation']
|
||||||
|
cl = self._expe['classifier']
|
||||||
|
|
||||||
|
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
|
||||||
|
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
|
||||||
|
|
||||||
|
prediction = np.zeros_like(gt, dtype=np.uint8)
|
||||||
|
|
||||||
|
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
|
||||||
|
rfc = classifier(**cl['parameters'])
|
||||||
|
rfc.fit(xt, yt)
|
||||||
|
|
||||||
|
ypred = rfc.predict(xv)
|
||||||
|
|
||||||
|
prediction[ti] = ypred
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
def _get_ground_truth(self):
|
||||||
|
gt_expe = self._expe['ground_truth']
|
||||||
|
gt = triskele.read(gt_expe['raster'])
|
||||||
|
|
||||||
|
# Meta labeling
|
||||||
|
idx_map = np.arange(gt.max() + 1)
|
||||||
|
|
||||||
|
if 'meta_labels' in gt_expe:
|
||||||
|
meta_idx = pd.read_csv(gt_expe['meta_labels'])
|
||||||
|
idx = np.array(meta_idx['index'])
|
||||||
|
midx = np.array(meta_idx['metaclass_index'])
|
||||||
|
idx_map[idx] = midx
|
||||||
|
|
||||||
|
return idx_map[gt]
|
||||||
|
|
||||||
|
def _get_results(self):
|
||||||
|
return self._results
|
||||||
|
|
||||||
|
def _run_metrics(self, classification, descriptors):
|
||||||
|
gt = self._get_ground_truth()
|
||||||
|
|
||||||
|
f = np.nonzero(classification)
|
||||||
|
pred = classification[f].ravel()
|
||||||
|
gt = gt[f].ravel()
|
||||||
|
|
||||||
|
results = OrderedDict()
|
||||||
|
results['dimensions'] = descriptors.shape[-1]
|
||||||
|
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
|
||||||
|
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
|
||||||
|
|
||||||
|
return results
|
||||||
@ -10,16 +10,12 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
class Protocol:
|
class Protocol:
|
||||||
def __init__(self, expe, name=None):
|
def __init__(self, expe, name=None):
|
||||||
self._log = logging.getLogger(name)
|
self._log = logging.getLogger(name)
|
||||||
self._expe = expe
|
self._expe = expe
|
||||||
self._name = name
|
self._name = name
|
||||||
self._times = OrderedDict()
|
|
||||||
self._results_base_name = None
|
|
||||||
self._log.debug('expe loaded: {}'.format(self._expe))
|
self._log.debug('expe loaded: {}'.format(self._expe))
|
||||||
|
|
||||||
def get_hashes(self):
|
def get_hashes(self):
|
||||||
|
|||||||
@ -3,15 +3,13 @@
|
|||||||
# \file supervisor.py
|
# \file supervisor.py
|
||||||
# \brief TODO
|
# \brief TODO
|
||||||
# \author Florent Guiotte <florent.guiotte@gmail.com>
|
# \author Florent Guiotte <florent.guiotte@gmail.com>
|
||||||
# \version 0.1
|
# \version 2.1
|
||||||
# \date 07 sept. 2018
|
# \date 07 sept. 2018
|
||||||
#
|
#
|
||||||
# TODO details
|
# TODO details
|
||||||
|
|
||||||
import yaml
|
|
||||||
import importlib
|
import importlib
|
||||||
import hashlib
|
import json
|
||||||
from collections import OrderedDict
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
@ -21,34 +19,10 @@ import traceback
|
|||||||
import logging
|
import logging
|
||||||
import logger
|
import logger
|
||||||
from protocols.protocol import TestError
|
from protocols.protocol import TestError
|
||||||
|
import database
|
||||||
ENRICHMENT_DIR = Path('./Enrichment/')
|
|
||||||
TEST_DIR = ENRICHMENT_DIR / 'Tests'
|
|
||||||
STAGING_DIR = ENRICHMENT_DIR / 'Staging'
|
|
||||||
RESULT_DIR = ENRICHMENT_DIR / 'Results'
|
|
||||||
FAILED_DIR = ENRICHMENT_DIR / 'Failed'
|
|
||||||
|
|
||||||
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
|
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
|
||||||
|
|
||||||
def setup_yaml():
|
|
||||||
""" Keep yaml ordered, newline string
|
|
||||||
from https://stackoverflow.com/a/8661021
|
|
||||||
"""
|
|
||||||
represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items())
|
|
||||||
yaml.add_representer(OrderedDict, represent_dict_order)
|
|
||||||
|
|
||||||
""" https://stackoverflow.com/a/24291536 """
|
|
||||||
yaml.Dumper.org_represent_str = yaml.Dumper.represent_str
|
|
||||||
yaml.add_representer(str, repr_str, Dumper=yaml.Dumper)
|
|
||||||
|
|
||||||
|
|
||||||
def repr_str(dumper, data):
|
|
||||||
if '\n' in data:
|
|
||||||
return dumper.represent_scalar(u'tag:yaml.org,2002:str',
|
|
||||||
data, style='|')
|
|
||||||
return dumper.org_represent_str(data)
|
|
||||||
|
|
||||||
|
|
||||||
def update_queue():
|
def update_queue():
|
||||||
tmp_queue = list()
|
tmp_queue = list()
|
||||||
for child in TEST_DIR.iterdir():
|
for child in TEST_DIR.iterdir():
|
||||||
@ -188,28 +162,12 @@ class ExpePath:
|
|||||||
def get_result_path(self):
|
def get_result_path(self):
|
||||||
return Path(RESULT_DIR) / self._get_hash_name()
|
return Path(RESULT_DIR) / self._get_hash_name()
|
||||||
|
|
||||||
def _check_hash(self, expe):
|
def get_database():
|
||||||
if self._hash is None:
|
|
||||||
if 'hashes' in expe:
|
|
||||||
self._hash = expe['hashes']['global']
|
|
||||||
|
|
||||||
def _write(self, path, expe):
|
|
||||||
new_path = Path(path) / self._get_complete_name()
|
|
||||||
with open(new_path, 'w') as of:
|
|
||||||
yaml.dump(expe, of,
|
|
||||||
default_flow_style=False,
|
|
||||||
encoding=None,
|
|
||||||
allow_unicode=True)
|
|
||||||
self._actual.unlink()
|
|
||||||
self._actual = new_path
|
|
||||||
|
|
||||||
|
|
||||||
def watch_folder():
|
|
||||||
log.info('Waiting for test')
|
|
||||||
while not list(TEST_DIR.glob('*.yml')):
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
print('Hello again')
|
||||||
|
database
|
||||||
|
return
|
||||||
while(True):
|
while(True):
|
||||||
try:
|
try:
|
||||||
queue = update_queue()
|
queue = update_queue()
|
||||||
@ -239,5 +197,4 @@ if __name__ == '__main__':
|
|||||||
logger.setup_logging()
|
logger.setup_logging()
|
||||||
log.info('Starting supervisor')
|
log.info('Starting supervisor')
|
||||||
|
|
||||||
setup_yaml()
|
|
||||||
main()
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user