# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmseg.models.builder import LOSSES from mmseg.models.losses.utils import weight_reduce_loss def dice_loss(pred, target, weight=None, eps=1e-3, reduction='mean', avg_factor=None): """Calculate dice loss, which is proposed in `V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation `_. Args: pred (torch.Tensor): The prediction, has a shape (n, *) target (torch.Tensor): The learning label of the prediction, shape (n, *), same shape of pred. weight (torch.Tensor, optional): The weight of loss for each prediction, has a shape (n,). Defaults to None. eps (float): Avoid dividing by zero. Default: 1e-3. reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. """ input = pred.flatten(1) target = target.flatten(1).float() a = torch.sum(input * target, 1) b = torch.sum(input * input, 1) + eps c = torch.sum(target * target, 1) + eps d = (2 * a) / (b + c) loss = 1 - d if weight is not None: assert weight.ndim == loss.ndim assert len(weight) == len(pred) loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction='mean', avg_factor=None): """Calculate naive dice loss, the coefficient in the denominator is the first power instead of the second power. Args: pred (torch.Tensor): The prediction, has a shape (n, *) target (torch.Tensor): The learning label of the prediction, shape (n, *), same shape of pred. weight (torch.Tensor, optional): The weight of loss for each prediction, has a shape (n,). Defaults to None. eps (float): Avoid dividing by zero. Default: 1e-3. reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. """ input = pred.flatten(1) target = target.flatten(1).float() a = torch.sum(input * target, 1) b = torch.sum(input, 1) c = torch.sum(target, 1) d = (2 * a + eps) / (b + c + eps) loss = 1 - d if weight is not None: assert weight.ndim == loss.ndim assert len(weight) == len(pred) loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss @LOSSES.register_module(force=True) class DiceLoss(nn.Module): def __init__(self, use_sigmoid=True, activate=True, reduction='mean', naive_dice=False, loss_weight=1.0, eps=1e-3): """Dice Loss, there are two forms of dice loss is supported: - the one proposed in `V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation `_. - the dice loss in which the power of the number in the denominator is the first power instead of the second power. Args: use_sigmoid (bool, optional): Whether to the prediction is used for sigmoid or softmax. Defaults to True. activate (bool): Whether to activate the predictions inside, this will disable the inside sigmoid operation. Defaults to True. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". Defaults to 'mean'. naive_dice (bool, optional): If false, use the dice loss defined in the V-Net paper, otherwise, use the naive dice loss in which the power of the number in the denominator is the first power instead of the second power.Defaults to False. loss_weight (float, optional): Weight of loss. Defaults to 1.0. eps (float): Avoid dividing by zero. Defaults to 1e-3. """ super(DiceLoss, self).__init__() self.use_sigmoid = use_sigmoid self.reduction = reduction self.naive_dice = naive_dice self.loss_weight = loss_weight self.eps = eps self.activate = activate def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None): """Forward function. Args: pred (torch.Tensor): The prediction, has a shape (n, *). target (torch.Tensor): The label of the prediction, shape (n, *), same shape of pred. weight (torch.Tensor, optional): The weight of loss for each prediction, has a shape (n,). Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = (reduction_override if reduction_override else self.reduction) if self.activate: if self.use_sigmoid: pred = pred.sigmoid() else: raise NotImplementedError if self.naive_dice: loss = self.loss_weight * naive_dice_loss( pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor) else: loss = self.loss_weight * dice_loss( pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor) return loss