Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmseg.registry import MODELS | |
| from .utils import get_class_weight, weight_reduce_loss | |
| def cross_entropy(pred, | |
| label, | |
| weight=None, | |
| class_weight=None, | |
| reduction='mean', | |
| avg_factor=None, | |
| ignore_index=-100, | |
| avg_non_ignore=False): | |
| """cross_entropy. The wrapper function for :func:`F.cross_entropy` | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, 1). | |
| label (torch.Tensor): The learning label of the prediction. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| Default: None. | |
| class_weight (list[float], optional): The weight for each class. | |
| Default: None. | |
| reduction (str, optional): The method used to reduce the loss. | |
| Options are 'none', 'mean' and 'sum'. Default: 'mean'. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Default: None. | |
| ignore_index (int): Specifies a target value that is ignored and | |
| does not contribute to the input gradients. When | |
| ``avg_non_ignore `` is ``True``, and the ``reduction`` is | |
| ``''mean''``, the loss is averaged over non-ignored targets. | |
| Defaults: -100. | |
| avg_non_ignore (bool): The flag decides to whether the loss is | |
| only averaged over non-ignored targets. Default: False. | |
| `New in version 0.23.0.` | |
| """ | |
| # class_weight is a manual rescaling weight given to each class. | |
| # If given, has to be a Tensor of size C element-wise losses | |
| loss = F.cross_entropy( | |
| pred, | |
| label, | |
| weight=class_weight, | |
| reduction='none', | |
| ignore_index=ignore_index) | |
| # apply weights and do the reduction | |
| # average loss over non-ignored elements | |
| # pytorch's official cross_entropy average loss over non-ignored elements | |
| # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa | |
| if (avg_factor is None) and reduction == 'mean': | |
| if class_weight is None: | |
| if avg_non_ignore: | |
| avg_factor = label.numel() - (label | |
| == ignore_index).sum().item() | |
| else: | |
| avg_factor = label.numel() | |
| else: | |
| # the average factor should take the class weights into account | |
| label_weights = torch.stack([class_weight[cls] for cls in label | |
| ]).to(device=class_weight.device) | |
| if avg_non_ignore: | |
| label_weights[label == ignore_index] = 0 | |
| avg_factor = label_weights.sum() | |
| if weight is not None: | |
| weight = weight.float() | |
| loss = weight_reduce_loss( | |
| loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
| return loss | |
| def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): | |
| """Expand onehot labels to match the size of prediction.""" | |
| bin_labels = labels.new_zeros(target_shape) | |
| valid_mask = (labels >= 0) & (labels != ignore_index) | |
| inds = torch.nonzero(valid_mask, as_tuple=True) | |
| if inds[0].numel() > 0: | |
| if labels.dim() == 3: | |
| bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 | |
| else: | |
| bin_labels[inds[0], labels[valid_mask]] = 1 | |
| valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() | |
| if label_weights is None: | |
| bin_label_weights = valid_mask | |
| else: | |
| bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) | |
| bin_label_weights = bin_label_weights * valid_mask | |
| return bin_labels, bin_label_weights, valid_mask | |
| def binary_cross_entropy(pred, | |
| label, | |
| weight=None, | |
| reduction='mean', | |
| avg_factor=None, | |
| class_weight=None, | |
| ignore_index=-100, | |
| avg_non_ignore=False, | |
| **kwargs): | |
| """Calculate the binary CrossEntropy loss. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, 1). | |
| label (torch.Tensor): The learning label of the prediction. | |
| Note: In bce loss, label < 0 is invalid. | |
| weight (torch.Tensor, optional): Sample-wise loss weight. | |
| reduction (str, optional): The method used to reduce the loss. | |
| Options are "none", "mean" and "sum". | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| class_weight (list[float], optional): The weight for each class. | |
| ignore_index (int): The label index to be ignored. Default: -100. | |
| avg_non_ignore (bool): The flag decides to whether the loss is | |
| only averaged over non-ignored targets. Default: False. | |
| `New in version 0.23.0.` | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| if pred.size(1) == 1: | |
| # For binary class segmentation, the shape of pred is | |
| # [N, 1, H, W] and that of label is [N, H, W]. | |
| # As the ignore_index often set as 255, so the | |
| # binary class label check should mask out | |
| # ignore_index | |
| assert label[label != ignore_index].max() <= 1, \ | |
| 'For pred with shape [N, 1, H, W], its label must have at ' \ | |
| 'most 2 classes' | |
| pred = pred.squeeze(1) | |
| if pred.dim() != label.dim(): | |
| assert (pred.dim() == 2 and label.dim() == 1) or ( | |
| pred.dim() == 4 and label.dim() == 3), \ | |
| 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ | |
| 'H, W], label shape [N, H, W] are supported' | |
| # `weight` returned from `_expand_onehot_labels` | |
| # has been treated for valid (non-ignore) pixels | |
| label, weight, valid_mask = _expand_onehot_labels( | |
| label, weight, pred.shape, ignore_index) | |
| else: | |
| # should mask out the ignored elements | |
| valid_mask = ((label >= 0) & (label != ignore_index)).float() | |
| if weight is not None: | |
| weight = weight * valid_mask | |
| else: | |
| weight = valid_mask | |
| # average loss over non-ignored and valid elements | |
| if reduction == 'mean' and avg_factor is None and avg_non_ignore: | |
| avg_factor = valid_mask.sum().item() | |
| loss = F.binary_cross_entropy_with_logits( | |
| pred, label.float(), pos_weight=class_weight, reduction='none') | |
| # do the reduction for the weighted loss | |
| loss = weight_reduce_loss( | |
| loss, weight, reduction=reduction, avg_factor=avg_factor) | |
| return loss | |
| def mask_cross_entropy(pred, | |
| target, | |
| label, | |
| reduction='mean', | |
| avg_factor=None, | |
| class_weight=None, | |
| ignore_index=None, | |
| **kwargs): | |
| """Calculate the CrossEntropy loss for masks. | |
| Args: | |
| pred (torch.Tensor): The prediction with shape (N, C), C is the number | |
| of classes. | |
| target (torch.Tensor): The learning label of the prediction. | |
| label (torch.Tensor): ``label`` indicates the class label of the mask' | |
| corresponding object. This will be used to select the mask in the | |
| of the class which the object belongs to when the mask prediction | |
| if not class-agnostic. | |
| reduction (str, optional): The method used to reduce the loss. | |
| Options are "none", "mean" and "sum". | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| class_weight (list[float], optional): The weight for each class. | |
| ignore_index (None): Placeholder, to be consistent with other loss. | |
| Default: None. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| assert ignore_index is None, 'BCE loss does not support ignore_index' | |
| # TODO: handle these two reserved arguments | |
| assert reduction == 'mean' and avg_factor is None | |
| num_rois = pred.size()[0] | |
| inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) | |
| pred_slice = pred[inds, label].squeeze(1) | |
| return F.binary_cross_entropy_with_logits( | |
| pred_slice, target, weight=class_weight, reduction='mean')[None] | |
| class CrossEntropyLoss(nn.Module): | |
| """CrossEntropyLoss. | |
| Args: | |
| use_sigmoid (bool, optional): Whether the prediction uses sigmoid | |
| of softmax. Defaults to False. | |
| use_mask (bool, optional): Whether to use mask cross entropy loss. | |
| Defaults to False. | |
| reduction (str, optional): . Defaults to 'mean'. | |
| Options are "none", "mean" and "sum". | |
| class_weight (list[float] | str, optional): Weight of each class. If in | |
| str format, read them from a file. Defaults to None. | |
| loss_weight (float, optional): Weight of the loss. Defaults to 1.0. | |
| loss_name (str, optional): Name of the loss item. If you want this loss | |
| item to be included into the backward graph, `loss_` must be the | |
| prefix of the name. Defaults to 'loss_ce'. | |
| avg_non_ignore (bool): The flag decides to whether the loss is | |
| only averaged over non-ignored targets. Default: False. | |
| `New in version 0.23.0.` | |
| """ | |
| def __init__(self, | |
| use_sigmoid=False, | |
| use_mask=False, | |
| reduction='mean', | |
| class_weight=None, | |
| loss_weight=1.0, | |
| loss_name='loss_ce', | |
| avg_non_ignore=False): | |
| super().__init__() | |
| assert (use_sigmoid is False) or (use_mask is False) | |
| self.use_sigmoid = use_sigmoid | |
| self.use_mask = use_mask | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.class_weight = get_class_weight(class_weight) | |
| self.avg_non_ignore = avg_non_ignore | |
| if not self.avg_non_ignore and self.reduction == 'mean': | |
| warnings.warn( | |
| 'Default ``avg_non_ignore`` is False, if you would like to ' | |
| 'ignore the certain label and average loss over non-ignore ' | |
| 'labels, which is the same with PyTorch official ' | |
| 'cross_entropy, set ``avg_non_ignore=True``.') | |
| if self.use_sigmoid: | |
| self.cls_criterion = binary_cross_entropy | |
| elif self.use_mask: | |
| self.cls_criterion = mask_cross_entropy | |
| else: | |
| self.cls_criterion = cross_entropy | |
| self._loss_name = loss_name | |
| def extra_repr(self): | |
| """Extra repr.""" | |
| s = f'avg_non_ignore={self.avg_non_ignore}' | |
| return s | |
| def forward(self, | |
| cls_score, | |
| label, | |
| weight=None, | |
| avg_factor=None, | |
| reduction_override=None, | |
| ignore_index=-100, | |
| **kwargs): | |
| """Forward function.""" | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| if self.class_weight is not None: | |
| class_weight = cls_score.new_tensor(self.class_weight) | |
| else: | |
| class_weight = None | |
| # Note: for BCE loss, label < 0 is invalid. | |
| loss_cls = self.loss_weight * self.cls_criterion( | |
| cls_score, | |
| label, | |
| weight, | |
| class_weight=class_weight, | |
| reduction=reduction, | |
| avg_factor=avg_factor, | |
| avg_non_ignore=self.avg_non_ignore, | |
| ignore_index=ignore_index, | |
| **kwargs) | |
| return loss_cls | |
| def loss_name(self): | |
| """Loss Name. | |
| This function must be implemented and will return the name of this | |
| loss function. This name will be used to combine different loss items | |
| by simple sum operation. In addition, if you want this loss item to be | |
| included into the backward graph, `loss_` must be the prefix of the | |
| name. | |
| Returns: | |
| str: The name of this loss item. | |
| """ | |
| return self._loss_name | |