mrfakename's picture
Upload 114 files
c8448bc verified
raw
history blame
55.8 kB
# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
# License can be found in LICENSES/LICENSE_ADP.txt
from inspect import isfunction
from math import ceil, floor, log, pi, log2
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from packaging import version
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many
from torch import Tensor, einsum
from torch.backends.cuda import sdp_kernel
from torch.nn import functional as F
from dac.nn.layers import Snake1d
from audiocraft.modules.conv import get_extra_padding_for_conv1d, pad1d, unpad1d
"""
Utils
"""
class ConditionedSequential(nn.Module):
def __init__(self, *modules):
super().__init__()
self.module_list = nn.ModuleList(*modules)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
for module in self.module_list:
x = module(x, mapping)
return x
T = TypeVar("T")
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
if exists(val):
return val
return d() if isfunction(d) else d
def exists(val: Optional[T]) -> T:
return val is not None
def closest_power_2(x: float) -> int:
exponent = log2(x)
distance_fn = lambda z: abs(x - 2 ** z) # noqa
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
return 2 ** int(exponent_closest)
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
return_dicts: Tuple[Dict, Dict] = ({}, {})
for key in d.keys():
no_prefix = int(not key.startswith(prefix))
return_dicts[no_prefix][key] = d[key]
return return_dicts
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
if keep_prefix:
return kwargs_with_prefix, kwargs
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
return kwargs_no_prefix, kwargs
"""
Convolutional Blocks
"""
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, causal=False) -> Tensor:
kernel_size = self.kernel_size[0]
stride = self.stride[0]
dilation = self.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding))
return super().forward(x)
class ConvTranspose1d(nn.ConvTranspose1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, causal=False) -> Tensor:
kernel_size = self.kernel_size[0]
stride = self.stride[0]
padding_total = kernel_size - stride
y = super().forward(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if causal:
padding_right = ceil(padding_total)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
def Downsample1d(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
# print(f'downsample getting in_channel: {in_channels}, out_channels: {out_channels}, factor:{factor}')
return Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor
)
def Upsample1d(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
# print(f'Upsample1d getting in_channel: {in_channels}, out_channel: {out_channels}, factor:{factor}')
if factor == 1:
return Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3
)
if use_nearest:
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3
),
)
else:
return ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor
)
class ConvBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
num_groups: int = 8,
use_norm: bool = True,
use_snake: bool = False
) -> None:
super().__init__()
self.groupnorm = (
nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
if use_norm
else nn.Identity()
)
if use_snake:
self.activation = Snake1d(in_channels)
else:
self.activation = nn.SiLU()
self.project = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)
def forward(
self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
) -> Tensor:
x = self.groupnorm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.activation(x)
return self.project(x, causal=causal)
class MappingToScaleShift(nn.Module):
def __init__(
self,
features: int,
channels: int,
):
super().__init__()
self.to_scale_shift = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=features, out_features=channels * 2),
)
def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
scale_shift = self.to_scale_shift(mapping)
scale_shift = rearrange(scale_shift, "b c -> b c 1")
scale, shift = scale_shift.chunk(2, dim=1)
return scale, shift
class ResnetBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
use_norm: bool = True,
use_snake: bool = False,
num_groups: int = 8,
context_mapping_features: Optional[int] = None,
) -> None:
super().__init__()
self.use_mapping = exists(context_mapping_features)
self.block1 = ConvBlock1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
use_norm=use_norm,
num_groups=num_groups,
use_snake=use_snake
)
if self.use_mapping:
assert exists(context_mapping_features)
self.to_scale_shift = MappingToScaleShift(
features=context_mapping_features, channels=out_channels
)
self.block2 = ConvBlock1d(
in_channels=out_channels,
out_channels=out_channels,
use_norm=use_norm,
num_groups=num_groups,
use_snake=use_snake
)
self.to_out = (
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
# print(f"ResnetBlock1d input shape {x.shape}")
assert_message = "context mapping required if context_mapping_features > 0"
assert not (self.use_mapping ^ exists(mapping)), assert_message
h = self.block1(x, causal=causal)
scale_shift = None
if self.use_mapping:
scale_shift = self.to_scale_shift(mapping)
h = self.block2(h, scale_shift=scale_shift, causal=causal)
# print(f"ResnetBlock1d output shape {h.shape}")
return h + self.to_out(x)
class Patcher(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
context_mapping_features: Optional[int] = None,
use_snake: bool = False,
):
super().__init__()
assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
assert out_channels % patch_size == 0, assert_message
self.patch_size = patch_size
self.block = ResnetBlock1d(
in_channels=in_channels,
out_channels=out_channels // patch_size,
num_groups=1,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
# print(f"Patcher input shape: {x.shape}")
x = self.block(x, mapping, causal=causal)
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
# print(f"Patcher output shape {x.shape}")
return x
class Unpatcher(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
context_mapping_features: Optional[int] = None,
use_snake: bool = False
):
super().__init__()
assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
assert in_channels % patch_size == 0, assert_message
self.patch_size = patch_size
self.block = ResnetBlock1d(
in_channels=in_channels // patch_size,
out_channels=out_channels,
num_groups=1,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
# print(f"Unpatcher input shape: {x.shape}")
x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
x = self.block(x, mapping, causal=causal)
# print(f"Unpatcher output shape: {x.shape}")
return x
"""
Attention Components
"""
def FeedForward(features: int, multiplier: int) -> nn.Module:
# print(f'feed forward getting multipler {multiplier}')
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
b, ndim = sim.shape[0], mask.ndim
if ndim == 3:
mask = rearrange(mask, "b n m -> b 1 n m")
if ndim == 2:
mask = repeat(mask, "n m -> b 1 n m", b=b)
max_neg_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(~mask, max_neg_value)
return sim
def causal_mask(q: Tensor, k: Tensor) -> Tensor:
b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
mask = repeat(mask, "n m -> b n m", b=b)
return mask
class AttentionBase(nn.Module):
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
):
super().__init__()
self.scale = head_features**-0.5
self.num_heads = num_heads
mid_features = head_features * num_heads
out_features = default(out_features, features)
self.to_out = nn.Linear(
in_features=mid_features, out_features=out_features
)
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
self.sdp_kernel_config = (False, True, True)
def forward(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
) -> Tensor:
# Split heads
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
if not self.use_flash:
if is_causal and not mask:
# Mask out future tokens for causal attention
mask = causal_mask(q, k)
# Compute similarity matrix and add eventual mask
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
sim = add_mask(sim, mask) if exists(mask) else sim
# Get attention matrix with softmax
attn = sim.softmax(dim=-1, dtype=torch.float32)
# Compute values
out = einsum("... n m, ... m d -> ... n d", attn, v)
else:
with sdp_kernel(*self.sdp_kernel_config):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Attention(nn.Module):
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
context_features: Optional[int] = None,
causal: bool = False,
):
super().__init__()
self.context_features = context_features
self.causal = causal
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = nn.LayerNorm(features)
self.norm_context = nn.LayerNorm(context_features)
self.to_q = nn.Linear(
in_features=features, out_features=mid_features, bias=False
)
self.to_kv = nn.Linear(
in_features=context_features, out_features=mid_features * 2, bias=False
)
self.attention = AttentionBase(
features,
num_heads=num_heads,
head_features=head_features,
out_features=out_features,
)
def forward(
self,
x: Tensor, # [b, n, c]
context: Optional[Tensor] = None, # [b, m, d]
context_mask: Optional[Tensor] = None, # [b, m], false is masked,
causal: Optional[bool] = False,
) -> Tensor:
assert_message = "You must provide a context when using context_features"
assert not self.context_features or exists(context), assert_message
# Use context if provided
context = default(context, x)
# Normalize then compute q from input and k,v from context
x, context = self.norm(x), self.norm_context(context)
# print("Shape of x:", x.shape)
# print("Shape of context:", context.shape)
# print("context_mask:", context_mask)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
if exists(context_mask):
# Mask out cross-attention for padding tokens
mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
k, v = k * mask, v * mask
# Compute and return attention
return self.attention(q, k, v, is_causal=self.causal or causal)
def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
"""
Transformer Blocks
"""
class TransformerBlock(nn.Module):
def __init__(
self,
features: int,
num_heads: int,
head_features: int,
multiplier: int,
context_features: Optional[int] = None,
):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = Attention(
features=features,
num_heads=num_heads,
head_features=head_features
)
if self.use_cross_attention:
self.cross_attention = Attention(
features=features,
num_heads=num_heads,
head_features=head_features,
context_features=context_features
)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
# print(f'TransformerBlock input shape: {x.shape}')
x = self.attention(x, causal=causal) + x
if self.use_cross_attention:
x = self.cross_attention(x, context=context, context_mask=context_mask) + x
x = self.feed_forward(x) + x
# print(f'TransformerBlock output shape: {x.shape}')
return x
"""
Transformers
"""
class Transformer1d(nn.Module):
def __init__(
self,
num_layers: int,
channels: int,
num_heads: int,
head_features: int,
multiplier: int,
context_features: Optional[int] = None,
):
super().__init__()
self.to_in = nn.Sequential(
nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=1,
),
Rearrange("b c t -> b t c"),
)
self.blocks = nn.ModuleList(
[
TransformerBlock(
features=channels,
head_features=head_features,
num_heads=num_heads,
multiplier=multiplier,
context_features=context_features,
)
for i in range(num_layers)
]
)
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=1,
),
)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
# print(f'Transformer1d input shape: {x.shape}')
x = self.to_in(x)
for block in self.blocks:
x = block(x, context=context, context_mask=context_mask, causal=causal)
x = self.to_out(x)
# print(f'Transformer1d output shape: {x.shape}')
return x
"""
Time Embeddings
"""
class SinusoidalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device, half_dim = x.device, self.dim // 2
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
nn.Linear(in_features=dim + 1, out_features=out_features),
)
"""
Encoder/Decoder Components
"""
class DownsampleBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
factor: int,
num_groups: int,
num_layers: int,
kernel_multiplier: int = 2,
use_pre_downsample: bool = True,
use_skip: bool = False,
use_snake: bool = False,
extract_channels: int = 0,
context_channels: int = 0,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
super().__init__()
self.use_pre_downsample = use_pre_downsample
self.use_skip = use_skip
self.use_transformer = num_transformer_blocks > 0
self.use_extract = extract_channels > 0
self.use_context = context_channels > 0
channels = out_channels if use_pre_downsample else in_channels
self.downsample = Downsample1d(
in_channels=in_channels,
out_channels=out_channels,
factor=factor,
kernel_multiplier=kernel_multiplier,
)
self.blocks = nn.ModuleList(
[
ResnetBlock1d(
in_channels=channels + context_channels if i == 0 else channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
for i in range(num_layers)
]
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features
)
if self.use_extract:
num_extract_groups = min(num_groups, extract_channels)
self.to_extracted = ResnetBlock1d(
in_channels=out_channels,
out_channels=extract_channels,
num_groups=num_extract_groups,
use_snake=use_snake
)
def forward(
self,
x: Tensor,
*,
mapping: Optional[Tensor] = None,
channels: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
# print(f'DownsampleBlock1d input shape: {x.shape}')
if self.use_pre_downsample:
x = self.downsample(x)
if self.use_context and exists(channels):
x = torch.cat([x, channels], dim=1)
skips = []
for block in self.blocks:
x = block(x, mapping=mapping, causal=causal)
skips += [x] if self.use_skip else []
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
skips += [x] if self.use_skip else []
if not self.use_pre_downsample:
x = self.downsample(x)
if self.use_extract:
extracted = self.to_extracted(x)
return x, extracted
# print(f'DownsampleBlock1d output shape: {x.shape}')
return (x, skips) if self.use_skip else x
class UpsampleBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
factor: int,
num_layers: int,
num_groups: int,
use_nearest: bool = False,
use_pre_upsample: bool = False,
use_skip: bool = False,
use_snake: bool = False,
skip_channels: int = 0,
use_skip_scale: bool = False,
extract_channels: int = 0,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
super().__init__()
self.use_extract = extract_channels > 0
self.use_pre_upsample = use_pre_upsample
self.use_transformer = num_transformer_blocks > 0
self.use_skip = use_skip
self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
channels = out_channels if use_pre_upsample else in_channels
self.blocks = nn.ModuleList(
[
ResnetBlock1d(
in_channels=channels + skip_channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
for _ in range(num_layers)
]
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
)
self.upsample = Upsample1d(
in_channels=in_channels,
out_channels=out_channels,
factor=factor,
use_nearest=use_nearest,
)
if self.use_extract:
num_extract_groups = min(num_groups, extract_channels)
self.to_extracted = ResnetBlock1d(
in_channels=out_channels,
out_channels=extract_channels,
num_groups=num_extract_groups,
use_snake=use_snake
)
def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
return torch.cat([x, skip * self.skip_scale], dim=1)
def forward(
self,
x: Tensor,
*,
skips: Optional[List[Tensor]] = None,
mapping: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Union[Tuple[Tensor, Tensor], Tensor]:
# print(f'UpsampleBlock1d input shape: {x.shape}')
if self.use_pre_upsample:
x = self.upsample(x)
for block in self.blocks:
x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
x = block(x, mapping=mapping, causal=causal)
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
if not self.use_pre_upsample:
x = self.upsample(x)
if self.use_extract:
extracted = self.to_extracted(x)
return x, extracted
# print(f'UpsampleBlock1d output shape: {x.shape}')
return x
class BottleneckBlock1d(nn.Module):
def __init__(
self,
channels: int,
*,
num_groups: int,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
use_snake: bool = False,
):
super().__init__()
self.use_transformer = num_transformer_blocks > 0
self.pre_block = ResnetBlock1d(
in_channels=channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
)
self.post_block = ResnetBlock1d(
in_channels=channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(
self,
x: Tensor,
*,
mapping: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Tensor:
# print(f'BottleneckBlock1d input shape: {x.shape}')
x = self.pre_block(x, mapping=mapping, causal=causal)
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
x = self.post_block(x, mapping=mapping, causal=causal)
# print(f'BottleneckBlock1d output shape: {x.shape}')
return x
"""
UNet
"""
class UNet1d(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
multipliers: Sequence[int],
factors: Sequence[int],
num_blocks: Sequence[int],
attentions: Sequence[int],
patch_size: int = 1,
resnet_groups: int = 8,
use_context_time: bool = True,
kernel_multiplier_downsample: int = 2,
use_nearest_upsample: bool = False,
use_skip_scale: bool = True,
use_snake: bool = False,
use_stft: bool = False,
use_stft_context: bool = False,
out_channels: Optional[int] = None,
context_features: Optional[int] = None,
context_features_multiplier: int = 4,
context_channels: Optional[Sequence[int]] = None,
context_embedding_features: Optional[int] = None,
**kwargs,
):
super().__init__()
out_channels = default(out_channels, in_channels)
context_channels = list(default(context_channels, []))
num_layers = len(multipliers) - 1
use_context_features = exists(context_features)
use_context_channels = len(context_channels) > 0
context_mapping_features = None
attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
self.num_layers = num_layers
self.use_context_time = use_context_time
self.use_context_features = use_context_features
self.use_context_channels = use_context_channels
self.use_stft = use_stft
self.use_stft_context = use_stft_context
self.context_features = context_features
context_channels_pad_length = num_layers + 1 - len(context_channels)
context_channels = context_channels + [0] * context_channels_pad_length
self.context_channels = context_channels
self.context_embedding_features = context_embedding_features
if use_context_channels:
has_context = [c > 0 for c in context_channels]
self.has_context = has_context
self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
assert (
len(factors) == num_layers
and len(attentions) >= num_layers
and len(num_blocks) == num_layers
)
if use_context_time or use_context_features:
context_mapping_features = channels * context_features_multiplier
self.to_mapping = nn.Sequential(
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
)
if use_context_time:
assert exists(context_mapping_features)
self.to_time = nn.Sequential(
TimePositionalEmbedding(
dim=channels, out_features=context_mapping_features
),
nn.GELU(),
)
if use_context_features:
assert exists(context_features) and exists(context_mapping_features)
self.to_features = nn.Sequential(
nn.Linear(
in_features=context_features, out_features=context_mapping_features
),
nn.GELU(),
)
if use_stft:
stft_kwargs, kwargs = groupby("stft_", kwargs)
assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
in_channels *= stft_channels
out_channels *= stft_channels
context_channels[0] *= stft_channels if use_stft_context else 1
assert exists(in_channels) and exists(out_channels)
self.stft = STFT(**stft_kwargs)
assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
self.to_in = Patcher(
in_channels=in_channels + context_channels[0],
out_channels=channels * multipliers[0],
patch_size=patch_size,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
self.downsamples = nn.ModuleList(
[
DownsampleBlock1d(
in_channels=channels * multipliers[i],
out_channels=channels * multipliers[i + 1],
context_mapping_features=context_mapping_features,
context_channels=context_channels[i + 1],
context_embedding_features=context_embedding_features,
num_layers=num_blocks[i],
factor=factors[i],
kernel_multiplier=kernel_multiplier_downsample,
num_groups=resnet_groups,
use_pre_downsample=True,
use_skip=True,
use_snake=use_snake,
num_transformer_blocks=attentions[i],
**attention_kwargs,
)
for i in range(num_layers)
]
)
self.bottleneck = BottleneckBlock1d(
channels=channels * multipliers[-1],
context_mapping_features=context_mapping_features,
context_embedding_features=context_embedding_features,
num_groups=resnet_groups,
num_transformer_blocks=attentions[-1],
use_snake=use_snake,
**attention_kwargs,
)
self.upsamples = nn.ModuleList(
[
UpsampleBlock1d(
in_channels=channels * multipliers[i + 1],
out_channels=channels * multipliers[i],
context_mapping_features=context_mapping_features,
context_embedding_features=context_embedding_features,
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
factor=factors[i],
use_nearest=use_nearest_upsample,
num_groups=resnet_groups,
use_skip_scale=use_skip_scale,
use_pre_upsample=False,
use_skip=True,
use_snake=use_snake,
skip_channels=channels * multipliers[i + 1],
num_transformer_blocks=attentions[i],
**attention_kwargs,
)
for i in reversed(range(num_layers))
]
)
self.to_out = Unpatcher(
in_channels=channels * multipliers[0],
out_channels=out_channels,
patch_size=patch_size,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def get_channels(
self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
) -> Optional[Tensor]:
"""Gets context channels at `layer` and checks that shape is correct"""
use_context_channels = self.use_context_channels and self.has_context[layer]
if not use_context_channels:
return None
assert exists(channels_list), "Missing context"
# Get channels index (skipping zero channel contexts)
channels_id = self.channels_ids[layer]
# Get channels
channels = channels_list[channels_id]
message = f"Missing context for layer {layer} at index {channels_id}"
assert exists(channels), message
# Check channels
num_channels = self.context_channels[layer]
message = f"Expected context with {num_channels} channels at idx {channels_id}"
assert channels.shape[1] == num_channels, message
# STFT channels if requested
channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
return channels
def get_mapping(
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
) -> Optional[Tensor]:
"""Combines context time features and features into mapping"""
items, mapping = [], None
# Compute time features
if self.use_context_time:
assert_message = "use_context_time=True but no time features provided"
assert exists(time), assert_message
items += [self.to_time(time)]
# Compute features
if self.use_context_features:
assert_message = "context_features exists but no features provided"
assert exists(features), assert_message
items += [self.to_features(features)]
# Compute joint mapping
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def forward(
self,
x: Tensor,
time: Optional[Tensor] = None,
*,
features: Optional[Tensor] = None,
channels_list: Optional[Sequence[Tensor]] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False,
) -> Tensor:
# print(f'Unet1d input shape: {x.shape}')
channels = self.get_channels(channels_list, layer=0)
# Apply stft if required
x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
# Concat context channels at layer 0 if provided
x = torch.cat([x, channels], dim=1) if exists(channels) else x
# Compute mapping from time and features
mapping = self.get_mapping(time, features)
x = self.to_in(x, mapping, causal=causal)
skips_list = [x]
for i, downsample in enumerate(self.downsamples):
channels = self.get_channels(channels_list, layer=i + 1)
x, skips = downsample(
x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
)
skips_list += [skips]
x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
for i, upsample in enumerate(self.upsamples):
skips = skips_list.pop()
x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
x += skips_list.pop()
x = self.to_out(x, mapping, causal=causal)
x = self.stft.decode1d(x) if self.use_stft else x
# print(f'Unet1d output shape: {x.shape}')
return x
""" Conditioning Modules """
class FixedEmbedding(nn.Module):
def __init__(self, max_length: int, features: int):
super().__init__()
self.max_length = max_length
self.embedding = nn.Embedding(max_length, features)
def forward(self, x: Tensor) -> Tensor:
# print(f'FixedEmbedding input shape: {x.shape}')
batch_size, length, device = *x.shape[0:2], x.device
# print(f'FixedEmbedding length: {length}, self.max length: {self.max_length}')
assert_message = "Input sequence length must be <= max_length"
assert length <= self.max_length, assert_message
position = torch.arange(length, device=device)
fixed_embedding = self.embedding(position)
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
# print(f'FixedEmbedding output shape: {fixed_embedding.shape}')
return fixed_embedding
def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
class UNetCFG1d(UNet1d):
"""UNet1d with Classifier-Free Guidance"""
def __init__(
self,
context_embedding_max_length: int,
context_embedding_features: int,
use_xattn_time: bool = False,
**kwargs,
):
super().__init__(
context_embedding_features=context_embedding_features, **kwargs
)
self.use_xattn_time = use_xattn_time
if use_xattn_time:
assert exists(context_embedding_features)
self.to_time_embedding = nn.Sequential(
TimePositionalEmbedding(
dim=kwargs["channels"], out_features=context_embedding_features
),
nn.GELU(),
)
context_embedding_max_length += 1 # Add one for time embedding
self.fixed_embedding = FixedEmbedding(
max_length=context_embedding_max_length, features=context_embedding_features
)
def forward( # type: ignore
self,
x: Tensor,
time: Tensor,
*,
embedding: Tensor,
embedding_mask: Optional[Tensor] = None,
embedding_scale: float = 1.0,
embedding_mask_proba: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
scale_phi: float = 0.4,
negative_embedding: Optional[Tensor] = None,
negative_embedding_mask: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
# print("Debugging UNetCFG1d forward method")
# print(f"Input x shape: {x.shape}, type: {type(x)}")
# print(f"Time embedding shape: {time.shape}, type: {type(time)}")
# print(f"Cross-attention embedding shape: {embedding.shape}, type: {type(embedding)}")
# print(f"Cross-attention embedding mask shape: {embedding_mask.shape if embedding_mask is not None else 'None'}, type: {type(embedding_mask)}")
# print(f"Embedding scale: {embedding_scale}, type: {type(embedding_scale)}")
# print(f"Embedding mask probability: {embedding_mask_proba}, type: {type(embedding_mask_proba)}")
# print(f"Batch CFG: {batch_cfg}, type: {type(batch_cfg)}")
# print(f"Rescale CFG: {rescale_cfg}, type: {type(rescale_cfg)}")
# print(f"Scale Phi: {scale_phi}, type: {type(scale_phi)}")
# if negative_embedding is not None:
# print(f"Negative embedding shape: {negative_embedding.shape}, type: {type(negative_embedding)}")
# if negative_embedding_mask is not None:
# print(f"Negative embedding mask shape: {negative_embedding_mask.shape}, type: {type(negative_embedding_mask)}")
b, device = embedding.shape[0], embedding.device
if self.use_xattn_time:
embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
if embedding_mask is not None:
embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
fixed_embedding = self.fixed_embedding(embedding)
# print(f'Fixed Embedding.shape {fixed_embedding.shape}')
assert fixed_embedding.shape == embedding.shape, f"Shape mismatch: {fixed_embedding.shape} vs {embedding.shape}"
if embedding_mask_proba > 0.0:
# Randomly mask embedding
batch_mask = rand_bool(
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
if batch_cfg:
batch_x = torch.cat([x, x], dim=0)
batch_time = torch.cat([time, time], dim=0)
if negative_embedding is not None:
if negative_embedding_mask is not None:
negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
batch_embed = torch.cat([embedding, negative_embedding], dim=0)
else:
batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
batch_mask = None
if embedding_mask is not None:
batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
batch_features = None
features = kwargs.pop("features", None)
if self.use_context_features:
batch_features = torch.cat([features, features], dim=0)
batch_channels = None
channels_list = kwargs.pop("channels_list", None)
if self.use_context_channels:
batch_channels = []
for channels in channels_list:
batch_channels += [torch.cat([channels, channels], dim=0)]
# Compute both normal and fixed embedding outputs
batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
out, out_masked = batch_out.chunk(2, dim=0)
else:
# Compute both normal and fixed embedding outputs
out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
out_cfg = out_masked + (out - out_masked) * embedding_scale
if rescale_cfg:
out_std = out.std(dim=1, keepdim=True)
out_cfg_std = out_cfg.std(dim=1, keepdim=True)
return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
else:
return out_cfg
else:
return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
class UNetNCCA1d(UNet1d):
"""UNet1d with Noise Channel Conditioning Augmentation"""
def __init__(self, context_features: int, **kwargs):
super().__init__(context_features=context_features, **kwargs)
self.embedder = NumberEmbedder(features=context_features)
def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
x = x if torch.is_tensor(x) else torch.tensor(x)
return x.expand(shape)
def forward( # type: ignore
self,
x: Tensor,
time: Tensor,
*,
channels_list: Sequence[Tensor],
channels_augmentation: Union[
bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
] = False,
channels_scale: Union[
float, Sequence[float], Sequence[Sequence[float]], Tensor
] = 0,
**kwargs,
) -> Tensor:
b, n = x.shape[0], len(channels_list)
channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
# Augmentation (for each channel list item)
for i in range(n):
scale = channels_scale[:, i] * channels_augmentation[:, i]
scale = rearrange(scale, "b -> b 1 1")
item = channels_list[i]
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
# Scale embedding (sum reduction if more than one channel list item)
channels_scale_emb = self.embedder(channels_scale)
channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
return super().forward(
x=x,
time=time,
channels_list=channels_list,
features=channels_scale_emb,
**kwargs,
)
class UNetAll1d(UNetCFG1d, UNetNCCA1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs): # type: ignore
return UNetCFG1d.forward(self, *args, **kwargs)
def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
if type == "base":
return UNet1d(**kwargs)
elif type == "all":
return UNetAll1d(**kwargs)
elif type == "cfg":
return UNetCFG1d(**kwargs)
elif type == "ncca":
return UNetNCCA1d(**kwargs)
else:
raise ValueError(f"Unknown XUNet1d type: {type}")
class NumberEmbedder(nn.Module):
def __init__(
self,
features: int,
dim: int = 256,
):
super().__init__()
self.features = features
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
device = next(self.embedding.parameters()).device
x = torch.tensor(x, device=device)
assert isinstance(x, Tensor)
shape = x.shape
x = rearrange(x, "... -> (...)")
embedding = self.embedding(x)
x = embedding.view(*shape, self.features)
return x # type: ignore
"""
Audio Transforms
"""
class STFT(nn.Module):
"""Helper for torch stft and istft"""
def __init__(
self,
num_fft: int = 1023,
hop_length: int = 256,
window_length: Optional[int] = None,
length: Optional[int] = None,
use_complex: bool = False,
):
super().__init__()
self.num_fft = num_fft
self.hop_length = default(hop_length, floor(num_fft // 4))
self.window_length = default(window_length, num_fft)
self.length = length
self.register_buffer("window", torch.hann_window(self.window_length))
self.use_complex = use_complex
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
b = wave.shape[0]
wave = rearrange(wave, "b c t -> (b c) t")
stft = torch.stft(
wave,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
return_complex=True,
normalized=True,
)
if self.use_complex:
# Returns real and imaginary
stft_a, stft_b = stft.real, stft.imag
else:
# Returns magnitude and phase matrices
magnitude, phase = torch.abs(stft), torch.angle(stft)
stft_a, stft_b = magnitude, phase
return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
length = closest_power_2(l * self.hop_length)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
if self.use_complex:
real, imag = stft_a, stft_b
else:
magnitude, phase = stft_a, stft_b
real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
stft = torch.stack([real, imag], dim=-1)
wave = torch.istft(
stft,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
length=default(self.length, length),
normalized=True,
)
return rearrange(wave, "(b c) t -> b c t", b=b)
def encode1d(
self, wave: Tensor, stacked: bool = True
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
stft_a, stft_b = self.encode(wave)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
def decode1d(self, stft_pair: Tensor) -> Tensor:
f = self.num_fft // 2 + 1
stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
return self.decode(stft_a, stft_b)