Working
This commit is contained in:
parent
d356ff1dd5
commit
bbd62654f8
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,5 +1,8 @@
|
|||||||
Enrichment/
|
Enrichment/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
Logs/
|
[Ll]ogs/
|
||||||
[Dd]ata/
|
[Dd]ata
|
||||||
|
[Cc]ache
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
|
credentials*
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ handlers:
|
|||||||
class: logging.handlers.RotatingFileHandler
|
class: logging.handlers.RotatingFileHandler
|
||||||
level: INFO
|
level: INFO
|
||||||
formatter: simple
|
formatter: simple
|
||||||
filename: Logs/info.log
|
filename: logs/info.log
|
||||||
maxBytes: 10485760 # 10MB
|
maxBytes: 10485760 # 10MB
|
||||||
backupCount: 20
|
backupCount: 20
|
||||||
encoding: utf8
|
encoding: utf8
|
||||||
@ -24,7 +24,7 @@ handlers:
|
|||||||
class: logging.handlers.RotatingFileHandler
|
class: logging.handlers.RotatingFileHandler
|
||||||
level: ERROR
|
level: ERROR
|
||||||
formatter: simple
|
formatter: simple
|
||||||
filename: Logs/errors.log
|
filename: logs/errors.log
|
||||||
maxBytes: 10485760 # 10MB
|
maxBytes: 10485760 # 10MB
|
||||||
backupCount: 20
|
backupCount: 20
|
||||||
encoding: utf8
|
encoding: utf8
|
||||||
|
|||||||
@ -12,8 +12,11 @@ from pony import orm
|
|||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
|
|
||||||
from .design import *
|
from .design import Session, Experiment, Project, db
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
def compute_expe_hash(experiment):
|
def compute_expe_hash(experiment):
|
||||||
return hashlib.md5(json.dumps(experiment, sort_keys=True).encode('utf-8')).hexdigest()
|
return hashlib.md5(json.dumps(experiment, sort_keys=True).encode('utf-8')).hexdigest()
|
||||||
@ -31,7 +34,7 @@ def create_experiment(session_name, protocol, expe, urgency=1):
|
|||||||
e = q.first()
|
e = q.first()
|
||||||
e.sessions.add(session)
|
e.sessions.add(session)
|
||||||
else:
|
else:
|
||||||
Experiment(sessions=session, protocol=protocol, expe=experiment, expe_hash=expe_hash)
|
Experiment(sessions=session, protocol=protocol, expe=expe, expe_hash=expe_hash)
|
||||||
|
|
||||||
@orm.db_session
|
@orm.db_session
|
||||||
def create_project(name):
|
def create_project(name):
|
||||||
@ -50,13 +53,46 @@ def create_session(name, desc, project_name, urgency=1):
|
|||||||
else:
|
else:
|
||||||
print('Session "{}" already exists.'.format(name))
|
print('Session "{}" already exists.'.format(name))
|
||||||
|
|
||||||
|
@orm.db_session
|
||||||
|
def pending_experiments():
|
||||||
|
return Experiment.select(lambda x: x.status == 'pending').exists()
|
||||||
|
|
||||||
|
@orm.db_session
|
||||||
|
def next_experiment():
|
||||||
|
# TODO: take session urgency into account
|
||||||
|
expe = orm.select(e for e in Experiment
|
||||||
|
if e.status == 'pending'
|
||||||
|
and e.urgency == orm.max(e.urgency for e in Experiment
|
||||||
|
if e.status == 'pending')
|
||||||
|
).random(1)
|
||||||
|
if expe:
|
||||||
|
expe = expe[0]
|
||||||
|
expe.status = 'staging'
|
||||||
|
return expe
|
||||||
|
|
||||||
|
|
||||||
|
def update_experiment(expe, **params):
|
||||||
|
try:
|
||||||
|
_update_experiment(expe, **params)
|
||||||
|
except orm.DatabaseError as e:
|
||||||
|
log.error(e)
|
||||||
|
log.info('Retry update')
|
||||||
|
_update_experiment(expe, **params)
|
||||||
|
|
||||||
|
@orm.db_session
|
||||||
|
def _update_experiment(expe, **params):
|
||||||
|
e = Experiment.select(lambda x: x.id == expe.id).first()
|
||||||
|
for k, v in params.items():
|
||||||
|
setattr(e, k, v)
|
||||||
|
|
||||||
|
|
||||||
def connect_testing():
|
def connect_testing():
|
||||||
db.bind('sqlite', ':memory:')
|
db.bind('sqlite', ':memory:')
|
||||||
db.generate_mapping(create_tables=True)
|
db.generate_mapping(create_tables=True)
|
||||||
|
|
||||||
|
|
||||||
def connect(credentials_file):
|
def connect(credentials_file):
|
||||||
with open(credentials_file) as f:
|
with open(credentials_file) as f:
|
||||||
credentials = json.load(f)
|
credentials = json.load(f)
|
||||||
db.bind(**credentials)
|
db.bind(**credentials)
|
||||||
db.generate_mapping(create_tables=True)
|
db.generate_mapping(create_tables=True)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
#!/usr/bin/python
|
j#!/usr/bin/python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# \file dfc_base.py
|
# \file dfc_base.py
|
||||||
# \brief TODO
|
# \brief TODO
|
||||||
|
|||||||
37
minigrida/descriptors/pixel.py
Normal file
37
minigrida/descriptors/pixel.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# file pixel.py
|
||||||
|
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||||
|
# version 0.0
|
||||||
|
# date 26 mai 2020
|
||||||
|
"""Abstract
|
||||||
|
|
||||||
|
doc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def run(gt, rasters, remove=None):
|
||||||
|
X = []
|
||||||
|
y = []
|
||||||
|
groups = []
|
||||||
|
|
||||||
|
for i, (gti, rastersi) in enumerate(zip(gt, rasters)):
|
||||||
|
# Create vectors
|
||||||
|
X_raw = np.moveaxis(np.array(list(rastersi.values())), 0, -1)
|
||||||
|
y_raw = gti
|
||||||
|
|
||||||
|
# Remove unwanted label X, y
|
||||||
|
lbl = np.ones_like(y_raw, dtype=np.bool)
|
||||||
|
for l in remove if remove else []:
|
||||||
|
lbl &= y_raw != l
|
||||||
|
|
||||||
|
X += [X_raw[lbl]]
|
||||||
|
y += [y_raw[lbl]]
|
||||||
|
groups += [np.repeat(i, lbl.sum())]
|
||||||
|
|
||||||
|
X = np.concatenate(X)
|
||||||
|
y = np.concatenate(y)
|
||||||
|
groups = np.concatenate(groups)
|
||||||
|
|
||||||
|
return X, y, groups
|
||||||
9
minigrida/loader/__init__.py
Normal file
9
minigrida/loader/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# file __init__.py
|
||||||
|
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||||
|
# version 0.0
|
||||||
|
# date 26 mai 2020
|
||||||
|
"""Abstract
|
||||||
|
|
||||||
|
doc.
|
||||||
|
"""
|
||||||
32
minigrida/loader/tiles.py
Normal file
32
minigrida/loader/tiles.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# file tiles.py
|
||||||
|
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||||
|
# version 0.0
|
||||||
|
# date 26 mai 2020
|
||||||
|
"""Abstract
|
||||||
|
|
||||||
|
doc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import rasterio as rio
|
||||||
|
|
||||||
|
|
||||||
|
def run(tiles_count, gt_path, gt_name, rasters_path, rasters_name):
|
||||||
|
gt = []
|
||||||
|
rasters = []
|
||||||
|
|
||||||
|
for i in range(tiles_count):
|
||||||
|
gt += [load_tif(raster_path(gt_path, gt_name, i))]
|
||||||
|
rasters += [{Path(n).stem: load_tif(raster_path(rasters_path, n, i))
|
||||||
|
for n in rasters_name}]
|
||||||
|
|
||||||
|
return gt, rasters
|
||||||
|
|
||||||
|
|
||||||
|
def load_tif(path):
|
||||||
|
return rio.open(str(path)).read()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def raster_path(path, name, i):
|
||||||
|
return Path(path) / '{}_{}_{}.tif'.format(Path(name).stem, 0, i)
|
||||||
@ -7,3 +7,6 @@
|
|||||||
# \date 09 sept. 2018
|
# \date 09 sept. 2018
|
||||||
#
|
#
|
||||||
# TODO details
|
# TODO details
|
||||||
|
|
||||||
|
#from .jurse import Jurse
|
||||||
|
from .jurse2 import Jurse2
|
||||||
|
|||||||
@ -11,7 +11,6 @@ doc.
|
|||||||
import importlib
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
import rasterio
|
|
||||||
from .protocol import Protocol, TestError
|
from .protocol import Protocol, TestError
|
||||||
|
|
||||||
|
|
||||||
@ -23,58 +22,50 @@ class Jurse2(Protocol):
|
|||||||
def __init__(self, expe):
|
def __init__(self, expe):
|
||||||
super().__init__(expe, self.__class__.__name__)
|
super().__init__(expe, self.__class__.__name__)
|
||||||
|
|
||||||
def _get_hashes(self):
|
|
||||||
hashes = OrderedDict()
|
|
||||||
glob = hashlib.sha1()
|
|
||||||
|
|
||||||
for k in ['ground_truth', 'descriptors_script',
|
|
||||||
'cross_validation', 'classifier']:
|
|
||||||
val = str(self._expe[k]).encode('utf-8')
|
|
||||||
hashes[k] = hashlib.sha1(val).hexdigest()
|
|
||||||
glob.update(val)
|
|
||||||
hashes['global'] = glob.hexdigest()
|
|
||||||
|
|
||||||
return hashes
|
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
|
self._log.info('Load data')
|
||||||
|
try:
|
||||||
|
data = self._load_data()
|
||||||
|
except Exception:
|
||||||
|
raise TestError('Error occured during data loading')
|
||||||
|
|
||||||
self._log.info('Compute descriptors')
|
self._log.info('Compute descriptors')
|
||||||
try:
|
try:
|
||||||
descriptors = self._compute_descriptors()
|
descriptors = self._compute_descriptors(data)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise TestError('Error occured during description')
|
raise TestError('Error occured during description')
|
||||||
self._time('description')
|
|
||||||
|
|
||||||
self._log.info('Classify data')
|
self._log.info('Classify descriptors')
|
||||||
try:
|
try:
|
||||||
classification = self._compute_classification(descriptors)
|
classification = self._compute_classification(descriptors)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise TestError('Error occured during classification')
|
raise TestError('Error occured during classification')
|
||||||
self._time('classification')
|
|
||||||
|
|
||||||
self._log.info('Run metrics')
|
self._log.info('Run metrics')
|
||||||
metrics = self._run_metrics(classification, descriptors)
|
metrics = self._run_metrics(classification, descriptors)
|
||||||
self._time('metrics')
|
|
||||||
|
|
||||||
cmap = str(self._results_base_name) + '.tif'
|
results = {}
|
||||||
self._log.info('Saving classification map {}'.format(cmap))
|
|
||||||
triskele.write(cmap, classification)
|
|
||||||
|
|
||||||
results = OrderedDict()
|
|
||||||
results['classification'] = cmap
|
|
||||||
results['metrics'] = metrics
|
results['metrics'] = metrics
|
||||||
self._results = results
|
self._results = results
|
||||||
|
|
||||||
def _compute_descriptors(self):
|
def _load_data(self):
|
||||||
|
data_loader = self._expe['data_loader']
|
||||||
|
|
||||||
|
loader = importlib.import_module(data_loader['name'])
|
||||||
|
data = loader.run(**data_loader['parameters'])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _compute_descriptors(self, data):
|
||||||
script = self._expe['descriptors_script']
|
script = self._expe['descriptors_script']
|
||||||
|
|
||||||
desc = importlib.import_module(script['name'])
|
desc = importlib.import_module(script['name'])
|
||||||
att = desc.run(**script['parameters'])
|
att = desc.run(*data, **script['parameters'])
|
||||||
|
|
||||||
return att
|
return att
|
||||||
|
|
||||||
def _compute_classification(self, descriptors):
|
def _compute_classification(self, descriptors):
|
||||||
# Ground truth
|
X, y, groups = descriptors
|
||||||
gt = self._get_ground_truth()
|
|
||||||
|
|
||||||
# CrossVal and ML
|
# CrossVal and ML
|
||||||
cv = self._expe['cross_validation']
|
cv = self._expe['cross_validation']
|
||||||
@ -83,46 +74,40 @@ class Jurse2(Protocol):
|
|||||||
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
|
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
|
||||||
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
|
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
|
||||||
|
|
||||||
prediction = np.zeros_like(gt, dtype=np.uint8)
|
y_pred = np.zeros_like(y)
|
||||||
|
|
||||||
for xt, xv, yt, yv, ti in cross_val(gt, descriptors, **cv['parameters']):
|
cvi = cross_val(**cv['parameters'])
|
||||||
rfc = classifier(**cl['parameters'])
|
for train_index, test_index in cvi.split(X, y, groups):
|
||||||
rfc.fit(xt, yt)
|
cli = classifier(**cl['parameters'])
|
||||||
|
|
||||||
ypred = rfc.predict(xv)
|
self._log.info(' - fit')
|
||||||
|
cli.fit(X[train_index], y[train_index])
|
||||||
|
|
||||||
prediction[ti] = ypred
|
self._log.info(' - predict')
|
||||||
|
y_pred[test_index] = cli.predict(X[test_index])
|
||||||
|
|
||||||
return prediction
|
return y_pred
|
||||||
|
|
||||||
def _get_ground_truth(self):
|
|
||||||
gt_expe = self._expe['ground_truth']
|
|
||||||
gt = triskele.read(gt_expe['raster'])
|
|
||||||
|
|
||||||
# Meta labeling
|
|
||||||
idx_map = np.arange(gt.max() + 1)
|
|
||||||
|
|
||||||
if 'meta_labels' in gt_expe:
|
|
||||||
meta_idx = pd.read_csv(gt_expe['meta_labels'])
|
|
||||||
idx = np.array(meta_idx['index'])
|
|
||||||
midx = np.array(meta_idx['metaclass_index'])
|
|
||||||
idx_map[idx] = midx
|
|
||||||
|
|
||||||
return idx_map[gt]
|
|
||||||
|
|
||||||
def _get_results(self):
|
def _get_results(self):
|
||||||
return self._results
|
return self._results
|
||||||
|
|
||||||
def _run_metrics(self, classification, descriptors):
|
def _run_metrics(self, classification, descriptors):
|
||||||
gt = self._get_ground_truth()
|
X, y_true, groups = descriptors
|
||||||
|
y_pred = classification
|
||||||
|
|
||||||
f = np.nonzero(classification)
|
self._log.info(' - Scores')
|
||||||
pred = classification[f].ravel()
|
self.oa = metrics.accuracy_score(y_true, y_pred)
|
||||||
gt = gt[f].ravel()
|
self.aa = metrics.balanced_accuracy_score(y_true, y_pred)
|
||||||
|
self.k = metrics.cohen_kappa_score(y_true, y_pred)
|
||||||
|
|
||||||
results = OrderedDict()
|
self._log.info(' - Additional results')
|
||||||
results['dimensions'] = descriptors.shape[-1]
|
p, r, f, s = metrics.precision_recall_fscore_support(y_true, y_pred)
|
||||||
results['overall_accuracy'] = float(metrics.accuracy_score(gt, pred))
|
cm = metrics.confusion_matrix(y_true, y_pred)
|
||||||
results['cohen_kappa'] = float(metrics.cohen_kappa_score(gt, pred))
|
results = {'dimensions': X.shape[-1],
|
||||||
|
'precision': p.tolist(),
|
||||||
|
'recall': r.tolist(),
|
||||||
|
'f1score': f.tolist(),
|
||||||
|
'support': s.tolist(),
|
||||||
|
'confusion': cm.tolist()}
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@ -11,52 +11,41 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
class Protocol:
|
class Protocol:
|
||||||
def __init__(self, expe, name=None):
|
def __init__(self, expe, name=None):
|
||||||
self._log = logging.getLogger(name)
|
self._log = logging.getLogger(name)
|
||||||
self._expe = expe
|
self._expe = expe
|
||||||
self._name = name
|
self._name = name
|
||||||
self._log.debug('expe loaded: {}'.format(self._expe))
|
self._log.debug('expe loaded: {}'.format(self._expe))
|
||||||
|
self.k = None
|
||||||
def get_hashes(self):
|
self.oa = None
|
||||||
self._log.info('Computing hashes')
|
self.aa = None
|
||||||
return(self._get_hashes())
|
|
||||||
|
|
||||||
def set_results_base_name(self, base_name):
|
|
||||||
self._results_base_name = base_name
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self._pt = time.process_time()
|
spt = time.process_time()
|
||||||
self._run()
|
self._run()
|
||||||
|
self._pt = time.process_time() - spt
|
||||||
|
|
||||||
def get_results(self):
|
def get_results(self):
|
||||||
return self._get_results()
|
return self._get_results()
|
||||||
|
|
||||||
def get_process_time(self):
|
def get_process_time(self):
|
||||||
return self._times
|
return self._pt
|
||||||
|
|
||||||
def _time(self, process):
|
|
||||||
self._times[process] = time.process_time() - self._pt
|
|
||||||
self._pt = time.process_time()
|
|
||||||
|
|
||||||
def _get_hashes(self):
|
|
||||||
self._log.warning('Protocol did not override _get_hashes()')
|
|
||||||
hashes = OrderedDict()
|
|
||||||
hashes['global'] = 'Protocol did not override _get_hashes()'
|
|
||||||
return hashes
|
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
self._log.error('Protocol did not override _run()')
|
self._log.error('Protocol did not override _run()')
|
||||||
raise NotImplementedError('Protocol {} did not override _run()'.format(self))
|
raise NotImplementedError('Protocol {} did not override _run()'
|
||||||
|
.format(self))
|
||||||
|
|
||||||
def _get_results(self):
|
def _get_results(self):
|
||||||
self._log.warning('Protocol did not override _get_results()')
|
self._log.warning('Protocol did not override _get_results()')
|
||||||
results = OrderedDict()
|
results = {}
|
||||||
results['global'] = 'Protocol did not override _get_results()'
|
results['global'] = 'Protocol did not override _get_results()'
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return('{}'.format(self._name))
|
return '{}'.format(self._name)
|
||||||
|
|
||||||
|
|
||||||
class TestError(Exception):
|
class TestError(Exception):
|
||||||
|
|||||||
@ -9,189 +9,75 @@
|
|||||||
# TODO details
|
# TODO details
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
from operator import itemgetter
|
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
import logger
|
import logger
|
||||||
from protocols.protocol import TestError
|
from protocols.protocol import TestError
|
||||||
import database
|
import database
|
||||||
|
|
||||||
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
|
host = os.uname()[1]
|
||||||
|
log = logging.getLogger('Supervisor [{}]'.format(host))
|
||||||
def update_queue():
|
|
||||||
tmp_queue = list()
|
|
||||||
for child in TEST_DIR.iterdir():
|
|
||||||
if child.is_file() and child.suffix == '.yml':
|
|
||||||
tmp_queue.append({'expe_file': ExpePath(child),
|
|
||||||
'priority': get_priority(child)})
|
|
||||||
|
|
||||||
queue = sorted(tmp_queue, key=itemgetter('priority'))
|
|
||||||
return queue
|
|
||||||
|
|
||||||
|
|
||||||
def get_priority(yml_file):
|
def run(expe):
|
||||||
with open(yml_file) as f:
|
database.update_experiment(expe, worker=host, start_date=datetime.now())
|
||||||
expe = OrderedDict(yaml.safe_load(f))
|
|
||||||
return expe['priority']
|
|
||||||
|
|
||||||
|
# Load protocol
|
||||||
|
log.info('Load protocol {}'.format(expe.protocol))
|
||||||
|
protocol_module = importlib.import_module('protocols')
|
||||||
|
importlib.reload(protocol_module)
|
||||||
|
protocol = getattr(protocol_module, expe.protocol)
|
||||||
|
test = protocol(expe.expe)
|
||||||
|
|
||||||
def run(expe_file):
|
# Run test
|
||||||
start_time = time.time()
|
|
||||||
log.info('Run test {}'.format(expe_file))
|
|
||||||
test = expe_file.read()
|
|
||||||
|
|
||||||
### Stage experience
|
|
||||||
expe_file.stage(test)
|
|
||||||
|
|
||||||
### Load protocol
|
|
||||||
try:
|
try:
|
||||||
#protocol = getattr(importlib.import_module(test['protocol']['package']), test['protocol']['name'])
|
test.run()
|
||||||
protocol_module = importlib.import_module(test['protocol']['package'])
|
|
||||||
importlib.reload(protocol_module)
|
|
||||||
protocol = getattr(protocol_module, test['protocol']['name'])
|
|
||||||
experience = protocol(test['expe'])
|
|
||||||
except Exception as e:
|
|
||||||
err = 'Could not load protocol from test {}'.format(expe_file)
|
|
||||||
log.warning(err)
|
|
||||||
expe_file.error(test, 'loading protocol', e)
|
|
||||||
raise TestError(err)
|
|
||||||
log.info('{} test protocol loaded'.format(experience))
|
|
||||||
|
|
||||||
### Get hashes
|
|
||||||
test['hashes'] = experience.get_hashes()
|
|
||||||
test['report'] = create_report(experience, start_time)
|
|
||||||
|
|
||||||
### Stage experience
|
|
||||||
expe_file.stage(test)
|
|
||||||
|
|
||||||
experience.set_results_base_name(expe_file.get_result_path())
|
|
||||||
|
|
||||||
### Run test
|
|
||||||
try:
|
|
||||||
experience.run()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err = 'Experience error'
|
err = 'Experience error'
|
||||||
log.warning(err)
|
log.warning(err)
|
||||||
expe_file.error(test, 'testing', e)
|
report = {'error': {'name': str(err),
|
||||||
|
'trace': traceback.format_exc()}}
|
||||||
|
database.update_experiment(expe, report=report, status='error')
|
||||||
raise TestError(err)
|
raise TestError(err)
|
||||||
|
|
||||||
end_time = time.time()
|
# Write report
|
||||||
|
log.info('Write report')
|
||||||
|
database.update_experiment(expe,
|
||||||
|
end_date=datetime.now(),
|
||||||
|
oa=test.oa,
|
||||||
|
aa=test.aa,
|
||||||
|
k=test.k,
|
||||||
|
report=test.get_results(),
|
||||||
|
status='complete')
|
||||||
|
|
||||||
### Get complete report
|
# End of test
|
||||||
test['report'] = create_report(experience, start_time, end_time)
|
log.info('Expe {} complete'.format(expe.expe_hash))
|
||||||
|
|
||||||
### Get results
|
|
||||||
test['results'] = experience.get_results()
|
|
||||||
|
|
||||||
### Write experience
|
|
||||||
expe_file.result(test)
|
|
||||||
|
|
||||||
### End of test
|
|
||||||
log.info('Test complete')
|
|
||||||
|
|
||||||
def create_report(experience, stime=None, etime=None):
|
|
||||||
expe_report = OrderedDict()
|
|
||||||
host = os.getenv("HOST")
|
|
||||||
expe_report['supervisor'] = host if host is not None else os.uname()[1]
|
|
||||||
|
|
||||||
# Dates
|
|
||||||
for datek, timev in zip(('start_date', 'end_date'), (stime, etime)):
|
|
||||||
expe_report[datek] = datetime.datetime.fromtimestamp(timev).strftime('Le %d/%m/%Y à %H:%M:%S') if timev is not None else None
|
|
||||||
|
|
||||||
# Ressources
|
|
||||||
ressources = OrderedDict()
|
|
||||||
ressources['ram'] = None
|
|
||||||
ressources['proccess_time'] = experience.get_process_time()
|
|
||||||
expe_report['ressources'] = ressources
|
|
||||||
|
|
||||||
return expe_report
|
|
||||||
|
|
||||||
class ExpePath:
|
|
||||||
"""Utility wrapper for expe files.
|
|
||||||
|
|
||||||
Extend pathlib Path with staging, result and errors function to move the
|
|
||||||
test report through the Enrichment center.
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(self, path, hash_length=6):
|
|
||||||
self._actual = Path(path)
|
|
||||||
self._base_name = self._actual.stem
|
|
||||||
self._hash_length = hash_length
|
|
||||||
self._hash = None
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self._get_complete_name()
|
|
||||||
|
|
||||||
def read(self):
|
|
||||||
with open(self._actual) as f:
|
|
||||||
return OrderedDict(yaml.safe_load(f))
|
|
||||||
|
|
||||||
def _get_hash_name(self):
|
|
||||||
return '{}{}'.format(self._base_name,
|
|
||||||
'_' + self._hash[:self._hash_length] if self._hash is not None else '')
|
|
||||||
|
|
||||||
def _get_complete_name(self):
|
|
||||||
return self._get_hash_name() + '.yml'
|
|
||||||
|
|
||||||
def exists(self):
|
|
||||||
return self._actual.exists()
|
|
||||||
|
|
||||||
def stage(self, expe):
|
|
||||||
log.info('Staging {}'.format(self._base_name))
|
|
||||||
self._check_hash(expe)
|
|
||||||
self._write(STAGING_DIR, expe)
|
|
||||||
|
|
||||||
def result(self, expe):
|
|
||||||
log.info('Write results for test {}'.format(self._base_name))
|
|
||||||
self._check_hash(expe)
|
|
||||||
self._write(RESULT_DIR, expe)
|
|
||||||
|
|
||||||
def error(self, expe, when='', e=Exception):
|
|
||||||
error = OrderedDict()
|
|
||||||
error['when'] = when
|
|
||||||
error['what'] = str(e)
|
|
||||||
error['where'] = traceback.format_exc()
|
|
||||||
expe['error'] = error
|
|
||||||
self._write(FAILED_DIR, expe)
|
|
||||||
|
|
||||||
def get_result_path(self):
|
|
||||||
return Path(RESULT_DIR) / self._get_hash_name()
|
|
||||||
|
|
||||||
def get_database():
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print('Hello again')
|
log.info('Connecting to database')
|
||||||
database
|
database.connect('credentials.json')
|
||||||
return
|
|
||||||
while(True):
|
while(True):
|
||||||
try:
|
if not database.pending_experiments():
|
||||||
queue = update_queue()
|
log.info('No pending experiments, waiting...')
|
||||||
except Exception:
|
time.sleep(30)
|
||||||
log.error('Critical exception while updating work queue')
|
|
||||||
log.error(traceback.format_exc())
|
else:
|
||||||
log.warning('Resuming')
|
log.info('Loading next experiment')
|
||||||
continue
|
expe = database.next_experiment()
|
||||||
if not queue:
|
if not expe:
|
||||||
watch_folder()
|
continue
|
||||||
continue
|
log.info('Expe {} loaded'.format(expe.expe_hash))
|
||||||
try:
|
try:
|
||||||
expe_file = queue.pop()['expe_file']
|
run(expe)
|
||||||
while(not expe_file.exists() and queue):
|
except Exception as e:
|
||||||
expe_file = queue.pop()['expe_file']
|
log.error(e)
|
||||||
if expe_file.exists():
|
log.error('Error occured on expe {}'.format(expe.id))
|
||||||
run(expe_file)
|
|
||||||
except TestError:
|
|
||||||
log.warning('Test failed, error logged. Resuming')
|
|
||||||
except Exception:
|
|
||||||
log.error('Critical exception while running test. Resuming')
|
|
||||||
log.error(traceback.format_exc())
|
|
||||||
log.warning('Resuming')
|
|
||||||
continue
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logger.setup_logging()
|
logger.setup_logging()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user