Merge branch 'develop'
This commit is contained in:
commit
0f75ebe1ea
8
.gitignore
vendored
8
.gitignore
vendored
@ -1,4 +1,8 @@
|
||||
Enrichment/
|
||||
__pycache__/
|
||||
Logs/
|
||||
[Dd]ata/
|
||||
[Ll]ogs/
|
||||
[Dd]ata
|
||||
[Cc]ache
|
||||
*.egg-info/
|
||||
credentials*
|
||||
|
||||
|
||||
3
config.json
Normal file
3
config.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"process_count": 2
|
||||
}
|
||||
@ -15,7 +15,7 @@ handlers:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
level: INFO
|
||||
formatter: simple
|
||||
filename: Logs/info.log
|
||||
filename: logs/info.log
|
||||
maxBytes: 10485760 # 10MB
|
||||
backupCount: 20
|
||||
encoding: utf8
|
||||
@ -24,7 +24,7 @@ handlers:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
level: ERROR
|
||||
formatter: simple
|
||||
filename: Logs/errors.log
|
||||
filename: logs/errors.log
|
||||
maxBytes: 10485760 # 10MB
|
||||
backupCount: 20
|
||||
encoding: utf8
|
||||
|
||||
10
minigrida/__init__.py
Normal file
10
minigrida/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
# file __init__.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 22 mai 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
|
||||
12
minigrida/database/__init__.py
Normal file
12
minigrida/database/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
# file __init__.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 22 mai 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
|
||||
from .design import *
|
||||
from .helpers import *
|
||||
46
minigrida/database/design.py
Normal file
46
minigrida/database/design.py
Normal file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
# file design.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 22 mai 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
|
||||
from datetime import date
|
||||
from pony import orm
|
||||
|
||||
db = orm.Database()
|
||||
|
||||
class Experiment(db.Entity):
|
||||
sessions = orm.Set('Session')
|
||||
urgency = orm.Required(int, default=1)
|
||||
status = orm.Required(str, default='pending')
|
||||
|
||||
protocol = orm.Required(str)
|
||||
expe = orm.Required(orm.Json)
|
||||
expe_hash = orm.Required(str, 32, unique=True)
|
||||
|
||||
start_date = orm.Optional(date)
|
||||
end_date = orm.Optional(date)
|
||||
worker = orm.Optional(str)
|
||||
ressources = orm.Optional(orm.Json)
|
||||
|
||||
report = orm.Optional(orm.Json)
|
||||
oa = orm.Optional(float)
|
||||
aa = orm.Optional(float)
|
||||
k = orm.Optional(float)
|
||||
|
||||
class Session(db.Entity):
|
||||
project = orm.Required('Project')
|
||||
date = orm.Required(date)
|
||||
name = orm.PrimaryKey(str)
|
||||
desc = orm.Optional(str)
|
||||
urgency = orm.Required(int, default=1)
|
||||
experiments = orm.Set('Experiment')
|
||||
|
||||
class Project(db.Entity):
|
||||
name = orm.PrimaryKey(str)
|
||||
sessions = orm.Set('Session')
|
||||
|
||||
100
minigrida/database/helpers.py
Normal file
100
minigrida/database/helpers.py
Normal file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python
|
||||
# file helpers.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 22 mai 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
|
||||
from pony import orm
|
||||
from datetime import datetime, date
|
||||
import json
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from .design import Session, Experiment, Project, db
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
def compute_expe_hash(experiment):
|
||||
return hashlib.md5(json.dumps(experiment, sort_keys=True).encode('utf-8')).hexdigest()
|
||||
|
||||
@orm.db_session
|
||||
def create_experiment(session_name, protocol, expe, urgency=1):
|
||||
session = Session.select(lambda x: x.name == session_name)
|
||||
if not session.exists():
|
||||
raise ValueError('Session "{}" does not exist'.format(session_name))
|
||||
|
||||
expe_hash = compute_expe_hash(expe)
|
||||
q = Experiment.select(lambda x: x.expe_hash == expe_hash)
|
||||
|
||||
if q.exists():
|
||||
e = q.first()
|
||||
e.sessions.add(session)
|
||||
else:
|
||||
Experiment(sessions=session, protocol=protocol, expe=expe, expe_hash=expe_hash)
|
||||
|
||||
@orm.db_session
|
||||
def create_project(name):
|
||||
if not Project.select(lambda x: x.name == name).exists():
|
||||
Project(name=name)
|
||||
else:
|
||||
print('Project "{}" already exists.'.format(name))
|
||||
|
||||
@orm.db_session
|
||||
def create_session(name, desc, project_name, urgency=1):
|
||||
project = Project[project_name]
|
||||
|
||||
if not Session.select(lambda x: x.name == name).exists():
|
||||
Session(project=project, date=datetime.now(),
|
||||
name=name, desc=desc, urgency=1)
|
||||
else:
|
||||
print('Session "{}" already exists.'.format(name))
|
||||
|
||||
@orm.db_session
|
||||
def pending_experiments():
|
||||
return Experiment.select(lambda x: x.status == 'pending').exists()
|
||||
|
||||
|
||||
@orm.db_session(serializable=True, optimistic=False)
|
||||
def next_experiment():
|
||||
# TODO: take session urgency into account
|
||||
expe = orm.select(e for e in Experiment
|
||||
if e.status == 'pending'
|
||||
and e.urgency == orm.max(e.urgency for e in Experiment
|
||||
if e.status == 'pending')
|
||||
).random(1)
|
||||
if expe:
|
||||
expe = expe[0]
|
||||
expe.status = 'staging'
|
||||
return expe
|
||||
|
||||
|
||||
def update_experiment(expe, **params):
|
||||
try:
|
||||
_update_experiment(expe, **params)
|
||||
except orm.DatabaseError as e:
|
||||
log.error(e)
|
||||
log.info('Retry update')
|
||||
_update_experiment(expe, **params)
|
||||
|
||||
|
||||
@orm.db_session
|
||||
def _update_experiment(expe, **params):
|
||||
e = Experiment.get_for_update(id=expe.id)
|
||||
for k, v in params.items():
|
||||
setattr(e, k, v)
|
||||
|
||||
|
||||
def connect_testing():
|
||||
db.bind('sqlite', ':memory:')
|
||||
db.generate_mapping(create_tables=True)
|
||||
|
||||
|
||||
def connect(credentials_file):
|
||||
with open(credentials_file) as f:
|
||||
credentials = json.load(f)
|
||||
db.bind(**credentials)
|
||||
db.generate_mapping(create_tables=True)
|
||||
@ -1,4 +1,4 @@
|
||||
#!/usr/bin/python
|
||||
j#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# \file dfc_base.py
|
||||
# \brief TODO
|
||||
37
minigrida/descriptors/pixel.py
Normal file
37
minigrida/descriptors/pixel.py
Normal file
@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python
|
||||
# file pixel.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 26 mai 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def run(gt, rasters, remove=None):
|
||||
X = []
|
||||
y = []
|
||||
groups = []
|
||||
|
||||
for i, (gti, rastersi) in enumerate(zip(gt, rasters)):
|
||||
# Create vectors
|
||||
X_raw = np.moveaxis(np.array(list(rastersi.values())), 0, -1)
|
||||
y_raw = gti
|
||||
|
||||
# Remove unwanted label X, y
|
||||
lbl = np.ones_like(y_raw, dtype=np.bool)
|
||||
for l in remove if remove else []:
|
||||
lbl &= y_raw != l
|
||||
|
||||
X += [X_raw[lbl]]
|
||||
y += [y_raw[lbl]]
|
||||
groups += [np.repeat(i, lbl.sum())]
|
||||
|
||||
X = np.concatenate(X)
|
||||
y = np.concatenate(y)
|
||||
groups = np.concatenate(groups)
|
||||
|
||||
return X, y, groups
|
||||
9
minigrida/loader/__init__.py
Normal file
9
minigrida/loader/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env python
|
||||
# file __init__.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 26 mai 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
32
minigrida/loader/tiles.py
Normal file
32
minigrida/loader/tiles.py
Normal file
@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
# file tiles.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 26 mai 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import rasterio as rio
|
||||
|
||||
|
||||
def run(tiles_count, gt_path, gt_name, rasters_path, rasters_name):
|
||||
gt = []
|
||||
rasters = []
|
||||
|
||||
for i in range(tiles_count):
|
||||
gt += [load_tif(raster_path(gt_path, gt_name, i))]
|
||||
rasters += [{Path(n).stem: load_tif(raster_path(rasters_path, n, i))
|
||||
for n in rasters_name}]
|
||||
|
||||
return gt, rasters
|
||||
|
||||
|
||||
def load_tif(path):
|
||||
return rio.open(str(path)).read()[0]
|
||||
|
||||
|
||||
def raster_path(path, name, i):
|
||||
return Path(path) / '{}_{}_{}.tif'.format(Path(name).stem, 0, i)
|
||||
40
minigrida/logging.yaml
Normal file
40
minigrida/logging.yaml
Normal file
@ -0,0 +1,40 @@
|
||||
version: 1
|
||||
disable_existing_loggers: False
|
||||
formatters:
|
||||
simple:
|
||||
format: "%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||||
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
level: INFO
|
||||
formatter: simple
|
||||
stream: ext://sys.stdout
|
||||
|
||||
info_file_handler:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
level: INFO
|
||||
formatter: simple
|
||||
filename: Logs/info.log
|
||||
maxBytes: 10485760 # 10MB
|
||||
backupCount: 20
|
||||
encoding: utf8
|
||||
|
||||
error_file_handler:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
level: ERROR
|
||||
formatter: simple
|
||||
filename: Logs/errors.log
|
||||
maxBytes: 10485760 # 10MB
|
||||
backupCount: 20
|
||||
encoding: utf8
|
||||
|
||||
loggers:
|
||||
my_module:
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
propagate: no
|
||||
|
||||
root:
|
||||
level: DEBUG
|
||||
handlers: [console, info_file_handler, error_file_handler]
|
||||
@ -7,3 +7,6 @@
|
||||
# \date 09 sept. 2018
|
||||
#
|
||||
# TODO details
|
||||
|
||||
#from .jurse import Jurse
|
||||
from .jurse2 import Jurse2
|
||||
113
minigrida/protocols/jurse2.py
Normal file
113
minigrida/protocols/jurse2.py
Normal file
@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python
|
||||
# file jurse2.py
|
||||
# author Florent Guiotte <florent.guiotte@irisa.fr>
|
||||
# version 0.0
|
||||
# date 26 mai 2020
|
||||
"""Abstract
|
||||
|
||||
doc.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import numpy as np
|
||||
from sklearn import metrics
|
||||
from .protocol import Protocol, TestError
|
||||
|
||||
|
||||
class Jurse2(Protocol):
|
||||
"""Second JURSE test protocol for LiDAR classification with 2D maps.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, expe):
|
||||
super().__init__(expe, self.__class__.__name__)
|
||||
|
||||
def _run(self):
|
||||
self._log.info('Load data')
|
||||
try:
|
||||
data = self._load_data()
|
||||
except Exception:
|
||||
raise TestError('Error occured during data loading')
|
||||
|
||||
self._log.info('Compute descriptors')
|
||||
try:
|
||||
descriptors = self._compute_descriptors(data)
|
||||
except Exception:
|
||||
raise TestError('Error occured during description')
|
||||
|
||||
self._log.info('Classify descriptors')
|
||||
try:
|
||||
classification = self._compute_classification(descriptors)
|
||||
except Exception:
|
||||
raise TestError('Error occured during classification')
|
||||
|
||||
self._log.info('Run metrics')
|
||||
metrics = self._run_metrics(classification, descriptors)
|
||||
|
||||
results = {}
|
||||
results['metrics'] = metrics
|
||||
self._results = results
|
||||
|
||||
def _load_data(self):
|
||||
data_loader = self._expe['data_loader']
|
||||
|
||||
loader = importlib.import_module(data_loader['name'])
|
||||
data = loader.run(**data_loader['parameters'])
|
||||
|
||||
return data
|
||||
|
||||
def _compute_descriptors(self, data):
|
||||
script = self._expe['descriptors_script']
|
||||
|
||||
desc = importlib.import_module(script['name'])
|
||||
att = desc.run(*data, **script['parameters'])
|
||||
|
||||
return att
|
||||
|
||||
def _compute_classification(self, descriptors):
|
||||
X, y, groups = descriptors
|
||||
|
||||
# CrossVal and ML
|
||||
cv = self._expe['cross_validation']
|
||||
cl = self._expe['classifier']
|
||||
|
||||
cross_val = getattr(importlib.import_module(cv['package']), cv['name'])
|
||||
classifier = getattr(importlib.import_module(cl['package']), cl['name'])
|
||||
|
||||
y_pred = np.zeros_like(y)
|
||||
|
||||
cvi = cross_val(**cv['parameters'])
|
||||
for train_index, test_index in cvi.split(X, y, groups):
|
||||
cli = classifier(**cl['parameters'])
|
||||
|
||||
self._log.info(' - fit')
|
||||
cli.fit(X[train_index], y[train_index])
|
||||
|
||||
self._log.info(' - predict')
|
||||
y_pred[test_index] = cli.predict(X[test_index])
|
||||
|
||||
return y_pred
|
||||
|
||||
def _get_results(self):
|
||||
return self._results
|
||||
|
||||
def _run_metrics(self, classification, descriptors):
|
||||
X, y_true, groups = descriptors
|
||||
y_pred = classification
|
||||
|
||||
self._log.info(' - Scores')
|
||||
self.oa = metrics.accuracy_score(y_true, y_pred)
|
||||
self.aa = metrics.balanced_accuracy_score(y_true, y_pred)
|
||||
self.k = metrics.cohen_kappa_score(y_true, y_pred)
|
||||
|
||||
self._log.info(' - Additional results')
|
||||
p, r, f, s = metrics.precision_recall_fscore_support(y_true, y_pred)
|
||||
cm = metrics.confusion_matrix(y_true, y_pred)
|
||||
results = {'dimensions': X.shape[-1],
|
||||
'precision': p.tolist(),
|
||||
'recall': r.tolist(),
|
||||
'f1score': f.tolist(),
|
||||
'support': s.tolist(),
|
||||
'confusion': cm.tolist()}
|
||||
|
||||
return results
|
||||
@ -10,7 +10,6 @@
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class Protocol:
|
||||
@ -18,49 +17,35 @@ class Protocol:
|
||||
self._log = logging.getLogger(name)
|
||||
self._expe = expe
|
||||
self._name = name
|
||||
self._times = OrderedDict()
|
||||
self._results_base_name = None
|
||||
self._log.debug('expe loaded: {}'.format(self._expe))
|
||||
|
||||
def get_hashes(self):
|
||||
self._log.info('Computing hashes')
|
||||
return(self._get_hashes())
|
||||
|
||||
def set_results_base_name(self, base_name):
|
||||
self._results_base_name = base_name
|
||||
self.k = None
|
||||
self.oa = None
|
||||
self.aa = None
|
||||
|
||||
def run(self):
|
||||
self._pt = time.process_time()
|
||||
spt = time.process_time()
|
||||
self._run()
|
||||
self._pt = time.process_time() - spt
|
||||
|
||||
def get_results(self):
|
||||
return self._get_results()
|
||||
|
||||
def get_process_time(self):
|
||||
return self._times
|
||||
|
||||
def _time(self, process):
|
||||
self._times[process] = time.process_time() - self._pt
|
||||
self._pt = time.process_time()
|
||||
|
||||
def _get_hashes(self):
|
||||
self._log.warning('Protocol did not override _get_hashes()')
|
||||
hashes = OrderedDict()
|
||||
hashes['global'] = 'Protocol did not override _get_hashes()'
|
||||
return hashes
|
||||
return self._pt
|
||||
|
||||
def _run(self):
|
||||
self._log.error('Protocol did not override _run()')
|
||||
raise NotImplementedError('Protocol {} did not override _run()'.format(self))
|
||||
raise NotImplementedError('Protocol {} did not override _run()'
|
||||
.format(self))
|
||||
|
||||
def _get_results(self):
|
||||
self._log.warning('Protocol did not override _get_results()')
|
||||
results = OrderedDict()
|
||||
results = {}
|
||||
results['global'] = 'Protocol did not override _get_results()'
|
||||
return results
|
||||
|
||||
def __str__(self):
|
||||
return('{}'.format(self._name))
|
||||
return '{}'.format(self._name)
|
||||
|
||||
|
||||
class TestError(Exception):
|
||||
100
minigrida/supervisor.py
Normal file
100
minigrida/supervisor.py
Normal file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# \file supervisor.py
|
||||
# \brief TODO
|
||||
# \author Florent Guiotte <florent.guiotte@gmail.com>
|
||||
# \version 2.1
|
||||
# \date 07 sept. 2018
|
||||
#
|
||||
# TODO details
|
||||
|
||||
import importlib
|
||||
import time
|
||||
import os
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
import logging
|
||||
import logger
|
||||
from protocols.protocol import TestError
|
||||
import database
|
||||
from multiprocessing import Process
|
||||
import json
|
||||
|
||||
host = os.uname()[1]
|
||||
log = logging.getLogger('Supervisor [{}]'.format(host))
|
||||
|
||||
|
||||
def run(expe, hostpid=host):
|
||||
database.update_experiment(expe, worker=hostpid, start_date=datetime.now())
|
||||
|
||||
# Load protocol
|
||||
log.info('Load protocol {}'.format(expe.protocol))
|
||||
protocol_module = importlib.import_module('protocols')
|
||||
importlib.reload(protocol_module)
|
||||
protocol = getattr(protocol_module, expe.protocol)
|
||||
test = protocol(expe.expe)
|
||||
|
||||
# Run test
|
||||
try:
|
||||
test.run()
|
||||
except Exception as e:
|
||||
err = 'Experience error'
|
||||
log.warning(err)
|
||||
report = {'error': {'name': str(err),
|
||||
'trace': traceback.format_exc()}}
|
||||
database.update_experiment(expe, report=report, status='error')
|
||||
raise TestError(err)
|
||||
|
||||
# Write report
|
||||
log.info('Write report')
|
||||
database.update_experiment(expe,
|
||||
end_date=datetime.now(),
|
||||
oa=test.oa,
|
||||
aa=test.aa,
|
||||
k=test.k,
|
||||
report=test.get_results(),
|
||||
status='complete')
|
||||
|
||||
# End of test
|
||||
log.info('Expe {} complete'.format(expe.expe_hash))
|
||||
|
||||
|
||||
def main(pid=None):
|
||||
hostpid = host + '_' + str(pid) if pid is not None else host
|
||||
log.name = 'Supervisor [{}]'.format(hostpid)
|
||||
|
||||
log.info('Connecting to database')
|
||||
database.connect('credentials.json')
|
||||
|
||||
while(True):
|
||||
if not database.pending_experiments():
|
||||
log.info('No pending experiments, waiting...')
|
||||
time.sleep(30)
|
||||
|
||||
else:
|
||||
log.info('Loading next experiment')
|
||||
expe = database.next_experiment()
|
||||
if not expe:
|
||||
continue
|
||||
log.info('Expe {} loaded'.format(expe.expe_hash))
|
||||
try:
|
||||
run(expe, hostpid)
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
log.error('Error occured on expe {}'.format(expe.id))
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.setup_logging()
|
||||
log.info('Starting supervisor')
|
||||
try:
|
||||
with open('config.json') as f:
|
||||
config = json.load(f)
|
||||
process_count = config['process_count']
|
||||
except Exception as e:
|
||||
log.warning(e)
|
||||
process_count = 1
|
||||
|
||||
for i in range(process_count):
|
||||
Process(target=main, args=(i,)).start()
|
||||
4
setup.py
4
setup.py
@ -11,10 +11,10 @@
|
||||
from distutils.core import setup
|
||||
|
||||
setup(name='minigrida',
|
||||
version='1.11',
|
||||
version='2.0',
|
||||
description='Simple and decentralized computing grid',
|
||||
author='Florent Guiotte',
|
||||
author_email='florent.guiotte@uhb.fr',
|
||||
url='https://git.guiotte.fr/Florent/minigrida',
|
||||
packages=['cvgenerators', 'descriptors', 'protocols'],
|
||||
packages=['minigrida'],#'cvgenerators', 'descriptors', 'protocols', 'database'],
|
||||
)
|
||||
|
||||
243
supervisor.py
243
supervisor.py
@ -1,243 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# \file supervisor.py
|
||||
# \brief TODO
|
||||
# \author Florent Guiotte <florent.guiotte@gmail.com>
|
||||
# \version 0.1
|
||||
# \date 07 sept. 2018
|
||||
#
|
||||
# TODO details
|
||||
|
||||
import yaml
|
||||
import importlib
|
||||
import hashlib
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
import os
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from operator import itemgetter
|
||||
import traceback
|
||||
import logging
|
||||
import logger
|
||||
from protocols.protocol import TestError
|
||||
|
||||
ENRICHMENT_DIR = Path('./Enrichment/')
|
||||
TEST_DIR = ENRICHMENT_DIR / 'Tests'
|
||||
STAGING_DIR = ENRICHMENT_DIR / 'Staging'
|
||||
RESULT_DIR = ENRICHMENT_DIR / 'Results'
|
||||
FAILED_DIR = ENRICHMENT_DIR / 'Failed'
|
||||
|
||||
log = logging.getLogger('Supervisor [{}]'.format(os.uname()[1]))
|
||||
|
||||
def setup_yaml():
|
||||
""" Keep yaml ordered, newline string
|
||||
from https://stackoverflow.com/a/8661021
|
||||
"""
|
||||
represent_dict_order = lambda self, data: self.represent_mapping('tag:yaml.org,2002:map', data.items())
|
||||
yaml.add_representer(OrderedDict, represent_dict_order)
|
||||
|
||||
""" https://stackoverflow.com/a/24291536 """
|
||||
yaml.Dumper.org_represent_str = yaml.Dumper.represent_str
|
||||
yaml.add_representer(str, repr_str, Dumper=yaml.Dumper)
|
||||
|
||||
|
||||
def repr_str(dumper, data):
|
||||
if '\n' in data:
|
||||
return dumper.represent_scalar(u'tag:yaml.org,2002:str',
|
||||
data, style='|')
|
||||
return dumper.org_represent_str(data)
|
||||
|
||||
|
||||
def update_queue():
|
||||
tmp_queue = list()
|
||||
for child in TEST_DIR.iterdir():
|
||||
if child.is_file() and child.suffix == '.yml':
|
||||
tmp_queue.append({'expe_file': ExpePath(child),
|
||||
'priority': get_priority(child)})
|
||||
|
||||
queue = sorted(tmp_queue, key=itemgetter('priority'))
|
||||
return queue
|
||||
|
||||
|
||||
def get_priority(yml_file):
|
||||
with open(yml_file) as f:
|
||||
expe = OrderedDict(yaml.safe_load(f))
|
||||
return expe['priority']
|
||||
|
||||
|
||||
def run(expe_file):
|
||||
start_time = time.time()
|
||||
log.info('Run test {}'.format(expe_file))
|
||||
test = expe_file.read()
|
||||
|
||||
### Stage experience
|
||||
expe_file.stage(test)
|
||||
|
||||
### Load protocol
|
||||
try:
|
||||
#protocol = getattr(importlib.import_module(test['protocol']['package']), test['protocol']['name'])
|
||||
protocol_module = importlib.import_module(test['protocol']['package'])
|
||||
importlib.reload(protocol_module)
|
||||
protocol = getattr(protocol_module, test['protocol']['name'])
|
||||
experience = protocol(test['expe'])
|
||||
except Exception as e:
|
||||
err = 'Could not load protocol from test {}'.format(expe_file)
|
||||
log.warning(err)
|
||||
expe_file.error(test, 'loading protocol', e)
|
||||
raise TestError(err)
|
||||
log.info('{} test protocol loaded'.format(experience))
|
||||
|
||||
### Get hashes
|
||||
test['hashes'] = experience.get_hashes()
|
||||
test['report'] = create_report(experience, start_time)
|
||||
|
||||
### Stage experience
|
||||
expe_file.stage(test)
|
||||
|
||||
experience.set_results_base_name(expe_file.get_result_path())
|
||||
|
||||
### Run test
|
||||
try:
|
||||
experience.run()
|
||||
except Exception as e:
|
||||
err = 'Experience error'
|
||||
log.warning(err)
|
||||
expe_file.error(test, 'testing', e)
|
||||
raise TestError(err)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
### Get complete report
|
||||
test['report'] = create_report(experience, start_time, end_time)
|
||||
|
||||
### Get results
|
||||
test['results'] = experience.get_results()
|
||||
|
||||
### Write experience
|
||||
expe_file.result(test)
|
||||
|
||||
### End of test
|
||||
log.info('Test complete')
|
||||
|
||||
def create_report(experience, stime=None, etime=None):
|
||||
expe_report = OrderedDict()
|
||||
host = os.getenv("HOST")
|
||||
expe_report['supervisor'] = host if host is not None else os.uname()[1]
|
||||
|
||||
# Dates
|
||||
for datek, timev in zip(('start_date', 'end_date'), (stime, etime)):
|
||||
expe_report[datek] = datetime.datetime.fromtimestamp(timev).strftime('Le %d/%m/%Y à %H:%M:%S') if timev is not None else None
|
||||
|
||||
# Ressources
|
||||
ressources = OrderedDict()
|
||||
ressources['ram'] = None
|
||||
ressources['proccess_time'] = experience.get_process_time()
|
||||
expe_report['ressources'] = ressources
|
||||
|
||||
return expe_report
|
||||
|
||||
class ExpePath:
|
||||
"""Utility wrapper for expe files.
|
||||
|
||||
Extend pathlib Path with staging, result and errors function to move the
|
||||
test report through the Enrichment center.
|
||||
|
||||
"""
|
||||
def __init__(self, path, hash_length=6):
|
||||
self._actual = Path(path)
|
||||
self._base_name = self._actual.stem
|
||||
self._hash_length = hash_length
|
||||
self._hash = None
|
||||
|
||||
def __str__(self):
|
||||
return self._get_complete_name()
|
||||
|
||||
def read(self):
|
||||
with open(self._actual) as f:
|
||||
return OrderedDict(yaml.safe_load(f))
|
||||
|
||||
def _get_hash_name(self):
|
||||
return '{}{}'.format(self._base_name,
|
||||
'_' + self._hash[:self._hash_length] if self._hash is not None else '')
|
||||
|
||||
def _get_complete_name(self):
|
||||
return self._get_hash_name() + '.yml'
|
||||
|
||||
def exists(self):
|
||||
return self._actual.exists()
|
||||
|
||||
def stage(self, expe):
|
||||
log.info('Staging {}'.format(self._base_name))
|
||||
self._check_hash(expe)
|
||||
self._write(STAGING_DIR, expe)
|
||||
|
||||
def result(self, expe):
|
||||
log.info('Write results for test {}'.format(self._base_name))
|
||||
self._check_hash(expe)
|
||||
self._write(RESULT_DIR, expe)
|
||||
|
||||
def error(self, expe, when='', e=Exception):
|
||||
error = OrderedDict()
|
||||
error['when'] = when
|
||||
error['what'] = str(e)
|
||||
error['where'] = traceback.format_exc()
|
||||
expe['error'] = error
|
||||
self._write(FAILED_DIR, expe)
|
||||
|
||||
def get_result_path(self):
|
||||
return Path(RESULT_DIR) / self._get_hash_name()
|
||||
|
||||
def _check_hash(self, expe):
|
||||
if self._hash is None:
|
||||
if 'hashes' in expe:
|
||||
self._hash = expe['hashes']['global']
|
||||
|
||||
def _write(self, path, expe):
|
||||
new_path = Path(path) / self._get_complete_name()
|
||||
with open(new_path, 'w') as of:
|
||||
yaml.dump(expe, of,
|
||||
default_flow_style=False,
|
||||
encoding=None,
|
||||
allow_unicode=True)
|
||||
self._actual.unlink()
|
||||
self._actual = new_path
|
||||
|
||||
|
||||
def watch_folder():
|
||||
log.info('Waiting for test')
|
||||
while not list(TEST_DIR.glob('*.yml')):
|
||||
time.sleep(3)
|
||||
|
||||
def main():
|
||||
while(True):
|
||||
try:
|
||||
queue = update_queue()
|
||||
except Exception:
|
||||
log.error('Critical exception while updating work queue')
|
||||
log.error(traceback.format_exc())
|
||||
log.warning('Resuming')
|
||||
continue
|
||||
if not queue:
|
||||
watch_folder()
|
||||
continue
|
||||
try:
|
||||
expe_file = queue.pop()['expe_file']
|
||||
while(not expe_file.exists() and queue):
|
||||
expe_file = queue.pop()['expe_file']
|
||||
if expe_file.exists():
|
||||
run(expe_file)
|
||||
except TestError:
|
||||
log.warning('Test failed, error logged. Resuming')
|
||||
except Exception:
|
||||
log.error('Critical exception while running test. Resuming')
|
||||
log.error(traceback.format_exc())
|
||||
log.warning('Resuming')
|
||||
continue
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.setup_logging()
|
||||
log.info('Starting supervisor')
|
||||
|
||||
setup_yaml()
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user