from functools import reduce import math import torch from torch import nn from torch.nn import functional as F from torch.backends.cuda import sdp_kernel from packaging import version from dac.nn.layers import Snake1d class ResidualBlock(nn.Module): def __init__(self, main, skip=None): super().__init__() self.main = nn.Sequential(*main) self.skip = skip if skip else nn.Identity() def forward(self, input): return self.main(input) + self.skip(input) class ResConvBlock(ResidualBlock): def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) super().__init__([ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), nn.GroupNorm(1, c_mid), Snake1d(c_mid) if use_snake else nn.GELU(), nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), ], skip) class SelfAttention1d(nn.Module): def __init__(self, c_in, n_head=1, dropout_rate=0.): super().__init__() assert c_in % n_head == 0 self.norm = nn.GroupNorm(1, c_in) self.n_head = n_head self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) self.out_proj = nn.Conv1d(c_in, c_in, 1) self.dropout = nn.Dropout(dropout_rate, inplace=True) self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') if not self.use_flash: return device_properties = torch.cuda.get_device_properties(torch.device('cuda')) if device_properties.major == 8 and device_properties.minor == 0: # Use flash attention for A100 GPUs self.sdp_kernel_config = (True, False, False) else: # Don't use flash attention for other GPUs self.sdp_kernel_config = (False, True, True) def forward(self, input): n, c, s = input.shape qkv = self.qkv_proj(self.norm(input)) qkv = qkv.view( [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) q, k, v = qkv.chunk(3, dim=1) scale = k.shape[3]**-0.25 if self.use_flash: with sdp_kernel(*self.sdp_kernel_config): y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) else: att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) return input + self.dropout(self.out_proj(y)) class SkipBlock(nn.Module): def __init__(self, *main): super().__init__() self.main = nn.Sequential(*main) def forward(self, input): return torch.cat([self.main(input), input], dim=1) class FourierFeatures(nn.Module): def __init__(self, in_features, out_features, std=1.): super().__init__() assert out_features % 2 == 0 self.weight = nn.Parameter(torch.randn( [out_features // 2, in_features]) * std) def forward(self, input): f = 2 * math.pi * input @ self.weight.T return torch.cat([f.cos(), f.sin()], dim=-1) def expand_to_planes(input, shape): return input[..., None].repeat([1, 1, shape[2]]) _kernels = { 'linear': [1 / 8, 3 / 8, 3 / 8, 1 / 8], 'cubic': [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], 'lanczos3': [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 0.44638532400131226, 0.13550527393817902, -0.066637322306633, -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] } class Downsample1d(nn.Module): def __init__(self, kernel='linear', pad_mode='reflect'): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer('kernel', kernel_1d) def forward(self, x): x = F.pad(x, (self.pad,) * 2, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) indices = torch.arange(x.shape[1], device=x.device) weight[indices, indices] = self.kernel.to(weight) return F.conv1d(x, weight, stride=2) class Upsample1d(nn.Module): def __init__(self, kernel='linear', pad_mode='reflect'): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) * 2 self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer('kernel', kernel_1d) def forward(self, x): x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) indices = torch.arange(x.shape[1], device=x.device) weight[indices, indices] = self.kernel.to(weight) return F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) def Downsample1d_2( in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 ) -> nn.Module: assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" return nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=factor * kernel_multiplier + 1, stride=factor, padding=factor * (kernel_multiplier // 2), ) def Upsample1d_2( in_channels: int, out_channels: int, factor: int, use_nearest: bool = False ) -> nn.Module: if factor == 1: return nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 ) if use_nearest: return nn.Sequential( nn.Upsample(scale_factor=factor, mode="nearest"), nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, ), ) else: return nn.ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=factor * 2, stride=factor, padding=factor // 2 + factor % 2, output_padding=factor % 2, ) def zero_init(layer): nn.init.zeros_(layer.weight) if layer.bias is not None: nn.init.zeros_(layer.bias) return layer def rms_norm(x, scale, eps): dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) return x * scale.to(x.dtype) class AdaRMSNorm(nn.Module): def __init__(self, features, cond_features, eps=1e-6): super().__init__() self.eps = eps self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) def extra_repr(self): return f"eps={self.eps}," def forward(self, x, cond): return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)