This commit is contained in:
Florent Guiotte 2020-05-27 08:50:54 +02:00
parent d356ff1dd5
commit bbd62654f8
11 changed files with 231 additions and 251 deletions

7
.gitignore vendored
View File

@ -1,5 +1,8 @@
Enrichment/ Enrichment/
__pycache__/ __pycache__/
Logs/ [Ll]ogs/
[Dd]ata/ [Dd]ata
[Cc]ache
*.egg-info/ *.egg-info/
credentials*

View File

@ -15,7 +15,7 @@ handlers:
class: logging.handlers.RotatingFileHandler class: logging.handlers.RotatingFileHandler
level: INFO level: INFO
formatter: simple formatter: simple
filename: Logs/info.log filename: logs/info.log
maxBytes: 10485760 # 10MB maxBytes: 10485760 # 10MB
backupCount: 20 backupCount: 20
encoding: utf8 encoding: utf8
@ -24,7 +24,7 @@ handlers:
class: logging.handlers.RotatingFileHandler class: logging.handlers.RotatingFileHandler
level: ERROR level: ERROR
formatter: simple formatter: simple
filename: Logs/errors.log filename: logs/errors.log
maxBytes: 10485760 # 10MB maxBytes: 10485760 # 10MB
backupCount: 20 backupCount: 20
encoding: utf8 encoding: utf8

View File

@ -12,8 +12,11 @@ from pony import orm
from datetime import datetime, date from datetime import datetime, date
import json import json
import hashlib import hashlib
import logging
from .design import * from .design import Session, Experiment, Project, db
log = logging.getLogger()
def compute_expe_hash(experiment): def compute_expe_hash(experiment):
return hashlib.md5(json.dumps(experiment, sort_keys=True).encode('utf-8')).hexdigest() return hashlib.md5(json.dumps(experiment, sort_keys=True).encode('utf-8')).hexdigest()
@ -31,7 +34,7 @@ def create_experiment(session_name, protocol, expe, urgency=1):
e = q.first() e = q.first()
e.sessions.add(session) e.sessions.add(session)
else: else:
Experiment(sessions=session, protocol=protocol, expe=experiment, expe_hash=expe_hash) Experiment(sessions=session, protocol=protocol, expe=expe, expe_hash=expe_hash)
@orm.db_session @orm.db_session
def create_project(name): def create_project(name):
@ -50,13 +53,46 @@ def create_session(name, desc, project_name, urgency=1):
else: else:
print('Session "{}" already exists.'.format(name)) print('Session "{}" already exists.'.format(name))
@orm.db_session
def pending_experiments():
return Experiment.select(lambda x: x.status == 'pending').exists()
@orm.db_session
def next_experiment():
# TODO: take session urgency into account
expe = orm.select(e for e in Experiment
if e.status == 'pending'
and e.urgency == orm.max(e.urgency for e in Experiment
if e.status == 'pending')
).random(1)
if expe:
expe = expe[0]
expe.status = 'staging'
return expe
def update_experiment(expe, **params):
try:
_update_experiment(expe, **params)
except orm.DatabaseError as e:
log.error(e)
log.info('Retry update')
_update_experiment(expe, **params)
@orm.db_session
def _update_experiment(expe, **params):
e = Experiment.select(lambda x: x.id == expe.id).first()
for k, v in params.items():
setattr(e, k, v)
def connect_testing(): def connect_testing():
db.bind('sqlite', ':memory:') db.bind('sqlite', ':memory:')
db.generate_mapping(create_tables=True) db.generate_mapping(create_tables=True)
def connect(credentials_file): def connect(credentials_file):
with open(credentials_file) as f: with open(credentials_file) as f:
credentials = json.load(f) credentials = json.load(f)
db.bind(**credentials) db.bind(**credentials)
db.generate_mapping(create_tables=True) db.generate_mapping(create_tables=True)

View File

@ -1,4 +1,4 @@
#!/usr/bin/python j#!/usr/bin/python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# \file dfc_base.py # \file dfc_base.py
# \brief TODO # \brief TODO

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python
# file pixel.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 26 mai 2020
"""Abstract
doc.
"""
import numpy as np
def run(gt, rasters, remove=None):
X = []
y = []
groups = []
for i, (gti, rastersi) in enumerate(zip(gt, rasters)):
# Create vectors
X_raw = np.moveaxis(np.array(list(rastersi.values())), 0, -1)
y_raw = gti
# Remove unwanted label X, y
lbl = np.ones_like(y_raw, dtype=np.bool)
for l in remove if remove else []:
lbl &= y_raw != l
X += [X_raw[lbl]]
y += [y_raw[lbl]]
groups += [np.repeat(i, lbl.sum())]
X = np.concatenate(X)
y = np.concatenate(y)
groups = np.concatenate(groups)
return X, y, groups

View File

@ -0,0 +1,9 @@
#!/usr/bin/env python
# file __init__.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 26 mai 2020
"""Abstract
doc.
"""

32
minigrida/loader/tiles.py Normal file
View File

@ -0,0 +1,32 @@
#!/usr/bin/env python
# file tiles.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 26 mai 2020
"""Abstract
doc.
"""
from pathlib import Path
import rasterio as rio
def run(tiles_count, gt_path, gt_name, rasters_path, rasters_name):
gt = []
rasters = []
for i in range(tiles_count):
gt += [load_tif(raster_path(gt_path, gt_name, i))]
rasters += [{Path(n).stem: load_tif(raster_path(rasters_path, n, i))
for n in rasters_name}]
return gt, rasters
def load_tif(path):
return rio.open(str(path)).read()[0]
def raster_path(path, name, i):
return Path(path) / '{}_{}_{}.tif'.format(Path(name).stem, 0, i)

View File

@ -7,3 +7,6 @@
# \date 09 sept. 2018 # \date 09 sept. 2018
# #
# TODO details # TODO details
#from .jurse import Jurse
from .jurse2 import Jurse2

View File

@ -11,7 +11,6 @@ doc.
import importlib import importlib
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
import rasterio
from .protocol import Protocol, TestError from .protocol import Protocol, TestError
@ -23,58 +22,50 @@ class Jurse2(Protocol):
def __init__(self, expe): def __init__(self, expe):
super().__init__(expe, self.__class__.__name__) super().__init__(expe, self.__class__.__name__)
def _get_hashes(self):
hashes = OrderedDict()
glob = hashlib.sha1()
for k in ['ground_truth', 'descriptors_script',
'cross_validation', 'classifier']:
val = str(self._expe[k]).encode('utf-8')
hashes[k] = hashlib.sha1(val).hexdigest()
glob.update(val)
hashes['global'] = glob.hexdigest()
return hashes
def _run(self): def _run(self):
self._log.info('Load data')
try:
data = self._load_data()
except Exception:
raise TestError('Error occured during data loading')
self._log.info('Compute descriptors') self._log.info('Compute descriptors')
try: try:
descriptors = self._compute_descriptors() descriptors = self._compute_descriptors(data)
except Exception: except Exception:
raise TestError('Error occured during description') raise TestError('Error occured during description')
self._time('description')
self._log.info('Classify data') self._log.info('Classify descriptors')
try: try:
classification = self._compute_classification(descriptors) classification = self._compute_classification(descriptors)
except Exception: except Exception:
raise TestError('Error occured during classification') raise TestError('Error occured during classification')
self._time('classification')
self._log.info('Run metrics') self._log.info('Run metrics')
metrics = self._run_metrics(classification, descriptors) metrics = self._run_metrics(classification, descriptors)
self._time('metrics')
cmap = str(self._results_base_name) + '.tif' results = {}
self._log.info('Saving classification map {}'.format(cmap))
triskele.write(cmap, classification)
results = OrderedDict()
results['classification'] = cmap
results['metrics'] = metrics results['metrics'] = metrics
self._results = results self._results = results
def _compute_descriptors(self): def _load_data(self):
data_loader = self._expe['data_loader']
loader = importlib.import_module(data_loader['name'])
data = loader.run(**data_loader['parameters'])
return data
def _compute_descriptors(self, data):
script = self._expe['descriptors_script'] script = self._expe['descriptors_script']
desc = importlib.import_module(script['name']) desc = importlib.import_module(script['name'])
att = desc.run(**script['parameters']) att = desc.run(*data, **script['parameters'])
return att return att
def _compute_classification(self, descriptors): def _compute_classification(self, descriptors):
# Ground truth X, y, groups = descriptors
gt = self._get_ground_truth()
# CrossVal and ML # CrossVal and ML
cv = self._expe['cross_validation'] cv = self._expe['cross_validation']
@ -83,46 +74,40 @@ class Jurse2(Protocol):
cross_val = getattr(importlib.import_module(cv['package']), cv['name']) cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
classifier = getattr(importlib.import_module(cl['package']), cl['name']) classifier = getattr(importlib.import_module(cl['package']), cl['name'])
prediction = np.zeros_like(gt, dtype=np.uint8) y_pred = np.zeros_like(y)
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']): cvi = cross_val(**cv['parameters'])
rfc = classifier(**cl['parameters']) for train_index, test_index in cvi.split(X, y, groups):
rfc.fit(xt, yt) cli = classifier(**cl['parameters'])
ypred = rfc.predict(xv) self._log.info(' - fit')
cli.fit(X[train_index], y[train_index])
prediction[ti] = ypred self._log.info(' - predict')
y_pred[test_index] = cli.predict(X[test_index])
return prediction return y_pred
def _get_ground_truth(self):
gt_expe = self._expe['ground_truth']
gt = triskele.read(gt_expe['raster'])
# Meta labeling
idx_map = np.arange(gt.max() + 1)
if 'meta_labels' in gt_expe:
meta_idx = pd.read_csv(gt_expe['meta_labels'])
idx = np.array(meta_idx['index'])
midx = np.array(meta_idx['metaclass_index'])
idx_map[idx] = midx
return idx_map[gt]
def _get_results(self): def _get_results(self):
return self._results return self._results
def _run_metrics(self, classification, descriptors): def _run_metrics(self, classification, descriptors):
gt = self._get_ground_truth() X, y_true, groups = descriptors
y_pred = classification
f = np.nonzero(classification) self._log.info(' - Scores')
pred = classification[f].ravel() self.oa = metrics.accuracy_score(y_true, y_pred)
gt = gt[f].ravel() self.aa = metrics.balanced_accuracy_score(y_true, y_pred)
self.k = metrics.cohen_kappa_score(y_true, y_pred)
results = OrderedDict() self._log.info(' - Additional results')
results['dimensions'] = descriptors.shape[-1] p, r, f, s = metrics.precision_recall_fscore_support(y_true, y_pred)
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred)) cm = metrics.confusion_matrix(y_true, y_pred)
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred)) results = {'dimensions': X.shape[-1],
'precision': p.tolist(),
'recall': r.tolist(),
'f1score': f.tolist(),
'support': s.tolist(),
'confusion': cm.tolist()}
return results return results

View File

@ -11,52 +11,41 @@
import logging import logging
import time import time
class Protocol: class Protocol:
def __init__(self, expe, name=None): def __init__(self, expe, name=None):
self._log = logging.getLogger(name) self._log = logging.getLogger(name)
self._expe = expe self._expe = expe
self._name = name self._name = name
self._log.debug('expe loaded: {}'.format(self._expe)) self._log.debug('expe loaded: {}'.format(self._expe))
self.k = None
def get_hashes(self): self.oa = None
self._log.info('Computing hashes') self.aa = None
return(self._get_hashes())
def set_results_base_name(self, base_name):
self._results_base_name = base_name
def run(self): def run(self):
self._pt = time.process_time() spt = time.process_time()
self._run() self._run()
self._pt = time.process_time() - spt
def get_results(self): def get_results(self):
return self._get_results() return self._get_results()
def get_process_time(self): def get_process_time(self):
return self._times return self._pt
def _time(self, process):
self._times[process] = time.process_time() - self._pt
self._pt = time.process_time()
def _get_hashes(self):
self._log.warning('Protocol did not override _get_hashes()')
hashes = OrderedDict()
hashes['global'] = 'Protocol did not override _get_hashes()'
return hashes
def _run(self): def _run(self):
self._log.error('Protocol did not override _run()') self._log.error('Protocol did not override _run()')
raise NotImplementedError('Protocol {} did not override _run()'.format(self)) raise NotImplementedError('Protocol {} did not override _run()'
.format(self))
def _get_results(self): def _get_results(self):
self._log.warning('Protocol did not override _get_results()') self._log.warning('Protocol did not override _get_results()')
results = OrderedDict() results = {}
results['global'] = 'Protocol did not override _get_results()' results['global'] = 'Protocol did not override _get_results()'
return results return results
def __str__(self): def __str__(self):
return('{}'.format(self._name)) return '{}'.format(self._name)
class TestError(Exception): class TestError(Exception):

View File

@ -9,190 +9,76 @@
# TODO details # TODO details
import importlib import importlib
import json
import time import time
import os import os
import datetime from datetime import datetime
from pathlib import Path
from operator import itemgetter
import traceback import traceback
import logging import logging
import logger import logger
from protocols.protocol import TestError from protocols.protocol import TestError
import database import database
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1])) host = os.uname()[1]
log = logging.getLogger('Supervisor [{}]'.format(host))
def update_queue():
tmp_queue = list()
for child in TEST_DIR.iterdir():
if child.is_file() and child.suffix == '.yml':
tmp_queue.append({'expe_file': ExpePath(child),
'priority': get_priority(child)})
queue = sorted(tmp_queue, key=itemgetter('priority'))
return queue
def get_priority(yml_file): def run(expe):
with open(yml_file) as f: database.update_experiment(expe, worker=host, start_date=datetime.now())
expe = OrderedDict(yaml.safe_load(f))
return expe['priority']
# Load protocol
log.info('Load protocol {}'.format(expe.protocol))
protocol_module = importlib.import_module('protocols')
importlib.reload(protocol_module)
protocol = getattr(protocol_module, expe.protocol)
test = protocol(expe.expe)
def run(expe_file): # Run test
start_time = time.time()
log.info('Run test {}'.format(expe_file))
test = expe_file.read()
### Stage experience
expe_file.stage(test)
### Load protocol
try: try:
#protocol = getattr(importlib.import_module(test['protocol']['package']), test['protocol']['name']) test.run()
protocol_module = importlib.import_module(test['protocol']['package'])
importlib.reload(protocol_module)
protocol = getattr(protocol_module, test['protocol']['name'])
experience = protocol(test['expe'])
except Exception as e:
err = 'Could not load protocol from test {}'.format(expe_file)
log.warning(err)
expe_file.error(test, 'loading protocol', e)
raise TestError(err)
log.info('{} test protocol loaded'.format(experience))
### Get hashes
test['hashes'] = experience.get_hashes()
test['report'] = create_report(experience, start_time)
### Stage experience
expe_file.stage(test)
experience.set_results_base_name(expe_file.get_result_path())
### Run test
try:
experience.run()
except Exception as e: except Exception as e:
err = 'Experience error' err = 'Experience error'
log.warning(err) log.warning(err)
expe_file.error(test, 'testing', e) report = {'error': {'name': str(err),
raise TestError(err) 'trace': traceback.format_exc()}}
database.update_experiment(expe, report=report, status='error')
raise TestError(err)
end_time = time.time() # Write report
log.info('Write report')
database.update_experiment(expe,
end_date=datetime.now(),
oa=test.oa,
aa=test.aa,
k=test.k,
report=test.get_results(),
status='complete')
### Get complete report # End of test
test['report'] = create_report(experience, start_time, end_time) log.info('Expe {} complete'.format(expe.expe_hash))
### Get results
test['results'] = experience.get_results()
### Write experience
expe_file.result(test)
### End of test
log.info('Test complete')
def create_report(experience, stime=None, etime=None):
expe_report = OrderedDict()
host = os.getenv("HOST")
expe_report['supervisor'] = host if host is not None else os.uname()[1]
# Dates
for datek, timev in zip(('start_date', 'end_date'), (stime, etime)):
expe_report[datek] = datetime.datetime.fromtimestamp(timev).strftime('Le %d/%m/%Y à %H:%M:%S') if timev is not None else None
# Ressources
ressources = OrderedDict()
ressources['ram'] = None
ressources['proccess_time'] = experience.get_process_time()
expe_report['ressources'] = ressources
return expe_report
class ExpePath:
"""Utility wrapper for expe files.
Extend pathlib Path with staging, result and errors function to move the
test report through the Enrichment center.
"""
def __init__(self, path, hash_length=6):
self._actual = Path(path)
self._base_name = self._actual.stem
self._hash_length = hash_length
self._hash = None
def __str__(self):
return self._get_complete_name()
def read(self):
with open(self._actual) as f:
return OrderedDict(yaml.safe_load(f))
def _get_hash_name(self):
return '{}{}'.format(self._base_name,
'_' + self._hash[:self._hash_length] if self._hash is not None else '')
def _get_complete_name(self):
return self._get_hash_name() + '.yml'
def exists(self):
return self._actual.exists()
def stage(self, expe):
log.info('Staging {}'.format(self._base_name))
self._check_hash(expe)
self._write(STAGING_DIR, expe)
def result(self, expe):
log.info('Write results for test {}'.format(self._base_name))
self._check_hash(expe)
self._write(RESULT_DIR, expe)
def error(self, expe, when='', e=Exception):
error = OrderedDict()
error['when'] = when
error['what'] = str(e)
error['where'] = traceback.format_exc()
expe['error'] = error
self._write(FAILED_DIR, expe)
def get_result_path(self):
return Path(RESULT_DIR) / self._get_hash_name()
def get_database():
def main(): def main():
print('Hello again') log.info('Connecting to database')
database database.connect('credentials.json')
return
while(True): while(True):
try: if not database.pending_experiments():
queue = update_queue() log.info('No pending experiments, waiting...')
except Exception: time.sleep(30)
log.error('Critical exception while updating work queue')
log.error(traceback.format_exc()) else:
log.warning('Resuming') log.info('Loading next experiment')
continue expe = database.next_experiment()
if not queue: if not expe:
watch_folder() continue
continue log.info('Expe {} loaded'.format(expe.expe_hash))
try: try:
expe_file = queue.pop()['expe_file'] run(expe)
while(not expe_file.exists() and queue): except Exception as e:
expe_file = queue.pop()['expe_file'] log.error(e)
if expe_file.exists(): log.error('Error occured on expe {}'.format(expe.id))
run(expe_file)
except TestError:
log.warning('Test failed, error logged. Resuming')
except Exception:
log.error('Critical exception while running test. Resuming')
log.error(traceback.format_exc())
log.warning('Resuming')
continue
if __name__ == '__main__': if __name__ == '__main__':
logger.setup_logging() logger.setup_logging()
log.info('Starting supervisor') log.info('Starting supervisor')