Spaces:
Build error
Build error
| import mmcv | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ..builder import LOSSES | |
| from .utils import weight_reduce_loss | |
| def varifocal_loss(pred, | |
| target, | |
| weight=None, | |
| alpha=0.75, | |
| gamma=2.0, | |
| iou_weighted=True, | |
| reduction='mean', | |
| avg_factor=None): | |
| """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_ | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C), C is the | |
| number of classes | |
| target (torch.Tensor): The learning target of the iou-aware | |
| classification score with shape (N, C), C is the number of classes. | |
| weight (torch.Tensor, optional): The weight of loss for each | |
| prediction. Defaults to None. | |
| alpha (float, optional): A balance factor for the negative part of | |
| Varifocal Loss, which is different from the alpha of Focal Loss. | |
| Defaults to 0.75. | |
| gamma (float, optional): The gamma for calculating the modulating | |
| factor. Defaults to 2.0. | |
| iou_weighted (bool, optional): Whether to weight the loss of the | |
| positive example with the iou target. Defaults to True. | |
| reduction (str, optional): The method used to reduce the loss into | |
| a scalar. Defaults to 'mean'. Options are "none", "mean" and | |
| "sum". | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| """ | |
| # pred and target should be of the same size | |
| assert pred.size() == target.size() | |
| pred_sigmoid = pred.sigmoid() | |
| target = target.type_as(pred) | |
| if iou_weighted: | |
| focal_weight = target * (target > 0.0).float() + \ | |
| alpha * (pred_sigmoid - target).abs().pow(gamma) * \ | |
| (target <= 0.0).float() | |
| else: | |
| focal_weight = (target > 0.0).float() + \ | |
| alpha * (pred_sigmoid - target).abs().pow(gamma) * \ | |
| (target <= 0.0).float() | |
| loss = F.binary_cross_entropy_with_logits( | |
| pred, target, reduction='none') * focal_weight | |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
| return loss | |
| class VarifocalLoss(nn.Module): | |
| def __init__(self, | |
| use_sigmoid=True, | |
| alpha=0.75, | |
| gamma=2.0, | |
| iou_weighted=True, | |
| reduction='mean', | |
| loss_weight=1.0): | |
| """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_ | |
| Args: | |
| use_sigmoid (bool, optional): Whether the prediction is | |
| used for sigmoid or softmax. Defaults to True. | |
| alpha (float, optional): A balance factor for the negative part of | |
| Varifocal Loss, which is different from the alpha of Focal | |
| Loss. Defaults to 0.75. | |
| gamma (float, optional): The gamma for calculating the modulating | |
| factor. Defaults to 2.0. | |
| iou_weighted (bool, optional): Whether to weight the loss of the | |
| positive examples with the iou target. Defaults to True. | |
| reduction (str, optional): The method used to reduce the loss into | |
| a scalar. Defaults to 'mean'. Options are "none", "mean" and | |
| "sum". | |
| loss_weight (float, optional): Weight of loss. Defaults to 1.0. | |
| """ | |
| super(VarifocalLoss, self).__init__() | |
| assert use_sigmoid is True, \ | |
| 'Only sigmoid varifocal loss supported now.' | |
| assert alpha >= 0.0 | |
| self.use_sigmoid = use_sigmoid | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.iou_weighted = iou_weighted | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| def forward(self, | |
| pred, | |
| target, | |
| weight=None, | |
| avg_factor=None, | |
| reduction_override=None): | |
| """Forward function. | |
| Args: | |
| pred (torch.Tensor): The prediction. | |
| target (torch.Tensor): The learning target of the prediction. | |
| weight (torch.Tensor, optional): The weight of loss for each | |
| prediction. Defaults to None. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| reduction_override (str, optional): The reduction method used to | |
| override the original reduction method of the loss. | |
| Options are "none", "mean" and "sum". | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| if self.use_sigmoid: | |
| loss_cls = self.loss_weight * varifocal_loss( | |
| pred, | |
| target, | |
| weight, | |
| alpha=self.alpha, | |
| gamma=self.gamma, | |
| iou_weighted=self.iou_weighted, | |
| reduction=reduction, | |
| avg_factor=avg_factor) | |
| else: | |
| raise NotImplementedError | |
| return loss_cls | |