Add metrics, raster and utils
This commit is contained in:
parent
e414724bc8
commit
54d0a10132
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) <year> <copyright holders>
|
||||
Copyright (c) 2022 Florent Guiotte
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
|
10
README.md
10
README.md
@ -1,3 +1,11 @@
|
||||
# geo-fastai
|
||||
|
||||
Geographic data and remote sensing image processing helpers for Fastai.
|
||||
Geographic data and remote sensing image processing helpers for [fastai].
|
||||
|
||||
[fastai]: https://docs.fast.ai/
|
||||
|
||||
## Installation
|
||||
|
||||
```
|
||||
pip install git+https://git.guiotte.fr/florent/geo-fastai.git
|
||||
```
|
3
geofastai/__init__.py
Normal file
3
geofastai/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .metrics import *
|
||||
from .raster import *
|
||||
from .utils import *
|
44
geofastai/metrics.py
Normal file
44
geofastai/metrics.py
Normal file
@ -0,0 +1,44 @@
|
||||
import fastai.vision.all as fai
|
||||
|
||||
|
||||
class DiceMultiWithoutUndef(fai.DiceMulti):
|
||||
"""Averaged Dice metric (Macro F1) for multiclass target in segmentation
|
||||
|
||||
Ignore the undifined class (labeled 0).
|
||||
|
||||
"""
|
||||
def accumulate(self, learn):
|
||||
pred, targ = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
|
||||
for c in range(1, learn.pred.shape[self.axis]):
|
||||
p = fai.torch.where(pred == c, 1, 0)
|
||||
t = fai.torch.where(targ == c, 1, 0)
|
||||
c_inter = (p*t).float().sum().item()
|
||||
c_union = (p+t).float().sum().item()
|
||||
if c in self.inter:
|
||||
self.inter[c] += c_inter
|
||||
self.union[c] += c_union
|
||||
else:
|
||||
self.inter[c] = c_inter
|
||||
self.union[c] = c_union
|
||||
|
||||
|
||||
class F1Multi(fai.Metric):
|
||||
def __init__(self, labels, axis=1):
|
||||
self.axis = axis
|
||||
self.labels = labels
|
||||
self.metric = fai.F1ScoreMulti(average=None, labels=labels)
|
||||
|
||||
def reset(self):
|
||||
self.pred = []
|
||||
self.target = []
|
||||
|
||||
def accumulate(self, learn):
|
||||
pred, target = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
|
||||
self.pred += [pred]
|
||||
self.target += [target]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
pred = fai.torch.cat(self.pred)
|
||||
target = fai.torch.cat(self.target)
|
||||
return self.metric(pred, target)
|
40
geofastai/raster.py
Normal file
40
geofastai/raster.py
Normal file
@ -0,0 +1,40 @@
|
||||
import fastai.vision.all as fai
|
||||
import rasterio as rio
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore', 'Dataset has no geotransform', module='rasterio')
|
||||
|
||||
|
||||
class RasterBase(fai.TensorImage):
|
||||
@classmethod
|
||||
def create(cls, fname, **kwargs):
|
||||
data = rio.open(fname).read()
|
||||
return cls(data)
|
||||
|
||||
|
||||
class RasterVNIR(RasterBase):
|
||||
def show(self, ctx=None, bands=[3,2,1], **kwargs):
|
||||
im = self[bands]
|
||||
im = im.to(fai.torch.float)
|
||||
im = im / im.quantile(.9)
|
||||
im = im.clip(0, 1)
|
||||
return fai.show_image(im, ctx=ctx, vmax=300, **kwargs)
|
||||
|
||||
|
||||
class RasterToFloatTensor(fai.DisplayedTransform):
|
||||
order = 10
|
||||
def __init__(self, div=None):
|
||||
if div is None:
|
||||
div = np.iinfo(np.uint16).max
|
||||
|
||||
self.div = div
|
||||
|
||||
def encodes(self, o:fai.TensorImage):
|
||||
return o.float().div_(self.div)
|
||||
|
||||
def decodes(self, o:fai.TensorImage):
|
||||
return (o.clamp(0., 1.) * self.div).long()
|
||||
|
||||
|
||||
def RasterBlock(cls=RasterBase):
|
||||
return fai.TransformBlock(type_tfms=cls.create)#, batch_tfms=RasterToFloatTensor)
|
26
geofastai/utils.py
Normal file
26
geofastai/utils.py
Normal file
@ -0,0 +1,26 @@
|
||||
import pandas as pd
|
||||
import fastai.vision.all as fai
|
||||
|
||||
|
||||
class DataframeLogger(fai.Callback):
|
||||
order = 60
|
||||
|
||||
def before_fit(self):
|
||||
if hasattr(self, "gather_preds"): return
|
||||
self.df = pd.DataFrame(columns=self.recorder.metric_names)
|
||||
self.old_logger, self.learn.logger = self.logger, self._record_line
|
||||
|
||||
def _record_line(self, log):
|
||||
if self._is_training(log):
|
||||
self.df.loc[log[0]] = log
|
||||
self.old_logger(log)
|
||||
|
||||
def _is_training(self, log):
|
||||
return self.df.columns.size == len(log)
|
||||
|
||||
def read_log(self):
|
||||
return self.df
|
||||
|
||||
def after_fit(self):
|
||||
if hasattr(self, "gather_preds"): return
|
||||
self.learn.logger = self.old_logger
|
3
pyproject.toml
Normal file
3
pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
Loading…
Reference in New Issue
Block a user