Spaces:
Running
on
L4
Running
on
L4
import math | |
import typing as tp | |
from dataclasses import dataclass | |
from typing import List, Optional, Union | |
import hydra | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from audiotools import AudioSignal | |
from audiotools.ml import BaseModel | |
from dac.model.base import CodecMixin | |
from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d | |
from omegaconf import OmegaConf | |
from torch import Tensor, nn | |
from torch.nn import functional as F | |
from torch.nn.utils.parametrizations import weight_norm | |
from torch.nn.utils.parametrize import remove_parametrizations | |
class VQResult: | |
z: torch.Tensor | |
codes: torch.Tensor | |
latents: torch.Tensor | |
codebook_loss: torch.Tensor | |
commitment_loss: torch.Tensor | |
semantic_distill_z: torch.Tensor | None = None | |
def find_multiple(n: int, k: int) -> int: | |
if n % k == 0: | |
return n | |
return n + k - (n % k) | |
class ModelArgs: | |
block_size: int = 2048 | |
n_layer: int = 8 | |
n_head: int = 8 | |
dim: int = 512 | |
intermediate_size: int = 1536 | |
n_local_heads: int = -1 | |
head_dim: int = 64 | |
rope_base: float = 10000 | |
norm_eps: float = 1e-5 | |
dropout_rate: float = 0.1 | |
attn_dropout_rate: float = 0.1 | |
channels_first: bool = True # to be compatible with conv1d input/output | |
pos_embed_type: str = "rope" # can be "rope" or "conformer" | |
max_relative_position: int = 128 # for conformer-style relative position embedding | |
def __post_init__(self): | |
if self.n_local_heads == -1: | |
self.n_local_heads = self.n_head | |
if self.intermediate_size is None: | |
hidden_dim = 4 * self.dim | |
n_hidden = int(2 * hidden_dim / 3) | |
self.intermediate_size = find_multiple(n_hidden, 256) | |
assert self.pos_embed_type in [ | |
"rope", | |
"conformer", | |
], "pos_embed_type must be either 'rope' or 'conformer'" | |
class KVCache(nn.Module): | |
def __init__( | |
self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 | |
): | |
super().__init__() | |
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) | |
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) | |
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | |
def update(self, input_pos, k_val, v_val): | |
# input_pos: [S], k_val: [B, H, S, D] | |
assert input_pos.shape[0] == k_val.shape[2] | |
k_out = self.k_cache | |
v_out = self.v_cache | |
k_out[:, :, input_pos] = k_val | |
v_out[:, :, input_pos] = v_val | |
return ( | |
k_out[:, :, : input_pos.max() + 1, :], | |
v_out[:, :, : input_pos.max() + 1, :], | |
) | |
def clear_cache(self, prompt_len): | |
self.k_cache[:, :, prompt_len:, :].fill_(0) | |
self.v_cache[:, :, prompt_len:, :].fill_(0) | |
class Transformer(nn.Module): | |
def __init__(self, config: ModelArgs) -> None: | |
super().__init__() | |
self.config = config | |
self.layers = nn.ModuleList( | |
TransformerBlock(config) for _ in range(config.n_layer) | |
) | |
self.norm = RMSNorm(config.dim, eps=config.norm_eps) | |
# Only compute RoPE frequencies if using RoPE | |
if config.pos_embed_type == "rope": | |
freqs_cis = precompute_freqs_cis( | |
self.config.block_size, self.config.head_dim, self.config.rope_base | |
) | |
self.register_buffer("freqs_cis", freqs_cis) | |
else: | |
self.register_buffer("freqs_cis", None) | |
causal_mask = torch.tril( | |
torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool) | |
) | |
self.register_buffer("causal_mask", causal_mask) | |
self.max_batch_size = -1 | |
self.max_seq_length = -1 | |
self.use_kv_cache = False | |
def setup_caches(self, max_batch_size, max_seq_length): | |
""" | |
This method will only be called during inference when using KV cache. | |
""" | |
head_dim = self.config.dim // self.config.n_head | |
max_seq_length = find_multiple(max_seq_length, 8) | |
self.max_seq_length = max_seq_length | |
self.max_batch_size = max_batch_size | |
dtype = self.norm.weight.dtype | |
device = self.norm.weight.device | |
for b in self.layers: | |
b.attention.kv_cache = KVCache( | |
max_batch_size, | |
max_seq_length, | |
self.config.n_local_heads, | |
head_dim, | |
dtype, | |
).to(device) | |
self.use_kv_cache = True | |
def forward( | |
self, | |
x: Tensor, | |
input_pos: Optional[Tensor] = None, | |
mask: Optional[Tensor] = None, | |
) -> Tensor: | |
if self.config.pos_embed_type == "rope": | |
assert ( | |
self.freqs_cis is not None | |
), "RoPE frequencies must be initialized for RoPE positional embedding" | |
freqs_cis = self.freqs_cis[input_pos] | |
else: | |
freqs_cis = None | |
if mask is None: # in case of non-causal model | |
if not self.training and self.use_kv_cache: | |
mask = self.causal_mask[None, None, input_pos] | |
mask = mask[..., : input_pos.max() + 1] | |
else: | |
mask = self.causal_mask[None, None, input_pos] | |
mask = mask[..., input_pos] | |
for i, layer in enumerate(self.layers): | |
x = layer(x, input_pos, freqs_cis, mask) | |
x = self.norm(x) | |
return x | |
class TransformerBlock(nn.Module): | |
def __init__(self, config: ModelArgs) -> None: | |
super().__init__() | |
self.attention = Attention(config) | |
self.feed_forward = FeedForward(config) | |
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) | |
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) | |
self.attention_layer_scale = LayerScale(config.dim, inplace=True) | |
self.ffn_layer_scale = LayerScale(config.dim, inplace=True) | |
def forward( | |
self, | |
x: Tensor, | |
input_pos: Tensor, | |
freqs_cis: Tensor, | |
mask: Tensor, | |
) -> Tensor: | |
h = x + self.attention_layer_scale( | |
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) | |
) | |
out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h))) | |
return out | |
class Attention(nn.Module): | |
def __init__(self, config: ModelArgs): | |
super().__init__() | |
assert config.dim % config.n_head == 0 | |
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim | |
# key, query, value projections for all heads, but in a batch | |
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) | |
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) | |
self.kv_cache = None | |
self.n_head = config.n_head | |
self.head_dim = config.head_dim | |
self.n_local_heads = config.n_local_heads | |
self.dim = config.dim | |
self.attn_dropout_rate = config.attn_dropout_rate | |
self.pos_embed_type = config.pos_embed_type | |
# Add relative position embedding for conformer-style | |
if self.pos_embed_type == "conformer": | |
self.max_relative_position = config.max_relative_position | |
num_pos_embeddings = 2 * config.max_relative_position + 1 | |
self.rel_pos_embeddings = nn.Parameter( | |
torch.zeros(num_pos_embeddings, self.head_dim) | |
) | |
nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02) | |
def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor: | |
# q: [B, H, S, D] | |
# Returns: [B, H, S, S] | |
positions = torch.arange(seqlen, device=q.device) | |
relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S] | |
relative_positions = torch.clamp( | |
relative_positions + self.max_relative_position, | |
0, | |
2 * self.max_relative_position, | |
) | |
rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D] | |
# Compute attention scores with relative position embeddings | |
q = q.transpose(1, 2) # [B, S, H, D] | |
rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S] | |
rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S] | |
return rel_logits | |
def forward( | |
self, | |
x: Tensor, | |
freqs_cis: Tensor, | |
mask: Tensor, | |
input_pos: Optional[Tensor] = None, | |
) -> Tensor: | |
bsz, seqlen, _ = x.shape | |
kv_size = self.n_local_heads * self.head_dim | |
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) | |
context_seqlen = seqlen | |
q = q.view(bsz, seqlen, self.n_head, self.head_dim) | |
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) | |
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) | |
if self.pos_embed_type == "rope": | |
q = apply_rotary_emb(q, freqs_cis) | |
k = apply_rotary_emb(k, freqs_cis) | |
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) | |
if self.kv_cache is not None: | |
k, v = self.kv_cache.update(input_pos, k, v) | |
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | |
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | |
if self.pos_embed_type == "conformer": | |
# Compute attention scores | |
scale = 1.0 / math.sqrt(self.head_dim) | |
scores = torch.matmul(q, k.transpose(-2, -1)) * scale | |
# Add relative position embeddings for conformer-style | |
rel_scores = self._compute_conformer_pos_scores(q, seqlen) | |
scores = scores + rel_scores | |
# Apply attention | |
if mask is not None: | |
scores = scores.masked_fill(~mask, float("-inf")) | |
attn = F.softmax(scores, dim=-1) | |
if self.attn_dropout_rate > 0 and self.training: | |
attn = F.dropout(attn, p=self.attn_dropout_rate) | |
y = torch.matmul(attn, v) | |
else: | |
y = F.scaled_dot_product_attention( | |
q, | |
k, | |
v, | |
dropout_p=self.attn_dropout_rate if self.training else 0.0, | |
attn_mask=mask, | |
) | |
# is_causal=True) | |
y = ( | |
y.transpose(1, 2) | |
.contiguous() | |
.view(bsz, seqlen, self.head_dim * self.n_head) | |
) | |
y = self.wo(y) | |
return y | |
class FeedForward(nn.Module): | |
def __init__(self, config: ModelArgs) -> None: | |
super().__init__() | |
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) | |
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) | |
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) | |
self.dropout = nn.Dropout(config.dropout_rate) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) | |
class RMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-5): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def _norm(self, x): | |
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) | |
def forward(self, x: Tensor) -> Tensor: | |
output = self._norm(x.float()).type_as(x) | |
return output * self.weight | |
class LayerScale(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
init_values: Union[float, Tensor] = 1e-2, | |
inplace: bool = False, | |
) -> None: | |
super().__init__() | |
self.inplace = inplace | |
self.gamma = nn.Parameter(init_values * torch.ones(dim)) | |
def forward(self, x: Tensor) -> Tensor: | |
return x.mul_(self.gamma) if self.inplace else x * self.gamma | |
class WindowLimitedTransformer(Transformer): | |
""" | |
Transformer with window limited attention, causal. | |
""" | |
def __init__( | |
self, | |
config: ModelArgs, | |
input_dim: int = 512, | |
window_size: Optional[int] = None, | |
causal: bool = True, | |
look_ahead_conv: nn.Module = None, | |
): | |
super().__init__(config) | |
self.window_size = window_size | |
self.causal = causal | |
self.channels_first = config.channels_first | |
self.look_ahead_conv = ( | |
look_ahead_conv if look_ahead_conv is not None else nn.Identity() | |
) | |
self.input_proj = ( | |
nn.Linear(input_dim, config.dim) | |
if input_dim != config.dim | |
else nn.Identity() | |
) | |
self.output_proj = ( | |
nn.Linear(config.dim, input_dim) | |
if input_dim != config.dim | |
else nn.Identity() | |
) | |
def make_window_limited_mask( | |
self, | |
max_length: int, | |
x_lens: Optional[Tensor] = None, | |
) -> Tensor: | |
""" | |
Make mask to form window limited attention. | |
""" | |
if self.causal: | |
mask = torch.tril(torch.ones(max_length, max_length)) | |
row_indices = torch.arange(max_length).view(-1, 1) | |
window_size = self.window_size or max_length | |
valid_range = (row_indices - window_size + 1).clamp(min=0) | |
column_indices = torch.arange(max_length) | |
mask = (column_indices >= valid_range) & mask.bool() | |
else: | |
raise NotImplementedError | |
mask = mask.bool()[None, None] | |
return mask | |
def make_mask( | |
self, | |
max_length: int, | |
x_lens: Optional[Tensor] = None, | |
) -> Tensor: | |
""" | |
Make ordinary mask if window size is not specified. | |
""" | |
if self.causal: | |
mask = torch.tril(torch.ones(max_length, max_length)) | |
else: | |
mask = torch.ones(max_length, max_length) | |
mask = mask.bool()[None, None] | |
for i, x_len in enumerate(x_lens): | |
mask[:x_len, i] = 0 | |
mask = mask.bool()[None, None] | |
return mask | |
def forward( | |
self, | |
x: Tensor, | |
x_lens: Optional[Tensor] = None, | |
) -> Tensor: | |
if self.channels_first: | |
x = x.transpose(1, 2) | |
x = self.input_proj(x) # (B, T, D) | |
x = self.look_ahead_conv(x) | |
input_pos = torch.arange(x.shape[1], device=x.device) | |
# construct mask to form window limited attention | |
max_length = x.shape[1] | |
if self.window_size is not None: | |
mask = self.make_window_limited_mask(max_length, x_lens) | |
else: | |
mask = self.make_mask(max_length, x_lens) | |
mask = mask.to(x.device) | |
x = super().forward(x, input_pos, mask) | |
x = self.output_proj(x) # (B, T, D) | |
if self.channels_first: | |
x = x.transpose(1, 2) | |
return x | |
def precompute_freqs_cis( | |
seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16 | |
) -> Tensor: | |
freqs = 1.0 / ( | |
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) | |
) | |
t = torch.arange(seq_len, device=freqs.device) | |
freqs = torch.outer(t, freqs) | |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) | |
return cache.to(dtype=dtype) | |
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) | |
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) | |
x_out2 = torch.stack( | |
[ | |
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], | |
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], | |
], | |
-1, | |
) | |
x_out2 = x_out2.flatten(3) | |
return x_out2.type_as(x) | |
def init_weights(m): | |
if isinstance(m, nn.Conv1d): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
nn.init.constant_(m.bias, 0) | |
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): | |
"""Remove padding from x, handling properly zero padding. Only for 1d!""" | |
padding_left, padding_right = paddings | |
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
assert (padding_left + padding_right) <= x.shape[-1] | |
end = x.shape[-1] - padding_right | |
return x[..., padding_left:end] | |
def get_extra_padding_for_conv1d( | |
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 | |
) -> int: | |
"""See `pad_for_conv1d`.""" | |
length = x.shape[-1] | |
n_frames = (length - kernel_size + padding_total) / stride + 1 | |
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
return ideal_length - length | |
def pad1d( | |
x: torch.Tensor, | |
paddings: tp.Tuple[int, int], | |
mode: str = "zeros", | |
value: float = 0.0, | |
): | |
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input. | |
If this is the case, we insert extra 0 padding to the right | |
before the reflection happen. | |
""" | |
length = x.shape[-1] | |
padding_left, padding_right = paddings | |
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
if mode == "reflect": | |
max_pad = max(padding_left, padding_right) | |
extra_pad = 0 | |
if length <= max_pad: | |
extra_pad = max_pad - length + 1 | |
x = F.pad(x, (0, extra_pad)) | |
padded = F.pad(x, paddings, mode, value) | |
end = padded.shape[-1] - extra_pad | |
return padded[..., :end] | |
else: | |
return F.pad(x, paddings, mode, value) | |
class CausalConvNet(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
dilation=1, | |
stride=1, | |
groups=1, | |
padding=None, | |
): | |
super(CausalConvNet, self).__init__() | |
self.conv = nn.Conv1d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
dilation=dilation, | |
groups=groups, | |
) | |
self.stride = stride | |
self.kernel_size = (kernel_size - 1) * dilation + 1 | |
self.dilation = dilation | |
self.padding = self.kernel_size - self.stride | |
def forward(self, x): | |
pad = self.padding | |
extra_padding = get_extra_padding_for_conv1d( | |
x, self.kernel_size, self.stride, pad | |
) | |
x = pad1d(x, (pad, extra_padding), mode="constant", value=0) | |
return self.conv(x).contiguous() | |
def weight_norm(self, name="weight", dim=0): | |
self.conv = weight_norm(self.conv, name=name, dim=dim) | |
return self | |
def remove_weight_norm(self): | |
self.conv = remove_parametrizations(self.conv) | |
return self | |
class CausalTransConvNet(nn.Module): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None | |
): | |
super(CausalTransConvNet, self).__init__() | |
self.conv = nn.ConvTranspose1d( | |
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation | |
) | |
self.stride = stride | |
self.kernel_size = kernel_size | |
def forward(self, x): | |
x = self.conv(x) | |
pad = self.kernel_size - self.stride | |
padding_right = math.ceil(pad) | |
padding_left = pad - padding_right | |
x = unpad1d(x, (padding_left, padding_right)) | |
return x.contiguous() | |
def weight_norm(self, name="weight", dim=0): | |
self.conv = weight_norm(self.conv, name=name, dim=dim) | |
return self | |
def remove_weight_norm(self): | |
self.conv = remove_parametrizations(self.conv) | |
return self | |
def CausalWNConv1d(*args, **kwargs): | |
return CausalConvNet(*args, **kwargs).weight_norm() | |
def CausalWNConvTranspose1d(*args, **kwargs): | |
return CausalTransConvNet(*args, **kwargs).weight_norm() | |
class ResidualUnit(nn.Module): | |
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): | |
super().__init__() | |
conv_class = CausalWNConv1d if causal else WNConv1d | |
pad = ((7 - 1) * dilation) // 2 | |
self.block = nn.Sequential( | |
Snake1d(dim), | |
conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad), | |
Snake1d(dim), | |
conv_class(dim, dim, kernel_size=1), | |
) | |
self.causal = causal | |
def forward(self, x): | |
y = self.block(x) | |
pad = x.shape[-1] - y.shape[-1] | |
if pad > 0: | |
if self.causal: | |
x = x[..., :-pad] | |
else: | |
x = x[..., pad // 2 : -pad // 2] | |
return x + y | |
class EncoderBlock(nn.Module): | |
def __init__( | |
self, | |
dim: int = 16, | |
stride: int = 1, | |
causal: bool = False, | |
n_t_layer: int = 0, | |
transformer_general_config=None, | |
): | |
super().__init__() | |
conv_class = CausalWNConv1d if causal else WNConv1d | |
transformer_module = ( | |
nn.Identity() | |
if n_t_layer == 0 | |
else ( | |
WindowLimitedTransformer( | |
causal=causal, | |
input_dim=dim, | |
window_size=512, | |
config=transformer_general_config( | |
n_layer=n_t_layer, | |
n_head=dim // 64, | |
dim=dim, | |
intermediate_size=dim * 3, | |
), | |
) | |
) | |
) | |
self.block = nn.Sequential( | |
ResidualUnit(dim // 2, dilation=1, causal=causal), | |
ResidualUnit(dim // 2, dilation=3, causal=causal), | |
ResidualUnit(dim // 2, dilation=9, causal=causal), | |
Snake1d(dim // 2), | |
conv_class( | |
dim // 2, | |
dim, | |
kernel_size=2 * stride, | |
stride=stride, | |
padding=math.ceil(stride / 2), | |
), | |
transformer_module, | |
) | |
def forward(self, x): | |
return self.block(x) | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
d_model: int = 64, | |
strides: list = [2, 4, 8, 8], | |
d_latent: int = 64, | |
n_transformer_layers: list = [0, 0, 4, 4], | |
transformer_general_config: ModelArgs = None, | |
causal: bool = False, | |
): | |
super().__init__() | |
conv_class = CausalWNConv1d if causal else WNConv1d | |
# Create first convolution | |
self.block = [conv_class(1, d_model, kernel_size=7, padding=3)] | |
# Create EncoderBlocks that double channels as they downsample by `stride` | |
for stride, n_t_layer in zip(strides, n_transformer_layers): | |
d_model *= 2 | |
self.block += [ | |
EncoderBlock( | |
d_model, | |
stride=stride, | |
causal=causal, | |
n_t_layer=n_t_layer, | |
transformer_general_config=transformer_general_config, | |
) | |
] | |
# Create last convolution | |
self.block += [ | |
Snake1d(d_model), | |
conv_class(d_model, d_latent, kernel_size=3, padding=1), | |
] | |
# Wrap black into nn.Sequential | |
self.block = nn.Sequential(*self.block) | |
self.enc_dim = d_model | |
def forward(self, x): | |
return self.block(x) | |
class DecoderBlock(nn.Module): | |
def __init__( | |
self, | |
input_dim: int = 16, | |
output_dim: int = 8, | |
stride: int = 1, | |
causal: bool = False, | |
n_t_layer: int = 0, | |
transformer_general_config=None, | |
): | |
super().__init__() | |
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d | |
transformer_module = ( | |
nn.Identity() | |
if n_t_layer == 0 | |
else ( | |
WindowLimitedTransformer( | |
causal=causal, | |
input_dim=input_dim, | |
window_size=None, | |
config=transformer_general_config( | |
n_layer=n_t_layer, | |
n_head=input_dim // 64, | |
dim=input_dim, | |
intermediate_size=input_dim * 3, | |
), | |
) | |
) | |
) | |
self.block = nn.Sequential( | |
# transformer_module, | |
Snake1d(input_dim), | |
conv_trans_class( | |
input_dim, | |
output_dim, | |
kernel_size=2 * stride, | |
stride=stride, | |
padding=math.ceil(stride / 2), | |
), | |
ResidualUnit(output_dim, dilation=1, causal=causal), | |
ResidualUnit(output_dim, dilation=3, causal=causal), | |
ResidualUnit(output_dim, dilation=9, causal=causal), | |
) | |
def forward(self, x): | |
return self.block(x) | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
input_channel, | |
channels, | |
rates, | |
d_out: int = 1, | |
causal: bool = False, | |
n_transformer_layers: list = [0, 0, 0, 0], | |
transformer_general_config=None, | |
): | |
super().__init__() | |
conv_class = CausalWNConv1d if causal else WNConv1d | |
# Add first conv layer | |
layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)] | |
# Add upsampling + MRF blocks | |
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)): | |
input_dim = channels // 2**i | |
output_dim = channels // 2 ** (i + 1) | |
layers += [ | |
DecoderBlock( | |
input_dim, | |
output_dim, | |
stride, | |
causal=causal, | |
n_t_layer=n_t_layer, | |
transformer_general_config=transformer_general_config, | |
) | |
] | |
# Add final conv layer | |
layers += [ | |
Snake1d(output_dim), | |
conv_class(output_dim, d_out, kernel_size=7, padding=3), | |
nn.Tanh(), | |
] | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |
class DAC(BaseModel, CodecMixin): | |
def __init__( | |
self, | |
encoder_dim: int = 64, | |
encoder_rates: List[int] = [2, 4, 8, 8], | |
latent_dim: int = None, | |
decoder_dim: int = 1536, | |
decoder_rates: List[int] = [8, 8, 4, 2], | |
quantizer: torch.nn.Module = None, | |
sample_rate: int = 44100, | |
causal: bool = True, | |
encoder_transformer_layers: List[int] = [0, 0, 0, 0], | |
decoder_transformer_layers: List[int] = [0, 0, 0, 0], | |
transformer_general_config=None, | |
): | |
super().__init__() | |
self.encoder_dim = encoder_dim | |
self.encoder_rates = encoder_rates | |
self.decoder_dim = decoder_dim | |
self.decoder_rates = decoder_rates | |
self.sample_rate = sample_rate | |
if latent_dim is None: | |
latent_dim = encoder_dim * (2 ** len(encoder_rates)) | |
self.latent_dim = latent_dim | |
self.hop_length = np.prod(encoder_rates) | |
self.encoder = Encoder( | |
encoder_dim, | |
encoder_rates, | |
latent_dim, | |
causal=causal, | |
n_transformer_layers=encoder_transformer_layers, | |
transformer_general_config=transformer_general_config, | |
) | |
self.quantizer = quantizer | |
self.decoder = Decoder( | |
latent_dim, | |
decoder_dim, | |
decoder_rates, | |
causal=causal, | |
n_transformer_layers=decoder_transformer_layers, | |
transformer_general_config=transformer_general_config, | |
) | |
self.sample_rate = sample_rate | |
self.apply(init_weights) | |
self.delay = self.get_delay() | |
self.frame_length = self.hop_length * 4 | |
def preprocess(self, audio_data, sample_rate): | |
if sample_rate is None: | |
sample_rate = self.sample_rate | |
assert sample_rate == self.sample_rate | |
length = audio_data.shape[-1] | |
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length | |
audio_data = nn.functional.pad(audio_data, (0, right_pad)) | |
return audio_data | |
def encode( | |
self, | |
audio_data: torch.Tensor, | |
audio_lengths: torch.Tensor = None, | |
n_quantizers: int = None, | |
**kwargs, | |
): | |
"""Encode given audio data and return quantized latent codes | |
Parameters | |
---------- | |
audio_data : Tensor[B x T] | |
Audio data to encode | |
n_quantizers : int, optional | |
Number of quantizers to use, by default None | |
If None, all quantizers are used. | |
Returns | |
------- | |
dict | |
A dictionary with the following keys: | |
"z" : Tensor[B x D x T] | |
Quantized continuous representation of input | |
"codes" : Tensor[B x N x T] | |
Codebook indices for each codebook | |
(quantized discrete representation of input) | |
"latents" : Tensor[B x N*D x T] | |
Projected latents (continuous representation of input before quantization) | |
"vq/commitment_loss" : Tensor[1] | |
Commitment loss to train encoder to predict vectors closer to codebook | |
entries | |
"vq/codebook_loss" : Tensor[1] | |
Codebook loss to update the codebook | |
"length" : int | |
Number of samples in input audio | |
""" | |
# pad to multiple of self.frame_length | |
if audio_data.ndim == 2: | |
audio_data = audio_data.unsqueeze(1) | |
# print(audio_data.shape) | |
length = audio_data.shape[-1] | |
right_pad = math.ceil(length / self.frame_length) * self.frame_length - length | |
audio_data = nn.functional.pad(audio_data, (0, right_pad)) | |
if audio_lengths is None: | |
audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device) | |
z = self.encoder(audio_data) | |
vq_results = self.quantizer(z, n_quantizers, **kwargs) | |
indices = vq_results.codes | |
indices_lens = torch.ceil(audio_lengths / self.frame_length).long() | |
return indices, indices_lens | |
def decode(self, indices: torch.Tensor, feature_lengths): | |
if indices.ndim == 2: | |
indices = indices[None] | |
z = self.quantizer.decode(indices) | |
audio_lengths = feature_lengths * self.frame_length | |
return self.decoder(z), audio_lengths | |
def forward( | |
self, | |
audio_data: torch.Tensor, | |
template: torch.Tensor = None, | |
mask: torch.Tensor = None, | |
sample_rate: int = None, | |
n_quantizers: int = None, | |
**kwargs, | |
): | |
"""Model forward pass | |
Parameters | |
---------- | |
audio_data : Tensor[B x 1 x T] | |
Audio data to encode | |
sample_rate : int, optional | |
Sample rate of audio data in Hz, by default None | |
If None, defaults to `self.sample_rate` | |
n_quantizers : int, optional | |
Number of quantizers to use, by default None. | |
If None, all quantizers are used. | |
Returns | |
------- | |
dict | |
A dictionary with the following keys: | |
"z" : Tensor[B x D x T] | |
Quantized continuous representation of input | |
"codes" : Tensor[B x N x T] | |
Codebook indices for each codebook | |
(quantized discrete representation of input) | |
"latents" : Tensor[B x N*D x T] | |
Projected latents (continuous representation of input before quantization) | |
"vq/commitment_loss" : Tensor[1] | |
Commitment loss to train encoder to predict vectors closer to codebook | |
entries | |
"vq/codebook_loss" : Tensor[1] | |
Codebook loss to update the codebook | |
"length" : int | |
Number of samples in input audio | |
"audio" : Tensor[B x 1 x length] | |
Decoded audio data. | |
""" | |
length = audio_data.shape[-1] | |
audio_data = self.preprocess(audio_data, sample_rate) | |
vq_results = self.encode(audio_data, n_quantizers, **kwargs) | |
z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z | |
x = self.decode(z) | |
return x[..., :length], vq_results | |