# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Ways to make the model stronger.""" import random import torch def power_iteration(m, niters=1, bs=1): """This is the power method. batch size is used to try multiple starting point in parallel.""" assert m.dim() == 2 assert m.shape[0] == m.shape[1] dim = m.shape[0] b = torch.randn(dim, bs, device=m.device, dtype=m.dtype) for _ in range(niters): n = m.mm(b) norm = n.norm(dim=0, keepdim=True) b = n / (1e-10 + norm) return norm.mean() # We need a shared RNG to make sure all the distributed worker will skip the penalty together, # as otherwise we wouldn't get any speed up. penalty_rng = random.Random(1234) def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True, proba=1, conv_only=False, exact=False, bs=1): """ Penalty on the largest singular value for a layer. Args: - model: model to penalize - min_size: minimum size in MB of a layer to penalize. - dim: projection dimension for the svd_lowrank. Higher is better but slower. - niters: number of iterations in the algorithm used by svd_lowrank. - powm: use power method instead of lowrank SVD, my own experience is that it is both slower and less stable. - convtr: when True, differentiate between Conv and Transposed Conv. this is kept for compatibility with older experiments. - proba: probability to apply the penalty. - conv_only: only apply to conv and conv transposed, not LSTM (might not be reliable for other models than Demucs). - exact: use exact SVD (slow but useful at validation). - bs: batch_size for power method. """ total = 0 if penalty_rng.random() > proba: return 0. for m in model.modules(): for name, p in m.named_parameters(recurse=False): if p.numel() / 2**18 < min_size: continue if convtr: if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): if p.dim() in [3, 4]: p = p.transpose(0, 1).contiguous() if p.dim() == 3: p = p.view(len(p), -1) elif p.dim() == 4: p = p.view(len(p), -1) elif p.dim() == 1: continue elif conv_only: continue assert p.dim() == 2, (name, p.shape) if exact: estimate = torch.svd(p, compute_uv=False)[1].pow(2).max() elif powm: a, b = p.shape if a < b: n = p.mm(p.t()) else: n = p.t().mm(p) estimate = power_iteration(n, niters, bs) else: estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2) total += estimate return total / proba