Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from mmseg.registry import MODELS | |
| class BoundaryLoss(nn.Module): | |
| """Boundary loss. | |
| This function is modified from | |
| `PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L122>`_. # noqa | |
| Licensed under the MIT License. | |
| Args: | |
| loss_weight (float): Weight of the loss. Defaults to 1.0. | |
| loss_name (str): Name of the loss item. If you want this loss | |
| item to be included into the backward graph, `loss_` must be the | |
| prefix of the name. Defaults to 'loss_boundary'. | |
| """ | |
| def __init__(self, | |
| loss_weight: float = 1.0, | |
| loss_name: str = 'loss_boundary'): | |
| super().__init__() | |
| self.loss_weight = loss_weight | |
| self.loss_name_ = loss_name | |
| def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor: | |
| """Forward function. | |
| Args: | |
| bd_pre (Tensor): Predictions of the boundary head. | |
| bd_gt (Tensor): Ground truth of the boundary. | |
| Returns: | |
| Tensor: Loss tensor. | |
| """ | |
| log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1) | |
| target_t = bd_gt.view(1, -1).float() | |
| pos_index = (target_t == 1) | |
| neg_index = (target_t == 0) | |
| weight = torch.zeros_like(log_p) | |
| pos_num = pos_index.sum() | |
| neg_num = neg_index.sum() | |
| sum_num = pos_num + neg_num | |
| weight[pos_index] = neg_num * 1.0 / sum_num | |
| weight[neg_index] = pos_num * 1.0 / sum_num | |
| loss = F.binary_cross_entropy_with_logits( | |
| log_p, target_t, weight, reduction='mean') | |
| return self.loss_weight * loss | |
| def loss_name(self): | |
| return self.loss_name_ | |