diff --git a/LICENSE b/LICENSE index 2071b23..651aec2 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) +Copyright (c) 2022 Florent Guiotte Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/README.md b/README.md index 0a700c5..452e348 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,11 @@ # geo-fastai -Geographic data and remote sensing image processing helpers for Fastai. \ No newline at end of file +Geographic data and remote sensing image processing helpers for [fastai]. + +[fastai]: https://docs.fast.ai/ + +## Installation + +``` +pip install git+https://git.guiotte.fr/florent/geo-fastai.git +``` \ No newline at end of file diff --git a/geofastai/__init__.py b/geofastai/__init__.py new file mode 100644 index 0000000..b7ce7ce --- /dev/null +++ b/geofastai/__init__.py @@ -0,0 +1,3 @@ +from .metrics import * +from .raster import * +from .utils import * \ No newline at end of file diff --git a/geofastai/metrics.py b/geofastai/metrics.py new file mode 100644 index 0000000..b583301 --- /dev/null +++ b/geofastai/metrics.py @@ -0,0 +1,44 @@ +import fastai.vision.all as fai + + +class DiceMultiWithoutUndef(fai.DiceMulti): + """Averaged Dice metric (Macro F1) for multiclass target in segmentation + + Ignore the undifined class (labeled 0). + + """ + def accumulate(self, learn): + pred, targ = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y) + for c in range(1, learn.pred.shape[self.axis]): + p = fai.torch.where(pred == c, 1, 0) + t = fai.torch.where(targ == c, 1, 0) + c_inter = (p*t).float().sum().item() + c_union = (p+t).float().sum().item() + if c in self.inter: + self.inter[c] += c_inter + self.union[c] += c_union + else: + self.inter[c] = c_inter + self.union[c] = c_union + + +class F1Multi(fai.Metric): + def __init__(self, labels, axis=1): + self.axis = axis + self.labels = labels + self.metric = fai.F1ScoreMulti(average=None, labels=labels) + + def reset(self): + self.pred = [] + self.target = [] + + def accumulate(self, learn): + pred, target = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y) + self.pred += [pred] + self.target += [target] + + @property + def value(self): + pred = fai.torch.cat(self.pred) + target = fai.torch.cat(self.target) + return self.metric(pred, target) \ No newline at end of file diff --git a/geofastai/raster.py b/geofastai/raster.py new file mode 100644 index 0000000..6d46304 --- /dev/null +++ b/geofastai/raster.py @@ -0,0 +1,40 @@ +import fastai.vision.all as fai +import rasterio as rio +import warnings + +warnings.filterwarnings('ignore', 'Dataset has no geotransform', module='rasterio') + + +class RasterBase(fai.TensorImage): + @classmethod + def create(cls, fname, **kwargs): + data = rio.open(fname).read() + return cls(data) + + +class RasterVNIR(RasterBase): + def show(self, ctx=None, bands=[3,2,1], **kwargs): + im = self[bands] + im = im.to(fai.torch.float) + im = im / im.quantile(.9) + im = im.clip(0, 1) + return fai.show_image(im, ctx=ctx, vmax=300, **kwargs) + + +class RasterToFloatTensor(fai.DisplayedTransform): + order = 10 + def __init__(self, div=None): + if div is None: + div = np.iinfo(np.uint16).max + + self.div = div + + def encodes(self, o:fai.TensorImage): + return o.float().div_(self.div) + + def decodes(self, o:fai.TensorImage): + return (o.clamp(0., 1.) * self.div).long() + + +def RasterBlock(cls=RasterBase): + return fai.TransformBlock(type_tfms=cls.create)#, batch_tfms=RasterToFloatTensor) \ No newline at end of file diff --git a/geofastai/utils.py b/geofastai/utils.py new file mode 100644 index 0000000..59881b9 --- /dev/null +++ b/geofastai/utils.py @@ -0,0 +1,26 @@ +import pandas as pd +import fastai.vision.all as fai + + +class DataframeLogger(fai.Callback): + order = 60 + + def before_fit(self): + if hasattr(self, "gather_preds"): return + self.df = pd.DataFrame(columns=self.recorder.metric_names) + self.old_logger, self.learn.logger = self.logger, self._record_line + + def _record_line(self, log): + if self._is_training(log): + self.df.loc[log[0]] = log + self.old_logger(log) + + def _is_training(self, log): + return self.df.columns.size == len(log) + + def read_log(self): + return self.df + + def after_fit(self): + if hasattr(self, "gather_preds"): return + self.learn.logger = self.old_logger \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7fd26b9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1361174 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,10 @@ +[metadata] +name = geofastai +version = 0.0.2 +long_description = file: README.md + +[options] +packages = geofastai +install_requires = + fastai + rasterio \ No newline at end of file