Spaces:
Runtime error
Runtime error
| """ The code is based on https://github.com/apple/ml-gsn/ with adaption. """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import autograd | |
| from lib.net.Discriminator import StyleDiscriminator | |
| def hinge_loss(fake_pred, real_pred, mode): | |
| if mode == 'd': | |
| # Discriminator update | |
| d_loss_fake = torch.mean(F.relu(1.0 + fake_pred)) | |
| d_loss_real = torch.mean(F.relu(1.0 - real_pred)) | |
| d_loss = d_loss_fake + d_loss_real | |
| elif mode == 'g': | |
| # Generator update | |
| d_loss = -torch.mean(fake_pred) | |
| return d_loss | |
| def logistic_loss(fake_pred, real_pred, mode): | |
| if mode == 'd': | |
| # Discriminator update | |
| d_loss_fake = torch.mean(F.softplus(fake_pred)) | |
| d_loss_real = torch.mean(F.softplus(-real_pred)) | |
| d_loss = d_loss_fake + d_loss_real | |
| elif mode == 'g': | |
| # Generator update | |
| d_loss = torch.mean(F.softplus(-fake_pred)) | |
| return d_loss | |
| def r1_loss(real_pred, real_img): | |
| (grad_real, ) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True) | |
| grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() | |
| return grad_penalty | |
| class GANLoss(nn.Module): | |
| def __init__( | |
| self, | |
| opt, | |
| disc_loss='logistic', | |
| ): | |
| super().__init__() | |
| self.opt = opt.gan | |
| input_dim = 3 | |
| self.discriminator = StyleDiscriminator(input_dim, self.opt.img_res) | |
| if disc_loss == 'hinge': | |
| self.disc_loss = hinge_loss | |
| elif disc_loss == 'logistic': | |
| self.disc_loss = logistic_loss | |
| def forward(self, input): | |
| disc_in_real = input['norm_real'] | |
| disc_in_fake = input['norm_fake'] | |
| logits_real = self.discriminator(disc_in_real) | |
| logits_fake = self.discriminator(disc_in_fake) | |
| disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d') | |
| log = { | |
| "disc_loss": disc_loss.detach(), | |
| "logits_real": logits_real.mean().detach(), | |
| "logits_fake": logits_fake.mean().detach(), | |
| } | |
| return disc_loss * self.opt.lambda_gan, log | |