import fastai.vision.all as fai class DiceMultiWithoutUndef(fai.DiceMulti): """Averaged Dice metric (Macro F1) for multiclass target in segmentation Ignore the undifined class (labeled 0). """ def accumulate(self, learn): pred, targ = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y) for c in range(1, learn.pred.shape[self.axis]): p = fai.torch.where(pred == c, 1, 0) t = fai.torch.where(targ == c, 1, 0) c_inter = (p*t).float().sum().item() c_union = (p+t).float().sum().item() if c in self.inter: self.inter[c] += c_inter self.union[c] += c_union else: self.inter[c] = c_inter self.union[c] = c_union class F1Multi(fai.Metric): def __init__(self, labels, axis=1): self.axis = axis self.labels = labels self.metric = fai.F1ScoreMulti(average=None, labels=labels) def reset(self): self.pred = [] self.target = [] def accumulate(self, learn): pred, target = fai.flatten_check(learn.pred.argmax(dim=self.axis), learn.y) self.pred += [pred] self.target += [target] @property def value(self): pred = fai.torch.cat(self.pred) target = fai.torch.cat(self.target) return self.metric(pred, target)