Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
import torch.nn as nn | |
from imaginaire.discriminators.fpse import FPSEDiscriminator | |
from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator | |
from imaginaire.utils.data import (get_paired_input_image_channel_number, | |
get_paired_input_label_channel_number) | |
from imaginaire.utils.distributed import master_only_print as print | |
class Discriminator(nn.Module): | |
r"""Multi-resolution patch discriminator. | |
Args: | |
dis_cfg (obj): Discriminator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
""" | |
def __init__(self, dis_cfg, data_cfg): | |
super(Discriminator, self).__init__() | |
print('Multi-resolution patch discriminator initialization.') | |
image_channels = getattr(dis_cfg, 'image_channels', None) | |
if image_channels is None: | |
image_channels = get_paired_input_image_channel_number(data_cfg) | |
num_labels = getattr(dis_cfg, 'num_labels', None) | |
if num_labels is None: | |
# Calculate number of channels in the input label when not specified. | |
num_labels = get_paired_input_label_channel_number(data_cfg) | |
# Build the discriminator. | |
kernel_size = getattr(dis_cfg, 'kernel_size', 3) | |
num_filters = getattr(dis_cfg, 'num_filters', 128) | |
max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) | |
num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) | |
num_layers = getattr(dis_cfg, 'num_layers', 5) | |
activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') | |
weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') | |
print('\tBase filter number: %d' % num_filters) | |
print('\tNumber of discriminators: %d' % num_discriminators) | |
print('\tNumber of layers in a discriminator: %d' % num_layers) | |
print('\tWeight norm type: %s' % weight_norm_type) | |
num_input_channels = image_channels + num_labels | |
self.discriminators = nn.ModuleList() | |
for i in range(num_discriminators): | |
net_discriminator = NLayerPatchDiscriminator( | |
kernel_size, | |
num_input_channels, | |
num_filters, | |
num_layers, | |
max_num_filters, | |
activation_norm_type, | |
weight_norm_type) | |
self.discriminators.append(net_discriminator) | |
print('Done with the Multi-resolution patch discriminator initialization.') | |
self.use_fpse = getattr(dis_cfg, 'use_fpse', True) | |
if self.use_fpse: | |
fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3) | |
fpse_activation_norm_type = getattr(dis_cfg, | |
'fpse_activation_norm_type', | |
'none') | |
self.fpse_discriminator = FPSEDiscriminator( | |
image_channels, | |
num_labels, | |
num_filters, | |
fpse_kernel_size, | |
weight_norm_type, | |
fpse_activation_norm_type) | |
def _single_forward(self, input_label, input_image): | |
# Compute discriminator outputs and intermediate features from input | |
# images and semantic labels. | |
input_x = torch.cat( | |
(input_label, input_image), 1) | |
output_list = [] | |
features_list = [] | |
if self.use_fpse: | |
pred2, pred3, pred4 = self.fpse_discriminator(input_image, input_label) | |
output_list = [pred2, pred3, pred4] | |
input_downsampled = input_x | |
for net_discriminator in self.discriminators: | |
output, features = net_discriminator(input_downsampled) | |
output_list.append(output) | |
features_list.append(features) | |
input_downsampled = nn.functional.interpolate( | |
input_downsampled, scale_factor=0.5, mode='bilinear', | |
align_corners=True) | |
return output_list, features_list | |
def forward(self, data, net_G_output): | |
r"""SPADE discriminator forward. | |
Args: | |
data (dict): | |
- data (N x C1 x H x W tensor) : Ground truth images. | |
- label (N x C2 x H x W tensor) : Semantic representations. | |
- z (N x style_dims tensor): Gaussian random noise. | |
net_G_output (dict): | |
fake_images (N x C1 x H x W tensor) : Fake images. | |
Returns: | |
(dict): | |
- real_outputs (list): list of output tensors produced by | |
individual patch discriminators for real images. | |
- real_features (list): list of lists of features produced by | |
individual patch discriminators for real images. | |
- fake_outputs (list): list of output tensors produced by | |
individual patch discriminators for fake images. | |
- fake_features (list): list of lists of features produced by | |
individual patch discriminators for fake images. | |
""" | |
output_x = dict() | |
output_x['real_outputs'], output_x['real_features'] = \ | |
self._single_forward(data['label'], data['images']) | |
output_x['fake_outputs'], output_x['fake_features'] = \ | |
self._single_forward(data['label'], net_G_output['fake_images']) | |
return output_x | |