diff --git a/minigrida/protocols/__init__.py b/minigrida/protocols/__init__.py index d9b5df2..b6f86b3 100644 --- a/minigrida/protocols/__init__.py +++ b/minigrida/protocols/__init__.py @@ -10,3 +10,4 @@ #from .jurse import Jurse from .jurse2 import Jurse2 +from .jurse3 import Jurse3 diff --git a/minigrida/protocols/jurse3.py b/minigrida/protocols/jurse3.py new file mode 100644 index 0000000..ca68511 --- /dev/null +++ b/minigrida/protocols/jurse3.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# file jurse3.py +# author Florent Guiotte +# version 0.0 +# date 17 sept. 2021 + +import importlib + +from joblib import Memory +from pathlib import Path + +from . import Jurse2 + +CACHE = './cache' + + +class Jurse3(Jurse2): + """Jurse2 protocl with cache + + Same as Jurse2 but enable caching results to speed up + hyperparameters tunning. + """ + def __init__(self, expe): + super().__init__(expe) + + self.memory = Memory(CACHE if Path(CACHE).exists else None, verbose=0) + + def _compute_descriptors(self, data): + script = self._expe['descriptors_script'] + + desc = importlib.import_module(script['name']) + run = self.memory.cache(desc.run) + att = run(*data, **script['parameters']) + + return att