Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmdet.registry import MODELS | |
| from .utils import weight_reduce_loss | |
| def _expand_onehot_labels(labels, label_weights, label_channels): | |
| bin_labels = labels.new_full((labels.size(0), label_channels), 0) | |
| inds = torch.nonzero( | |
| (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze() | |
| if inds.numel() > 0: | |
| bin_labels[inds, labels[inds]] = 1 | |
| bin_label_weights = label_weights.view(-1, 1).expand( | |
| label_weights.size(0), label_channels) | |
| return bin_labels, bin_label_weights | |
| # TODO: code refactoring to make it consistent with other losses | |
| class GHMC(nn.Module): | |
| """GHM Classification Loss. | |
| Details of the theorem can be viewed in the paper | |
| `Gradient Harmonized Single-stage Detector | |
| <https://arxiv.org/abs/1811.05181>`_. | |
| Args: | |
| bins (int): Number of the unit regions for distribution calculation. | |
| momentum (float): The parameter for moving average. | |
| use_sigmoid (bool): Can only be true for BCE based loss now. | |
| loss_weight (float): The weight of the total GHM-C loss. | |
| reduction (str): Options are "none", "mean" and "sum". | |
| Defaults to "mean" | |
| """ | |
| def __init__(self, | |
| bins=10, | |
| momentum=0, | |
| use_sigmoid=True, | |
| loss_weight=1.0, | |
| reduction='mean'): | |
| super(GHMC, self).__init__() | |
| self.bins = bins | |
| self.momentum = momentum | |
| edges = torch.arange(bins + 1).float() / bins | |
| self.register_buffer('edges', edges) | |
| self.edges[-1] += 1e-6 | |
| if momentum > 0: | |
| acc_sum = torch.zeros(bins) | |
| self.register_buffer('acc_sum', acc_sum) | |
| self.use_sigmoid = use_sigmoid | |
| if not self.use_sigmoid: | |
| raise NotImplementedError | |
| self.loss_weight = loss_weight | |
| self.reduction = reduction | |
| def forward(self, | |
| pred, | |
| target, | |
| label_weight, | |
| reduction_override=None, | |
| **kwargs): | |
| """Calculate the GHM-C loss. | |
| Args: | |
| pred (float tensor of size [batch_num, class_num]): | |
| The direct prediction of classification fc layer. | |
| target (float tensor of size [batch_num, class_num]): | |
| Binary class target for each sample. | |
| label_weight (float tensor of size [batch_num, class_num]): | |
| the value is 1 if the sample is valid and 0 if ignored. | |
| reduction_override (str, optional): The reduction method used to | |
| override the original reduction method of the loss. | |
| Defaults to None. | |
| Returns: | |
| The gradient harmonized loss. | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| # the target should be binary class label | |
| if pred.dim() != target.dim(): | |
| target, label_weight = _expand_onehot_labels( | |
| target, label_weight, pred.size(-1)) | |
| target, label_weight = target.float(), label_weight.float() | |
| edges = self.edges | |
| mmt = self.momentum | |
| weights = torch.zeros_like(pred) | |
| # gradient length | |
| g = torch.abs(pred.sigmoid().detach() - target) | |
| valid = label_weight > 0 | |
| tot = max(valid.float().sum().item(), 1.0) | |
| n = 0 # n valid bins | |
| for i in range(self.bins): | |
| inds = (g >= edges[i]) & (g < edges[i + 1]) & valid | |
| num_in_bin = inds.sum().item() | |
| if num_in_bin > 0: | |
| if mmt > 0: | |
| self.acc_sum[i] = mmt * self.acc_sum[i] \ | |
| + (1 - mmt) * num_in_bin | |
| weights[inds] = tot / self.acc_sum[i] | |
| else: | |
| weights[inds] = tot / num_in_bin | |
| n += 1 | |
| if n > 0: | |
| weights = weights / n | |
| loss = F.binary_cross_entropy_with_logits( | |
| pred, target, reduction='none') | |
| loss = weight_reduce_loss( | |
| loss, weights, reduction=reduction, avg_factor=tot) | |
| return loss * self.loss_weight | |
| # TODO: code refactoring to make it consistent with other losses | |
| class GHMR(nn.Module): | |
| """GHM Regression Loss. | |
| Details of the theorem can be viewed in the paper | |
| `Gradient Harmonized Single-stage Detector | |
| <https://arxiv.org/abs/1811.05181>`_. | |
| Args: | |
| mu (float): The parameter for the Authentic Smooth L1 loss. | |
| bins (int): Number of the unit regions for distribution calculation. | |
| momentum (float): The parameter for moving average. | |
| loss_weight (float): The weight of the total GHM-R loss. | |
| reduction (str): Options are "none", "mean" and "sum". | |
| Defaults to "mean" | |
| """ | |
| def __init__(self, | |
| mu=0.02, | |
| bins=10, | |
| momentum=0, | |
| loss_weight=1.0, | |
| reduction='mean'): | |
| super(GHMR, self).__init__() | |
| self.mu = mu | |
| self.bins = bins | |
| edges = torch.arange(bins + 1).float() / bins | |
| self.register_buffer('edges', edges) | |
| self.edges[-1] = 1e3 | |
| self.momentum = momentum | |
| if momentum > 0: | |
| acc_sum = torch.zeros(bins) | |
| self.register_buffer('acc_sum', acc_sum) | |
| self.loss_weight = loss_weight | |
| self.reduction = reduction | |
| # TODO: support reduction parameter | |
| def forward(self, | |
| pred, | |
| target, | |
| label_weight, | |
| avg_factor=None, | |
| reduction_override=None): | |
| """Calculate the GHM-R loss. | |
| Args: | |
| pred (float tensor of size [batch_num, 4 (* class_num)]): | |
| The prediction of box regression layer. Channel number can be 4 | |
| or 4 * class_num depending on whether it is class-agnostic. | |
| target (float tensor of size [batch_num, 4 (* class_num)]): | |
| The target regression values with the same size of pred. | |
| label_weight (float tensor of size [batch_num, 4 (* class_num)]): | |
| The weight of each sample, 0 if ignored. | |
| reduction_override (str, optional): The reduction method used to | |
| override the original reduction method of the loss. | |
| Defaults to None. | |
| Returns: | |
| The gradient harmonized loss. | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| mu = self.mu | |
| edges = self.edges | |
| mmt = self.momentum | |
| # ASL1 loss | |
| diff = pred - target | |
| loss = torch.sqrt(diff * diff + mu * mu) - mu | |
| # gradient length | |
| g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach() | |
| weights = torch.zeros_like(g) | |
| valid = label_weight > 0 | |
| tot = max(label_weight.float().sum().item(), 1.0) | |
| n = 0 # n: valid bins | |
| for i in range(self.bins): | |
| inds = (g >= edges[i]) & (g < edges[i + 1]) & valid | |
| num_in_bin = inds.sum().item() | |
| if num_in_bin > 0: | |
| n += 1 | |
| if mmt > 0: | |
| self.acc_sum[i] = mmt * self.acc_sum[i] \ | |
| + (1 - mmt) * num_in_bin | |
| weights[inds] = tot / self.acc_sum[i] | |
| else: | |
| weights[inds] = tot / num_in_bin | |
| if n > 0: | |
| weights /= n | |
| loss = weight_reduce_loss( | |
| loss, weights, reduction=reduction, avg_factor=tot) | |
| return loss * self.loss_weight | |