Add metrics, raster and utils

This commit is contained in:
Florent Guiotte 2022-04-19 11:39:47 +00:00
parent e414724bc8
commit 54d0a10132
8 changed files with 136 additions and 2 deletions

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) <year> <copyright holders>
Copyright (c) 2022 Florent Guiotte
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

View File

@ -1,3 +1,11 @@
# geo-fastai
Geographic data and remote sensing image processing helpers for Fastai.
Geographic data and remote sensing image processing helpers for [fastai].
[fastai]: https://docs.fast.ai/
## Installation
```
pip install git+https://git.guiotte.fr/florent/geo-fastai.git
```

3
geofastai/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .metrics import *
from .raster import *
from .utils import *

44
geofastai/metrics.py Normal file
View File

@ -0,0 +1,44 @@
import fastai.vision.all as fai
class DiceMultiWithoutUndef(fai.DiceMulti):
"""Averaged Dice metric (Macro F1) for multiclass target in segmentation
Ignore the undifined class (labeled 0).
"""
def accumulate(self, learn):
pred, targ = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
for c in range(1, learn.pred.shape[self.axis]):
p = fai.torch.where(pred == c, 1, 0)
t = fai.torch.where(targ == c, 1, 0)
c_inter = (p*t).float().sum().item()
c_union = (p+t).float().sum().item()
if c in self.inter:
self.inter[c] += c_inter
self.union[c] += c_union
else:
self.inter[c] = c_inter
self.union[c] = c_union
class F1Multi(fai.Metric):
def __init__(self, labels, axis=1):
self.axis = axis
self.labels = labels
self.metric = fai.F1ScoreMulti(average=None, labels=labels)
def reset(self):
self.pred = []
self.target = []
def accumulate(self, learn):
pred, target = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
self.pred += [pred]
self.target += [target]
@property
def value(self):
pred = fai.torch.cat(self.pred)
target = fai.torch.cat(self.target)
return self.metric(pred, target)

40
geofastai/raster.py Normal file
View File

@ -0,0 +1,40 @@
import fastai.vision.all as fai
import rasterio as rio
import warnings
warnings.filterwarnings('ignore', 'Dataset has no geotransform', module='rasterio')
class RasterBase(fai.TensorImage):
@classmethod
def create(cls, fname, **kwargs):
data = rio.open(fname).read()
return cls(data)
class RasterVNIR(RasterBase):
def show(self, ctx=None, bands=[3,2,1], **kwargs):
im = self[bands]
im = im.to(fai.torch.float)
im = im / im.quantile(.9)
im = im.clip(0, 1)
return fai.show_image(im, ctx=ctx, vmax=300, **kwargs)
class RasterToFloatTensor(fai.DisplayedTransform):
order = 10
def __init__(self, div=None):
if div is None:
div = np.iinfo(np.uint16).max
self.div = div
def encodes(self, o:fai.TensorImage):
return o.float().div_(self.div)
def decodes(self, o:fai.TensorImage):
return (o.clamp(0., 1.) * self.div).long()
def RasterBlock(cls=RasterBase):
return fai.TransformBlock(type_tfms=cls.create)#, batch_tfms=RasterToFloatTensor)

26
geofastai/utils.py Normal file
View File

@ -0,0 +1,26 @@
import pandas as pd
import fastai.vision.all as fai
class DataframeLogger(fai.Callback):
order = 60
def before_fit(self):
if hasattr(self, "gather_preds"): return
self.df = pd.DataFrame(columns=self.recorder.metric_names)
self.old_logger, self.learn.logger = self.logger, self._record_line
def _record_line(self, log):
if self._is_training(log):
self.df.loc[log[0]] = log
self.old_logger(log)
def _is_training(self, log):
return self.df.columns.size == len(log)
def read_log(self):
return self.df
def after_fit(self):
if hasattr(self, "gather_preds"): return
self.learn.logger = self.old_logger

3
pyproject.toml Normal file
View File

@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

10
setup.cfg Normal file
View File

@ -0,0 +1,10 @@
[metadata]
name = geofastai
version = 0.0.2
long_description = file: README.md
[options]
packages = geofastai
install_requires =
fastai
rasterio