Refactoring in progress on Supervisor and Jurse

This commit is contained in:
Florent Guiotte 2018-09-09 16:37:02 +02:00
parent 12d0bcd2c8
commit 5eb4180bae
9 changed files with 267 additions and 138 deletions

4
.gitignore vendored
View File

@ -1 +1,3 @@
Enrichment
Enrichment/
__pycache__/
Logs/

51
logger.py Normal file
View File

@ -0,0 +1,51 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file %filename%.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 24 avril 2018
#
# from https://fangpenlin.com/posts/2012/08/26/good-logging-practice-in-python/
import os
import logging.config
from pathlib import Path
import yaml
def setup_logging(
default_path='logging.yaml',
default_level=logging.WARN,
env_key='LOG_CFG'
):
"""Setup logging configuration
"""
path = default_path
value = os.getenv(env_key, None)
if value:
path = value
if os.path.exists(path):
with open(path, 'rt') as f:
config = yaml.safe_load(f.read())
makedirs(config)
logging.config.dictConfig(config)
else:
logging.basicConfig(level=default_level)
def makedirs(dic):
files = finddirs(dic)
for f in files:
d = Path(*f.parts[:-1])
d.mkdir(parents=True, exist_ok=True)
def finddirs(dic, key='filename'):
r = list()
value = dic.get(key)
if value : r.append(Path(value))
for k, v in dic.items():
if isinstance(v, dict):
r.extend(finddirs(v))
return r

40
logging.yaml Normal file
View File

@ -0,0 +1,40 @@
version: 1
disable_existing_loggers: False
formatters:
simple:
format: "%(asctime)s %(levelname)s %(name)s: %(message)s"
handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: simple
stream: ext://sys.stdout
info_file_handler:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: simple
filename: Logs/info.log
maxBytes: 10485760 # 10MB
backupCount: 20
encoding: utf8
error_file_handler:
class: logging.handlers.RotatingFileHandler
level: ERROR
formatter: simple
filename: Logs/errors.log
maxBytes: 10485760 # 10MB
backupCount: 20
encoding: utf8
loggers:
my_module:
level: DEBUG
handlers: [console]
propagate: no
root:
level: DEBUG
handlers: [console, info_file_handler, error_file_handler]

View File

@ -7,7 +7,7 @@ detail: |
générique. Il faut ajouter le chargement dynamique du protocole puis
réusiner le fonctionnement du supervisor pour respecter l'esprit universel
de minigrida.
protocol: JurseSF
protocol: Jurse
expe:
ground_truth:
raster: ./Data/ground_truth/2018_IEEE_GRSS_DFC_GT_TR.tif

9
protocols/__init__.py Normal file
View File

@ -0,0 +1,9 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file __init__.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 09 sept. 2018
#
# TODO details

122
protocols/jurse.py Normal file
View File

@ -0,0 +1,122 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file jurse.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 07 sept. 2018
#
# TODO details
import hashlib
import importlib
from collections import OrderedDict
import numpy as np
import pandas as pd
from sklearn import metrics
# TODO: create package, use dev
import sys
sys.path.append('../triskele/python')
import triskele
from .protocol import Protocol, TestError
class Jurse(Protocol):
"""First JURSE test protocol for LiDAR classification with 2D maps.
This first protocol compute attribute profiles on the whole scene then
split in train and test for a random forest classifier.
"""
def __init__(self, expe):
super().__init__(expe, self.__class__.__name__)
def _get_hashes(self):
hashes = OrderedDict()
glob = hashlib.sha1()
for k in ['ground_truth', 'descriptors_script',
'cross_validation', 'classifier']:
val = str(self._expe[k]).encode('utf-8')
hashes[k] = hashlib.sha1(val).hexdigest()
glob.update(val)
hashes['global'] = glob.hexdigest()
return hashes
def _run(self):
self._log.info('Compute descriptors')
try:
descriptors = self._compute_descriptors()
except Exception:
raise TestError('Error occured during description')
self._log.info('Classify data')
try:
classification = self._compute_classificatin(descriptors)
except Exception:
raise TestError('Error occured during classification')
self._log.info('Run metrics')
self._metrics = self._run_metrics(classification, descriptors)
def _compute_descriptors(self):
script = self._expe['descriptors_script']
desc = importlib.import_module(script['name'])
att = desc.run(**script['parameters'])
return att
def _compute_classification(self, descriptors):
# Ground truth
gt = self._get_ground_truth()
# CrossVal and ML
cv = self._expe['cross_validation']
cl = self._expe['classifier']
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
prediction = np.zeros_like(gt)
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
rfc = classifier(**cl['parameters'])
rfc.fit(xt, yt)
ypred = rfc.predict(xv)
prediction[ti] = ypred
return prediction
def _get_ground_truth(self):
gt_expe = self._expe['ground_truth']
gt = triskele.read(gt_expe['raster'])
# Meta labeling
idx_map = np.arange(gt.max() + 1)
if 'meta_labels' in self._expe:
meta_idx = pd.read_csv(gt_expe['meta_labels'])
idx = np.array(meta_idx['index'])
midx = np.array(meta_idx['metaclass_index'])
idx_map[idx] = midx
return idx_map[gt]
def _run_metrics(self, classification, descriptors):
gt = self._get_ground_truth()
f = np.nonzero(classification)
pred = classification[f].ravel()
gt = gt[f].ravel()
results = OrderedDict()
results['dimension'] = descriptors.shape[-1]
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
return results

View File

@ -1,30 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file jurse_default.py
# \brief TODO
# \author Florent Guiotte <florent.guiotte@gmail.com>
# \version 0.1
# \date 07 sept. 2018
#
# TODO details
import hashlib
from protocol import Protocol
class Jurse(Protocol):
def __init__(self, expe):
super().__init__(expe, self.__class__.__name__)
def _get_hashes(self):
hashes = OrderedDict()
hashes['global'] = 'Protocol did not override _get_hashes()'
glob = hashlib.sha1()
for k in ['ground_truth', 'descriptors_script', 'cross_validation', 'classifier']:
v = str(expe[k]).encode('utf-8')
hashes[k] = hashlib.sha1(v).hexdigest()
glob.update(v)
hashes['global'] = glob.hexdigest()
return hashes

View File

@ -12,21 +12,23 @@ import logging
import time
from collections import OrderedDict
class Protocol:
def __init__(self, expe, name=None):
self._log = logging.getLogger(name)
self._expe = expe
self._name = name
self._times = OrderedDict()
self._pt = time.process_time()
self._log.debug('expe loaded: {}'.format(self._expe))
def get_hashes(self):
self._log.info('Computing hashes')
return(self._get_hashes())
def run(self):
self._kronos = Kronos()
self._pt = time.process_time()
self._run()
# TODO: Strop process timer
def get_results(self):
self._get_results()
@ -60,4 +62,3 @@ class Protocol:
class TestError(Exception):
pass

View File

@ -20,25 +20,7 @@ from operator import itemgetter
import traceback
import logging
import logger
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
### Keep yaml ordered, newline string
def setup_yaml():
""" https://stackoverflow.com/a/8661021 """
represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items())
yaml.add_representer(OrderedDict, represent_dict_order)
""" https://stackoverflow.com/a/24291536 """
yaml.Dumper.org_represent_str = yaml.Dumper.represent_str
yaml.add_representer(str, repr_str, Dumper=yaml.Dumper)
def repr_str(dumper, data):
if '\n' in data:
return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')
return dumper.org_represent_str(data)
from protocols.protocol import TestError
ENRICHMENT_DIR = Path('./Enrichment/')
TEST_DIR = ENRICHMENT_DIR / 'Tests'
@ -46,30 +28,61 @@ STAGING_DIR = ENRICHMENT_DIR / 'Staging'
RESULT_DIR = ENRICHMENT_DIR / 'Results'
FAILED_DIR = ENRICHMENT_DIR / 'Failed'
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
def setup_yaml():
""" Keep yaml ordered, newline string
from https://stackoverflow.com/a/8661021
"""
represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items())
yaml.add_representer(OrderedDict, represent_dict_order)
""" https://stackoverflow.com/a/24291536 """
yaml.Dumper.org_represent_str = yaml.Dumper.represent_str
yaml.add_representer(str, repr_str, Dumper=yaml.Dumper)
def repr_str(dumper, data):
if '\n' in data:
return dumper.represent_scalar(u'tag:yaml.org,2002:str',
data, style='|')
return dumper.org_represent_str(data)
def update_queue():
tmp_queue = list()
for child in TEST_DIR.iterdir():
if child.is_file() and child.suffix == '.yml':
tmp_queue.append({'expe_file': child, 'priority': get_priority(child)})
tmp_queue.append({'expe_file': child,
'priority': get_priority(child)})
queue = sorted(tmp_queue, key=itemgetter('priority'))
return queue
def get_priority(yml_file):
with open(yml_file) as f:
expe = OrderedDict(yaml.safe_load(f))
return expe['priority']
def run(expe_file):
log.info('Run test {}'.format(expe_file))
with open(expe_file) as f:
test = OrderedDict(yaml.safe_load(f))
### Stage test
### Load protocol
protocol = getattr(importlib.import_module('protocols.jurse'), test['protocol'])
experience = protocol(test['expe'])
log.info('{} test protocol loaded'.format(experience))
### Write hahes
hashes = experience.get_hashes()
log.info(hashes)
### Run test
@ -80,12 +93,6 @@ def run(expe_file):
### End of test
return
### Keep track of time
kronos = Kronos()
### Compute hashes
log.info('Computing hashes')
expe_hashes = compute_hashes(expe)
### Create output names
oname = '{}_{}'.format(expe_file.stem, expe_hashes['global'][:6])
@ -179,79 +186,6 @@ def compute_hashes(expe):
expe_hashes['global'] = glob.hexdigest()
return expe_hashes
def compute_descriptors(expe):
"""Compute descriptors from a standard expe recipe"""
script = expe['descriptors_script']
desc = importlib.import_module(script['name'])
#importlib.reload(Descriptors)
att = desc.run(**script['parameters'])
return att
def get_ground_truth(expe):
gt = triskele.read(expe['ground_truth'])
# Meta labeling
idx_map = np.arange(gt.max() + 1)
if 'meta_labels' in expe:
meta_idx = pd.read_csv(expe['meta_labels'])
idx = np.array(meta_idx['index'])
midx = np.array(meta_idx['metaclass_index'])
idx_map[idx] = midx
return idx_map[gt]
def compute_classification(expe, descriptors):
"""Read a standard expe recipe and descriptors, return the result classification"""
# Ground truth
gt = get_ground_truth(expe)
# CrossVal and ML
cv = expe['cross_validation']
cl = expe['classifier']
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
prediction = np.zeros_like(gt)
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
rfc = classifier(**cl['parameters'])
rfc.fit(xt, yt)
ypred = rfc.predict(xv)
prediction[ti] = ypred
return prediction
def compute_metrics(ground_truth, classification, descriptors):
"""Return dict of metrics for ground_truth and classification prediction in parameters"""
f = np.nonzero(classification)
pred = classification[f].ravel()
gt = ground_truth[f].ravel()
results = OrderedDict()
results['dimension'] = descriptors.shape[-1]
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
return results
def run_metrics(expe, classification, descriptors):
"""Compute the metrics from a standard expe recipe and an given classification"""
### Extensible: meta-classes
gt = get_ground_truth(expe)
return compute_metrics(gt, classification, descriptors)
def create_report(kronos):
expe_report = OrderedDict()
@ -279,7 +213,7 @@ def main():
log.error('Critical exception while updating work queue')
log.error(traceback.format_exc())
log.warning('Resuming')
continue
break # continue
if not queue:
watch_folder()
continue
@ -291,7 +225,7 @@ def main():
log.error('Critical exception while running test. Resuming')
log.error(traceback.format_exc())
log.warning('Resuming')
continue
break # continue
if __name__ == '__main__':
logger.setup_logging()