Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Optional, Union | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from mmdet.registry import MODELS | |
| from .utils import weight_reduce_loss, weighted_loss | |
| def gaussian_focal_loss(pred: Tensor, | |
| gaussian_target: Tensor, | |
| alpha: float = 2.0, | |
| gamma: float = 4.0, | |
| pos_weight: float = 1.0, | |
| neg_weight: float = 1.0) -> Tensor: | |
| """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian | |
| distribution. | |
| Args: | |
| pred (torch.Tensor): The prediction. | |
| gaussian_target (torch.Tensor): The learning target of the prediction | |
| in gaussian distribution. | |
| alpha (float, optional): A balanced form for Focal Loss. | |
| Defaults to 2.0. | |
| gamma (float, optional): The gamma for calculating the modulating | |
| factor. Defaults to 4.0. | |
| pos_weight(float): Positive sample loss weight. Defaults to 1.0. | |
| neg_weight(float): Negative sample loss weight. Defaults to 1.0. | |
| """ | |
| eps = 1e-12 | |
| pos_weights = gaussian_target.eq(1) | |
| neg_weights = (1 - gaussian_target).pow(gamma) | |
| pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights | |
| neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights | |
| return pos_weight * pos_loss + neg_weight * neg_loss | |
| def gaussian_focal_loss_with_pos_inds( | |
| pred: Tensor, | |
| gaussian_target: Tensor, | |
| pos_inds: Tensor, | |
| pos_labels: Tensor, | |
| alpha: float = 2.0, | |
| gamma: float = 4.0, | |
| pos_weight: float = 1.0, | |
| neg_weight: float = 1.0, | |
| reduction: str = 'mean', | |
| avg_factor: Optional[Union[int, float]] = None) -> Tensor: | |
| """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian | |
| distribution. | |
| Note: The index with a value of 1 in ``gaussian_target`` in the | |
| ``gaussian_focal_loss`` function is a positive sample, but in | |
| ``gaussian_focal_loss_with_pos_inds`` the positive sample is passed | |
| in through the ``pos_inds`` parameter. | |
| Args: | |
| pred (torch.Tensor): The prediction. The shape is (N, num_classes). | |
| gaussian_target (torch.Tensor): The learning target of the prediction | |
| in gaussian distribution. The shape is (N, num_classes). | |
| pos_inds (torch.Tensor): The positive sample index. | |
| The shape is (M, ). | |
| pos_labels (torch.Tensor): The label corresponding to the positive | |
| sample index. The shape is (M, ). | |
| alpha (float, optional): A balanced form for Focal Loss. | |
| Defaults to 2.0. | |
| gamma (float, optional): The gamma for calculating the modulating | |
| factor. Defaults to 4.0. | |
| pos_weight(float): Positive sample loss weight. Defaults to 1.0. | |
| neg_weight(float): Negative sample loss weight. Defaults to 1.0. | |
| reduction (str): Options are "none", "mean" and "sum". | |
| Defaults to 'mean`. | |
| avg_factor (int, float, optional): Average factor that is used to | |
| average the loss. Defaults to None. | |
| """ | |
| eps = 1e-12 | |
| neg_weights = (1 - gaussian_target).pow(gamma) | |
| pos_pred_pix = pred[pos_inds] | |
| pos_pred = pos_pred_pix.gather(1, pos_labels.unsqueeze(1)) | |
| pos_loss = -(pos_pred + eps).log() * (1 - pos_pred).pow(alpha) | |
| pos_loss = weight_reduce_loss(pos_loss, None, reduction, avg_factor) | |
| neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights | |
| neg_loss = weight_reduce_loss(neg_loss, None, reduction, avg_factor) | |
| return pos_weight * pos_loss + neg_weight * neg_loss | |
| class GaussianFocalLoss(nn.Module): | |
| """GaussianFocalLoss is a variant of focal loss. | |
| More details can be found in the `paper | |
| <https://arxiv.org/abs/1808.01244>`_ | |
| Code is modified from `kp_utils.py | |
| <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501 | |
| Please notice that the target in GaussianFocalLoss is a gaussian heatmap, | |
| not 0/1 binary target. | |
| Args: | |
| alpha (float): Power of prediction. | |
| gamma (float): Power of target for negative samples. | |
| reduction (str): Options are "none", "mean" and "sum". | |
| loss_weight (float): Loss weight of current loss. | |
| pos_weight(float): Positive sample loss weight. Defaults to 1.0. | |
| neg_weight(float): Negative sample loss weight. Defaults to 1.0. | |
| """ | |
| def __init__(self, | |
| alpha: float = 2.0, | |
| gamma: float = 4.0, | |
| reduction: str = 'mean', | |
| loss_weight: float = 1.0, | |
| pos_weight: float = 1.0, | |
| neg_weight: float = 1.0) -> None: | |
| super().__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.pos_weight = pos_weight | |
| self.neg_weight = neg_weight | |
| def forward(self, | |
| pred: Tensor, | |
| target: Tensor, | |
| pos_inds: Optional[Tensor] = None, | |
| pos_labels: Optional[Tensor] = None, | |
| weight: Optional[Tensor] = None, | |
| avg_factor: Optional[Union[int, float]] = None, | |
| reduction_override: Optional[str] = None) -> Tensor: | |
| """Forward function. | |
| If you want to manually determine which positions are | |
| positive samples, you can set the pos_index and pos_label | |
| parameter. Currently, only the CenterNet update version uses | |
| the parameter. | |
| Args: | |
| pred (torch.Tensor): The prediction. The shape is (N, num_classes). | |
| target (torch.Tensor): The learning target of the prediction | |
| in gaussian distribution. The shape is (N, num_classes). | |
| pos_inds (torch.Tensor): The positive sample index. | |
| Defaults to None. | |
| pos_labels (torch.Tensor): The label corresponding to the positive | |
| sample index. Defaults to None. | |
| weight (torch.Tensor, optional): The weight of loss for each | |
| prediction. Defaults to None. | |
| avg_factor (int, float, optional): Average factor that is used to | |
| average the loss. Defaults to None. | |
| reduction_override (str, optional): The reduction method used to | |
| override the original reduction method of the loss. | |
| Defaults to None. | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| if pos_inds is not None: | |
| assert pos_labels is not None | |
| # Only used by centernet update version | |
| loss_reg = self.loss_weight * gaussian_focal_loss_with_pos_inds( | |
| pred, | |
| target, | |
| pos_inds, | |
| pos_labels, | |
| alpha=self.alpha, | |
| gamma=self.gamma, | |
| pos_weight=self.pos_weight, | |
| neg_weight=self.neg_weight, | |
| reduction=reduction, | |
| avg_factor=avg_factor) | |
| else: | |
| loss_reg = self.loss_weight * gaussian_focal_loss( | |
| pred, | |
| target, | |
| weight, | |
| alpha=self.alpha, | |
| gamma=self.gamma, | |
| pos_weight=self.pos_weight, | |
| neg_weight=self.neg_weight, | |
| reduction=reduction, | |
| avg_factor=avg_factor) | |
| return loss_reg | |