# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License # License can be found in LICENSES/LICENSE_ADP.txt from inspect import isfunction from math import ceil, floor, log, pi, log2 from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union from packaging import version import torch import torch.nn as nn from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange from einops_exts import rearrange_many from torch import Tensor, einsum from torch.backends.cuda import sdp_kernel from torch.nn import functional as F from dac.nn.layers import Snake1d from audiocraft.modules.conv import get_extra_padding_for_conv1d, pad1d, unpad1d """ Utils """ class ConditionedSequential(nn.Module): def __init__(self, *modules): super().__init__() self.module_list = nn.ModuleList(*modules) def forward(self, x: Tensor, mapping: Optional[Tensor] = None): for module in self.module_list: x = module(x, mapping) return x T = TypeVar("T") def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: if exists(val): return val return d() if isfunction(d) else d def exists(val: Optional[T]) -> T: return val is not None def closest_power_2(x: float) -> int: exponent = log2(x) distance_fn = lambda z: abs(x - 2 ** z) # noqa exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) return 2 ** int(exponent_closest) def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: return_dicts: Tuple[Dict, Dict] = ({}, {}) for key in d.keys(): no_prefix = int(not key.startswith(prefix)) return_dicts[no_prefix][key] = d[key] return return_dicts def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) if keep_prefix: return kwargs_with_prefix, kwargs kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} return kwargs_no_prefix, kwargs """ Convolutional Blocks """ class Conv1d(nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x: Tensor, causal=False) -> Tensor: kernel_size = self.kernel_size[0] stride = self.stride[0] dilation = self.dilation[0] kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations padding_total = kernel_size - stride extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) if causal: # Left padding for causal x = pad1d(x, (padding_total, extra_padding)) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right x = pad1d(x, (padding_left, padding_right + extra_padding)) return super().forward(x) class ConvTranspose1d(nn.ConvTranspose1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x: Tensor, causal=False) -> Tensor: kernel_size = self.kernel_size[0] stride = self.stride[0] padding_total = kernel_size - stride y = super().forward(x) # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be # removed at the very end, when keeping only the right length for the output, # as removing it here would require also passing the length at the matching layer # in the encoder. if causal: padding_right = ceil(padding_total) padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) return y def Downsample1d( in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 ) -> nn.Module: assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" # print(f'downsample getting in_channel: {in_channels}, out_channels: {out_channels}, factor:{factor}') return Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=factor * kernel_multiplier + 1, stride=factor ) def Upsample1d( in_channels: int, out_channels: int, factor: int, use_nearest: bool = False ) -> nn.Module: # print(f'Upsample1d getting in_channel: {in_channels}, out_channel: {out_channels}, factor:{factor}') if factor == 1: return Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=3 ) if use_nearest: return nn.Sequential( nn.Upsample(scale_factor=factor, mode="nearest"), Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=3 ), ) else: return ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=factor * 2, stride=factor ) class ConvBlock1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, *, kernel_size: int = 3, stride: int = 1, dilation: int = 1, num_groups: int = 8, use_norm: bool = True, use_snake: bool = False ) -> None: super().__init__() self.groupnorm = ( nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) if use_norm else nn.Identity() ) if use_snake: self.activation = Snake1d(in_channels) else: self.activation = nn.SiLU() self.project = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, ) def forward( self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False ) -> Tensor: x = self.groupnorm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.activation(x) return self.project(x, causal=causal) class MappingToScaleShift(nn.Module): def __init__( self, features: int, channels: int, ): super().__init__() self.to_scale_shift = nn.Sequential( nn.SiLU(), nn.Linear(in_features=features, out_features=channels * 2), ) def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: scale_shift = self.to_scale_shift(mapping) scale_shift = rearrange(scale_shift, "b c -> b c 1") scale, shift = scale_shift.chunk(2, dim=1) return scale, shift class ResnetBlock1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, *, kernel_size: int = 3, stride: int = 1, dilation: int = 1, use_norm: bool = True, use_snake: bool = False, num_groups: int = 8, context_mapping_features: Optional[int] = None, ) -> None: super().__init__() self.use_mapping = exists(context_mapping_features) self.block1 = ConvBlock1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, use_norm=use_norm, num_groups=num_groups, use_snake=use_snake ) if self.use_mapping: assert exists(context_mapping_features) self.to_scale_shift = MappingToScaleShift( features=context_mapping_features, channels=out_channels ) self.block2 = ConvBlock1d( in_channels=out_channels, out_channels=out_channels, use_norm=use_norm, num_groups=num_groups, use_snake=use_snake ) self.to_out = ( Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() ) def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: # print(f"ResnetBlock1d input shape {x.shape}") assert_message = "context mapping required if context_mapping_features > 0" assert not (self.use_mapping ^ exists(mapping)), assert_message h = self.block1(x, causal=causal) scale_shift = None if self.use_mapping: scale_shift = self.to_scale_shift(mapping) h = self.block2(h, scale_shift=scale_shift, causal=causal) # print(f"ResnetBlock1d output shape {h.shape}") return h + self.to_out(x) class Patcher(nn.Module): def __init__( self, in_channels: int, out_channels: int, patch_size: int, context_mapping_features: Optional[int] = None, use_snake: bool = False, ): super().__init__() assert_message = f"out_channels must be divisible by patch_size ({patch_size})" assert out_channels % patch_size == 0, assert_message self.patch_size = patch_size self.block = ResnetBlock1d( in_channels=in_channels, out_channels=out_channels // patch_size, num_groups=1, context_mapping_features=context_mapping_features, use_snake=use_snake ) def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: # print(f"Patcher input shape: {x.shape}") x = self.block(x, mapping, causal=causal) x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) # print(f"Patcher output shape {x.shape}") return x class Unpatcher(nn.Module): def __init__( self, in_channels: int, out_channels: int, patch_size: int, context_mapping_features: Optional[int] = None, use_snake: bool = False ): super().__init__() assert_message = f"in_channels must be divisible by patch_size ({patch_size})" assert in_channels % patch_size == 0, assert_message self.patch_size = patch_size self.block = ResnetBlock1d( in_channels=in_channels // patch_size, out_channels=out_channels, num_groups=1, context_mapping_features=context_mapping_features, use_snake=use_snake ) def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: # print(f"Unpatcher input shape: {x.shape}") x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) x = self.block(x, mapping, causal=causal) # print(f"Unpatcher output shape: {x.shape}") return x """ Attention Components """ def FeedForward(features: int, multiplier: int) -> nn.Module: # print(f'feed forward getting multipler {multiplier}') mid_features = features * multiplier return nn.Sequential( nn.Linear(in_features=features, out_features=mid_features), nn.GELU(), nn.Linear(in_features=mid_features, out_features=features), ) def add_mask(sim: Tensor, mask: Tensor) -> Tensor: b, ndim = sim.shape[0], mask.ndim if ndim == 3: mask = rearrange(mask, "b n m -> b 1 n m") if ndim == 2: mask = repeat(mask, "n m -> b 1 n m", b=b) max_neg_value = -torch.finfo(sim.dtype).max sim = sim.masked_fill(~mask, max_neg_value) return sim def causal_mask(q: Tensor, k: Tensor) -> Tensor: b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) mask = repeat(mask, "n m -> b n m", b=b) return mask class AttentionBase(nn.Module): def __init__( self, features: int, *, head_features: int, num_heads: int, out_features: Optional[int] = None, ): super().__init__() self.scale = head_features**-0.5 self.num_heads = num_heads mid_features = head_features * num_heads out_features = default(out_features, features) self.to_out = nn.Linear( in_features=mid_features, out_features=out_features ) self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') if not self.use_flash: return device_properties = torch.cuda.get_device_properties(torch.device('cuda')) if device_properties.major == 8 and device_properties.minor == 0: # Use flash attention for A100 GPUs self.sdp_kernel_config = (True, False, False) else: # Don't use flash attention for other GPUs self.sdp_kernel_config = (False, True, True) def forward( self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False ) -> Tensor: # Split heads q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) if not self.use_flash: if is_causal and not mask: # Mask out future tokens for causal attention mask = causal_mask(q, k) # Compute similarity matrix and add eventual mask sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale sim = add_mask(sim, mask) if exists(mask) else sim # Get attention matrix with softmax attn = sim.softmax(dim=-1, dtype=torch.float32) # Compute values out = einsum("... n m, ... m d -> ... n d", attn, v) else: with sdp_kernel(*self.sdp_kernel_config): out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class Attention(nn.Module): def __init__( self, features: int, *, head_features: int, num_heads: int, out_features: Optional[int] = None, context_features: Optional[int] = None, causal: bool = False, ): super().__init__() self.context_features = context_features self.causal = causal mid_features = head_features * num_heads context_features = default(context_features, features) self.norm = nn.LayerNorm(features) self.norm_context = nn.LayerNorm(context_features) self.to_q = nn.Linear( in_features=features, out_features=mid_features, bias=False ) self.to_kv = nn.Linear( in_features=context_features, out_features=mid_features * 2, bias=False ) self.attention = AttentionBase( features, num_heads=num_heads, head_features=head_features, out_features=out_features, ) def forward( self, x: Tensor, # [b, n, c] context: Optional[Tensor] = None, # [b, m, d] context_mask: Optional[Tensor] = None, # [b, m], false is masked, causal: Optional[bool] = False, ) -> Tensor: assert_message = "You must provide a context when using context_features" assert not self.context_features or exists(context), assert_message # Use context if provided context = default(context, x) # Normalize then compute q from input and k,v from context x, context = self.norm(x), self.norm_context(context) # print("Shape of x:", x.shape) # print("Shape of context:", context.shape) # print("context_mask:", context_mask) q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) if exists(context_mask): # Mask out cross-attention for padding tokens mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) k, v = k * mask, v * mask # Compute and return attention return self.attention(q, k, v, is_causal=self.causal or causal) def FeedForward(features: int, multiplier: int) -> nn.Module: mid_features = features * multiplier return nn.Sequential( nn.Linear(in_features=features, out_features=mid_features), nn.GELU(), nn.Linear(in_features=mid_features, out_features=features), ) """ Transformer Blocks """ class TransformerBlock(nn.Module): def __init__( self, features: int, num_heads: int, head_features: int, multiplier: int, context_features: Optional[int] = None, ): super().__init__() self.use_cross_attention = exists(context_features) and context_features > 0 self.attention = Attention( features=features, num_heads=num_heads, head_features=head_features ) if self.use_cross_attention: self.cross_attention = Attention( features=features, num_heads=num_heads, head_features=head_features, context_features=context_features ) self.feed_forward = FeedForward(features=features, multiplier=multiplier) def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: # print(f'TransformerBlock input shape: {x.shape}') x = self.attention(x, causal=causal) + x if self.use_cross_attention: x = self.cross_attention(x, context=context, context_mask=context_mask) + x x = self.feed_forward(x) + x # print(f'TransformerBlock output shape: {x.shape}') return x """ Transformers """ class Transformer1d(nn.Module): def __init__( self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int, context_features: Optional[int] = None, ): super().__init__() self.to_in = nn.Sequential( nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), Conv1d( in_channels=channels, out_channels=channels, kernel_size=1, ), Rearrange("b c t -> b t c"), ) self.blocks = nn.ModuleList( [ TransformerBlock( features=channels, head_features=head_features, num_heads=num_heads, multiplier=multiplier, context_features=context_features, ) for i in range(num_layers) ] ) self.to_out = nn.Sequential( Rearrange("b t c -> b c t"), Conv1d( in_channels=channels, out_channels=channels, kernel_size=1, ), ) def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: # print(f'Transformer1d input shape: {x.shape}') x = self.to_in(x) for block in self.blocks: x = block(x, context=context, context_mask=context_mask, causal=causal) x = self.to_out(x) # print(f'Transformer1d output shape: {x.shape}') return x """ Time Embeddings """ class SinusoidalEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, x: Tensor) -> Tensor: device, half_dim = x.device, self.dim // 2 emb = torch.tensor(log(10000) / (half_dim - 1), device=device) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") return torch.cat((emb.sin(), emb.cos()), dim=-1) class LearnedPositionalEmbedding(nn.Module): """Used for continuous time""" def __init__(self, dim: int): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x: Tensor) -> Tensor: x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((x, fouriered), dim=-1) return fouriered def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: return nn.Sequential( LearnedPositionalEmbedding(dim), nn.Linear(in_features=dim + 1, out_features=out_features), ) """ Encoder/Decoder Components """ class DownsampleBlock1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, *, factor: int, num_groups: int, num_layers: int, kernel_multiplier: int = 2, use_pre_downsample: bool = True, use_skip: bool = False, use_snake: bool = False, extract_channels: int = 0, context_channels: int = 0, num_transformer_blocks: int = 0, attention_heads: Optional[int] = None, attention_features: Optional[int] = None, attention_multiplier: Optional[int] = None, context_mapping_features: Optional[int] = None, context_embedding_features: Optional[int] = None, ): super().__init__() self.use_pre_downsample = use_pre_downsample self.use_skip = use_skip self.use_transformer = num_transformer_blocks > 0 self.use_extract = extract_channels > 0 self.use_context = context_channels > 0 channels = out_channels if use_pre_downsample else in_channels self.downsample = Downsample1d( in_channels=in_channels, out_channels=out_channels, factor=factor, kernel_multiplier=kernel_multiplier, ) self.blocks = nn.ModuleList( [ ResnetBlock1d( in_channels=channels + context_channels if i == 0 else channels, out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, use_snake=use_snake ) for i in range(num_layers) ] ) if self.use_transformer: assert ( (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) ) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads if attention_heads is None and attention_features is not None: attention_heads = channels // attention_features self.transformer = Transformer1d( num_layers=num_transformer_blocks, channels=channels, num_heads=attention_heads, head_features=attention_features, multiplier=attention_multiplier, context_features=context_embedding_features ) if self.use_extract: num_extract_groups = min(num_groups, extract_channels) self.to_extracted = ResnetBlock1d( in_channels=out_channels, out_channels=extract_channels, num_groups=num_extract_groups, use_snake=use_snake ) def forward( self, x: Tensor, *, mapping: Optional[Tensor] = None, channels: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, causal: Optional[bool] = False ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: # print(f'DownsampleBlock1d input shape: {x.shape}') if self.use_pre_downsample: x = self.downsample(x) if self.use_context and exists(channels): x = torch.cat([x, channels], dim=1) skips = [] for block in self.blocks: x = block(x, mapping=mapping, causal=causal) skips += [x] if self.use_skip else [] if self.use_transformer: x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) skips += [x] if self.use_skip else [] if not self.use_pre_downsample: x = self.downsample(x) if self.use_extract: extracted = self.to_extracted(x) return x, extracted # print(f'DownsampleBlock1d output shape: {x.shape}') return (x, skips) if self.use_skip else x class UpsampleBlock1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, *, factor: int, num_layers: int, num_groups: int, use_nearest: bool = False, use_pre_upsample: bool = False, use_skip: bool = False, use_snake: bool = False, skip_channels: int = 0, use_skip_scale: bool = False, extract_channels: int = 0, num_transformer_blocks: int = 0, attention_heads: Optional[int] = None, attention_features: Optional[int] = None, attention_multiplier: Optional[int] = None, context_mapping_features: Optional[int] = None, context_embedding_features: Optional[int] = None, ): super().__init__() self.use_extract = extract_channels > 0 self.use_pre_upsample = use_pre_upsample self.use_transformer = num_transformer_blocks > 0 self.use_skip = use_skip self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 channels = out_channels if use_pre_upsample else in_channels self.blocks = nn.ModuleList( [ ResnetBlock1d( in_channels=channels + skip_channels, out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, use_snake=use_snake ) for _ in range(num_layers) ] ) if self.use_transformer: assert ( (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) ) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads if attention_heads is None and attention_features is not None: attention_heads = channels // attention_features self.transformer = Transformer1d( num_layers=num_transformer_blocks, channels=channels, num_heads=attention_heads, head_features=attention_features, multiplier=attention_multiplier, context_features=context_embedding_features, ) self.upsample = Upsample1d( in_channels=in_channels, out_channels=out_channels, factor=factor, use_nearest=use_nearest, ) if self.use_extract: num_extract_groups = min(num_groups, extract_channels) self.to_extracted = ResnetBlock1d( in_channels=out_channels, out_channels=extract_channels, num_groups=num_extract_groups, use_snake=use_snake ) def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: return torch.cat([x, skip * self.skip_scale], dim=1) def forward( self, x: Tensor, *, skips: Optional[List[Tensor]] = None, mapping: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, causal: Optional[bool] = False ) -> Union[Tuple[Tensor, Tensor], Tensor]: # print(f'UpsampleBlock1d input shape: {x.shape}') if self.use_pre_upsample: x = self.upsample(x) for block in self.blocks: x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x x = block(x, mapping=mapping, causal=causal) if self.use_transformer: x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) if not self.use_pre_upsample: x = self.upsample(x) if self.use_extract: extracted = self.to_extracted(x) return x, extracted # print(f'UpsampleBlock1d output shape: {x.shape}') return x class BottleneckBlock1d(nn.Module): def __init__( self, channels: int, *, num_groups: int, num_transformer_blocks: int = 0, attention_heads: Optional[int] = None, attention_features: Optional[int] = None, attention_multiplier: Optional[int] = None, context_mapping_features: Optional[int] = None, context_embedding_features: Optional[int] = None, use_snake: bool = False, ): super().__init__() self.use_transformer = num_transformer_blocks > 0 self.pre_block = ResnetBlock1d( in_channels=channels, out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, use_snake=use_snake ) if self.use_transformer: assert ( (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) ) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads if attention_heads is None and attention_features is not None: attention_heads = channels // attention_features self.transformer = Transformer1d( num_layers=num_transformer_blocks, channels=channels, num_heads=attention_heads, head_features=attention_features, multiplier=attention_multiplier, context_features=context_embedding_features, ) self.post_block = ResnetBlock1d( in_channels=channels, out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, use_snake=use_snake ) def forward( self, x: Tensor, *, mapping: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, causal: Optional[bool] = False ) -> Tensor: # print(f'BottleneckBlock1d input shape: {x.shape}') x = self.pre_block(x, mapping=mapping, causal=causal) if self.use_transformer: x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) x = self.post_block(x, mapping=mapping, causal=causal) # print(f'BottleneckBlock1d output shape: {x.shape}') return x """ UNet """ class UNet1d(nn.Module): def __init__( self, in_channels: int, channels: int, multipliers: Sequence[int], factors: Sequence[int], num_blocks: Sequence[int], attentions: Sequence[int], patch_size: int = 1, resnet_groups: int = 8, use_context_time: bool = True, kernel_multiplier_downsample: int = 2, use_nearest_upsample: bool = False, use_skip_scale: bool = True, use_snake: bool = False, use_stft: bool = False, use_stft_context: bool = False, out_channels: Optional[int] = None, context_features: Optional[int] = None, context_features_multiplier: int = 4, context_channels: Optional[Sequence[int]] = None, context_embedding_features: Optional[int] = None, **kwargs, ): super().__init__() out_channels = default(out_channels, in_channels) context_channels = list(default(context_channels, [])) num_layers = len(multipliers) - 1 use_context_features = exists(context_features) use_context_channels = len(context_channels) > 0 context_mapping_features = None attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) self.num_layers = num_layers self.use_context_time = use_context_time self.use_context_features = use_context_features self.use_context_channels = use_context_channels self.use_stft = use_stft self.use_stft_context = use_stft_context self.context_features = context_features context_channels_pad_length = num_layers + 1 - len(context_channels) context_channels = context_channels + [0] * context_channels_pad_length self.context_channels = context_channels self.context_embedding_features = context_embedding_features if use_context_channels: has_context = [c > 0 for c in context_channels] self.has_context = has_context self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] assert ( len(factors) == num_layers and len(attentions) >= num_layers and len(num_blocks) == num_layers ) if use_context_time or use_context_features: context_mapping_features = channels * context_features_multiplier self.to_mapping = nn.Sequential( nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(), nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(), ) if use_context_time: assert exists(context_mapping_features) self.to_time = nn.Sequential( TimePositionalEmbedding( dim=channels, out_features=context_mapping_features ), nn.GELU(), ) if use_context_features: assert exists(context_features) and exists(context_mapping_features) self.to_features = nn.Sequential( nn.Linear( in_features=context_features, out_features=context_mapping_features ), nn.GELU(), ) if use_stft: stft_kwargs, kwargs = groupby("stft_", kwargs) assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 in_channels *= stft_channels out_channels *= stft_channels context_channels[0] *= stft_channels if use_stft_context else 1 assert exists(in_channels) and exists(out_channels) self.stft = STFT(**stft_kwargs) assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" self.to_in = Patcher( in_channels=in_channels + context_channels[0], out_channels=channels * multipliers[0], patch_size=patch_size, context_mapping_features=context_mapping_features, use_snake=use_snake ) self.downsamples = nn.ModuleList( [ DownsampleBlock1d( in_channels=channels * multipliers[i], out_channels=channels * multipliers[i + 1], context_mapping_features=context_mapping_features, context_channels=context_channels[i + 1], context_embedding_features=context_embedding_features, num_layers=num_blocks[i], factor=factors[i], kernel_multiplier=kernel_multiplier_downsample, num_groups=resnet_groups, use_pre_downsample=True, use_skip=True, use_snake=use_snake, num_transformer_blocks=attentions[i], **attention_kwargs, ) for i in range(num_layers) ] ) self.bottleneck = BottleneckBlock1d( channels=channels * multipliers[-1], context_mapping_features=context_mapping_features, context_embedding_features=context_embedding_features, num_groups=resnet_groups, num_transformer_blocks=attentions[-1], use_snake=use_snake, **attention_kwargs, ) self.upsamples = nn.ModuleList( [ UpsampleBlock1d( in_channels=channels * multipliers[i + 1], out_channels=channels * multipliers[i], context_mapping_features=context_mapping_features, context_embedding_features=context_embedding_features, num_layers=num_blocks[i] + (1 if attentions[i] else 0), factor=factors[i], use_nearest=use_nearest_upsample, num_groups=resnet_groups, use_skip_scale=use_skip_scale, use_pre_upsample=False, use_skip=True, use_snake=use_snake, skip_channels=channels * multipliers[i + 1], num_transformer_blocks=attentions[i], **attention_kwargs, ) for i in reversed(range(num_layers)) ] ) self.to_out = Unpatcher( in_channels=channels * multipliers[0], out_channels=out_channels, patch_size=patch_size, context_mapping_features=context_mapping_features, use_snake=use_snake ) def get_channels( self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 ) -> Optional[Tensor]: """Gets context channels at `layer` and checks that shape is correct""" use_context_channels = self.use_context_channels and self.has_context[layer] if not use_context_channels: return None assert exists(channels_list), "Missing context" # Get channels index (skipping zero channel contexts) channels_id = self.channels_ids[layer] # Get channels channels = channels_list[channels_id] message = f"Missing context for layer {layer} at index {channels_id}" assert exists(channels), message # Check channels num_channels = self.context_channels[layer] message = f"Expected context with {num_channels} channels at idx {channels_id}" assert channels.shape[1] == num_channels, message # STFT channels if requested channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa return channels def get_mapping( self, time: Optional[Tensor] = None, features: Optional[Tensor] = None ) -> Optional[Tensor]: """Combines context time features and features into mapping""" items, mapping = [], None # Compute time features if self.use_context_time: assert_message = "use_context_time=True but no time features provided" assert exists(time), assert_message items += [self.to_time(time)] # Compute features if self.use_context_features: assert_message = "context_features exists but no features provided" assert exists(features), assert_message items += [self.to_features(features)] # Compute joint mapping if self.use_context_time or self.use_context_features: mapping = reduce(torch.stack(items), "n b m -> b m", "sum") mapping = self.to_mapping(mapping) return mapping def forward( self, x: Tensor, time: Optional[Tensor] = None, *, features: Optional[Tensor] = None, channels_list: Optional[Sequence[Tensor]] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, causal: Optional[bool] = False, ) -> Tensor: # print(f'Unet1d input shape: {x.shape}') channels = self.get_channels(channels_list, layer=0) # Apply stft if required x = self.stft.encode1d(x) if self.use_stft else x # type: ignore # Concat context channels at layer 0 if provided x = torch.cat([x, channels], dim=1) if exists(channels) else x # Compute mapping from time and features mapping = self.get_mapping(time, features) x = self.to_in(x, mapping, causal=causal) skips_list = [x] for i, downsample in enumerate(self.downsamples): channels = self.get_channels(channels_list, layer=i + 1) x, skips = downsample( x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal ) skips_list += [skips] x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) for i, upsample in enumerate(self.upsamples): skips = skips_list.pop() x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) x += skips_list.pop() x = self.to_out(x, mapping, causal=causal) x = self.stft.decode1d(x) if self.use_stft else x # print(f'Unet1d output shape: {x.shape}') return x """ Conditioning Modules """ class FixedEmbedding(nn.Module): def __init__(self, max_length: int, features: int): super().__init__() self.max_length = max_length self.embedding = nn.Embedding(max_length, features) def forward(self, x: Tensor) -> Tensor: # print(f'FixedEmbedding input shape: {x.shape}') batch_size, length, device = *x.shape[0:2], x.device # print(f'FixedEmbedding length: {length}, self.max length: {self.max_length}') assert_message = "Input sequence length must be <= max_length" assert length <= self.max_length, assert_message position = torch.arange(length, device=device) fixed_embedding = self.embedding(position) fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) # print(f'FixedEmbedding output shape: {fixed_embedding.shape}') return fixed_embedding def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: if proba == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif proba == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) class UNetCFG1d(UNet1d): """UNet1d with Classifier-Free Guidance""" def __init__( self, context_embedding_max_length: int, context_embedding_features: int, use_xattn_time: bool = False, **kwargs, ): super().__init__( context_embedding_features=context_embedding_features, **kwargs ) self.use_xattn_time = use_xattn_time if use_xattn_time: assert exists(context_embedding_features) self.to_time_embedding = nn.Sequential( TimePositionalEmbedding( dim=kwargs["channels"], out_features=context_embedding_features ), nn.GELU(), ) context_embedding_max_length += 1 # Add one for time embedding self.fixed_embedding = FixedEmbedding( max_length=context_embedding_max_length, features=context_embedding_features ) def forward( # type: ignore self, x: Tensor, time: Tensor, *, embedding: Tensor, embedding_mask: Optional[Tensor] = None, embedding_scale: float = 1.0, embedding_mask_proba: float = 0.0, batch_cfg: bool = False, rescale_cfg: bool = False, scale_phi: float = 0.4, negative_embedding: Optional[Tensor] = None, negative_embedding_mask: Optional[Tensor] = None, **kwargs, ) -> Tensor: # print("Debugging UNetCFG1d forward method") # print(f"Input x shape: {x.shape}, type: {type(x)}") # print(f"Time embedding shape: {time.shape}, type: {type(time)}") # print(f"Cross-attention embedding shape: {embedding.shape}, type: {type(embedding)}") # print(f"Cross-attention embedding mask shape: {embedding_mask.shape if embedding_mask is not None else 'None'}, type: {type(embedding_mask)}") # print(f"Embedding scale: {embedding_scale}, type: {type(embedding_scale)}") # print(f"Embedding mask probability: {embedding_mask_proba}, type: {type(embedding_mask_proba)}") # print(f"Batch CFG: {batch_cfg}, type: {type(batch_cfg)}") # print(f"Rescale CFG: {rescale_cfg}, type: {type(rescale_cfg)}") # print(f"Scale Phi: {scale_phi}, type: {type(scale_phi)}") # if negative_embedding is not None: # print(f"Negative embedding shape: {negative_embedding.shape}, type: {type(negative_embedding)}") # if negative_embedding_mask is not None: # print(f"Negative embedding mask shape: {negative_embedding_mask.shape}, type: {type(negative_embedding_mask)}") b, device = embedding.shape[0], embedding.device if self.use_xattn_time: embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) if embedding_mask is not None: embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) fixed_embedding = self.fixed_embedding(embedding) # print(f'Fixed Embedding.shape {fixed_embedding.shape}') assert fixed_embedding.shape == embedding.shape, f"Shape mismatch: {fixed_embedding.shape} vs {embedding.shape}" if embedding_mask_proba > 0.0: # Randomly mask embedding batch_mask = rand_bool( shape=(b, 1, 1), proba=embedding_mask_proba, device=device ) embedding = torch.where(batch_mask, fixed_embedding, embedding) if embedding_scale != 1.0: if batch_cfg: batch_x = torch.cat([x, x], dim=0) batch_time = torch.cat([time, time], dim=0) if negative_embedding is not None: if negative_embedding_mask is not None: negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) batch_embed = torch.cat([embedding, negative_embedding], dim=0) else: batch_embed = torch.cat([embedding, fixed_embedding], dim=0) batch_mask = None if embedding_mask is not None: batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) batch_features = None features = kwargs.pop("features", None) if self.use_context_features: batch_features = torch.cat([features, features], dim=0) batch_channels = None channels_list = kwargs.pop("channels_list", None) if self.use_context_channels: batch_channels = [] for channels in channels_list: batch_channels += [torch.cat([channels, channels], dim=0)] # Compute both normal and fixed embedding outputs batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) out, out_masked = batch_out.chunk(2, dim=0) else: # Compute both normal and fixed embedding outputs out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) out_cfg = out_masked + (out - out_masked) * embedding_scale if rescale_cfg: out_std = out.std(dim=1, keepdim=True) out_cfg_std = out_cfg.std(dim=1, keepdim=True) return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg else: return out_cfg else: return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) class UNetNCCA1d(UNet1d): """UNet1d with Noise Channel Conditioning Augmentation""" def __init__(self, context_features: int, **kwargs): super().__init__(context_features=context_features, **kwargs) self.embedder = NumberEmbedder(features=context_features) def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: x = x if torch.is_tensor(x) else torch.tensor(x) return x.expand(shape) def forward( # type: ignore self, x: Tensor, time: Tensor, *, channels_list: Sequence[Tensor], channels_augmentation: Union[ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor ] = False, channels_scale: Union[ float, Sequence[float], Sequence[Sequence[float]], Tensor ] = 0, **kwargs, ) -> Tensor: b, n = x.shape[0], len(channels_list) channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) # Augmentation (for each channel list item) for i in range(n): scale = channels_scale[:, i] * channels_augmentation[:, i] scale = rearrange(scale, "b -> b 1 1") item = channels_list[i] channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa # Scale embedding (sum reduction if more than one channel list item) channels_scale_emb = self.embedder(channels_scale) channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") return super().forward( x=x, time=time, channels_list=channels_list, features=channels_scale_emb, **kwargs, ) class UNetAll1d(UNetCFG1d, UNetNCCA1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, *args, **kwargs): # type: ignore return UNetCFG1d.forward(self, *args, **kwargs) def XUNet1d(type: str = "base", **kwargs) -> UNet1d: if type == "base": return UNet1d(**kwargs) elif type == "all": return UNetAll1d(**kwargs) elif type == "cfg": return UNetCFG1d(**kwargs) elif type == "ncca": return UNetNCCA1d(**kwargs) else: raise ValueError(f"Unknown XUNet1d type: {type}") class NumberEmbedder(nn.Module): def __init__( self, features: int, dim: int = 256, ): super().__init__() self.features = features self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) def forward(self, x: Union[List[float], Tensor]) -> Tensor: if not torch.is_tensor(x): device = next(self.embedding.parameters()).device x = torch.tensor(x, device=device) assert isinstance(x, Tensor) shape = x.shape x = rearrange(x, "... -> (...)") embedding = self.embedding(x) x = embedding.view(*shape, self.features) return x # type: ignore """ Audio Transforms """ class STFT(nn.Module): """Helper for torch stft and istft""" def __init__( self, num_fft: int = 1023, hop_length: int = 256, window_length: Optional[int] = None, length: Optional[int] = None, use_complex: bool = False, ): super().__init__() self.num_fft = num_fft self.hop_length = default(hop_length, floor(num_fft // 4)) self.window_length = default(window_length, num_fft) self.length = length self.register_buffer("window", torch.hann_window(self.window_length)) self.use_complex = use_complex def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: b = wave.shape[0] wave = rearrange(wave, "b c t -> (b c) t") stft = torch.stft( wave, n_fft=self.num_fft, hop_length=self.hop_length, win_length=self.window_length, window=self.window, # type: ignore return_complex=True, normalized=True, ) if self.use_complex: # Returns real and imaginary stft_a, stft_b = stft.real, stft.imag else: # Returns magnitude and phase matrices magnitude, phase = torch.abs(stft), torch.angle(stft) stft_a, stft_b = magnitude, phase return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: b, l = stft_a.shape[0], stft_a.shape[-1] # noqa length = closest_power_2(l * self.hop_length) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") if self.use_complex: real, imag = stft_a, stft_b else: magnitude, phase = stft_a, stft_b real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) stft = torch.stack([real, imag], dim=-1) wave = torch.istft( stft, n_fft=self.num_fft, hop_length=self.hop_length, win_length=self.window_length, window=self.window, # type: ignore length=default(self.length, length), normalized=True, ) return rearrange(wave, "(b c) t -> b c t", b=b) def encode1d( self, wave: Tensor, stacked: bool = True ) -> Union[Tensor, Tuple[Tensor, Tensor]]: stft_a, stft_b = self.encode(wave) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) def decode1d(self, stft_pair: Tensor) -> Tensor: f = self.num_fft // 2 + 1 stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) return self.decode(stft_a, stft_b)