Stardust-minus's picture
Upload folder using huggingface_hub
440bab4 verified
import math
import typing as tp
from dataclasses import dataclass
from typing import List, Optional, Union
import hydra
import librosa
import numpy as np
import soundfile as sf
import torch
from audiotools import AudioSignal
from audiotools.ml import BaseModel
from dac.model.base import CodecMixin
from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
from omegaconf import OmegaConf
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
@dataclass
class VQResult:
z: torch.Tensor
codes: torch.Tensor
latents: torch.Tensor
codebook_loss: torch.Tensor
commitment_loss: torch.Tensor
semantic_distill_z: torch.Tensor | None = None
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class ModelArgs:
block_size: int = 2048
n_layer: int = 8
n_head: int = 8
dim: int = 512
intermediate_size: int = 1536
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
dropout_rate: float = 0.1
attn_dropout_rate: float = 0.1
channels_first: bool = True # to be compatible with conv1d input/output
pos_embed_type: str = "rope" # can be "rope" or "conformer"
max_relative_position: int = 128 # for conformer-style relative position embedding
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
assert self.pos_embed_type in [
"rope",
"conformer",
], "pos_embed_type must be either 'rope' or 'conformer'"
class KVCache(nn.Module):
def __init__(
self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return (
k_out[:, :, : input_pos.max() + 1, :],
v_out[:, :, : input_pos.max() + 1, :],
)
def clear_cache(self, prompt_len):
self.k_cache[:, :, prompt_len:, :].fill_(0)
self.v_cache[:, :, prompt_len:, :].fill_(0)
class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList(
TransformerBlock(config) for _ in range(config.n_layer)
)
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
# Only compute RoPE frequencies if using RoPE
if config.pos_embed_type == "rope":
freqs_cis = precompute_freqs_cis(
self.config.block_size, self.config.head_dim, self.config.rope_base
)
self.register_buffer("freqs_cis", freqs_cis)
else:
self.register_buffer("freqs_cis", None)
causal_mask = torch.tril(
torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool)
)
self.register_buffer("causal_mask", causal_mask)
self.max_batch_size = -1
self.max_seq_length = -1
self.use_kv_cache = False
def setup_caches(self, max_batch_size, max_seq_length):
"""
This method will only be called during inference when using KV cache.
"""
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype = self.norm.weight.dtype
device = self.norm.weight.device
for b in self.layers:
b.attention.kv_cache = KVCache(
max_batch_size,
max_seq_length,
self.config.n_local_heads,
head_dim,
dtype,
).to(device)
self.use_kv_cache = True
def forward(
self,
x: Tensor,
input_pos: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
) -> Tensor:
if self.config.pos_embed_type == "rope":
assert (
self.freqs_cis is not None
), "RoPE frequencies must be initialized for RoPE positional embedding"
freqs_cis = self.freqs_cis[input_pos]
else:
freqs_cis = None
if mask is None: # in case of non-causal model
if not self.training and self.use_kv_cache:
mask = self.causal_mask[None, None, input_pos]
mask = mask[..., : input_pos.max() + 1]
else:
mask = self.causal_mask[None, None, input_pos]
mask = mask[..., input_pos]
for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.attention_layer_scale = LayerScale(config.dim, inplace=True)
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
def forward(
self,
x: Tensor,
input_pos: Tensor,
freqs_cis: Tensor,
mask: Tensor,
) -> Tensor:
h = x + self.attention_layer_scale(
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
)
out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
return out
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
self.kv_cache = None
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self.attn_dropout_rate = config.attn_dropout_rate
self.pos_embed_type = config.pos_embed_type
# Add relative position embedding for conformer-style
if self.pos_embed_type == "conformer":
self.max_relative_position = config.max_relative_position
num_pos_embeddings = 2 * config.max_relative_position + 1
self.rel_pos_embeddings = nn.Parameter(
torch.zeros(num_pos_embeddings, self.head_dim)
)
nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
# q: [B, H, S, D]
# Returns: [B, H, S, S]
positions = torch.arange(seqlen, device=q.device)
relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
relative_positions = torch.clamp(
relative_positions + self.max_relative_position,
0,
2 * self.max_relative_position,
)
rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
# Compute attention scores with relative position embeddings
q = q.transpose(1, 2) # [B, S, H, D]
rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
return rel_logits
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
mask: Tensor,
input_pos: Optional[Tensor] = None,
) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
context_seqlen = seqlen
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
if self.pos_embed_type == "rope":
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
if self.pos_embed_type == "conformer":
# Compute attention scores
scale = 1.0 / math.sqrt(self.head_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# Add relative position embeddings for conformer-style
rel_scores = self._compute_conformer_pos_scores(q, seqlen)
scores = scores + rel_scores
# Apply attention
if mask is not None:
scores = scores.masked_fill(~mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
if self.attn_dropout_rate > 0 and self.training:
attn = F.dropout(attn, p=self.attn_dropout_rate)
y = torch.matmul(attn, v)
else:
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_dropout_rate if self.training else 0.0,
attn_mask=mask,
)
# is_causal=True)
y = (
y.transpose(1, 2)
.contiguous()
.view(bsz, seqlen, self.head_dim * self.n_head)
)
y = self.wo(y)
return y
class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x: Tensor) -> Tensor:
return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-2,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class WindowLimitedTransformer(Transformer):
"""
Transformer with window limited attention, causal.
"""
def __init__(
self,
config: ModelArgs,
input_dim: int = 512,
window_size: Optional[int] = None,
causal: bool = True,
look_ahead_conv: nn.Module = None,
):
super().__init__(config)
self.window_size = window_size
self.causal = causal
self.channels_first = config.channels_first
self.look_ahead_conv = (
look_ahead_conv if look_ahead_conv is not None else nn.Identity()
)
self.input_proj = (
nn.Linear(input_dim, config.dim)
if input_dim != config.dim
else nn.Identity()
)
self.output_proj = (
nn.Linear(config.dim, input_dim)
if input_dim != config.dim
else nn.Identity()
)
def make_window_limited_mask(
self,
max_length: int,
x_lens: Optional[Tensor] = None,
) -> Tensor:
"""
Make mask to form window limited attention.
"""
if self.causal:
mask = torch.tril(torch.ones(max_length, max_length))
row_indices = torch.arange(max_length).view(-1, 1)
window_size = self.window_size or max_length
valid_range = (row_indices - window_size + 1).clamp(min=0)
column_indices = torch.arange(max_length)
mask = (column_indices >= valid_range) & mask.bool()
else:
raise NotImplementedError
mask = mask.bool()[None, None]
return mask
def make_mask(
self,
max_length: int,
x_lens: Optional[Tensor] = None,
) -> Tensor:
"""
Make ordinary mask if window size is not specified.
"""
if self.causal:
mask = torch.tril(torch.ones(max_length, max_length))
else:
mask = torch.ones(max_length, max_length)
mask = mask.bool()[None, None]
for i, x_len in enumerate(x_lens):
mask[:x_len, i] = 0
mask = mask.bool()[None, None]
return mask
def forward(
self,
x: Tensor,
x_lens: Optional[Tensor] = None,
) -> Tensor:
if self.channels_first:
x = x.transpose(1, 2)
x = self.input_proj(x) # (B, T, D)
x = self.look_ahead_conv(x)
input_pos = torch.arange(x.shape[1], device=x.device)
# construct mask to form window limited attention
max_length = x.shape[1]
if self.window_size is not None:
mask = self.make_window_limited_mask(max_length, x_lens)
else:
mask = self.make_mask(max_length, x_lens)
mask = mask.to(x.device)
x = super().forward(x, input_pos, mask)
x = self.output_proj(x) # (B, T, D)
if self.channels_first:
x = x.transpose(1, 2)
return x
def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
) -> Tensor:
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=dtype)
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(
x: torch.Tensor,
paddings: tp.Tuple[int, int],
mode: str = "zeros",
value: float = 0.0,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right
before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class CausalConvNet(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation=1,
stride=1,
groups=1,
padding=None,
):
super(CausalConvNet, self).__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
)
self.stride = stride
self.kernel_size = (kernel_size - 1) * dilation + 1
self.dilation = dilation
self.padding = self.kernel_size - self.stride
def forward(self, x):
pad = self.padding
extra_padding = get_extra_padding_for_conv1d(
x, self.kernel_size, self.stride, pad
)
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
return self.conv(x).contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_weight_norm(self):
self.conv = remove_parametrizations(self.conv)
return self
class CausalTransConvNet(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
):
super(CausalTransConvNet, self).__init__()
self.conv = nn.ConvTranspose1d(
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
)
self.stride = stride
self.kernel_size = kernel_size
def forward(self, x):
x = self.conv(x)
pad = self.kernel_size - self.stride
padding_right = math.ceil(pad)
padding_left = pad - padding_right
x = unpad1d(x, (padding_left, padding_right))
return x.contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_weight_norm(self):
self.conv = remove_parametrizations(self.conv)
return self
def CausalWNConv1d(*args, **kwargs):
return CausalConvNet(*args, **kwargs).weight_norm()
def CausalWNConvTranspose1d(*args, **kwargs):
return CausalTransConvNet(*args, **kwargs).weight_norm()
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
super().__init__()
conv_class = CausalWNConv1d if causal else WNConv1d
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
conv_class(dim, dim, kernel_size=1),
)
self.causal = causal
def forward(self, x):
y = self.block(x)
pad = x.shape[-1] - y.shape[-1]
if pad > 0:
if self.causal:
x = x[..., :-pad]
else:
x = x[..., pad // 2 : -pad // 2]
return x + y
class EncoderBlock(nn.Module):
def __init__(
self,
dim: int = 16,
stride: int = 1,
causal: bool = False,
n_t_layer: int = 0,
transformer_general_config=None,
):
super().__init__()
conv_class = CausalWNConv1d if causal else WNConv1d
transformer_module = (
nn.Identity()
if n_t_layer == 0
else (
WindowLimitedTransformer(
causal=causal,
input_dim=dim,
window_size=512,
config=transformer_general_config(
n_layer=n_t_layer,
n_head=dim // 64,
dim=dim,
intermediate_size=dim * 3,
),
)
)
)
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1, causal=causal),
ResidualUnit(dim // 2, dilation=3, causal=causal),
ResidualUnit(dim // 2, dilation=9, causal=causal),
Snake1d(dim // 2),
conv_class(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
transformer_module,
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 64,
n_transformer_layers: list = [0, 0, 4, 4],
transformer_general_config: ModelArgs = None,
causal: bool = False,
):
super().__init__()
conv_class = CausalWNConv1d if causal else WNConv1d
# Create first convolution
self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride, n_t_layer in zip(strides, n_transformer_layers):
d_model *= 2
self.block += [
EncoderBlock(
d_model,
stride=stride,
causal=causal,
n_t_layer=n_t_layer,
transformer_general_config=transformer_general_config,
)
]
# Create last convolution
self.block += [
Snake1d(d_model),
conv_class(d_model, d_latent, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
causal: bool = False,
n_t_layer: int = 0,
transformer_general_config=None,
):
super().__init__()
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
transformer_module = (
nn.Identity()
if n_t_layer == 0
else (
WindowLimitedTransformer(
causal=causal,
input_dim=input_dim,
window_size=None,
config=transformer_general_config(
n_layer=n_t_layer,
n_head=input_dim // 64,
dim=input_dim,
intermediate_size=input_dim * 3,
),
)
)
)
self.block = nn.Sequential(
# transformer_module,
Snake1d(input_dim),
conv_trans_class(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
ResidualUnit(output_dim, dilation=1, causal=causal),
ResidualUnit(output_dim, dilation=3, causal=causal),
ResidualUnit(output_dim, dilation=9, causal=causal),
)
def forward(self, x):
return self.block(x)
class Decoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
causal: bool = False,
n_transformer_layers: list = [0, 0, 0, 0],
transformer_general_config=None,
):
super().__init__()
conv_class = CausalWNConv1d if causal else WNConv1d
# Add first conv layer
layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [
DecoderBlock(
input_dim,
output_dim,
stride,
causal=causal,
n_t_layer=n_t_layer,
transformer_general_config=transformer_general_config,
)
]
# Add final conv layer
layers += [
Snake1d(output_dim),
conv_class(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DAC(BaseModel, CodecMixin):
def __init__(
self,
encoder_dim: int = 64,
encoder_rates: List[int] = [2, 4, 8, 8],
latent_dim: int = None,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 4, 2],
quantizer: torch.nn.Module = None,
sample_rate: int = 44100,
causal: bool = True,
encoder_transformer_layers: List[int] = [0, 0, 0, 0],
decoder_transformer_layers: List[int] = [0, 0, 0, 0],
transformer_general_config=None,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = Encoder(
encoder_dim,
encoder_rates,
latent_dim,
causal=causal,
n_transformer_layers=encoder_transformer_layers,
transformer_general_config=transformer_general_config,
)
self.quantizer = quantizer
self.decoder = Decoder(
latent_dim,
decoder_dim,
decoder_rates,
causal=causal,
n_transformer_layers=decoder_transformer_layers,
transformer_general_config=transformer_general_config,
)
self.sample_rate = sample_rate
self.apply(init_weights)
self.delay = self.get_delay()
self.frame_length = self.hop_length * 4
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def encode(
self,
audio_data: torch.Tensor,
audio_lengths: torch.Tensor = None,
n_quantizers: int = None,
**kwargs,
):
"""Encode given audio data and return quantized latent codes
Parameters
----------
audio_data : Tensor[B x T]
Audio data to encode
n_quantizers : int, optional
Number of quantizers to use, by default None
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"""
# pad to multiple of self.frame_length
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
# print(audio_data.shape)
length = audio_data.shape[-1]
right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
if audio_lengths is None:
audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
z = self.encoder(audio_data)
vq_results = self.quantizer(z, n_quantizers, **kwargs)
indices = vq_results.codes
indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
return indices, indices_lens
def decode(self, indices: torch.Tensor, feature_lengths):
if indices.ndim == 2:
indices = indices[None]
z = self.quantizer.decode(indices)
audio_lengths = feature_lengths * self.frame_length
return self.decoder(z), audio_lengths
def forward(
self,
audio_data: torch.Tensor,
template: torch.Tensor = None,
mask: torch.Tensor = None,
sample_rate: int = None,
n_quantizers: int = None,
**kwargs,
):
"""Model forward pass
Parameters
----------
audio_data : Tensor[B x 1 x T]
Audio data to encode
sample_rate : int, optional
Sample rate of audio data in Hz, by default None
If None, defaults to `self.sample_rate`
n_quantizers : int, optional
Number of quantizers to use, by default None.
If None, all quantizers are used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"vq/commitment_loss" : Tensor[1]
Commitment loss to train encoder to predict vectors closer to codebook
entries
"vq/codebook_loss" : Tensor[1]
Codebook loss to update the codebook
"length" : int
Number of samples in input audio
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
length = audio_data.shape[-1]
audio_data = self.preprocess(audio_data, sample_rate)
vq_results = self.encode(audio_data, n_quantizers, **kwargs)
z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
x = self.decode(z)
return x[..., :length], vq_results