Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmseg.registry import MODELS | |
| class KLDivLoss(nn.Module): | |
| def __init__(self, | |
| temperature: float = 1.0, | |
| reduction: str = 'mean', | |
| loss_name: str = 'loss_kld'): | |
| """Kullback-Leibler divergence Loss. | |
| <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence> | |
| Args: | |
| temperature (float, optional): Temperature param | |
| reduction (str, optional): The method to reduce the loss into a | |
| scalar. Default is "mean". Options are "none", "sum", | |
| and "mean" | |
| """ | |
| assert isinstance(temperature, (float, int)), \ | |
| 'Expected temperature to be' \ | |
| f'float or int, but got {temperature.__class__.__name__} instead' | |
| assert temperature != 0., 'Temperature must not be zero' | |
| assert reduction in ['mean', 'none', 'sum'], \ | |
| 'Reduction must be one of the options ("mean", ' \ | |
| f'"sum", "none"), but got {reduction}' | |
| super().__init__() | |
| self.temperature = temperature | |
| self.reduction = reduction | |
| self._loss_name = loss_name | |
| def forward(self, input: torch.Tensor, target: torch.Tensor): | |
| """Forward function. Calculate KL divergence Loss. | |
| Args: | |
| input (Tensor): Logit tensor, | |
| the data type is float32 or float64. | |
| The shape is (N, C) where N is batchsize and C is number of | |
| channels. | |
| If there more than 2 dimensions, shape is (N, C, D1, D2, ... | |
| Dk), k>= 1 | |
| target (Tensor): Logit tensor, | |
| the data type is float32 or float64. | |
| input and target must be with the same shape. | |
| Returns: | |
| (Tensor): Reduced loss. | |
| """ | |
| assert isinstance(input, torch.Tensor), 'Expected input to' \ | |
| f'be Tensor, but got {input.__class__.__name__} instead' | |
| assert isinstance(target, torch.Tensor), 'Expected target to' \ | |
| f'be Tensor, but got {target.__class__.__name__} instead' | |
| assert input.shape == target.shape, 'Input and target ' \ | |
| 'must have same shape,' \ | |
| f'but got shapes {input.shape} and {target.shape}' | |
| input = F.softmax(input / self.temperature, dim=1) | |
| target = F.softmax(target / self.temperature, dim=1) | |
| loss = F.kl_div(input, target, reduction='none', log_target=False) | |
| loss = loss * self.temperature**2 | |
| batch_size = input.shape[0] | |
| if self.reduction == 'sum': | |
| # Change view to calculate instance-wise sum | |
| loss = loss.view(batch_size, -1) | |
| return torch.sum(loss, dim=1) | |
| elif self.reduction == 'mean': | |
| # Change view to calculate instance-wise mean | |
| loss = loss.view(batch_size, -1) | |
| return torch.mean(loss, dim=1) | |
| return loss | |
| def loss_name(self): | |
| """Loss Name. | |
| This function must be implemented and will return the name of this | |
| loss function. This name will be used to combine different loss items | |
| by simple sum operation. In addition, if you want this loss item to be | |
| included into the backward graph, `loss_` must be the prefix of the | |
| name. | |
| Returns: | |
| str: The name of this loss item. | |
| """ | |
| return self._loss_name | |