geo-fastai/geofastai/metrics.py

44 lines
1.4 KiB
Python

import fastai.vision.all as fai
class DiceMultiWithoutUndef(fai.DiceMulti):
"""Averaged Dice metric (Macro F1) for multiclass target in segmentation
Ignore the undifined class (labeled 0).
"""
def accumulate(self, learn):
pred, targ = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
for c in range(1, learn.pred.shape[self.axis]):
p = fai.torch.where(pred == c, 1, 0)
t = fai.torch.where(targ == c, 1, 0)
c_inter = (p*t).float().sum().item()
c_union = (p+t).float().sum().item()
if c in self.inter:
self.inter[c] += c_inter
self.union[c] += c_union
else:
self.inter[c] = c_inter
self.union[c] = c_union
class F1Multi(fai.Metric):
def __init__(self, labels, axis=1):
self.axis = axis
self.labels = labels
self.metric = fai.F1ScoreMulti(average=None, labels=labels)
def reset(self):
self.pred = []
self.target = []
def accumulate(self, learn):
pred, target = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
self.pred += [pred]
self.target += [target]
@property
def value(self):
pred = fai.torch.cat(self.pred)
target = fai.torch.cat(self.target)
return self.metric(pred, target)