ghost / models /networks /normalization.py
Jagrut Thakare
v1
9be8aa9
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
import torch.nn.utils.spectral_norm as spectral_norm
def get_nonspade_norm_layer(opt, norm_type='instance'):
# helper function to get # output channels of the previous layer
def get_out_channel(layer):
if hasattr(layer, 'out_channels'):
return getattr(layer, 'out_channels')
return layer.weight.size(0)
# this function will be returned
def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith('spectral'):
layer = spectral_norm(layer)
subnorm_type = norm_type[len('spectral'):]
if subnorm_type == 'none' or len(subnorm_type) == 0:
return layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if getattr(layer, 'bias', None) is not None:
delattr(layer, 'bias')
layer.register_parameter('bias', None)
if subnorm_type == 'batch':
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == 'sync_batch':
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == 'instance':
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError('normalization layer %s is not recognized' % subnorm_type)
return nn.Sequential(layer, norm_layer)
return add_norm_layer
class InstanceNorm2d(nn.Module):
def __init__(self, epsilon=1e-8, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def forward(self, x):
#x = x - torch.mean(x, (2, 3), True)
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc):
super().__init__()
assert config_text.startswith('spade')
parsed = re.search('spade(\D+)(\d)x\d', config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
if param_free_norm_type == 'instance':
self.param_free_norm = InstanceNorm2d(norm_nc)
elif param_free_norm_type == 'syncbatch':
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
elif param_free_norm_type == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError('%s is not a recognized param-free norm type in SPADE'
% param_free_norm_type)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128 if norm_nc>128 else norm_nc
pw = ks // 2
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
def forward(self, x, segmap):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * gamma + beta
return out