Spaces:
Running
on
Zero
Running
on
Zero
| r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" | |
| from torch.nn.parameter import Parameter, UninitializedParameter | |
| from torch import norm_except_dim | |
| from typing import Any, TypeVar | |
| import warnings | |
| from torch.nn.modules import Module | |
| import torch | |
| class WeightNorm: | |
| name: str | |
| dim: int | |
| def __init__(self, name: str, dim: int) -> None: | |
| if dim is None: | |
| dim = -1 | |
| self.name = name | |
| self.dim = dim | |
| # TODO Make return type more specific | |
| def compute_weight(self, module: Module) -> Any: | |
| g = getattr(module, self.name + '_g') | |
| v = getattr(module, self.name + '_v') | |
| input_dtype = v.dtype | |
| v = v.to(torch.float32) | |
| reduce_dims = list(range(v.dim())) | |
| reduce_dims.pop(self.dim) | |
| variance = v.pow(2).sum(reduce_dims, keepdim=True) | |
| v = v * torch.rsqrt(variance + 1e-6) | |
| return g * v.to(input_dtype) | |
| def apply(module, name: str, dim: int) -> 'WeightNorm': | |
| warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.") | |
| for hook in module._forward_pre_hooks.values(): | |
| if isinstance(hook, WeightNorm) and hook.name == name: | |
| raise RuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}") | |
| if dim is None: | |
| dim = -1 | |
| fn = WeightNorm(name, dim) | |
| weight = getattr(module, name) | |
| if isinstance(weight, UninitializedParameter): | |
| raise ValueError( | |
| 'The module passed to `WeightNorm` can\'t have uninitialized parameters. ' | |
| 'Make sure to run the dummy forward before applying weight normalization') | |
| # remove w from parameter list | |
| del module._parameters[name] | |
| # add g and v as new parameters and express w as g/||v|| * v | |
| module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data)) | |
| module.register_parameter(name + '_v', Parameter(weight.data)) | |
| setattr(module, name, fn.compute_weight(module)) | |
| # recompute weight before every forward() | |
| module.register_forward_pre_hook(fn) | |
| return fn | |
| def remove(self, module: Module) -> None: | |
| weight = self.compute_weight(module) | |
| delattr(module, self.name) | |
| del module._parameters[self.name + '_g'] | |
| del module._parameters[self.name + '_v'] | |
| setattr(module, self.name, Parameter(weight.data)) | |
| def __call__(self, module: Module, inputs: Any) -> None: | |
| setattr(module, self.name, self.compute_weight(module)) | |
| T_module = TypeVar('T_module', bound=Module) | |
| def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module: | |
| r"""Apply weight normalization to a parameter in the given module. | |
| .. math:: | |
| \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} | |
| Weight normalization is a reparameterization that decouples the magnitude | |
| of a weight tensor from its direction. This replaces the parameter specified | |
| by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude | |
| (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). | |
| Weight normalization is implemented via a hook that recomputes the weight | |
| tensor from the magnitude and direction before every :meth:`~Module.forward` | |
| call. | |
| By default, with ``dim=0``, the norm is computed independently per output | |
| channel/plane. To compute a norm over the entire weight tensor, use | |
| ``dim=None``. | |
| See https://arxiv.org/abs/1602.07868 | |
| .. warning:: | |
| This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm` | |
| which uses the modern parametrization API. The new ``weight_norm`` is compatible | |
| with ``state_dict`` generated from old ``weight_norm``. | |
| Migration guide: | |
| * The magnitude (``weight_g``) and direction (``weight_v``) are now expressed | |
| as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1`` | |
| respectively. If this is bothering you, please comment on | |
| https://github.com/pytorch/pytorch/issues/102999 | |
| * To remove the weight normalization reparametrization, use | |
| :func:`torch.nn.utils.parametrize.remove_parametrizations`. | |
| * The weight is no longer recomputed once at module forward; instead, it will | |
| be recomputed on every access. To restore the old behavior, use | |
| :func:`torch.nn.utils.parametrize.cached` before invoking the module | |
| in question. | |
| Args: | |
| module (Module): containing module | |
| name (str, optional): name of weight parameter | |
| dim (int, optional): dimension over which to compute the norm | |
| Returns: | |
| The original module with the weight norm hook | |
| Example:: | |
| >>> m = weight_norm(nn.Linear(20, 40), name='weight') | |
| >>> m | |
| Linear(in_features=20, out_features=40, bias=True) | |
| >>> m.weight_g.size() | |
| torch.Size([40, 1]) | |
| >>> m.weight_v.size() | |
| torch.Size([40, 20]) | |
| """ | |
| WeightNorm.apply(module, name, dim) | |
| return module | |