venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import warnings
import torch
from torch import nn
from imaginaire.layers import Conv2dBlock, Res2dBlock
class Discriminator(nn.Module):
r"""Discriminator in the improved FUNIT baseline in the COCO-FUNIT paper.
Args:
dis_cfg (obj): Discriminator definition part of the yaml config file.
data_cfg (obj): Data definition part of the yaml config file.
"""
def __init__(self, dis_cfg, data_cfg):
super().__init__()
self.model = ResDiscriminator(**vars(dis_cfg))
def forward(self, data, net_G_output, recon=True):
r"""Improved FUNIT discriminator forward function.
Args:
data (dict): Training data at the current iteration.
net_G_output (dict): Fake data generated at the current iteration.
recon (bool): If ``True``, also classifies reconstructed images.
"""
source_labels = data['labels_content']
target_labels = data['labels_style']
fake_out_trans, fake_features_trans = \
self.model(net_G_output['images_trans'], target_labels)
output = dict(fake_out_trans=fake_out_trans,
fake_features_trans=fake_features_trans)
real_out_style, real_features_style = \
self.model(data['images_style'], target_labels)
output.update(dict(real_out_style=real_out_style,
real_features_style=real_features_style))
if recon:
fake_out_recon, fake_features_recon = \
self.model(net_G_output['images_recon'], source_labels)
output.update(dict(fake_out_recon=fake_out_recon,
fake_features_recon=fake_features_recon))
return output
class ResDiscriminator(nn.Module):
r"""Residual discriminator architecture used in the FUNIT paper."""
def __init__(self,
image_channels=3,
num_classes=119,
num_filters=64,
max_num_filters=1024,
num_layers=6,
padding_mode='reflect',
weight_norm_type='',
**kwargs):
super().__init__()
for key in kwargs:
if key != 'type':
warnings.warn(
"Discriminator argument {} is not used".format(key))
conv_params = dict(padding_mode=padding_mode,
activation_norm_type='none',
weight_norm_type=weight_norm_type,
bias=[True, True, True],
nonlinearity='leakyrelu',
order='NACNAC')
first_kernel_size = 7
first_padding = (first_kernel_size - 1) // 2
model = [Conv2dBlock(image_channels, num_filters,
first_kernel_size, 1, first_padding,
padding_mode=padding_mode,
weight_norm_type=weight_norm_type)]
for i in range(num_layers):
num_filters_prev = num_filters
num_filters = min(num_filters * 2, max_num_filters)
model += [Res2dBlock(num_filters_prev, num_filters_prev,
**conv_params),
Res2dBlock(num_filters_prev, num_filters,
**conv_params)]
if i != num_layers - 1:
model += [nn.ReflectionPad2d(1),
nn.AvgPool2d(3, stride=2)]
self.model = nn.Sequential(*model)
self.classifier = Conv2dBlock(num_filters, 1, 1, 1, 0,
nonlinearity='leakyrelu',
weight_norm_type=weight_norm_type,
order='NACNAC')
self.embedder = nn.Embedding(num_classes, num_filters)
def forward(self, images, labels=None):
r"""Forward function of the projection discriminator.
Args:
images (image tensor): Images inputted to the discriminator.
labels (long int tensor): Class labels of the images.
"""
assert (images.size(0) == labels.size(0))
features = self.model(images)
outputs = self.classifier(features)
features_1x1 = features.mean(3).mean(2)
if labels is None:
return features_1x1
embeddings = self.embedder(labels)
outputs += torch.sum(embeddings * features_1x1, dim=1,
keepdim=True).view(images.size(0), 1, 1, 1)
return outputs, features_1x1