venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
import torch.nn as nn
from imaginaire.discriminators.fpse import FPSEDiscriminator
from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator
from imaginaire.utils.data import (get_paired_input_image_channel_number,
get_paired_input_label_channel_number)
from imaginaire.utils.distributed import master_only_print as print
class Discriminator(nn.Module):
r"""Multi-resolution patch discriminator.
Args:
dis_cfg (obj): Discriminator definition part of the yaml config file.
data_cfg (obj): Data definition part of the yaml config file.
"""
def __init__(self, dis_cfg, data_cfg):
super(Discriminator, self).__init__()
print('Multi-resolution patch discriminator initialization.')
image_channels = getattr(dis_cfg, 'image_channels', None)
if image_channels is None:
image_channels = get_paired_input_image_channel_number(data_cfg)
num_labels = getattr(dis_cfg, 'num_labels', None)
if num_labels is None:
# Calculate number of channels in the input label when not specified.
num_labels = get_paired_input_label_channel_number(data_cfg)
# Build the discriminator.
kernel_size = getattr(dis_cfg, 'kernel_size', 3)
num_filters = getattr(dis_cfg, 'num_filters', 128)
max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
num_layers = getattr(dis_cfg, 'num_layers', 5)
activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
print('\tBase filter number: %d' % num_filters)
print('\tNumber of discriminators: %d' % num_discriminators)
print('\tNumber of layers in a discriminator: %d' % num_layers)
print('\tWeight norm type: %s' % weight_norm_type)
num_input_channels = image_channels + num_labels
self.discriminators = nn.ModuleList()
for i in range(num_discriminators):
net_discriminator = NLayerPatchDiscriminator(
kernel_size,
num_input_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type)
self.discriminators.append(net_discriminator)
print('Done with the Multi-resolution patch discriminator initialization.')
self.use_fpse = getattr(dis_cfg, 'use_fpse', True)
if self.use_fpse:
fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
fpse_activation_norm_type = getattr(dis_cfg,
'fpse_activation_norm_type',
'none')
self.fpse_discriminator = FPSEDiscriminator(
image_channels,
num_labels,
num_filters,
fpse_kernel_size,
weight_norm_type,
fpse_activation_norm_type)
def _single_forward(self, input_label, input_image):
# Compute discriminator outputs and intermediate features from input
# images and semantic labels.
input_x = torch.cat(
(input_label, input_image), 1)
output_list = []
features_list = []
if self.use_fpse:
pred2, pred3, pred4 = self.fpse_discriminator(input_image, input_label)
output_list = [pred2, pred3, pred4]
input_downsampled = input_x
for net_discriminator in self.discriminators:
output, features = net_discriminator(input_downsampled)
output_list.append(output)
features_list.append(features)
input_downsampled = nn.functional.interpolate(
input_downsampled, scale_factor=0.5, mode='bilinear',
align_corners=True)
return output_list, features_list
def forward(self, data, net_G_output):
r"""SPADE discriminator forward.
Args:
data (dict):
- data (N x C1 x H x W tensor) : Ground truth images.
- label (N x C2 x H x W tensor) : Semantic representations.
- z (N x style_dims tensor): Gaussian random noise.
net_G_output (dict):
fake_images (N x C1 x H x W tensor) : Fake images.
Returns:
(dict):
- real_outputs (list): list of output tensors produced by
individual patch discriminators for real images.
- real_features (list): list of lists of features produced by
individual patch discriminators for real images.
- fake_outputs (list): list of output tensors produced by
individual patch discriminators for fake images.
- fake_features (list): list of lists of features produced by
individual patch discriminators for fake images.
"""
output_x = dict()
output_x['real_outputs'], output_x['real_features'] = \
self._single_forward(data['label'], data['images'])
output_x['fake_outputs'], output_x['fake_features'] = \
self._single_forward(data['label'], net_G_output['fake_images'])
return output_x