44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import fastai.vision.all as fai
|
|
|
|
|
|
class DiceMultiWithoutUndef(fai.DiceMulti):
|
|
"""Averaged Dice metric (Macro F1) for multiclass target in segmentation
|
|
|
|
Ignore the undifined class (labeled 0).
|
|
|
|
"""
|
|
def accumulate(self, learn):
|
|
pred, targ = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
|
|
for c in range(1, learn.pred.shape[self.axis]):
|
|
p = fai.torch.where(pred == c, 1, 0)
|
|
t = fai.torch.where(targ == c, 1, 0)
|
|
c_inter = (p*t).float().sum().item()
|
|
c_union = (p+t).float().sum().item()
|
|
if c in self.inter:
|
|
self.inter[c] += c_inter
|
|
self.union[c] += c_union
|
|
else:
|
|
self.inter[c] = c_inter
|
|
self.union[c] = c_union
|
|
|
|
|
|
class F1Multi(fai.Metric):
|
|
def __init__(self, labels, axis=1):
|
|
self.axis = axis
|
|
self.labels = labels
|
|
self.metric = fai.F1ScoreMulti(average=None, labels=labels)
|
|
|
|
def reset(self):
|
|
self.pred = []
|
|
self.target = []
|
|
|
|
def accumulate(self, learn):
|
|
pred, target = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y)
|
|
self.pred += [pred]
|
|
self.target += [target]
|
|
|
|
@property
|
|
def value(self):
|
|
pred = fai.torch.cat(self.pred)
|
|
target = fai.torch.cat(self.target)
|
|
return self.metric(pred, target) |