323 lines
9.4 KiB
Python
323 lines
9.4 KiB
Python
#!/usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
# \file supervisor.py
|
|
# \brief TODO
|
|
# \author Florent Guiotte <florent.guiotte@gmail.com>
|
|
# \version 0.1
|
|
# \date 25 juil. 2018
|
|
#
|
|
# TODO details
|
|
|
|
import yaml
|
|
import numpy as np
|
|
import importlib
|
|
import sys
|
|
import hashlib
|
|
from collections import OrderedDict
|
|
import time
|
|
import os
|
|
import datetime
|
|
from sklearn import metrics
|
|
from pathlib import Path
|
|
from operator import itemgetter
|
|
import traceback
|
|
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
sys.path.append('./triskele/python')
|
|
import triskele
|
|
|
|
import logging
|
|
import logger
|
|
|
|
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
|
|
|
|
### Keep yaml ordered, newline string
|
|
def setup_yaml():
|
|
""" https://stackoverflow.com/a/8661021 """
|
|
represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items())
|
|
yaml.add_representer(OrderedDict, represent_dict_order)
|
|
|
|
""" https://stackoverflow.com/a/24291536 """
|
|
yaml.Dumper.org_represent_str = yaml.Dumper.represent_str
|
|
yaml.add_representer(str, repr_str, Dumper=yaml.Dumper)
|
|
|
|
def repr_str(dumper, data):
|
|
if '\n' in data:
|
|
return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')
|
|
return dumper.org_represent_str(data)
|
|
|
|
|
|
setup_yaml()
|
|
|
|
enrichment_dir = Path('./Enrichment/')
|
|
test_dir = enrichment_dir / 'Tests'
|
|
staging_dir = enrichment_dir / 'Staging'
|
|
result_dir = enrichment_dir / 'Results'
|
|
failed_dir = enrichment_dir / 'Failed'
|
|
|
|
class TestError(Exception):
|
|
pass
|
|
|
|
def update_queue():
|
|
tmp_queue = list()
|
|
for child in test_dir.iterdir():
|
|
if child.is_file() and child.suffix == '.yml':
|
|
tmp_queue.append({'expe_file': child, 'priority': get_priority(child)})
|
|
|
|
queue = sorted(tmp_queue, key=itemgetter('priority'))
|
|
return queue
|
|
|
|
|
|
def get_priority(yml_file):
|
|
with open(yml_file) as f:
|
|
expe = OrderedDict(yaml.safe_load(f)['expe'])
|
|
return expe['priority']
|
|
|
|
|
|
def run(expe_file):
|
|
log.info('Run test {}'.format(expe_file))
|
|
with open(expe_file) as f:
|
|
expe = OrderedDict(yaml.safe_load(f)['expe'])
|
|
|
|
### Keep track of time
|
|
kronos = Kronos()
|
|
|
|
### Compute hashes
|
|
log.info('Computing hashes')
|
|
expe_hashes = compute_hashes(expe)
|
|
|
|
### Create output names
|
|
oname = '{}_{}'.format(expe_file.stem, expe_hashes['global'][:6])
|
|
oname_yml = oname + '.yml'
|
|
oname_tif = oname + '.tif'
|
|
|
|
### Create partial report
|
|
expe_report = create_report(kronos)
|
|
|
|
### Stage expe
|
|
log.info('Staging test')
|
|
write_expe_file(staging_dir / oname_yml, expe, expe_hashes, expe_report)
|
|
expe_file.unlink()
|
|
|
|
### Compute descriptors
|
|
log.info('Compute descriptors')
|
|
try:
|
|
descriptors = compute_descriptors(expe)
|
|
except Exception as e:
|
|
kronos.time('description')
|
|
expe_report = create_report(kronos)
|
|
(staging_dir / oname_yml).unlink()
|
|
write_error(failed_dir / oname_yml, expe, expe_hashes, expe_report, 'description', e)
|
|
raise TestError('Error occured during description')
|
|
|
|
kronos.time('description')
|
|
|
|
### Compute classification
|
|
log.info('Classify data')
|
|
try:
|
|
classification = compute_classification(expe, descriptors)
|
|
except Exception as e:
|
|
kronos.time('classification')
|
|
expe_report = create_report(kronos)
|
|
(staging_dir / oname_yml).unlink()
|
|
write_error(failed_dir / oname_yml, expe, expe_hashes, expe_report, 'classification', e)
|
|
raise TestError('Error occured during classification')
|
|
|
|
kronos.time('classification')
|
|
|
|
### Metrics
|
|
log.info('Run initial metrics')
|
|
metrics = run_metrics(expe, classification, descriptors)
|
|
kronos.time('metrics')
|
|
|
|
### Create complete report
|
|
log.info('Write complete report')
|
|
expe_report = create_report(kronos)
|
|
|
|
### Name and write prediction
|
|
triskele.write(result_dir / oname_tif, classification)
|
|
|
|
### Write report and results
|
|
(staging_dir / oname_yml).unlink()
|
|
write_expe_file(result_dir / oname_yml, expe, expe_hashes, expe_report, oname_tif, metrics)
|
|
|
|
log.info('Test complete')
|
|
|
|
|
|
def write_error(file, expe, hashes=None, report=None, when='', e=Exception):
|
|
error = OrderedDict()
|
|
error['when'] = when
|
|
error['what'] = str(e)
|
|
error['where'] = traceback.format_exc()
|
|
with open(file, 'w') as of:
|
|
yaml.dump(OrderedDict({'expe': expe,
|
|
'expe_hashes': hashes,
|
|
'expe_report': report,
|
|
'expe_error': error}),
|
|
of, default_flow_style=False, encoding=None, allow_unicode=True)
|
|
|
|
def write_expe_file(file, expe, hashes=None, report=None, classification=None, results=None):
|
|
with open(file, 'w') as of:
|
|
yaml.dump(OrderedDict({'expe': expe,
|
|
'expe_hashes': hashes,
|
|
'expe_report': report,
|
|
'expe_classification': classification,
|
|
'expe_results': results}),
|
|
of, default_flow_style=False, encoding=None, allow_unicode=True)
|
|
|
|
|
|
def compute_hashes(expe):
|
|
glob = hashlib.sha1()
|
|
|
|
expe_hashes = OrderedDict()
|
|
|
|
for k in ['ground_truth', 'descriptors_script', 'cross_validation', 'classifier']:
|
|
v = str(expe[k]).encode('utf-8')
|
|
expe_hashes[k] = hashlib.sha1(v).hexdigest()
|
|
glob.update(v)
|
|
expe_hashes['global'] = glob.hexdigest()
|
|
return expe_hashes
|
|
|
|
|
|
def compute_descriptors(expe):
|
|
"""Compute descriptors from a standard expe recipe"""
|
|
script = expe['descriptors_script']
|
|
desc = importlib.import_module(script['name'])
|
|
#importlib.reload(Descriptors)
|
|
att = desc.run(**script['parameters'])
|
|
|
|
return att
|
|
|
|
|
|
def get_ground_truth(expe):
|
|
gt = triskele.read(expe['ground_truth'])
|
|
|
|
# Meta labeling
|
|
idx_map = np.arange(gt.max() + 1)
|
|
|
|
if 'meta_labels' in expe:
|
|
meta_idx = pd.read_csv(expe['meta_labels'])
|
|
idx = np.array(meta_idx['index'])
|
|
midx = np.array(meta_idx['metaclass_index'])
|
|
idx_map[idx] = midx
|
|
|
|
return idx_map[gt]
|
|
|
|
|
|
def compute_classification(expe, descriptors):
|
|
"""Read a standard expe recipe and descriptors, return the result classification"""
|
|
# Ground truth
|
|
gt = get_ground_truth(expe)
|
|
|
|
# CrossVal and ML
|
|
cv = expe['cross_validation']
|
|
cl = expe['classifier']
|
|
|
|
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
|
|
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
|
|
|
|
prediction = np.zeros_like(gt)
|
|
|
|
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
|
|
rfc = classifier(**cl['parameters'])
|
|
rfc.fit(xt, yt)
|
|
|
|
ypred = rfc.predict(xv)
|
|
|
|
prediction[ti] = ypred
|
|
|
|
return prediction
|
|
|
|
|
|
def compute_metrics(ground_truth, classification, descriptors):
|
|
"""Return dict of metrics for ground_truth and classification prediction in parameters"""
|
|
f = np.nonzero(classification)
|
|
pred = classification[f].ravel()
|
|
gt = ground_truth[f].ravel()
|
|
|
|
results = OrderedDict()
|
|
results['dimension'] = descriptors.shape[-1]
|
|
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
|
|
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
|
|
|
|
return results
|
|
|
|
|
|
def run_metrics(expe, classification, descriptors):
|
|
"""Compute the metrics from a standard expe recipe and an given classification"""
|
|
|
|
### Extensible: meta-classes
|
|
gt = get_ground_truth(expe)
|
|
return compute_metrics(gt, classification, descriptors)
|
|
|
|
|
|
def create_report(kronos):
|
|
expe_report = OrderedDict()
|
|
|
|
expe_report['supervisor'] = os.uname()[1]
|
|
|
|
for timev, datek in zip((kronos.get_start_date(), kronos.get_end_date()), ('start_date', 'end_date')):
|
|
expe_report[datek] = datetime.datetime.fromtimestamp(timev).strftime('Le %d/%m/%Y à %H:%M:%S') if timev is not None else None
|
|
|
|
ressources = kronos.get_times()
|
|
ressources['ram'] = None
|
|
|
|
expe_report['ressources'] = ressources
|
|
return expe_report
|
|
|
|
|
|
class Kronos(object):
|
|
def __init__(self):
|
|
self._pt = time.process_time()
|
|
self._times = OrderedDict()
|
|
self._stime = time.time()
|
|
self._etime = None
|
|
|
|
def time(self, name):
|
|
self._times[name + '_process_time'] = time.process_time() - self._pt
|
|
self._pt = time.process_time()
|
|
self._etime = time.time()
|
|
|
|
def get_times(self):
|
|
return self._times
|
|
|
|
def get_start_date(self):
|
|
return self._stime
|
|
|
|
def get_end_date(self):
|
|
return self._etime
|
|
|
|
|
|
def watch_folder():
|
|
log.info('Waiting for test')
|
|
while not list(test_dir.glob('*.yml')):
|
|
time.sleep(10)
|
|
|
|
def main():
|
|
while(True):
|
|
try:
|
|
queue = update_queue()
|
|
except Exception:
|
|
log.error('Critical exception while updating work queue')
|
|
log.error(traceback.format_exc())
|
|
log.warning('Resuming')
|
|
continue
|
|
if not queue:
|
|
watch_folder()
|
|
continue
|
|
try:
|
|
run(queue.pop()['expe_file'])
|
|
except TestError:
|
|
log.warning('Test failed, error logged. Resuming')
|
|
except Exception:
|
|
log.error('Critical exception while running test. Resuming')
|
|
log.error(traceback.format_exc())
|
|
log.warning('Resuming')
|
|
continue
|
|
|
|
if __name__ == '__main__':
|
|
logger.setup_logging()
|
|
log.info('Starting supervisor')
|
|
main()
|