Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import collections | |
| import functools | |
| import torch | |
| from torch import nn | |
| from torch.nn.utils import spectral_norm, weight_norm | |
| from torch.nn.utils.spectral_norm import SpectralNorm, \ | |
| SpectralNormStateDictHook, SpectralNormLoadStateDictPreHook | |
| from .conv import LinearBlock | |
| class WeightDemodulation(nn.Module): | |
| r"""Weight demodulation in | |
| "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. | |
| Args: | |
| conv (torch.nn.Modules): Convolutional layer. | |
| cond_dims (int): The number of channels in the conditional input. | |
| eps (float, optional, default=1e-8): a value added to the | |
| denominator for numerical stability. | |
| adaptive_bias (bool, optional, default=False): If ``True``, adaptively | |
| predicts bias from the conditional input. | |
| demod (bool, optional, default=False): If ``True``, performs | |
| weight demodulation. | |
| """ | |
| def __init__(self, conv, cond_dims, eps=1e-8, | |
| adaptive_bias=False, demod=True): | |
| super().__init__() | |
| self.conv = conv | |
| self.adaptive_bias = adaptive_bias | |
| if adaptive_bias: | |
| self.conv.register_parameter('bias', None) | |
| self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels) | |
| self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels) | |
| self.eps = eps | |
| self.demod = demod | |
| self.conditional = True | |
| def forward(self, x, y, **_kwargs): | |
| r"""Weight demodulation forward""" | |
| b, c, h, w = x.size() | |
| self.conv.groups = b | |
| gamma = self.fc_gamma(y) | |
| gamma = gamma[:, None, :, None, None] | |
| weight = self.conv.weight[None, :, :, :, :] * gamma | |
| if self.demod: | |
| d = torch.rsqrt( | |
| (weight ** 2).sum( | |
| dim=(2, 3, 4), keepdim=True) + self.eps) | |
| weight = weight * d | |
| x = x.reshape(1, -1, h, w) | |
| _, _, *ws = weight.shape | |
| weight = weight.reshape(b * self.conv.out_channels, *ws) | |
| x = self.conv._conv_forward(x, weight) | |
| x = x.reshape(-1, self.conv.out_channels, h, w) | |
| if self.adaptive_bias: | |
| x += self.fc_beta(y)[:, :, None, None] | |
| return x | |
| def weight_demod( | |
| conv, cond_dims=256, eps=1e-8, adaptive_bias=False, demod=True): | |
| r"""Weight demodulation.""" | |
| return WeightDemodulation(conv, cond_dims, eps, adaptive_bias, demod) | |
| class ScaledLR(object): | |
| def __init__(self, weight_name, bias_name): | |
| self.weight_name = weight_name | |
| self.bias_name = bias_name | |
| def compute_weight(self, module): | |
| weight = getattr(module, self.weight_name + '_ori') | |
| return weight * module.weight_scale | |
| def compute_bias(self, module): | |
| bias = getattr(module, self.bias_name + '_ori') | |
| if bias is not None: | |
| return bias * module.bias_scale | |
| else: | |
| return None | |
| def apply(module, weight_name, bias_name, lr_mul, equalized): | |
| assert weight_name == 'weight' | |
| assert bias_name == 'bias' | |
| fn = ScaledLR(weight_name, bias_name) | |
| module.register_forward_pre_hook(fn) | |
| if hasattr(module, bias_name): | |
| # module.bias is a parameter (can be None). | |
| bias = getattr(module, bias_name) | |
| delattr(module, bias_name) | |
| module.register_parameter(bias_name + '_ori', bias) | |
| else: | |
| # module.bias does not exist. | |
| bias = None | |
| setattr(module, bias_name + '_ori', bias) | |
| if bias is not None: | |
| setattr(module, bias_name, bias.data) | |
| else: | |
| setattr(module, bias_name, None) | |
| module.register_buffer('bias_scale', torch.tensor(lr_mul)) | |
| if hasattr(module, weight_name + '_orig'): | |
| # The module has been wrapped with spectral normalization. | |
| # We only want to keep a single weight parameter. | |
| weight = getattr(module, weight_name + '_orig') | |
| delattr(module, weight_name + '_orig') | |
| module.register_parameter(weight_name + '_ori', weight) | |
| setattr(module, weight_name + '_orig', weight.data) | |
| # Put this hook before the spectral norm hook. | |
| module._forward_pre_hooks = collections.OrderedDict( | |
| reversed(list(module._forward_pre_hooks.items())) | |
| ) | |
| module.use_sn = True | |
| else: | |
| weight = getattr(module, weight_name) | |
| delattr(module, weight_name) | |
| module.register_parameter(weight_name + '_ori', weight) | |
| setattr(module, weight_name, weight.data) | |
| module.use_sn = False | |
| # assert weight.dim() == 4 or weight.dim() == 2 | |
| if equalized: | |
| fan_in = weight.data.size(1) * weight.data[0][0].numel() | |
| # Theoretically, the gain should be sqrt(2) instead of 1. | |
| # The official StyleGAN2 uses 1 for some reason. | |
| module.register_buffer( | |
| 'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5)) | |
| ) | |
| else: | |
| module.register_buffer('weight_scale', torch.tensor(lr_mul)) | |
| module.lr_mul = module.weight_scale | |
| module.base_lr_mul = lr_mul | |
| return fn | |
| def remove(self, module): | |
| with torch.no_grad(): | |
| weight = self.compute_weight(module) | |
| delattr(module, self.weight_name + '_ori') | |
| if module.use_sn: | |
| setattr(module, self.weight_name + '_orig', weight.detach()) | |
| else: | |
| delattr(module, self.weight_name) | |
| module.register_parameter(self.weight_name, | |
| torch.nn.Parameter(weight.detach())) | |
| with torch.no_grad(): | |
| bias = self.compute_bias(module) | |
| delattr(module, self.bias_name) | |
| delattr(module, self.bias_name + '_ori') | |
| if bias is not None: | |
| module.register_parameter(self.bias_name, | |
| torch.nn.Parameter(bias.detach())) | |
| else: | |
| module.register_parameter(self.bias_name, None) | |
| module.lr_mul = 1.0 | |
| module.base_lr_mul = 1.0 | |
| def __call__(self, module, input): | |
| weight = self.compute_weight(module) | |
| if module.use_sn: | |
| # The following spectral norm hook will compute the SN of | |
| # "module.weight_orig" and store the normalized weight in | |
| # "module.weight". | |
| setattr(module, self.weight_name + '_orig', weight) | |
| else: | |
| setattr(module, self.weight_name, weight) | |
| bias = self.compute_bias(module) | |
| setattr(module, self.bias_name, bias) | |
| def remove_weight_norms(module, weight_name='weight', bias_name='bias'): | |
| if hasattr(module, 'weight_ori') or hasattr(module, 'weight_orig'): | |
| for k in list(module._forward_pre_hooks.keys()): | |
| hook = module._forward_pre_hooks[k] | |
| if (isinstance(hook, ScaledLR) or isinstance(hook, SpectralNorm)): | |
| hook.remove(module) | |
| del module._forward_pre_hooks[k] | |
| for k, hook in module._state_dict_hooks.items(): | |
| if isinstance(hook, SpectralNormStateDictHook) and \ | |
| hook.fn.name == weight_name: | |
| del module._state_dict_hooks[k] | |
| break | |
| for k, hook in module._load_state_dict_pre_hooks.items(): | |
| if isinstance(hook, SpectralNormLoadStateDictPreHook) and \ | |
| hook.fn.name == weight_name: | |
| del module._load_state_dict_pre_hooks[k] | |
| break | |
| return module | |
| def remove_equalized_lr(module, weight_name='weight', bias_name='bias'): | |
| for k, hook in module._forward_pre_hooks.items(): | |
| if isinstance(hook, ScaledLR) and hook.weight_name == weight_name: | |
| hook.remove(module) | |
| del module._forward_pre_hooks[k] | |
| break | |
| else: | |
| raise ValueError("Equalized learning rate not found") | |
| return module | |
| def scaled_lr( | |
| module, weight_name='weight', bias_name='bias', lr_mul=1., | |
| equalized=False, | |
| ): | |
| ScaledLR.apply(module, weight_name, bias_name, lr_mul, equalized) | |
| return module | |
| def get_weight_norm_layer(norm_type, **norm_params): | |
| r"""Return weight normalization. | |
| Args: | |
| norm_type (str): | |
| Type of weight normalization. | |
| ``'none'``, ``'spectral'``, ``'weight'`` | |
| or ``'weight_demod'``. | |
| norm_params: Arbitrary keyword arguments that will be used to | |
| initialize the weight normalization. | |
| """ | |
| if norm_type == 'none' or norm_type == '': # no normalization | |
| return lambda x: x | |
| elif norm_type == 'spectral': # spectral normalization | |
| return functools.partial(spectral_norm, **norm_params) | |
| elif norm_type == 'weight': # weight normalization | |
| return functools.partial(weight_norm, **norm_params) | |
| elif norm_type == 'weight_demod': # weight demodulation | |
| return functools.partial(weight_demod, **norm_params) | |
| elif norm_type == 'equalized_lr': # equalized learning rate | |
| return functools.partial(scaled_lr, equalized=True, **norm_params) | |
| elif norm_type == 'scaled_lr': # equalized learning rate | |
| return functools.partial(scaled_lr, **norm_params) | |
| elif norm_type == 'equalized_lr_spectral': | |
| lr_mul = norm_params.pop('lr_mul', 1.0) | |
| return lambda x: functools.partial( | |
| scaled_lr, equalized=True, lr_mul=lr_mul)( | |
| functools.partial(spectral_norm, **norm_params)(x) | |
| ) | |
| elif norm_type == 'scaled_lr_spectral': | |
| lr_mul = norm_params.pop('lr_mul', 1.0) | |
| return lambda x: functools.partial( | |
| scaled_lr, lr_mul=lr_mul)( | |
| functools.partial(spectral_norm, **norm_params)(x) | |
| ) | |
| else: | |
| raise ValueError( | |
| 'Weight norm layer %s is not recognized' % norm_type) | |