Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class GANLoss(nn.Module): | |
| def __init__(self, target_real_label=1.0, target_fake_label=0.0): | |
| r"""GAN loss constructor. | |
| Args: | |
| target_real_label (float): Desired output label for the real images. | |
| target_fake_label (float): Desired output label for the fake images. | |
| """ | |
| super(GANLoss, self).__init__() | |
| self.real_label = target_real_label | |
| self.fake_label = target_fake_label | |
| self.real_label_tensor = None | |
| self.fake_label_tensor = None | |
| def forward(self, input_x, t_real, weight=None, | |
| reduce_dim=True, dis_update=True): | |
| r"""GAN loss computation. | |
| Args: | |
| input_x (tensor or list of tensors): Output values. | |
| t_real (boolean): Is this output value for real images. | |
| reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use | |
| multi-resolution discriminators. | |
| weight (float): Weight to scale the loss value. | |
| dis_update (boolean): Updating the discriminator or the generator. | |
| Returns: | |
| loss (tensor): Loss value. | |
| """ | |
| if isinstance(input_x, list): | |
| loss = 0 | |
| for pred_i in input_x: | |
| if isinstance(pred_i, list): | |
| pred_i = pred_i[-1] | |
| loss_tensor = self.loss(pred_i, t_real, weight, | |
| reduce_dim, dis_update) | |
| bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) | |
| new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) | |
| loss += new_loss | |
| return loss / len(input_x) | |
| else: | |
| return self.loss(input_x, t_real, weight, reduce_dim, dis_update) | |
| def loss(self, input_x, t_real, weight=None, | |
| reduce_dim=True, dis_update=True): | |
| r"""N+1 label GAN loss computation. | |
| Args: | |
| input_x (tensor): Output values. | |
| t_real (boolean): Is this output value for real images. | |
| reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use | |
| multi-resolution discriminators. | |
| weight (float): Weight to scale the loss value. | |
| dis_update (boolean): Updating the discriminator or the generator. | |
| Returns: | |
| loss (tensor): Loss value. | |
| """ | |
| assert reduce_dim is True | |
| pred = input_x['pred'].clone() | |
| label = input_x['label'].clone() | |
| batch_size = pred.size(0) | |
| # ignore label 0 | |
| label[:, 0, ...] = 0 | |
| pred[:, 0, ...] = 0 | |
| pred = F.log_softmax(pred, dim=1) | |
| assert pred.size(1) == (label.size(1) + 1) | |
| if dis_update: | |
| if t_real: | |
| pred_real = pred[:, :-1, :, :] | |
| loss = - label * pred_real | |
| loss = torch.sum(loss, dim=1, keepdim=True) | |
| else: | |
| pred_fake = pred[:, -1, None, :, :] # N plus 1 | |
| loss = - pred_fake | |
| else: | |
| assert t_real, "GAN loss must be aiming for real." | |
| pred_real = pred[:, :-1, :, :] | |
| loss = - label * pred_real | |
| loss = torch.sum(loss, dim=1, keepdim=True) | |
| if weight is not None: | |
| loss = loss * weight | |
| if reduce_dim: | |
| loss = torch.mean(loss) | |
| else: | |
| loss = loss.view(batch_size, -1).mean(dim=1) | |
| return loss | |