Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Spectral Normalization from https://arxiv.org/abs/1802.05957 | |
| """ | |
| import torch | |
| from torch.nn.functional import normalize | |
| class SpectralNorm(object): | |
| # Invariant before and after each forward call: | |
| # u = normalize(W @ v) | |
| # NB: At initialization, this invariant is not enforced | |
| _version = 1 | |
| # At version 1: | |
| # made `W` not a buffer, | |
| # added `v` as a buffer, and | |
| # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. | |
| def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): | |
| self.name = name | |
| self.dim = dim | |
| if n_power_iterations <= 0: | |
| raise ValueError( | |
| 'Expected n_power_iterations to be positive, but ' | |
| 'got n_power_iterations={}'.format(n_power_iterations)) | |
| self.n_power_iterations = n_power_iterations | |
| self.eps = eps | |
| def reshape_weight_to_matrix(self, weight): | |
| weight_mat = weight | |
| if self.dim != 0: | |
| # permute dim to front | |
| weight_mat = weight_mat.permute( | |
| self.dim, | |
| *[d for d in range(weight_mat.dim()) if d != self.dim]) | |
| height = weight_mat.size(0) | |
| return weight_mat.reshape(height, -1) | |
| def compute_weight(self, module, do_power_iteration): | |
| # NB: If `do_power_iteration` is set, the `u` and `v` vectors are | |
| # updated in power iteration **in-place**. This is very important | |
| # because in `DataParallel` forward, the vectors (being buffers) are | |
| # broadcast from the parallelized module to each module replica, | |
| # which is a new module object created on the fly. And each replica | |
| # runs its own spectral norm power iteration. So simply assigning | |
| # the updated vectors to the module this function runs on will cause | |
| # the update to be lost forever. And the next time the parallelized | |
| # module is replicated, the same randomly initialized vectors are | |
| # broadcast and used! | |
| # | |
| # Therefore, to make the change propagate back, we rely on two | |
| # important behaviors (also enforced via tests): | |
| # 1. `DataParallel` doesn't clone storage if the broadcast tensor | |
| # is already on correct device; and it makes sure that the | |
| # parallelized module is already on `device[0]`. | |
| # 2. If the out tensor in `out=` kwarg has correct shape, it will | |
| # just fill in the values. | |
| # Therefore, since the same power iteration is performed on all | |
| # devices, simply updating the tensors in-place will make sure that | |
| # the module replica on `device[0]` will update the _u vector on the | |
| # parallized module (by shared storage). | |
| # | |
| # However, after we update `u` and `v` in-place, we need to **clone** | |
| # them before using them to normalize the weight. This is to support | |
| # backproping through two forward passes, e.g., the common pattern in | |
| # GAN training: loss = D(real) - D(fake). Otherwise, engine will | |
| # complain that variables needed to do backward for the first forward | |
| # (i.e., the `u` and `v` vectors) are changed in the second forward. | |
| weight = getattr(module, self.name + '_orig') | |
| u = getattr(module, self.name + '_u') | |
| v = getattr(module, self.name + '_v') | |
| weight_mat = self.reshape_weight_to_matrix(weight) | |
| if do_power_iteration: | |
| with torch.no_grad(): | |
| for _ in range(self.n_power_iterations): | |
| # Spectral norm of weight equals to `u^T W v`, where `u` and `v` | |
| # are the first left and right singular vectors. | |
| # This power iteration produces approximations of `u` and `v`. | |
| v = normalize(torch.mv(weight_mat.t(), u), | |
| dim=0, | |
| eps=self.eps, | |
| out=v) | |
| u = normalize(torch.mv(weight_mat, v), | |
| dim=0, | |
| eps=self.eps, | |
| out=u) | |
| if self.n_power_iterations > 0: | |
| # See above on why we need to clone | |
| u = u.clone() | |
| v = v.clone() | |
| sigma = torch.dot(u, torch.mv(weight_mat, v)) | |
| weight = weight / sigma | |
| return weight | |
| def remove(self, module): | |
| with torch.no_grad(): | |
| weight = self.compute_weight(module, do_power_iteration=False) | |
| delattr(module, self.name) | |
| delattr(module, self.name + '_u') | |
| delattr(module, self.name + '_v') | |
| delattr(module, self.name + '_orig') | |
| module.register_parameter(self.name, | |
| torch.nn.Parameter(weight.detach())) | |
| def __call__(self, module, inputs): | |
| setattr( | |
| module, self.name, | |
| self.compute_weight(module, do_power_iteration=module.training)) | |
| def _solve_v_and_rescale(self, weight_mat, u, target_sigma): | |
| # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` | |
| # (the invariant at top of this class) and `u @ W @ v = sigma`. | |
| # This uses pinverse in case W^T W is not invertible. | |
| v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), | |
| weight_mat.t(), u.unsqueeze(1)).squeeze(1) | |
| return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) | |
| def apply(module, name, n_power_iterations, dim, eps): | |
| for k, hook in module._forward_pre_hooks.items(): | |
| if isinstance(hook, SpectralNorm) and hook.name == name: | |
| raise RuntimeError( | |
| "Cannot register two spectral_norm hooks on " | |
| "the same parameter {}".format(name)) | |
| fn = SpectralNorm(name, n_power_iterations, dim, eps) | |
| weight = module._parameters[name] | |
| with torch.no_grad(): | |
| weight_mat = fn.reshape_weight_to_matrix(weight) | |
| h, w = weight_mat.size() | |
| # randomly initialize `u` and `v` | |
| u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) | |
| v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) | |
| delattr(module, fn.name) | |
| module.register_parameter(fn.name + "_orig", weight) | |
| # We still need to assign weight back as fn.name because all sorts of | |
| # things may assume that it exists, e.g., when initializing weights. | |
| # However, we can't directly assign as it could be an nn.Parameter and | |
| # gets added as a parameter. Instead, we register weight.data as a plain | |
| # attribute. | |
| setattr(module, fn.name, weight.data) | |
| module.register_buffer(fn.name + "_u", u) | |
| module.register_buffer(fn.name + "_v", v) | |
| module.register_forward_pre_hook(fn) | |
| module._register_state_dict_hook(SpectralNormStateDictHook(fn)) | |
| module._register_load_state_dict_pre_hook( | |
| SpectralNormLoadStateDictPreHook(fn)) | |
| return fn | |
| # This is a top level class because Py2 pickle doesn't like inner class nor an | |
| # instancemethod. | |
| class SpectralNormLoadStateDictPreHook(object): | |
| # See docstring of SpectralNorm._version on the changes to spectral_norm. | |
| def __init__(self, fn): | |
| self.fn = fn | |
| # For state_dict with version None, (assuming that it has gone through at | |
| # least one training forward), we have | |
| # | |
| # u = normalize(W_orig @ v) | |
| # W = W_orig / sigma, where sigma = u @ W_orig @ v | |
| # | |
| # To compute `v`, we solve `W_orig @ x = u`, and let | |
| # v = x / (u @ W_orig @ x) * (W / W_orig). | |
| def __call__(self, state_dict, prefix, local_metadata, strict, | |
| missing_keys, unexpected_keys, error_msgs): | |
| fn = self.fn | |
| version = local_metadata.get('spectral_norm', | |
| {}).get(fn.name + '.version', None) | |
| if version is None or version < 1: | |
| with torch.no_grad(): | |
| weight_orig = state_dict[prefix + fn.name + '_orig'] | |
| # weight = state_dict.pop(prefix + fn.name) | |
| # sigma = (weight_orig / weight).mean() | |
| weight_mat = fn.reshape_weight_to_matrix(weight_orig) | |
| u = state_dict[prefix + fn.name + '_u'] | |
| # v = fn._solve_v_and_rescale(weight_mat, u, sigma) | |
| # state_dict[prefix + fn.name + '_v'] = v | |
| # This is a top level class because Py2 pickle doesn't like inner class nor an | |
| # instancemethod. | |
| class SpectralNormStateDictHook(object): | |
| # See docstring of SpectralNorm._version on the changes to spectral_norm. | |
| def __init__(self, fn): | |
| self.fn = fn | |
| def __call__(self, module, state_dict, prefix, local_metadata): | |
| if 'spectral_norm' not in local_metadata: | |
| local_metadata['spectral_norm'] = {} | |
| key = self.fn.name + '.version' | |
| if key in local_metadata['spectral_norm']: | |
| raise RuntimeError( | |
| "Unexpected key in metadata['spectral_norm']: {}".format(key)) | |
| local_metadata['spectral_norm'][key] = self.fn._version | |
| def spectral_norm(module, | |
| name='weight', | |
| n_power_iterations=1, | |
| eps=1e-12, | |
| dim=None): | |
| r"""Applies spectral normalization to a parameter in the given module. | |
| .. math:: | |
| \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, | |
| \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} | |
| Spectral normalization stabilizes the training of discriminators (critics) | |
| in Generative Adversarial Networks (GANs) by rescaling the weight tensor | |
| with spectral norm :math:`\sigma` of the weight matrix calculated using | |
| power iteration method. If the dimension of the weight tensor is greater | |
| than 2, it is reshaped to 2D in power iteration method to get spectral | |
| norm. This is implemented via a hook that calculates spectral norm and | |
| rescales weight before every :meth:`~Module.forward` call. | |
| See `Spectral Normalization for Generative Adversarial Networks`_ . | |
| .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 | |
| Args: | |
| module (nn.Module): containing module | |
| name (str, optional): name of weight parameter | |
| n_power_iterations (int, optional): number of power iterations to | |
| calculate spectral norm | |
| eps (float, optional): epsilon for numerical stability in | |
| calculating norms | |
| dim (int, optional): dimension corresponding to number of outputs, | |
| the default is ``0``, except for modules that are instances of | |
| ConvTranspose{1,2,3}d, when it is ``1`` | |
| Returns: | |
| The original module with the spectral norm hook | |
| Example:: | |
| >>> m = spectral_norm(nn.Linear(20, 40)) | |
| >>> m | |
| Linear(in_features=20, out_features=40, bias=True) | |
| >>> m.weight_u.size() | |
| torch.Size([40]) | |
| """ | |
| if dim is None: | |
| if isinstance(module, | |
| (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, | |
| torch.nn.ConvTranspose3d)): | |
| dim = 1 | |
| else: | |
| dim = 0 | |
| SpectralNorm.apply(module, name, n_power_iterations, dim, eps) | |
| return module | |
| def remove_spectral_norm(module, name='weight'): | |
| r"""Removes the spectral normalization reparameterization from a module. | |
| Args: | |
| module (Module): containing module | |
| name (str, optional): name of weight parameter | |
| Example: | |
| >>> m = spectral_norm(nn.Linear(40, 10)) | |
| >>> remove_spectral_norm(m) | |
| """ | |
| for k, hook in module._forward_pre_hooks.items(): | |
| if isinstance(hook, SpectralNorm) and hook.name == name: | |
| hook.remove(module) | |
| del module._forward_pre_hooks[k] | |
| return module | |
| raise ValueError("spectral_norm of '{}' not found in {}".format( | |
| name, module)) | |
| def use_spectral_norm(module, use_sn=False): | |
| if use_sn: | |
| return spectral_norm(module) | |
| return module |