Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------------ | |
| # Adapted from https://github.com/jiawei-ren/BalancedMSE | |
| # Original licence: Copyright (c) 2022 Jiawei Ren, under the MIT License. | |
| # ------------------------------------------------------------------------------ | |
| from typing import Optional, Union | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| from mmcv.runner import get_dist_info | |
| from torch.nn.modules.loss import _Loss | |
| from .utils import weighted_loss | |
| def bmc_loss_md(pred: torch.Tensor, target: torch.Tensor, | |
| noise_var: torch.Tensor, all_gather: bool, | |
| loss_mse_weight: float, | |
| loss_debias_weight: float) -> torch.Tensor: | |
| """ | |
| Args: | |
| pred (torch.Tensor): The prediction. Shape should be (N, L). | |
| target (torch.Tensor): The learning target of the prediction. | |
| noise_var (torch.Tensor): Noise var of ground truth distribution. | |
| all_gather (bool): Whether gather tensors across all sub-processes. | |
| Only used in DDP training scheme. | |
| loss_mse_weight (float, optional): The weight of the mse term. | |
| loss_debias_weight (float, optional): The weight of the debiased term. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| N = pred.shape[0] | |
| L = pred.shape[1] | |
| device = pred.device | |
| loss_mse = F.mse_loss(pred, target, reduction='none').sum(-1) | |
| loss_mse = loss_mse / noise_var | |
| if all_gather: | |
| rank, world_size = get_dist_info() | |
| bs, length = target.shape | |
| all_bs = [torch.zeros(1).to(device) for _ in range(world_size)] | |
| dist.all_gather(all_bs, torch.Tensor([bs]).to(device)) | |
| all_bs_int = [int(v.item()) for v in all_bs] | |
| max_bs_int = max(all_bs_int) | |
| target_padding = torch.zeros(max_bs_int, length).to(device) | |
| target_padding[:bs] = target | |
| all_tensor = [] | |
| for _ in range(world_size): | |
| all_tensor.append(torch.zeros(max_bs_int, length).type_as(target)) | |
| dist.all_gather(all_tensor, target_padding) | |
| # remove padding | |
| for i in range(world_size): | |
| all_tensor[i] = all_tensor[i][:all_bs_int[i]] | |
| target = torch.cat(all_tensor, dim=0) | |
| # Debias term | |
| target = target.unsqueeze(0).repeat(N, 1, 1) | |
| pred = pred.unsqueeze(1).expand_as(target) | |
| debias_term = F.mse_loss(pred, target, reduction='none').sum(-1) | |
| debias_term = -0.5 * debias_term / noise_var | |
| loss_debias = torch.logsumexp(debias_term, dim=1).squeeze(-1) | |
| loss = loss_mse * loss_mse_weight + loss_debias * loss_debias_weight | |
| # recover loss scale of mse_loss | |
| loss = loss / L * noise_var.detach() | |
| return loss | |
| class BMCLossMD(_Loss): | |
| """Balanced MSE loss, use batch monte-carlo to estimate distribution. | |
| https://arxiv.org/abs/2203.16427. | |
| Args: | |
| init_noise_sigma (float, optional): The initial value of noise sigma. | |
| This sigma is used to represent ground truth distribution. | |
| Defaults to 1.0. | |
| all_gather (bool, optional): Whether gather tensors across all | |
| sub-processes. If set True, BMC will have more precise estimation | |
| with more time cost. Default: False. | |
| reduction (str, optional): The method that reduces the loss to a | |
| scalar. Options are "none", "mean" and "sum". | |
| loss_mse_weight (float, optional): The weight of the mse term. | |
| Defaults to 1.0. | |
| loss_debias_weight (float, optional): The weight of the debiased term. | |
| Defaults to 1.0. | |
| """ | |
| def __init__(self, | |
| init_noise_sigma: Optional[float] = 1.0, | |
| all_gather: Optional[bool] = False, | |
| reduction: Optional[str] = 'mean', | |
| loss_mse_weight: Optional[float] = 1.0, | |
| loss_debias_weight: Optional[float] = 1.0): | |
| super(BMCLossMD, self).__init__() | |
| self.noise_sigma = torch.nn.Parameter( | |
| torch.tensor(init_noise_sigma).float()) | |
| self.all_gather = all_gather | |
| assert reduction in (None, 'none', 'mean', 'sum') | |
| reduction = 'none' if reduction is None else reduction | |
| self.reduction = reduction | |
| self.loss_mse_weight = loss_mse_weight | |
| self.loss_debias_weight = loss_debias_weight | |
| def forward( | |
| self, | |
| pred: torch.Tensor, | |
| target: torch.Tensor, | |
| weight: Optional[Union[torch.Tensor, None]] = None, | |
| avg_factor: Optional[Union[int, None]] = None, | |
| reduction_override: Optional[Union[str, | |
| None]] = None) -> torch.Tensor: | |
| """Forward function of loss. | |
| Args: | |
| pred (torch.Tensor): The prediction. | |
| target (torch.Tensor): The learning target of the prediction. | |
| weight (torch.Tensor, optional): Weight of the loss for each | |
| prediction. Defaults to None. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| weight (torch.Tensor, optional): Weight of the loss for each | |
| prediction. Defaults to None. | |
| reduction_override (str, optional): The reduction method used to | |
| override the original reduction method of the loss. | |
| Defaults to None. | |
| Returns: | |
| torch.Tensor: The calculated loss | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = (reduction_override | |
| if reduction_override else self.reduction) | |
| noise_var = (self.noise_sigma**2).type_as(pred) | |
| pred = pred.view(pred.shape[0], -1) | |
| target = target.view(target.shape[0], -1) | |
| loss = bmc_loss_md(pred, | |
| target, | |
| noise_var=noise_var, | |
| all_gather=self.all_gather, | |
| loss_mse_weight=self.loss_mse_weight, | |
| loss_debias_weight=self.loss_debias_weight, | |
| weight=weight, | |
| reduction=reduction, | |
| avg_factor=avg_factor) | |
| return loss | |