ghost / models /networks /discriminator.py
Jagrut Thakare
v1
9be8aa9
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import torch.nn as nn
import numpy as np
import torch, math
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
import utils.inference.util as util
class MultiscaleDiscriminator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--netD_subarch', type=str, default='n_layer',
help='architecture of each discriminator')
parser.add_argument('--num_D', type=int, default=2,
help='number of discriminators to be used in multiscale')
opt, _ = parser.parse_known_args()
# define properties of each discriminator of the multiscale discriminator
subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator',
'models.networks.discriminator')
subnetD.modify_commandline_options(parser, is_train)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
for i in range(opt.num_D):
subnetD = self.create_single_discriminator(opt)
self.add_module('discriminator_%d' % i, subnetD)
def create_single_discriminator(self, opt):
subarch = opt.netD_subarch
if subarch == 'n_layer':
netD = NLayerDiscriminator(opt)
else:
raise ValueError('unrecognized discriminator subarchitecture %s' % subarch)
return netD
def downsample(self, input):
return F.avg_pool2d(input, kernel_size=3,
stride=2, padding=[1, 1],
count_include_pad=False)
# Returns list of lists of discriminator outputs.
# The final result is of size opt.num_D x opt.n_layers_D
def forward(self, input):
result = []
get_intermediate_features = not self.opt.no_ganFeat_loss
for name, D in self.named_children():
out = D(input)
if not get_intermediate_features:
out = [out]
result.append(out)
input = self.downsample(input)
return result
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--n_layers_D', type=int, default=4,
help='# layers in each discriminator')
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
nf = opt.ndf
input_nc = self.compute_D_input_nc(opt)
norm_layer = get_nonspade_norm_layer(opt, opt.norm_D)
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, False)]]
for n in range(1, opt.n_layers_D):
nf_prev = nf
nf = min(nf * 2, 512)
stride = 1 if n == opt.n_layers_D - 1 else 2
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
stride=stride, padding=padw)),
nn.LeakyReLU(0.2, False)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
def compute_D_input_nc(self, opt):
input_nc = opt.label_nc + opt.output_nc
if opt.contain_dontcare_label:
input_nc += 1
if not opt.no_instance:
input_nc += 1
return input_nc
def forward(self, input):
results = [input]
for submodel in self.children():
intermediate_output = submodel(results[-1])
results.append(intermediate_output)
get_intermediate_features = not self.opt.no_ganFeat_loss
if get_intermediate_features:
return results[1:]
else:
return results[-1]
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
out = F.leaky_relu(input, negative_slope=self.negative_slope)
return out * math.sqrt(2)
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class EqualConv2d(nn.Module):
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
)
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
out = F.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
class ConvLayer(nn.Sequential):
def __init__(self, in_channel, out_channel, kernel_size,
downsample=False, blur_kernel=[1, 3, 3, 1],
bias=True, activate=True):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(in_channel, out_channel, kernel_size,
padding=self.padding, stride=stride, bias=bias and not activate)
)
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)