""" Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ import torch import models.networks as networks import utils.inference.util as util import random class Pix2PixModel(torch.nn.Module): @staticmethod def modify_commandline_options(parser, is_train): networks.modify_commandline_options(parser, is_train) return parser def __init__(self, opt): super().__init__() self.opt = opt self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \ else torch.FloatTensor self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \ else torch.ByteTensor self.netG, self.netD, self.netE = self.initialize_networks(opt) # set loss functions if opt.isTrain: self.criterionGAN = networks.GANLoss( opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids) if opt.use_vae: self.KLDLoss = networks.KLDLoss() # Entry point for all calls involving forward pass # of deep networks. We used this approach since DataParallel module # can't parallelize custom functions, we branch to different # routines based on |mode|. def forward(self, data, mode): input_semantics, real_image = self.preprocess_input(data) # input_semantics, real_image = data['label'], data['image'] if mode == 'generator': g_loss, generated = self.compute_generator_loss(input_semantics, real_image) return g_loss, generated elif mode == 'discriminator': d_loss = self.compute_discriminator_loss( input_semantics, real_image) return d_loss elif mode == 'inference': with torch.no_grad(): fake_image = self.generate_fake(input_semantics) return fake_image elif mode == 'inference2': with torch.no_grad(): fake_image = self.netG(input_semantics) return fake_image else: raise ValueError("|mode| is invalid") def preprocess_input(self, data): if self.use_gpu(): data['label'] = data['label'].cuda() data['image'] = data['image'].cuda() return data['label'], data['image'] def compute_generator_loss(self, input_semantics, real_image): G_losses = {} fake_image = self.generate_fake(input_semantics) pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image) G_losses['GAN'] = self.criterionGAN(pred_fake, True, for_discriminator=False) if not self.opt.no_ganFeat_loss: num_D = len(pred_fake) GAN_Feat_loss = self.FloatTensor(1).fill_(0) for i in range(num_D): # for each discriminator # last output is the final prediction, so we exclude it num_intermediate_outputs = len(pred_fake[i]) - 1 for j in range(num_intermediate_outputs): # for each layer output unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D G_losses['GAN_Feat'] = GAN_Feat_loss h ,w = fake_image.shape[-2:] if not self.opt.no_vgg_loss and min(w,h)>=64: G_losses['VGG'] = self.criterionVGG(fake_image, real_image) \ * self.opt.lambda_vgg return G_losses, fake_image def compute_discriminator_loss(self, input_semantics, real_image): D_losses = {} with torch.no_grad(): fake_image = self.generate_fake(input_semantics) fake_image = fake_image.detach() fake_image.requires_grad_() pred_fake, pred_real = self.discriminate( input_semantics, fake_image, real_image) D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, for_discriminator=True) D_losses['D_real'] = self.criterionGAN(pred_real, True, for_discriminator=True) return D_losses def generate_fake(self, input_semantics): # input_semantics = torch.nn.functional.interpolate(input_semantics, size=(h//4, w//4), # mode='nearest')#[:, :, ::4, ::4] fake_image = self.netG(input_semantics) return fake_image def discriminate(self, input_semantics, fake_image, real_image): h, w = fake_image.shape[-2:] if fake_image.shape[-2:]!=input_semantics.shape[-2:]: semantics = torch.nn.functional.interpolate(input_semantics, (h, w)) real = torch.nn.functional.interpolate(real_image, (h, w)) fake_concat = torch.cat([semantics, fake_image], dim=1) real_concat = torch.cat([semantics, real], dim=1) else: fake_concat = torch.cat([input_semantics, fake_image], dim=1) real_concat = torch.cat([input_semantics, real_image], dim=1) # fake_concat = fake_image # real_concat = real_image # In Batch Normalization, the fake and real images are # recommended to be in the same batch to avoid disparate # statistics in fake and real images. # So both fake and real images are fed to D all at once. fake_and_real = torch.cat([fake_concat, real_concat], dim=0) discriminator_out = self.netD(fake_and_real) pred_fake, pred_real = self.divide_pred(discriminator_out) return pred_fake, pred_real def encode_z(self, real_image): mu, logvar = self.netE(real_image) z = self.reparameterize(mu, logvar) return z, mu, logvar def create_optimizers(self, opt): G_params = list(self.netG.parameters()) if opt.use_vae: G_params += list(self.netE.parameters()) if opt.isTrain: D_params = list(self.netD.parameters()) beta1, beta2 = opt.beta1, opt.beta2 if opt.no_TTUR: G_lr, D_lr = opt.lr, opt.lr else: G_lr, D_lr = opt.lr / 2, opt.lr * 2 optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2)) optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) return optimizer_G, optimizer_D def save(self, epoch): util.save_network(self.netG, 'G', epoch, self.opt) util.save_network(self.netD, 'D', epoch, self.opt) if self.opt.use_vae: util.save_network(self.netE, 'E', epoch, self.opt) ############################################################################ # Private helper methods ############################################################################ def initialize_networks(self, opt): netG = networks.define_G(opt) netD = networks.define_D(opt) if opt.isTrain else None netE = networks.define_E(opt) if opt.use_vae else None if not opt.isTrain or opt.continue_train: netG = util.load_network(netG, 'G', opt.which_epoch, opt) if opt.isTrain: netD = util.load_network(netD, 'D', opt.which_epoch, opt) if opt.use_vae: netE = util.load_network(netE, 'E', opt.which_epoch, opt) return netG, netD, netE # preprocess the input, such as moving the tensors to GPUs and # transforming the label map to one-hot encoding # |data|: dictionary of the input data # Take the prediction of fake and real images from the combined batch def divide_pred(self, pred): # the prediction contains the intermediate outputs of multiscale GAN, # so it's usually a list if type(pred) == list: fake = [] real = [] for p in pred: fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) real.append([tensor[tensor.size(0) // 2:] for tensor in p]) else: fake = pred[:pred.size(0) // 2] real = pred[pred.size(0) // 2:] return fake, real def get_edges(self, t): edge = self.ByteTensor(t.size()).zero_() edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) return edge.float() def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps.mul(std) + mu def use_gpu(self): return len(self.opt.gpu_ids) > 0