This commit is contained in:
Florent Guiotte 2020-05-27 08:50:54 +02:00
parent d356ff1dd5
commit bbd62654f8
11 changed files with 231 additions and 251 deletions

7
.gitignore vendored
View File

@ -1,5 +1,8 @@
Enrichment/
__pycache__/
Logs/
[Dd]ata/
[Ll]ogs/
[Dd]ata
[Cc]ache
*.egg-info/
credentials*

View File

@ -15,7 +15,7 @@ handlers:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: simple
filename: Logs/info.log
filename: logs/info.log
maxBytes: 10485760 # 10MB
backupCount: 20
encoding: utf8
@ -24,7 +24,7 @@ handlers:
class: logging.handlers.RotatingFileHandler
level: ERROR
formatter: simple
filename: Logs/errors.log
filename: logs/errors.log
maxBytes: 10485760 # 10MB
backupCount: 20
encoding: utf8

View File

@ -12,8 +12,11 @@ from pony import orm
from datetime import datetime, date
import json
import hashlib
import logging
from .design import *
from .design import Session, Experiment, Project, db
log = logging.getLogger()
def compute_expe_hash(experiment):
return hashlib.md5(json.dumps(experiment, sort_keys=True).encode('utf-8')).hexdigest()
@ -31,7 +34,7 @@ def create_experiment(session_name, protocol, expe, urgency=1):
e = q.first()
e.sessions.add(session)
else:
Experiment(sessions=session, protocol=protocol, expe=experiment, expe_hash=expe_hash)
Experiment(sessions=session, protocol=protocol, expe=expe, expe_hash=expe_hash)
@orm.db_session
def create_project(name):
@ -50,13 +53,46 @@ def create_session(name, desc, project_name, urgency=1):
else:
print('Session "{}" already exists.'.format(name))
@orm.db_session
def pending_experiments():
return Experiment.select(lambda x: x.status == 'pending').exists()
@orm.db_session
def next_experiment():
# TODO: take session urgency into account
expe = orm.select(e for e in Experiment
if e.status == 'pending'
and e.urgency == orm.max(e.urgency for e in Experiment
if e.status == 'pending')
).random(1)
if expe:
expe = expe[0]
expe.status = 'staging'
return expe
def update_experiment(expe, **params):
try:
_update_experiment(expe, **params)
except orm.DatabaseError as e:
log.error(e)
log.info('Retry update')
_update_experiment(expe, **params)
@orm.db_session
def _update_experiment(expe, **params):
e = Experiment.select(lambda x: x.id == expe.id).first()
for k, v in params.items():
setattr(e, k, v)
def connect_testing():
db.bind('sqlite', ':memory:')
db.generate_mapping(create_tables=True)
def connect(credentials_file):
with open(credentials_file) as f:
credentials = json.load(f)
db.bind(**credentials)
db.generate_mapping(create_tables=True)

View File

@ -1,4 +1,4 @@
#!/usr/bin/python
j#!/usr/bin/python
# -*- coding: utf-8 -*-
# \file dfc_base.py
# \brief TODO

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python
# file pixel.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 26 mai 2020
"""Abstract
doc.
"""
import numpy as np
def run(gt, rasters, remove=None):
X = []
y = []
groups = []
for i, (gti, rastersi) in enumerate(zip(gt, rasters)):
# Create vectors
X_raw = np.moveaxis(np.array(list(rastersi.values())), 0, -1)
y_raw = gti
# Remove unwanted label X, y
lbl = np.ones_like(y_raw, dtype=np.bool)
for l in remove if remove else []:
lbl &= y_raw != l
X += [X_raw[lbl]]
y += [y_raw[lbl]]
groups += [np.repeat(i, lbl.sum())]
X = np.concatenate(X)
y = np.concatenate(y)
groups = np.concatenate(groups)
return X, y, groups

View File

@ -0,0 +1,9 @@
#!/usr/bin/env python
# file __init__.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 26 mai 2020
"""Abstract
doc.
"""

32
minigrida/loader/tiles.py Normal file
View File

@ -0,0 +1,32 @@
#!/usr/bin/env python
# file tiles.py
# author Florent Guiotte <florent.guiotte@irisa.fr>
# version 0.0
# date 26 mai 2020
"""Abstract
doc.
"""
from pathlib import Path
import rasterio as rio
def run(tiles_count, gt_path, gt_name, rasters_path, rasters_name):
gt = []
rasters = []
for i in range(tiles_count):
gt += [load_tif(raster_path(gt_path, gt_name, i))]
rasters += [{Path(n).stem: load_tif(raster_path(rasters_path, n, i))
for n in rasters_name}]
return gt, rasters
def load_tif(path):
return rio.open(str(path)).read()[0]
def raster_path(path, name, i):
return Path(path) / '{}_{}_{}.tif'.format(Path(name).stem, 0, i)

View File

@ -7,3 +7,6 @@
# \date 09 sept. 2018
#
# TODO details
#from .jurse import Jurse
from .jurse2 import Jurse2

View File

@ -11,7 +11,6 @@ doc.
import importlib
import numpy as np
from sklearn import metrics
import rasterio
from .protocol import Protocol, TestError
@ -23,58 +22,50 @@ class Jurse2(Protocol):
def __init__(self, expe):
super().__init__(expe, self.__class__.__name__)
def _get_hashes(self):
hashes = OrderedDict()
glob = hashlib.sha1()
for k in ['ground_truth', 'descriptors_script',
'cross_validation', 'classifier']:
val = str(self._expe[k]).encode('utf-8')
hashes[k] = hashlib.sha1(val).hexdigest()
glob.update(val)
hashes['global'] = glob.hexdigest()
return hashes
def _run(self):
self._log.info('Load data')
try:
data = self._load_data()
except Exception:
raise TestError('Error occured during data loading')
self._log.info('Compute descriptors')
try:
descriptors = self._compute_descriptors()
descriptors = self._compute_descriptors(data)
except Exception:
raise TestError('Error occured during description')
self._time('description')
self._log.info('Classify data')
self._log.info('Classify descriptors')
try:
classification = self._compute_classification(descriptors)
except Exception:
raise TestError('Error occured during classification')
self._time('classification')
self._log.info('Run metrics')
metrics = self._run_metrics(classification, descriptors)
self._time('metrics')
cmap = str(self._results_base_name) + '.tif'
self._log.info('Saving classification map {}'.format(cmap))
triskele.write(cmap, classification)
results = OrderedDict()
results['classification'] = cmap
results = {}
results['metrics'] = metrics
self._results = results
def _compute_descriptors(self):
def _load_data(self):
data_loader = self._expe['data_loader']
loader = importlib.import_module(data_loader['name'])
data = loader.run(**data_loader['parameters'])
return data
def _compute_descriptors(self, data):
script = self._expe['descriptors_script']
desc = importlib.import_module(script['name'])
att = desc.run(**script['parameters'])
att = desc.run(*data, **script['parameters'])
return att
def _compute_classification(self, descriptors):
# Ground truth
gt = self._get_ground_truth()
X, y, groups = descriptors
# CrossVal and ML
cv = self._expe['cross_validation']
@ -83,46 +74,40 @@ class Jurse2(Protocol):
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
prediction = np.zeros_like(gt, dtype=np.uint8)
y_pred = np.zeros_like(y)
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
rfc = classifier(**cl['parameters'])
rfc.fit(xt, yt)
cvi = cross_val(**cv['parameters'])
for train_index, test_index in cvi.split(X, y, groups):
cli = classifier(**cl['parameters'])
ypred = rfc.predict(xv)
self._log.info(' - fit')
cli.fit(X[train_index], y[train_index])
prediction[ti] = ypred
self._log.info(' - predict')
y_pred[test_index] = cli.predict(X[test_index])
return prediction
def _get_ground_truth(self):
gt_expe = self._expe['ground_truth']
gt = triskele.read(gt_expe['raster'])
# Meta labeling
idx_map = np.arange(gt.max() + 1)
if 'meta_labels' in gt_expe:
meta_idx = pd.read_csv(gt_expe['meta_labels'])
idx = np.array(meta_idx['index'])
midx = np.array(meta_idx['metaclass_index'])
idx_map[idx] = midx
return idx_map[gt]
return y_pred
def _get_results(self):
return self._results
def _run_metrics(self, classification, descriptors):
gt = self._get_ground_truth()
X, y_true, groups = descriptors
y_pred = classification
f = np.nonzero(classification)
pred = classification[f].ravel()
gt = gt[f].ravel()
self._log.info(' - Scores')
self.oa = metrics.accuracy_score(y_true, y_pred)
self.aa = metrics.balanced_accuracy_score(y_true, y_pred)
self.k = metrics.cohen_kappa_score(y_true, y_pred)
results = OrderedDict()
results['dimensions'] = descriptors.shape[-1]
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
self._log.info(' - Additional results')
p, r, f, s = metrics.precision_recall_fscore_support(y_true, y_pred)
cm = metrics.confusion_matrix(y_true, y_pred)
results = {'dimensions': X.shape[-1],
'precision': p.tolist(),
'recall': r.tolist(),
'f1score': f.tolist(),
'support': s.tolist(),
'confusion': cm.tolist()}
return results

View File

@ -11,52 +11,41 @@
import logging
import time
class Protocol:
def __init__(self, expe, name=None):
self._log = logging.getLogger(name)
self._expe = expe
self._name = name
self._log.debug('expe loaded: {}'.format(self._expe))
def get_hashes(self):
self._log.info('Computing hashes')
return(self._get_hashes())
def set_results_base_name(self, base_name):
self._results_base_name = base_name
self.k = None
self.oa = None
self.aa = None
def run(self):
self._pt = time.process_time()
spt = time.process_time()
self._run()
self._pt = time.process_time() - spt
def get_results(self):
return self._get_results()
def get_process_time(self):
return self._times
def _time(self, process):
self._times[process] = time.process_time() - self._pt
self._pt = time.process_time()
def _get_hashes(self):
self._log.warning('Protocol did not override _get_hashes()')
hashes = OrderedDict()
hashes['global'] = 'Protocol did not override _get_hashes()'
return hashes
return self._pt
def _run(self):
self._log.error('Protocol did not override _run()')
raise NotImplementedError('Protocol {} did not override _run()'.format(self))
raise NotImplementedError('Protocol {} did not override _run()'
.format(self))
def _get_results(self):
self._log.warning('Protocol did not override _get_results()')
results = OrderedDict()
results = {}
results['global'] = 'Protocol did not override _get_results()'
return results
def __str__(self):
return('{}'.format(self._name))
return '{}'.format(self._name)
class TestError(Exception):

View File

@ -9,190 +9,76 @@
# TODO details
import importlib
import json
import time
import os
import datetime
from pathlib import Path
from operator import itemgetter
from datetime import datetime
import traceback
import logging
import logger
from protocols.protocol import TestError
import database
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
def update_queue():
tmp_queue = list()
for child in TEST_DIR.iterdir():
if child.is_file() and child.suffix == '.yml':
tmp_queue.append({'expe_file': ExpePath(child),
'priority': get_priority(child)})
queue = sorted(tmp_queue, key=itemgetter('priority'))
return queue
host = os.uname()[1]
log = logging.getLogger('Supervisor [{}]'.format(host))
def get_priority(yml_file):
with open(yml_file) as f:
expe = OrderedDict(yaml.safe_load(f))
return expe['priority']
def run(expe):
database.update_experiment(expe, worker=host, start_date=datetime.now())
# Load protocol
log.info('Load protocol {}'.format(expe.protocol))
protocol_module = importlib.import_module('protocols')
importlib.reload(protocol_module)
protocol = getattr(protocol_module, expe.protocol)
test = protocol(expe.expe)
def run(expe_file):
start_time = time.time()
log.info('Run test {}'.format(expe_file))
test = expe_file.read()
### Stage experience
expe_file.stage(test)
### Load protocol
# Run test
try:
#protocol = getattr(importlib.import_module(test['protocol']['package']), test['protocol']['name'])
protocol_module = importlib.import_module(test['protocol']['package'])
importlib.reload(protocol_module)
protocol = getattr(protocol_module, test['protocol']['name'])
experience = protocol(test['expe'])
except Exception as e:
err = 'Could not load protocol from test {}'.format(expe_file)
log.warning(err)
expe_file.error(test, 'loading protocol', e)
raise TestError(err)
log.info('{} test protocol loaded'.format(experience))
### Get hashes
test['hashes'] = experience.get_hashes()
test['report'] = create_report(experience, start_time)
### Stage experience
expe_file.stage(test)
experience.set_results_base_name(expe_file.get_result_path())
### Run test
try:
experience.run()
test.run()
except Exception as e:
err = 'Experience error'
log.warning(err)
expe_file.error(test, 'testing', e)
raise TestError(err)
report = {'error': {'name': str(err),
'trace': traceback.format_exc()}}
database.update_experiment(expe, report=report, status='error')
raise TestError(err)
end_time = time.time()
# Write report
log.info('Write report')
database.update_experiment(expe,
end_date=datetime.now(),
oa=test.oa,
aa=test.aa,
k=test.k,
report=test.get_results(),
status='complete')
### Get complete report
test['report'] = create_report(experience, start_time, end_time)
# End of test
log.info('Expe {} complete'.format(expe.expe_hash))
### Get results
test['results'] = experience.get_results()
### Write experience
expe_file.result(test)
### End of test
log.info('Test complete')
def create_report(experience, stime=None, etime=None):
expe_report = OrderedDict()
host = os.getenv("HOST")
expe_report['supervisor'] = host if host is not None else os.uname()[1]
# Dates
for datek, timev in zip(('start_date', 'end_date'), (stime, etime)):
expe_report[datek] = datetime.datetime.fromtimestamp(timev).strftime('Le %d/%m/%Y à %H:%M:%S') if timev is not None else None
# Ressources
ressources = OrderedDict()
ressources['ram'] = None
ressources['proccess_time'] = experience.get_process_time()
expe_report['ressources'] = ressources
return expe_report
class ExpePath:
"""Utility wrapper for expe files.
Extend pathlib Path with staging, result and errors function to move the
test report through the Enrichment center.
"""
def __init__(self, path, hash_length=6):
self._actual = Path(path)
self._base_name = self._actual.stem
self._hash_length = hash_length
self._hash = None
def __str__(self):
return self._get_complete_name()
def read(self):
with open(self._actual) as f:
return OrderedDict(yaml.safe_load(f))
def _get_hash_name(self):
return '{}{}'.format(self._base_name,
'_' + self._hash[:self._hash_length] if self._hash is not None else '')
def _get_complete_name(self):
return self._get_hash_name() + '.yml'
def exists(self):
return self._actual.exists()
def stage(self, expe):
log.info('Staging {}'.format(self._base_name))
self._check_hash(expe)
self._write(STAGING_DIR, expe)
def result(self, expe):
log.info('Write results for test {}'.format(self._base_name))
self._check_hash(expe)
self._write(RESULT_DIR, expe)
def error(self, expe, when='', e=Exception):
error = OrderedDict()
error['when'] = when
error['what'] = str(e)
error['where'] = traceback.format_exc()
expe['error'] = error
self._write(FAILED_DIR, expe)
def get_result_path(self):
return Path(RESULT_DIR) / self._get_hash_name()
def get_database():
def main():
print('Hello again')
database
return
log.info('Connecting to database')
database.connect('credentials.json')
while(True):
try:
queue = update_queue()
except Exception:
log.error('Critical exception while updating work queue')
log.error(traceback.format_exc())
log.warning('Resuming')
continue
if not queue:
watch_folder()
continue
try:
expe_file = queue.pop()['expe_file']
while(not expe_file.exists() and queue):
expe_file = queue.pop()['expe_file']
if expe_file.exists():
run(expe_file)
except TestError:
log.warning('Test failed, error logged. Resuming')
except Exception:
log.error('Critical exception while running test. Resuming')
log.error(traceback.format_exc())
log.warning('Resuming')
continue
if not database.pending_experiments():
log.info('No pending experiments, waiting...')
time.sleep(30)
else:
log.info('Loading next experiment')
expe = database.next_experiment()
if not expe:
continue
log.info('Expe {} loaded'.format(expe.expe_hash))
try:
run(expe)
except Exception as e:
log.error(e)
log.error('Error occured on expe {}'.format(expe.id))
if __name__ == '__main__':
logger.setup_logging()
log.info('Starting supervisor')