diff --git a/minigrida/descriptors/pixel.py b/minigrida/descriptors/pixel.py index ff37bef..ba96ac7 100644 --- a/minigrida/descriptors/pixel.py +++ b/minigrida/descriptors/pixel.py @@ -33,5 +33,6 @@ def run(gt, rasters, coords, remove=None): X = np.concatenate(X) y = np.concatenate(y) groups = np.concatenate(groups) + Xn = rasters[0].keys() - return X, y, groups + return X, y, groups, Xn diff --git a/minigrida/protocols/jurse2.py b/minigrida/protocols/jurse2.py index 8238b16..8fa908a 100644 --- a/minigrida/protocols/jurse2.py +++ b/minigrida/protocols/jurse2.py @@ -21,6 +21,7 @@ class Jurse2(Protocol): def __init__(self, expe): super().__init__(expe, self.__class__.__name__) + self._results = {} def _run(self): self._log.info('Load data') @@ -44,9 +45,7 @@ class Jurse2(Protocol): self._log.info('Run metrics') metrics = self._run_metrics(classification, descriptors) - results = {} - results['metrics'] = metrics - self._results = results + self._results['metrics'] = metrics def _load_data(self): data_loader = self._expe['data_loader'] @@ -65,7 +64,7 @@ class Jurse2(Protocol): return att def _compute_classification(self, descriptors): - X, y, groups = descriptors + X, y, groups, Xn = descriptors # CrossVal and ML cv = self._expe['cross_validation'] @@ -76,6 +75,7 @@ class Jurse2(Protocol): y_pred = np.zeros_like(y) + cl_feature_importances = [] cvi = cross_val(**cv['parameters']) for train_index, test_index in cvi.split(X, y, groups): cli = classifier(**cl['parameters']) @@ -86,6 +86,14 @@ class Jurse2(Protocol): self._log.info(' - predict') y_pred[test_index] = cli.predict(X[test_index]) + cl_feature_importances += [classifier.feature_importances_.copy()] + + cl_feature_importances = np.array(cl_feature_importances) + self._results['features'] = { + 'name': Xn, + 'importance': cl_feature_importances.tolist() + } + return y_pred def _get_results(self):