Refactoring in progress on Supervisor and Jurse

This commit is contained in:
Florent Guiotte 2018-09-09 16:37:02 +02:00
parent 12d0bcd2c8
commit 5eb4180bae
9 changed files with 267 additions and 138 deletions

4
.gitignore vendored
View File

@ -1 +1,3 @@
Enrichment Enrichment/
__pycache__/
Logs/

51
logger.py Normal file
View File

@ -0,0 +1,51 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file %filename%.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 24 avril 2018
#
# from https://fangpenlin.com/posts/2012/08/26/good-logging-practice-in-python/
import os
import logging.config
from pathlib import Path
import yaml
def setup_logging(
default_path='logging.yaml',
default_level=logging.WARN,
env_key='LOG_CFG'
):
"""Setup logging configuration
"""
path = default_path
value = os.getenv(env_key, None)
if value:
path = value
if os.path.exists(path):
with open(path, 'rt') as f:
config = yaml.safe_load(f.read())
makedirs(config)
logging.config.dictConfig(config)
else:
logging.basicConfig(level=default_level)
def makedirs(dic):
files = finddirs(dic)
for f in files:
d = Path(*f.parts[:-1])
d.mkdir(parents=True, exist_ok=True)
def finddirs(dic, key='filename'):
r = list()
value = dic.get(key)
if value : r.append(Path(value))
for k, v in dic.items():
if isinstance(v, dict):
r.extend(finddirs(v))
return r

40
logging.yaml Normal file
View File

@ -0,0 +1,40 @@
version: 1
disable_existing_loggers: False
formatters:
simple:
format: "%(asctime)s %(levelname)s %(name)s: %(message)s"
handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: simple
stream: ext://sys.stdout
info_file_handler:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: simple
filename: Logs/info.log
maxBytes: 10485760 # 10MB
backupCount: 20
encoding: utf8
error_file_handler:
class: logging.handlers.RotatingFileHandler
level: ERROR
formatter: simple
filename: Logs/errors.log
maxBytes: 10485760 # 10MB
backupCount: 20
encoding: utf8
loggers:
my_module:
level: DEBUG
handlers: [console]
propagate: no
root:
level: DEBUG
handlers: [console, info_file_handler, error_file_handler]

View File

@ -7,7 +7,7 @@ detail: |
générique. Il faut ajouter le chargement dynamique du protocole puis générique. Il faut ajouter le chargement dynamique du protocole puis
réusiner le fonctionnement du supervisor pour respecter l'esprit universel réusiner le fonctionnement du supervisor pour respecter l'esprit universel
de minigrida. de minigrida.
protocol: JurseSF protocol: Jurse
expe: expe:
ground_truth: ground_truth:
raster: ./Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif raster: ./Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif

9
protocols/__init__.py Normal file
View File

@ -0,0 +1,9 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file __init__.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 09 sept. 2018
#
# TODO details

122
protocols/jurse.py Normal file
View File

@ -0,0 +1,122 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file jurse.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 07 sept. 2018
#
# TODO details
import hashlib
import importlib
from collections import OrderedDict
import numpy as np
import pandas as pd
from sklearn import metrics
# TODO: create package, use dev
import sys
sys.path.append('../triskele/python')
import triskele
from .protocol import Protocol, TestError
class Jurse(Protocol):
"""First JURSE test protocol for LiDAR classification with 2D maps.
This first protocol compute attribute profiles on the whole scene then
split in train and test for a random forest classifier.
"""
def __init__(self, expe):
super().__init__(expe, self.__class__.__name__)
def _get_hashes(self):
hashes = OrderedDict()
glob = hashlib.sha1()
for k in ['ground_truth', 'descriptors_script',
'cross_validation', 'classifier']:
val = str(self._expe[k]).encode('utf-8')
hashes[k] = hashlib.sha1(val).hexdigest()
glob.update(val)
hashes['global'] = glob.hexdigest()
return hashes
def _run(self):
self._log.info('Compute descriptors')
try:
descriptors = self._compute_descriptors()
except Exception:
raise TestError('Error occured during description')
self._log.info('Classify data')
try:
classification = self._compute_classificatin(descriptors)
except Exception:
raise TestError('Error occured during classification')
self._log.info('Run metrics')
self._metrics = self._run_metrics(classification, descriptors)
def _compute_descriptors(self):
script = self._expe['descriptors_script']
desc = importlib.import_module(script['name'])
att = desc.run(**script['parameters'])
return att
def _compute_classification(self, descriptors):
# Ground truth
gt = self._get_ground_truth()
# CrossVal and ML
cv = self._expe['cross_validation']
cl = self._expe['classifier']
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
prediction = np.zeros_like(gt)
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
rfc = classifier(**cl['parameters'])
rfc.fit(xt, yt)
ypred = rfc.predict(xv)
prediction[ti] = ypred
return prediction
def _get_ground_truth(self):
gt_expe = self._expe['ground_truth']
gt = triskele.read(gt_expe['raster'])
# Meta labeling
idx_map = np.arange(gt.max() + 1)
if 'meta_labels' in self._expe:
meta_idx = pd.read_csv(gt_expe['meta_labels'])
idx = np.array(meta_idx['index'])
midx = np.array(meta_idx['metaclass_index'])
idx_map[idx] = midx
return idx_map[gt]
def _run_metrics(self, classification, descriptors):
gt = self._get_ground_truth()
f = np.nonzero(classification)
pred = classification[f].ravel()
gt = gt[f].ravel()
results = OrderedDict()
results['dimension'] = descriptors.shape[-1]
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
return results

View File

@ -1,30 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file jurse_default.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 07 sept. 2018
#
# TODO details
import hashlib
from protocol import Protocol
class Jurse(Protocol):
def __init__(self, expe):
super().__init__(expe, self.__class__.__name__)
def _get_hashes(self):
hashes = OrderedDict()
hashes['global'] = 'Protocol did not override _get_hashes()'
glob = hashlib.sha1()
for k in ['ground_truth', 'descriptors_script', 'cross_validation', 'classifier']:
v = str(expe[k]).encode('utf-8')
hashes[k] = hashlib.sha1(v).hexdigest()
glob.update(v)
hashes['global'] = glob.hexdigest()
return hashes

View File

@ -12,21 +12,23 @@ import logging
import time import time
from collections import OrderedDict from collections import OrderedDict
class Protocol: class Protocol:
def __init__(self, expe, name=None): def __init__(self, expe, name=None):
self._log = logging.getLogger(name) self._log = logging.getLogger(name)
self._expe = expe self._expe = expe
self._name = name self._name = name
self._times = OrderedDict() self._times = OrderedDict()
self._pt = time.process_time() self._log.debug('expe loaded: {}'.format(self._expe))
def get_hashes(self): def get_hashes(self):
self._log.info('Computing hashes') self._log.info('Computing hashes')
return(self._get_hashes()) return(self._get_hashes())
def run(self): def run(self):
self._kronos = Kronos() self._pt = time.process_time()
self._run() self._run()
# TODO: Strop process timer
def get_results(self): def get_results(self):
self._get_results() self._get_results()
@ -60,4 +62,3 @@ class Protocol:
class TestError(Exception): class TestError(Exception):
pass pass

View File

@ -20,25 +20,7 @@ from operator import itemgetter
import traceback import traceback
import logging import logging
import logger import logger
from protocols.protocol import TestError
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
### Keep yaml ordered, newline string
def setup_yaml():
""" https://stackoverflow.com/a/8661021 """
represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items())
yaml.add_representer(OrderedDict, represent_dict_order)
""" https://stackoverflow.com/a/24291536 """
yaml.Dumper.org_represent_str = yaml.Dumper.represent_str
yaml.add_representer(str, repr_str, Dumper=yaml.Dumper)
def repr_str(dumper, data):
if '\n' in data:
return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')
return dumper.org_represent_str(data)
ENRICHMENT_DIR = Path('./Enrichment/') ENRICHMENT_DIR = Path('./Enrichment/')
TEST_DIR = ENRICHMENT_DIR / 'Tests' TEST_DIR = ENRICHMENT_DIR / 'Tests'
@ -46,30 +28,61 @@ STAGING_DIR = ENRICHMENT_DIR / 'Staging'
RESULT_DIR = ENRICHMENT_DIR / 'Results' RESULT_DIR = ENRICHMENT_DIR / 'Results'
FAILED_DIR = ENRICHMENT_DIR / 'Failed' FAILED_DIR = ENRICHMENT_DIR / 'Failed'
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
def setup_yaml():
""" Keep yaml ordered, newline string
from https://stackoverflow.com/a/8661021
"""
represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items())
yaml.add_representer(OrderedDict, represent_dict_order)
""" https://stackoverflow.com/a/24291536 """
yaml.Dumper.org_represent_str = yaml.Dumper.represent_str
yaml.add_representer(str, repr_str, Dumper=yaml.Dumper)
def repr_str(dumper, data):
if '\n' in data:
return dumper.represent_scalar(u'tag:yaml.org,2002:str',
data, style='|')
return dumper.org_represent_str(data)
def update_queue(): def update_queue():
tmp_queue = list() tmp_queue = list()
for child in TEST_DIR.iterdir(): for child in TEST_DIR.iterdir():
if child.is_file() and child.suffix == '.yml': if child.is_file() and child.suffix == '.yml':
tmp_queue.append({'expe_file': child, 'priority': get_priority(child)}) tmp_queue.append({'expe_file': child,
'priority': get_priority(child)})
queue = sorted(tmp_queue, key=itemgetter('priority')) queue = sorted(tmp_queue, key=itemgetter('priority'))
return queue return queue
def get_priority(yml_file): def get_priority(yml_file):
with open(yml_file) as f: with open(yml_file) as f:
expe = OrderedDict(yaml.safe_load(f)) expe = OrderedDict(yaml.safe_load(f))
return expe['priority'] return expe['priority']
def run(expe_file): def run(expe_file):
log.info('Run test {}'.format(expe_file)) log.info('Run test {}'.format(expe_file))
with open(expe_file) as f: with open(expe_file) as f:
test = OrderedDict(yaml.safe_load(f)) test = OrderedDict(yaml.safe_load(f))
### Stage test ### Stage test
### Load protocol ### Load protocol
protocol = getattr(importlib.import_module('protocols.jurse'), test['protocol'])
experience = protocol(test['expe'])
log.info('{} test protocol loaded'.format(experience))
### Write hahes ### Write hahes
hashes = experience.get_hashes()
log.info(hashes)
### Run test ### Run test
@ -80,12 +93,6 @@ def run(expe_file):
### End of test ### End of test
return return
### Keep track of time
kronos = Kronos()
### Compute hashes
log.info('Computing hashes')
expe_hashes = compute_hashes(expe)
### Create output names ### Create output names
oname = '{}_{}'.format(expe_file.stem, expe_hashes['global'][:6]) oname = '{}_{}'.format(expe_file.stem, expe_hashes['global'][:6])
@ -179,79 +186,6 @@ def compute_hashes(expe):
expe_hashes['global'] = glob.hexdigest() expe_hashes['global'] = glob.hexdigest()
return expe_hashes return expe_hashes
def compute_descriptors(expe):
"""Compute descriptors from a standard expe recipe"""
script = expe['descriptors_script']
desc = importlib.import_module(script['name'])
#importlib.reload(Descriptors)
att = desc.run(**script['parameters'])
return att
def get_ground_truth(expe):
gt = triskele.read(expe['ground_truth'])
# Meta labeling
idx_map = np.arange(gt.max() + 1)
if 'meta_labels' in expe:
meta_idx = pd.read_csv(expe['meta_labels'])
idx = np.array(meta_idx['index'])
midx = np.array(meta_idx['metaclass_index'])
idx_map[idx] = midx
return idx_map[gt]
def compute_classification(expe, descriptors):
"""Read a standard expe recipe and descriptors, return the result classification"""
# Ground truth
gt = get_ground_truth(expe)
# CrossVal and ML
cv = expe['cross_validation']
cl = expe['classifier']
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
prediction = np.zeros_like(gt)
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
rfc = classifier(**cl['parameters'])
rfc.fit(xt, yt)
ypred = rfc.predict(xv)
prediction[ti] = ypred
return prediction
def compute_metrics(ground_truth, classification, descriptors):
"""Return dict of metrics for ground_truth and classification prediction in parameters"""
f = np.nonzero(classification)
pred = classification[f].ravel()
gt = ground_truth[f].ravel()
results = OrderedDict()
results['dimension'] = descriptors.shape[-1]
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
return results
def run_metrics(expe, classification, descriptors):
"""Compute the metrics from a standard expe recipe and an given classification"""
### Extensible: meta-classes
gt = get_ground_truth(expe)
return compute_metrics(gt, classification, descriptors)
def create_report(kronos): def create_report(kronos):
expe_report = OrderedDict() expe_report = OrderedDict()
@ -279,7 +213,7 @@ def main():
log.error('Critical exception while updating work queue') log.error('Critical exception while updating work queue')
log.error(traceback.format_exc()) log.error(traceback.format_exc())
log.warning('Resuming') log.warning('Resuming')
continue break # continue
if not queue: if not queue:
watch_folder() watch_folder()
continue continue
@ -291,7 +225,7 @@ def main():
log.error('Critical exception while running test. Resuming') log.error('Critical exception while running test. Resuming')
log.error(traceback.format_exc()) log.error(traceback.format_exc())
log.warning('Resuming') log.warning('Resuming')
continue break # continue
if __name__ == '__main__': if __name__ == '__main__':
logger.setup_logging() logger.setup_logging()