Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from mmseg.registry import MODELS | |
| from .utils import weight_reduce_loss | |
| def silog_loss(pred: Tensor, | |
| target: Tensor, | |
| weight: Optional[Tensor] = None, | |
| eps: float = 1e-4, | |
| reduction: Union[str, None] = 'mean', | |
| avg_factor: Optional[int] = None) -> Tensor: | |
| """Computes the Scale-Invariant Logarithmic (SI-Log) loss between | |
| prediction and target. | |
| Args: | |
| pred (Tensor): Predicted output. | |
| target (Tensor): Ground truth. | |
| weight (Optional[Tensor]): Optional weight to apply on the loss. | |
| eps (float): Epsilon value to avoid division and log(0). | |
| reduction (Union[str, None]): Specifies the reduction to apply to the | |
| output: 'mean', 'sum' or None. | |
| avg_factor (Optional[int]): Optional average factor for the loss. | |
| Returns: | |
| Tensor: The calculated SI-Log loss. | |
| """ | |
| pred, target = pred.flatten(1), target.flatten(1) | |
| valid_mask = (target > eps).detach().float() | |
| diff_log = torch.log(target.clamp(min=eps)) - torch.log( | |
| pred.clamp(min=eps)) | |
| valid_mask = (target > eps).detach() & (~torch.isnan(diff_log)) | |
| diff_log[~valid_mask] = 0.0 | |
| valid_mask = valid_mask.float() | |
| diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum( | |
| dim=1) / valid_mask.sum(dim=1).clamp(min=eps) | |
| diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum( | |
| dim=1).clamp(min=eps) | |
| loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2)) | |
| if weight is not None: | |
| weight = weight.float() | |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
| return loss | |
| class SiLogLoss(nn.Module): | |
| """Compute SiLog loss. | |
| Args: | |
| reduction (str, optional): The method used | |
| to reduce the loss. Options are "none", | |
| "mean" and "sum". Defaults to 'mean'. | |
| loss_weight (float, optional): Weight of loss. Defaults to 1.0. | |
| eps (float): Avoid dividing by zero. Defaults to 1e-3. | |
| loss_name (str, optional): Name of the loss item. If you want this | |
| loss item to be included into the backward graph, `loss_` must | |
| be the prefix of the name. Defaults to 'loss_silog'. | |
| """ | |
| def __init__(self, | |
| reduction='mean', | |
| loss_weight=1.0, | |
| eps=1e-6, | |
| loss_name='loss_silog'): | |
| super().__init__() | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.eps = eps | |
| self._loss_name = loss_name | |
| def forward( | |
| self, | |
| pred, | |
| target, | |
| weight=None, | |
| avg_factor=None, | |
| reduction_override=None, | |
| ): | |
| assert pred.shape == target.shape, 'the shapes of pred ' \ | |
| f'({pred.shape}) and target ({target.shape}) are mismatch' | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| loss = self.loss_weight * silog_loss( | |
| pred, | |
| target, | |
| weight, | |
| eps=self.eps, | |
| reduction=reduction, | |
| avg_factor=avg_factor, | |
| ) | |
| return loss | |
| def loss_name(self): | |
| """Loss Name. | |
| This function must be implemented and will return the name of this | |
| loss function. This name will be used to combine different loss items | |
| by simple sum operation. In addition, if you want this loss item to be | |
| included into the backward graph, `loss_` must be the prefix of the | |
| name. | |
| Returns: | |
| str: The name of this loss item. | |
| """ | |
| return self._loss_name | |