Spaces:
Runtime error
Runtime error
| import gin | |
| import torch | |
| import torch.fft | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .dynamic import FiLM, TimeDistributedMLP | |
| class Sine(nn.Module): | |
| def forward(self, x: torch.Tensor): | |
| return torch.sin(x) | |
| class TrainableNonlinearity(nn.Module): | |
| def __init__( | |
| self, channels, width, nonlinearity=nn.ReLU, final_nonlinearity=Sine, depth=3 | |
| ): | |
| super().__init__() | |
| self.input_scale = nn.Parameter(torch.randn(1, channels, 1) * 10) | |
| layers = [] | |
| for i in range(depth): | |
| layers.append( | |
| nn.Conv1d( | |
| channels if i == 0 else channels * width, | |
| channels * width if i < depth - 1 else channels, | |
| 1, | |
| groups=channels, | |
| ) | |
| ) | |
| layers.append(nonlinearity() if i < depth - 1 else final_nonlinearity()) | |
| self.net = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.net(self.input_scale * x) | |
| class NEWT(nn.Module): | |
| def __init__( | |
| self, | |
| n_waveshapers: int, | |
| control_embedding_size: int, | |
| shaping_fn_size: int = 16, | |
| out_channels: int = 1, | |
| ): | |
| super().__init__() | |
| self.n_waveshapers = n_waveshapers | |
| self.mlp = TimeDistributedMLP( | |
| control_embedding_size, control_embedding_size, n_waveshapers * 4, depth=4 | |
| ) | |
| self.waveshaping_index = FiLM() | |
| self.shaping_fn = TrainableNonlinearity( | |
| n_waveshapers, shaping_fn_size, nonlinearity=Sine | |
| ) | |
| self.normalising_coeff = FiLM() | |
| self.mixer = nn.Sequential( | |
| nn.Conv1d(n_waveshapers, out_channels, 1), | |
| ) | |
| def forward(self, exciter, control_embedding): | |
| film_params = self.mlp(control_embedding) | |
| film_params = F.upsample(film_params, exciter.shape[-1], mode="linear") | |
| gamma_index, beta_index, gamma_norm, beta_norm = torch.split( | |
| film_params, self.n_waveshapers, 1 | |
| ) | |
| x = self.waveshaping_index(exciter, gamma_index, beta_index) | |
| x = self.shaping_fn(x) | |
| x = self.normalising_coeff(x, gamma_norm, beta_norm) | |
| # return x | |
| return self.mixer(x) | |
| class FastNEWT(NEWT): | |
| def __init__( | |
| self, | |
| newt: NEWT, | |
| table_size: int = 4096, | |
| table_min: float = -3.0, | |
| table_max: float = 3.0, | |
| ): | |
| super().__init__() | |
| self.table_size = table_size | |
| self.table_min = table_min | |
| self.table_max = table_max | |
| self.n_waveshapers = newt.n_waveshapers | |
| self.mlp = newt.mlp | |
| self.waveshaping_index = newt.waveshaping_index | |
| self.normalising_coeff = newt.normalising_coeff | |
| self.mixer = newt.mixer | |
| self.lookup_table = self._init_lookup_table( | |
| newt, table_size, self.n_waveshapers, table_min, table_max | |
| ) | |
| self.to(next(iter(newt.parameters())).device) | |
| def _init_lookup_table( | |
| self, | |
| newt: NEWT, | |
| table_size: int, | |
| n_waveshapers: int, | |
| table_min: float, | |
| table_max: float, | |
| ): | |
| sample_values = torch.linspace(table_min, table_max, table_size, device=next(iter(newt.parameters())).device).expand( | |
| 1, n_waveshapers, table_size | |
| ) | |
| lookup_table = newt.shaping_fn(sample_values)[0] | |
| return nn.Parameter(lookup_table) | |
| def _lookup(self, idx): | |
| return torch.stack( | |
| [ | |
| torch.stack( | |
| [ | |
| self.lookup_table[shaper, idx[batch, shaper]] | |
| for shaper in range(idx.shape[1]) | |
| ], | |
| dim=0, | |
| ) | |
| for batch in range(idx.shape[0]) | |
| ], | |
| dim=0, | |
| ) | |
| def shaping_fn(self, x): | |
| idx = self.table_size * (x - self.table_min) / (self.table_max - self.table_min) | |
| lower = torch.floor(idx).long() | |
| lower[lower < 0] = 0 | |
| lower[lower >= self.table_size] = self.table_size - 1 | |
| upper = lower + 1 | |
| upper[upper >= self.table_size] = self.table_size - 1 | |
| fract = idx - lower | |
| lower_v = self._lookup(lower) | |
| upper_v = self._lookup(upper) | |
| output = (upper_v - lower_v) * fract + lower_v | |
| return output | |
| class Reverb(nn.Module): | |
| def __init__(self, length_in_seconds, sr): | |
| super().__init__() | |
| self.ir = nn.Parameter(torch.randn(1, sr * length_in_seconds - 1) * 1e-6) | |
| self.register_buffer("initial_zero", torch.zeros(1, 1)) | |
| def forward(self, x): | |
| ir_ = torch.cat((self.initial_zero, self.ir), dim=-1) | |
| if x.shape[-1] > ir_.shape[-1]: | |
| ir_ = F.pad(ir_, (0, x.shape[-1] - ir_.shape[-1])) | |
| x_ = x | |
| else: | |
| x_ = F.pad(x, (0, ir_.shape[-1] - x.shape[-1])) | |
| return ( | |
| x | |
| + torch.fft.irfft(torch.fft.rfft(x_) * torch.fft.rfft(ir_))[ | |
| ..., : x.shape[-1] | |
| ] | |
| ) | |