Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss | |
| from .utils import weight_reduce_loss | |
| # This method is only for debugging | |
| def py_sigmoid_focal_loss(pred, | |
| target, | |
| weight=None, | |
| gamma=2.0, | |
| alpha=0.25, | |
| reduction='mean', | |
| avg_factor=None): | |
| """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C), C is the | |
| number of classes | |
| target (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| gamma (float, optional): The gamma for calculating the modulating | |
| factor. Defaults to 2.0. | |
| alpha (float, optional): A balanced form for Focal Loss. | |
| Defaults to 0.25. | |
| reduction (str, optional): The method used to reduce the loss into | |
| a scalar. Defaults to 'mean'. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| """ | |
| pred_sigmoid = pred.sigmoid() | |
| target = target.type_as(pred) | |
| pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) | |
| focal_weight = (alpha * target + (1 - alpha) * | |
| (1 - target)) * pt.pow(gamma) | |
| loss = F.binary_cross_entropy_with_logits(pred, target, | |
| reduction='none') * focal_weight | |
| if weight is not None: | |
| if weight.shape != loss.shape: | |
| if weight.size(0) == loss.size(0): | |
| # For most cases, weight is of shape (num_priors, ), | |
| # which means it does not have the second axis num_class | |
| weight = weight.view(-1, 1) | |
| else: | |
| # Sometimes, weight per anchor per class is also needed. e.g. | |
| # in FSAF. But it may be flattened of shape | |
| # (num_priors x num_class, ), while loss is still of shape | |
| # (num_priors, num_class). | |
| assert weight.numel() == loss.numel() | |
| weight = weight.view(loss.size(0), -1) | |
| assert weight.ndim == loss.ndim | |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
| return loss | |
| def py_focal_loss_with_prob(pred, | |
| target, | |
| weight=None, | |
| gamma=2.0, | |
| alpha=0.25, | |
| reduction='mean', | |
| avg_factor=None): | |
| """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_. | |
| Different from `py_sigmoid_focal_loss`, this function accepts probability | |
| as input. | |
| Args: | |
| pred (torch.Tensor): The prediction probability with shape (N, C), | |
| C is the number of classes. | |
| target (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| gamma (float, optional): The gamma for calculating the modulating | |
| factor. Defaults to 2.0. | |
| alpha (float, optional): A balanced form for Focal Loss. | |
| Defaults to 0.25. | |
| reduction (str, optional): The method used to reduce the loss into | |
| a scalar. Defaults to 'mean'. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| """ | |
| num_classes = pred.size(1) | |
| target = F.one_hot(target, num_classes=num_classes + 1) | |
| target = target[:, :num_classes] | |
| target = target.type_as(pred) | |
| pt = (1 - pred) * target + pred * (1 - target) | |
| focal_weight = (alpha * target + (1 - alpha) * | |
| (1 - target)) * pt.pow(gamma) | |
| loss = F.binary_cross_entropy(pred, target, | |
| reduction='none') * focal_weight | |
| if weight is not None: | |
| if weight.shape != loss.shape: | |
| if weight.size(0) == loss.size(0): | |
| # For most cases, weight is of shape (num_priors, ), | |
| # which means it does not have the second axis num_class | |
| weight = weight.view(-1, 1) | |
| else: | |
| # Sometimes, weight per anchor per class is also needed. e.g. | |
| # in FSAF. But it may be flattened of shape | |
| # (num_priors x num_class, ), while loss is still of shape | |
| # (num_priors, num_class). | |
| assert weight.numel() == loss.numel() | |
| weight = weight.view(loss.size(0), -1) | |
| assert weight.ndim == loss.ndim | |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
| return loss | |
| def sigmoid_focal_loss(pred, | |
| target, | |
| weight=None, | |
| gamma=2.0, | |
| alpha=0.25, | |
| reduction='mean', | |
| avg_factor=None): | |
| r"""A warpper of cuda version `Focal Loss | |
| <https://arxiv.org/abs/1708.02002>`_. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C), C is the number | |
| of classes. | |
| target (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| gamma (float, optional): The gamma for calculating the modulating | |
| factor. Defaults to 2.0. | |
| alpha (float, optional): A balanced form for Focal Loss. | |
| Defaults to 0.25. | |
| reduction (str, optional): The method used to reduce the loss into | |
| a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| """ | |
| # Function.apply does not accept keyword arguments, so the decorator | |
| # "weighted_loss" is not applicable | |
| loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, | |
| alpha, None, 'none') | |
| if weight is not None: | |
| if weight.shape != loss.shape: | |
| if weight.size(0) == loss.size(0): | |
| # For most cases, weight is of shape (num_priors, ), | |
| # which means it does not have the second axis num_class | |
| weight = weight.view(-1, 1) | |
| else: | |
| # Sometimes, weight per anchor per class is also needed. e.g. | |
| # in FSAF. But it may be flattened of shape | |
| # (num_priors x num_class, ), while loss is still of shape | |
| # (num_priors, num_class). | |
| assert weight.numel() == loss.numel() | |
| weight = weight.view(loss.size(0), -1) | |
| assert weight.ndim == loss.ndim | |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
| return loss | |
| class FocalLoss(nn.Module): | |
| def __init__(self, | |
| use_sigmoid=True, | |
| gamma=2.0, | |
| alpha=0.25, | |
| reduction='mean', | |
| loss_weight=1.0, | |
| activated=False): | |
| """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ | |
| Args: | |
| use_sigmoid (bool, optional): Whether to the prediction is | |
| used for sigmoid or softmax. Defaults to True. | |
| gamma (float, optional): The gamma for calculating the modulating | |
| factor. Defaults to 2.0. | |
| alpha (float, optional): A balanced form for Focal Loss. | |
| Defaults to 0.25. | |
| reduction (str, optional): The method used to reduce the loss into | |
| a scalar. Defaults to 'mean'. Options are "none", "mean" and | |
| "sum". | |
| loss_weight (float, optional): Weight of loss. Defaults to 1.0. | |
| activated (bool, optional): Whether the input is activated. | |
| If True, it means the input has been activated and can be | |
| treated as probabilities. Else, it should be treated as logits. | |
| Defaults to False. | |
| """ | |
| super(FocalLoss, self).__init__() | |
| assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' | |
| self.use_sigmoid = use_sigmoid | |
| self.gamma = gamma | |
| self.alpha = alpha | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.activated = activated | |
| def forward(self, | |
| pred, | |
| target, | |
| weight=None, | |
| avg_factor=None, | |
| reduction_override=None): | |
| """Forward function. | |
| Args: | |
| pred (torch.Tensor): The prediction. | |
| target (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor, optional): The weight of loss for each | |
| prediction. Defaults to None. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| reduction_override (str, optional): The reduction method used to | |
| override the original reduction method of the loss. | |
| Options are "none", "mean" and "sum". | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = (reduction_override | |
| if reduction_override else self.reduction) | |
| if self.use_sigmoid: | |
| if self.activated: | |
| calculate_loss_func = py_focal_loss_with_prob | |
| else: | |
| if torch.cuda.is_available() and pred.is_cuda: | |
| calculate_loss_func = sigmoid_focal_loss | |
| else: | |
| num_classes = pred.size(1) | |
| target = F.one_hot(target, num_classes=num_classes + 1) | |
| target = target[:, :num_classes] | |
| calculate_loss_func = py_sigmoid_focal_loss | |
| loss_cls = self.loss_weight * calculate_loss_func( | |
| pred, | |
| target, | |
| weight, | |
| gamma=self.gamma, | |
| alpha=self.alpha, | |
| reduction=reduction, | |
| avg_factor=avg_factor) | |
| else: | |
| raise NotImplementedError | |
| return loss_cls | |