|
import torch
|
|
import os
|
|
import pyworld as pw
|
|
import numpy as np
|
|
import torchaudio
|
|
import torch.nn.functional as F
|
|
from datasets import load_dataset, Audio
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Dict
|
|
import math
|
|
import matplotlib.pyplot as plt
|
|
import torch.nn as nn
|
|
import torch.nn.init as init
|
|
from torch import Tensor
|
|
from typing import Optional, Union, Tuple
|
|
from torch.nn.functional import scaled_dot_product_attention
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
dtype = torch.float32
|
|
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, emb_dim):
|
|
super().__init__()
|
|
self.eps = 1e-5
|
|
self.scale = nn.Parameter(torch.ones(emb_dim))
|
|
self.shift = nn.Parameter(torch.zeros(emb_dim))
|
|
|
|
def forward(self, x):
|
|
mean = x.mean(dim=-1, keepdim=True)
|
|
var = x.var(dim=-1, keepdim=True, unbiased=False)
|
|
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
|
return self.scale * norm_x + self.shift
|
|
|
|
class GELU(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return 0.5 * x * (1 + torch.tanh(
|
|
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
|
|
(x + 0.044715 * torch.pow(x, 3))
|
|
))
|
|
|
|
def sinusoids(ctx, dims, max_tscale=10000):
|
|
assert dims % 2 == 0
|
|
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
|
|
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
|
|
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
|
|
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
|
|
positional_embedding = nn.Parameter(position, requires_grad=True)
|
|
return positional_embedding
|
|
|
|
class SLSTM(nn.Module):
|
|
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = False, bias = True, batch_first = True):
|
|
super().__init__()
|
|
self.skip = skip
|
|
self.lstm = nn.LSTM(dimension, dimension, num_layers, bias, batch_first)
|
|
|
|
def forward(self, x):
|
|
x = x.permute(2, 0, 1)
|
|
y, _ = self.lstm(x)
|
|
if self.skip:
|
|
y = y + x
|
|
y = y.permute(1, 2, 0)
|
|
return y
|
|
|
|
def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
|
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
|
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
|
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
|
|
qk_cosine = qk_cosine + mask
|
|
weights = F.softmax(qk_cosine, dim=-1)
|
|
out = torch.matmul(weights, v)
|
|
return out
|
|
|
|
def scaled_relu(x, sequence_length):
|
|
relu_output = torch.relu(x)
|
|
return relu_output / sequence_length
|
|
|
|
def taylor_softmax(x, order=2):
|
|
tapprox = 1.0
|
|
for i in range(1, order + 1):
|
|
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
|
tapprox += x**i / factorial_i
|
|
return tapprox / torch.sum(tapprox, dim=-1, keepdim=True)
|
|
|
|
def taylor_masked(x, mask, order=2):
|
|
tapprox = torch.zeros_like(x)
|
|
unmasked = x.masked_select(mask)
|
|
approx_values = 1.0 + unmasked
|
|
for i in range(1, order + 1):
|
|
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
|
approx_values += unmasked**i / factorial_i
|
|
tapprox.masked_scatter_(mask, approx_values)
|
|
sum_approx = torch.sum(tapprox, dim=-1, keepdim=True)
|
|
toutput = tapprox / (sum_approx + 1e-9)
|
|
toutput = toutput * mask
|
|
return toutput
|
|
|
|
def taylor_softmax2(x, mask=None, order=2):
|
|
if mask is None:
|
|
tapprox = 1.0 + x
|
|
for i in range(1, order + 1):
|
|
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
|
tapprox += x**i / factorial_i
|
|
return tapprox / torch.sum(tapprox, dim=-1, keepdim=True)
|
|
|
|
else:
|
|
tapprox = torch.zeros_like(x)
|
|
unmasked = x.masked_select(mask)
|
|
tapprox = 1.0 + unmasked
|
|
for i in range(1, order + 1):
|
|
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
|
|
tapprox += unmasked**i / factorial_i
|
|
|
|
tapprox_full = torch.zeros_like(x)
|
|
tapprox_full.masked_scatter_(mask, tapprox)
|
|
|
|
sum_approx = torch.sum(tapprox_full, dim=-1, keepdim=True)
|
|
toutput = tapprox_full / (sum_approx + 1e-9)
|
|
|
|
toutput = toutput * mask.float()
|
|
return toutput
|
|
|
|
def taylor_softmax_2nd_order(x):
|
|
exp_approx = 1 + x + (x**2) / 2
|
|
return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True)
|
|
|
|
def taylor_softmax_approximation(x, order=2):
|
|
if order == 0:
|
|
return torch.ones_like(x) / x.size(-1)
|
|
elif order == 1:
|
|
numerator = 1 + x
|
|
elif order == 2:
|
|
numerator = 1 + x + 0.5 * x**2
|
|
else:
|
|
raise NotImplementedError("Higher orders are not implemented yet.")
|
|
denominator = torch.sum(numerator, dim=-1, keepdim=True)
|
|
return numerator / denominator
|
|
|
|
def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0):
|
|
dot_scores = torch.matmul(q, k.transpose(-1, -2))
|
|
if rbf_ratio <= 0.0:
|
|
return dot_scores
|
|
q_norm = q.pow(2).sum(dim=-1, keepdim=True)
|
|
k_norm = k.pow(2).sum(dim=-1, keepdim=True)
|
|
qk = torch.matmul(q, k.transpose(-1, -2))
|
|
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
|
|
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
|
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
|
|
|
def sliding_window_mask(q_len, k_len, window, device):
|
|
idxs = torch.arange(q_len, device=device).unsqueeze(1)
|
|
jdxs = torch.arange(k_len, device=device).unsqueeze(0)
|
|
mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs)
|
|
return mask.float()
|
|
|
|
def mask_win(text_ctx, aud_ctx):
|
|
mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0)
|
|
audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype))
|
|
full_mask = torch.cat([mask, audio_mask], dim=-1)
|
|
return full_mask
|
|
|
|
def maskc(ctx, device):
|
|
return torch.tril(torch.ones(ctx, ctx, device=device, dtype=dtype), diagonal=0)
|
|
|
|
def create_attention_mask(batch_size, ctx, is_causal=True, padding_mask=None, device=None):
|
|
if is_causal:
|
|
mask = torch.triu(torch.ones((ctx, ctx), device=device), diagonal=0)
|
|
mask = mask.expand(batch_size, 1, ctx, ctx)
|
|
else:
|
|
mask = torch.zeros((batch_size, 1, ctx, ctx), device=device)
|
|
if padding_mask is not None:
|
|
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).bool()
|
|
mask = (mask.bool() | (~padding_mask)).float()
|
|
return mask
|
|
|
|
def calculate_attention(q, k, v, mask=None, temp=1.0):
|
|
scaled_q = q
|
|
if temp != 1.0 and temp > 0:
|
|
scaled_q = q * (1.0 / temp)**.5
|
|
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
|
return out
|
|
|
|
def calculate_attentionb(q_norm, k_norm, v_iter, mask=None, temp=1.0):
|
|
d_k = q_norm.size(-1)
|
|
scores = torch.matmul(q_norm, k_norm.transpose(-2, -1)) / (torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) / temp)
|
|
if mask is not None:
|
|
scores = scores.masked_fill(mask == 0, float('-inf'))
|
|
attention_weights = F.softmax(scores, dim=-1)
|
|
output = torch.matmul(attention_weights, v_iter)
|
|
return output
|
|
|
|
class LocalOut(nn.Module):
|
|
def __init__(self, dims: int, head: int):
|
|
super().__init__()
|
|
self.head_dim = dims // head
|
|
self.dims = dims
|
|
self.q_module = nn.Linear(self.head_dim, self.head_dim)
|
|
self.k_module = nn.Linear(self.head_dim, self.head_dim)
|
|
self.v_module = nn.Linear(self.head_dim, self.head_dim)
|
|
self.o_proj = nn.Linear(self.head_dim, self.head_dim)
|
|
|
|
def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
|
|
batch, _, ctx, _ = attn_output.shape
|
|
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
|
|
|
|
def qkv_init(dims, head):
|
|
head_dim = dims // head
|
|
q = nn.Linear(dims, dims)
|
|
k = nn.Linear(dims, dims)
|
|
v = nn.Linear(dims, dims)
|
|
o = nn.Linear(dims, dims)
|
|
lna = nn.LayerNorm(dims)
|
|
lnb = nn.LayerNorm(dims)
|
|
lnc = nn.LayerNorm(head_dim)
|
|
lnd = nn.LayerNorm(head_dim)
|
|
return q, k, v, o, lna, lnb, lnc, lnd
|
|
|
|
def shape(dims, head, q, k, v):
|
|
batch_size = q.shape[0]
|
|
seq_len_q = q.shape[1]
|
|
seq_len_kv = k.shape[1]
|
|
head_dim = dims // head
|
|
|
|
q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2)
|
|
k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
|
|
v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
|
|
return q, k, v
|
|
|
|
def create_qkv(dims, head, q, k, v, x, xa):
|
|
head_dim = dims // head
|
|
scale = head_dim ** -0.25
|
|
q = q(x) * scale
|
|
k = k(xa) * scale
|
|
v = v(xa)
|
|
batch, ctx, dims = x.shape
|
|
def _shape(tensor):
|
|
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
|
|
return _shape(q), _shape(k), _shape(v)
|
|
|
|
class KVCache(nn.Module):
|
|
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
|
|
super().__init__()
|
|
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
|
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
|
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
|
|
|
def update(self, input_pos, k_val, v_val):
|
|
assert input_pos.shape[0] == k_val.shape[2]
|
|
|
|
k_out = self.k_cache
|
|
v_out = self.v_cache
|
|
k_out[:, :, input_pos] = k_val
|
|
v_out[:, :, input_pos] = v_val
|
|
|
|
return k_out, v_out
|
|
|
|
def mel_scale_scalar(freq: float) -> float:
|
|
return 1127.0 * math.log(1.0 + freq / 700.0)
|
|
|
|
def mel_scale(freq: Tensor) -> Tensor:
|
|
return 1127.0 * (1.0 + freq / 700.0).log()
|
|
|
|
def trace_x(func):
|
|
def wrapper(*args, **kwargs):
|
|
print(f"Calling {func.__name__}")
|
|
result = func(*args, **kwargs)
|
|
if isinstance(result, torch.Tensor):
|
|
print(f" {func.__name__} returned shape: {result.shape}")
|
|
return result
|
|
return wrapper
|
|
|
|
def track_x(new_x, operation=""):
|
|
""" track_x(x, "x") """
|
|
x_id = [id(new_x)]
|
|
if new_x is None:
|
|
return new_x
|
|
current_id = id(new_x)
|
|
if current_id != x_id[0]:
|
|
print(f"x FLOW: {x_id[0]} → {current_id} in {operation}")
|
|
x_id[0] = current_id
|
|
else:
|
|
print(f"x REUSE: {current_id} in {operation}")
|
|
return new_x
|
|
|
|
def track_xa(new_xa, operation=""):
|
|
""" track_xa(xa, "xa - decoder") """
|
|
xa_id = [id(new_xa)] if new_xa is not None else [None]
|
|
if new_xa is None:
|
|
return new_xa
|
|
current_id = id(new_xa)
|
|
if current_id != xa_id[0]:
|
|
print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}")
|
|
xa_id[0] = current_id
|
|
else:
|
|
print(f"xa REUSE: {current_id} in {operation}")
|
|
return new_xa
|
|
|
|
def get_activation(act: str) -> nn.Module:
|
|
"""Get activation function by name."""
|
|
act_map = {
|
|
"gelu": nn.GELU(),
|
|
"relu": nn.ReLU(),
|
|
"sigmoid": nn.Sigmoid(),
|
|
"tanh": nn.Tanh(),
|
|
"swish": nn.SiLU(),
|
|
"tanhshrink": nn.Tanhshrink(),
|
|
"softplus": nn.Softplus(),
|
|
"softshrink": nn.Softshrink(),
|
|
"leaky_relu": nn.LeakyReLU(),
|
|
"elu": nn.ELU()
|
|
}
|
|
return act_map.get(act, nn.GELU())
|
|
|
|
def get_generation_config(param):
|
|
return GenerationConfig(
|
|
max_length=param.text_ctx,
|
|
pad_token_id=getattr(param, "pad_token_id", 0),
|
|
bos_token_id=getattr(param, "bos_token_id", 1),
|
|
eos_token_id=getattr(param, "eos_token_id", 2),
|
|
do_sample=False,
|
|
num_beams=1,
|
|
early_stopping=False,
|
|
length_penalty=1.0,
|
|
no_repeat_ngram_size=0,
|
|
repetition_penalty=1.0,
|
|
temperature=1.0,
|
|
decoder_start_token_id=1,
|
|
is_multilingual=False,
|
|
use_cache=False,
|
|
return_timestamps=False)
|
|
|
|
class feature_encoder(nn.Module):
|
|
def __init__(self, mels, input_dims, dims, head, layer, act, features, feature=None, use_rope=False, spec_shape=None, debug=[], attend_feature=False, target_length=None):
|
|
"""
|
|
Feature encoder for audio processing.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.dims = dims
|
|
self.head = head
|
|
self.head_dim = dims // head
|
|
self.dropout = 0.01
|
|
self.use_rope = use_rope
|
|
self.attend_feature = attend_feature
|
|
self.target_length = target_length
|
|
self.feature = feature
|
|
|
|
self.debug = debug
|
|
act_fn = get_activation(act)
|
|
|
|
if self.attend_feature:
|
|
self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims))
|
|
else:
|
|
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
|
|
self.mlp = None
|
|
|
|
self.spectrogram = nn.Sequential(
|
|
Conv1d(mels, dims, kernel_size=3), act_fn,
|
|
Conv1d(dims, dims, kernel_size=3), act_fn,
|
|
Conv1d(dims, dims, kernel_size=3, groups=dims), act_fn)
|
|
|
|
self.waveform = nn.Sequential(
|
|
Conv1d(1, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
|
|
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
|
|
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
|
|
|
|
self.pitch = nn.Sequential(
|
|
Conv1d(1, dims, kernel_size=7, stride=1, padding=3), act_fn,
|
|
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
|
|
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
|
|
|
if use_rope:
|
|
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
|
|
else:
|
|
self.rope = None
|
|
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
self.norm = RMSNorm(dims)
|
|
|
|
def rope(self, x, xa=None, mask=None, feats=None, feature=None, layer=None):
|
|
if isinstance(x, int):
|
|
ctx = x
|
|
elif isinstance(x, torch.Tensor):
|
|
ctx = x.shape[1] if x.dim() > 1 else x.shape[0]
|
|
batch, ctx, dims = x.shape[0], ctx, x.shape[-1]
|
|
|
|
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
|
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
|
|
x = self.rope.apply_rotary(x, freqs)
|
|
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
|
return x
|
|
|
|
def mel_scalar(self, freq: float) -> float:
|
|
return 1127.0 * math.log(1.0 + freq / 700.0)
|
|
|
|
def forward(self, x, xa=None, mask=None, feats=None, feature=None, layer=None, max_tscale=36000):
|
|
target_length = x.shape[1] if self.target_length is None else self.target_length
|
|
|
|
if feature == "pitch":
|
|
xp = x.clone()
|
|
enc_dict = feats if feats is not None else {}
|
|
enc_dict = dict(enc_dict)
|
|
enc_dict["f0"] = xp
|
|
feats = enc_dict
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(0)
|
|
x = self.pitch(x).permute(0, 2, 1)
|
|
|
|
if feature == "phase":
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(0)
|
|
x = self.pitch(x).permute(0, 2, 1)
|
|
|
|
if feature == "waveform":
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(0)
|
|
x = self.waveform(x).permute(0, 2, 1)
|
|
if target_length and x.shape[1] != self.target_length:
|
|
x = F.adaptive_avg_pool1d(x.transpose(1, 2), target_length).transpose(1, 2)
|
|
|
|
if feature == "harmonics":
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(0)
|
|
x = self.spectrogram(x).permute(0, 2, 1)
|
|
|
|
if feature == "aperiodic":
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(0)
|
|
x = self.spectrogram(x).permute(0, 2, 1)
|
|
|
|
if feature == "spectrogram":
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(0)
|
|
x = self.spectrogram(x).permute(0, 2, 1)
|
|
|
|
if self.use_rope:
|
|
x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
|
|
x = self.rope(x=x, xa=None, mask=None, feats=feats, feature=feature, layer=layer)
|
|
else:
|
|
max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
|
|
x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
|
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
|
x = self.norm(x)
|
|
|
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
class OneShot(nn.Module):
|
|
def __init__(self, dims: int, head: int, scale: float = 0.3, features: Optional[List[str]] = None):
|
|
super().__init__()
|
|
if features is None:
|
|
features = ["spectrogram", "waveform", "pitch", "aperiodic", "harmonics"]
|
|
self.head = head
|
|
self.head_dim = dims // head
|
|
self.scale = 1.0 // len(features) if features else scale
|
|
|
|
self.q = Linear(dims, dims)
|
|
self.k = Linear(dims, dims)
|
|
|
|
def forward(self, x: Tensor, xa: Tensor, feature=None) -> Tensor | None:
|
|
B, L, D = x.shape
|
|
K = xa.size(1)
|
|
q = self.q(x).view(B, L, self.head, self.head_dim).transpose(1,2)
|
|
k = self.k(xa).view(B, K, self.head, self.head_dim).transpose(1,2)
|
|
bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.head_dim)
|
|
return bias
|
|
|
|
class curiosity(nn.Module):
|
|
def __init__(self, d, h, bias=True):
|
|
super().__init__()
|
|
self.h = h
|
|
self.dh = d // h
|
|
self.qkv = nn.Linear(d, d * 3, bias=bias)
|
|
self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
|
|
self.o = nn.Linear(d, d, bias=bias)
|
|
self.g = nn.Parameter(torch.zeros(h))
|
|
|
|
def split(self, x):
|
|
b, t, _ = x.shape
|
|
return x.view(b, t, self.h, self.dh).transpose(1, 2)
|
|
|
|
def merge(self, x):
|
|
b, h, t, dh = x.shape
|
|
return x.transpose(1, 2).contiguous().view(b, t, h * dh)
|
|
|
|
def forward(self, x, xa, mask=None):
|
|
|
|
q, k, v = self.qkv(x).chunk(3, -1)
|
|
qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
|
|
q, k, v = map(self.split, (q, k, v))
|
|
qa, ka, va = map(self.split, (qa, ka, va))
|
|
dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
|
|
dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
|
|
if mask is not None: dots = dots.masked_fill(mask, -9e15)
|
|
p = dots.softmax(-1)
|
|
pa = dots_aux.softmax(-1)
|
|
h_main = p @ v
|
|
h_aux = pa @ va
|
|
g = torch.sigmoid(self.g).view(1, -1, 1, 1)
|
|
out = self.merge(h_main * (1 - g) + h_aux * g)
|
|
return self.o(out)
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, dims, ctx):
|
|
super(PositionalEncoding, self).__init__()
|
|
self.dims = dims
|
|
self.ctx = ctx
|
|
self.pe = self.get_positional_encoding(max_ctx=ctx)
|
|
|
|
def get_positional_encoding(self, max_ctx):
|
|
pe = torch.zeros(max_ctx, self.dims)
|
|
position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, self.dims, 2, dtype=torch.float32)
|
|
* (-math.log(10000.0) / self.dims)
|
|
)
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
return pe.to(device)
|
|
|
|
def forward(self, x):
|
|
ctx = x.size(1)
|
|
pe = self.pe[:, :ctx, :]
|
|
x = x * math.sqrt(self.dims)
|
|
x = x + pe
|
|
return x
|
|
|
|
def valid(default_value, *items):
|
|
for item in items:
|
|
if item is not None:
|
|
return item
|
|
return default_value
|
|
|
|
def dict_to(d, device, dtype=dtype):
|
|
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
|
|
for k, v in d.items()}
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
def default(v, b):
|
|
return v if exists(v) else b
|
|
|
|
class Conv1d(nn.Conv1d):
|
|
def _conv_forward(
|
|
self, x: Tensor, weight: Tensor, bias) -> Tensor:
|
|
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
|
|
|
|
class Conv2d(nn.Conv2d):
|
|
def _conv_forward(
|
|
self, x: Tensor, weight: Tensor, bias) -> Tensor:
|
|
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
|
|
|
|
class Linear(nn.Module):
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
|
super(Linear, self).__init__()
|
|
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
init.xavier_uniform_(self.linear.weight)
|
|
if bias:
|
|
init.zeros_(self.linear.bias)
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.linear(x)
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dims: Union[int, Tensor, List, Tuple],
|
|
eps = 1e-8, elementwise_affine = True):
|
|
super(RMSNorm, self).__init__()
|
|
if isinstance(dims, int):
|
|
self.normalized_shape = (dims,)
|
|
else:
|
|
self.normalized_shape = tuple(dims)
|
|
self.eps = eps
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
|
|
init.ones_(self.weight)
|
|
else:
|
|
self.register_parameter("weight", None)
|
|
def forward(self, x):
|
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
|
|
|
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
|
|
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
|
|
eps: float = 1e-5) -> Tensor:
|
|
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
|
|
|
def get_device():
|
|
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
def get_dtype():
|
|
return torch.float32 if torch.cuda.is_available() else torch.float64
|
|
|
|
def tox():
|
|
return {"device": get_device(), "dtype": get_dtype()}
|
|
|
|
def sinusoids(ctx, dims, max_tscale=10000):
|
|
assert dims % 2 == 0
|
|
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
|
|
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
|
|
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
|
|
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
|
|
positional_embedding = nn.Parameter(position, requires_grad=True)
|
|
return positional_embedding
|
|
|
|
class SelfCriticalRL(nn.Module):
|
|
def __init__(self, model, tokenizer, reward_fn):
|
|
super().__init__()
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.reward_fn = reward_fn
|
|
|
|
def forward(self, input_ids, features, labels=None, max_len=128, feature_name="spectrogram"):
|
|
|
|
with torch.no_grad():
|
|
greedy_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len)
|
|
greedy_text = [self.tokenizer.decode(ids) for ids in greedy_ids]
|
|
sampled_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len, do_sample=True, top_k=5)
|
|
sampled_text = [self.tokenizer.decode(ids) for ids in sampled_ids]
|
|
|
|
rewards = []
|
|
baseline = []
|
|
for s, g, ref in zip(sampled_text, greedy_text, labels):
|
|
ref_text = self.tokenizer.decode(ref)
|
|
rewards.append(self.reward_fn(s, ref_text))
|
|
baseline.append(self.reward_fn(g, ref_text))
|
|
rewards = torch.tensor(rewards, device=device, dtype=torch.float)
|
|
baseline = torch.tensor(baseline, device=device, dtype=torch.float)
|
|
advantage = rewards - baseline
|
|
logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"]
|
|
log_probs = F.log_softmax(logits, dim=-1)
|
|
log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1)
|
|
log_probs_sum = log_probs_seq.sum(dim=1)
|
|
loss = -(advantage * log_probs_sum).mean()
|
|
return loss
|
|
|
|
class SelfTrainingModule(nn.Module):
|
|
def __init__(self, model, tokenizer, quality_fn=None, threshold=0.8):
|
|
super().__init__()
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.quality_fn = quality_fn
|
|
self.threshold = threshold
|
|
|
|
def generate_pseudo_labels(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
|
|
with torch.no_grad():
|
|
pred_ids = self.model.generate(input_ids=unlabeled_batch, **{feature_name: features}, max_length=max_len)
|
|
|
|
if self.quality_fn is not None:
|
|
quality_scores = self.quality_fn(pred_ids, self.model, features)
|
|
mask = quality_scores > self.threshold
|
|
pred_ids = pred_ids[mask]
|
|
return pred_ids
|
|
|
|
def forward(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
|
|
pseudo_labels = self.generate_pseudo_labels(unlabeled_batch, features, max_len, feature_name=feature_name)
|
|
logits = self.model(input_ids=unlabeled_batch, **{feature_name: features}, labels=pseudo_labels)["logits"]
|
|
loss = nn.functional.cross_entropy(
|
|
logits.view(-1, logits.shape[-1]), pseudo_labels.view(-1), ignore_index=0)
|
|
return loss
|
|
|
|
def confidence_indicator(pred_ids, model, features):
|
|
with torch.no_grad():
|
|
logits = model(input_ids=pred_ids, **features)["logits"]
|
|
probs = torch.softmax(logits, dim=-1)
|
|
max_probs, _ = probs.max(dim=-1)
|
|
return max_probs.mean(dim=1)
|
|
|
|
def wer_reward(hyp, ref):
|
|
hyp_words = hyp.split()
|
|
ref_words = ref.split()
|
|
d = [[0] * (len(ref_words)+1) for _ in range(len(hyp_words)+1)]
|
|
for i in range(len(hyp_words)+1):
|
|
d[i][0] = i
|
|
for j in range(len(ref_words)+1):
|
|
d[0][j] = j
|
|
for i in range(1, len(hyp_words)+1):
|
|
for j in range(1, len(ref_words)+1):
|
|
if hyp_words[i-1] == ref_words[j-1]:
|
|
d[i][j] = d[i-1][j-1]
|
|
else:
|
|
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
|
wer = d[-1][-1] / max(1, len(ref_words))
|
|
return -wer
|
|
|
|
def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
|
|
if isinstance(ids, torch.Tensor):
|
|
ids = ids.tolist()
|
|
return [int(id) for id in ids if id != -100 and id != pad_token_id and id != bos_token_id and id != eos_token_id]
|
|
|
|
def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
|
|
return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids]
|
|
|
|
def setup_tokenizer(dir: str):
|
|
from tokenizers import Tokenizer
|
|
tokenizer = Tokenizer.from_file(f"{dir}")
|
|
orig_encode = tokenizer.encode
|
|
orig_decode = tokenizer.decode
|
|
|
|
def enc(text, add_special_tokens=True):
|
|
ids = orig_encode(text).ids
|
|
if not add_special_tokens:
|
|
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
|
|
ids = [id for id in ids if id not in sp_ids]
|
|
return ids
|
|
|
|
def bdec(ids_list, pad_token_id=0, bos_token_id=1, eos_token_id=2, skip_special_tokens=True):
|
|
results = []
|
|
if isinstance(ids_list, torch.Tensor):
|
|
ids_list = ids_list.tolist()
|
|
elif isinstance(ids_list, np.ndarray):
|
|
ids_list = ids_list.tolist()
|
|
for ids in ids_list:
|
|
ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)]
|
|
results.append(orig_decode(ids))
|
|
return results
|
|
|
|
def dec(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
|
|
ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)]
|
|
return orig_decode(ids)
|
|
|
|
def save_pretrained(save_dir):
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
tokenizer.save(f"{save_dir}/tokenizer.json")
|
|
|
|
tokenizer.encode = enc
|
|
tokenizer.batch_decode = bdec
|
|
tokenizer.decode = dec
|
|
tokenizer.save_pretrained = save_pretrained
|
|
tokenizer.pad_token_id = 0
|
|
tokenizer.bos_token_id = 1
|
|
tokenizer.eos_token_id = 2
|
|
return tokenizer
|
|
|
|
def tokenize_pitch(pitch_features, target_length):
|
|
pitch_len = pitch_features.shape[-1]
|
|
token_len = target_length
|
|
if pitch_len > token_len:
|
|
pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len)
|
|
else:
|
|
pitch_tokens = F.interpolate(pitch_features, token_len)
|
|
return pitch_tokens
|
|
|
|
def load_wave(wave_data, sample_rate=16000):
|
|
|
|
if isinstance(wave_data, str):
|
|
waveform, sample_rate = torchaudio.load(uri=wave_data, normalize=False)
|
|
elif isinstance(wave_data, dict):
|
|
waveform = torch.tensor(data=wave_data["array"]).float()
|
|
sample_rate = wave_data["sampling_rate"]
|
|
else:
|
|
raise TypeError("Invalid wave_data format.")
|
|
return waveform
|
|
|
|
def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
|
|
import librosa
|
|
mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
|
|
mel_basis = torch.from_numpy(mel_basis).float()
|
|
sp_mel = torch.matmul(sp, mel_basis.T)
|
|
ap_mel = torch.matmul(ap, mel_basis.T)
|
|
return sp_mel, ap_mel
|
|
|
|
def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False):
|
|
|
|
torch_windows = {
|
|
'hann': torch.hann_window,
|
|
'hamming': torch.hamming_window,
|
|
'blackman': torch.blackman_window,
|
|
'bartlett': torch.bartlett_window,
|
|
'ones': torch.ones,
|
|
None: torch.ones,
|
|
}
|
|
if dummy:
|
|
return {
|
|
"spectrogram": torch.zeros((1, 128, 100)),
|
|
"f0": torch.zeros((1, 100)),
|
|
"pitch_tokens": torch.zeros((1, 100)),
|
|
"pitch": torch.zeros((1, 100)),
|
|
"harmonics": torch.zeros((1, 128, 100)),
|
|
"aperiodics": torch.zeros((1, 128, 100)),
|
|
"crepe_time": None,
|
|
"crepe_frequency": None,
|
|
"crepe_confidence": None,
|
|
"crepe_activation": None,
|
|
}
|
|
|
|
audio = batch["audio"]
|
|
sample_rate = audio["sampling_rate"]
|
|
labels = tokenizer.encode(batch["transcription"])
|
|
wav = load_wave(wave_data=audio, sample_rate=sample_rate)
|
|
|
|
def crepe_predict(wav, sample_rate, viterbi=False):
|
|
import torchcrepe
|
|
wav = wav.numpy().astype(np.float32)
|
|
time, frequency, confidence, activation = torchcrepe.predict(
|
|
wav, sample_rate=sample_rate, viterbi=viterbi)
|
|
crepe_time = torch.from_numpy(time)
|
|
crepe_frequency = torch.from_numpy(frequency)
|
|
crepe_confidence = torch.from_numpy(confidence)
|
|
crepe_activation = torch.from_numpy(activation)
|
|
return crepe_time, crepe_frequency, crepe_confidence, crepe_activation
|
|
|
|
if crepe:
|
|
crepe_time, crepe_frequency, crepe_confidence, crepe_activation = crepe_predict(wav, sample_rate, viterbi=True)
|
|
|
|
else:
|
|
crepe_time = None
|
|
crepe_frequency = None
|
|
crepe_confidence = None
|
|
crepe_activation = None
|
|
|
|
def spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
|
|
if isinstance(window_fn, str):
|
|
window_fn = torch_windows[window_fn]
|
|
if window_fn is None:
|
|
window_fn = torch.ones(n_fft)
|
|
if isinstance(window_fn, torch.Tensor):
|
|
window_fn = window_fn.to(device)
|
|
return torchaudio.functional.spectrogram(
|
|
wav, n_fft=n_fft, hop_length=hop_length, win_length=n_fft,
|
|
window=window_fn, center=True, pad_mode="reflect", power=1.0)
|
|
|
|
def mel_spectrogram(wav, sample_rate):
|
|
|
|
spectrogram_config = {
|
|
"hop_length": 256,
|
|
"f_min": 150,
|
|
"f_max": 2000,
|
|
"n_mels": 128,
|
|
"n_fft": 1024,
|
|
"sample_rate": 16000,
|
|
"pad_mode": "constant",
|
|
"center": True,
|
|
"power": 1.0,
|
|
"window_fn": torch.hann_window,
|
|
"mel_scale": "htk",
|
|
"norm": None,
|
|
"normalized": False,
|
|
}
|
|
transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
|
|
mel_spectrogram = transform(wav)
|
|
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
|
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
|
spectrogram_tensor = (log_mel + 4.0) / 4.0
|
|
return spectrogram_tensor
|
|
|
|
if spec:
|
|
spectrogram_tensor = mel_spectrogram(wav, sample_rate)
|
|
|
|
def mfcc(wav, sample_rate, n_mels=128, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
|
|
transform = torchaudio.transforms.MFCC(
|
|
sample_rate=sample_rate,
|
|
n_mfcc=n_mels,
|
|
melkwargs={
|
|
"n_fft": n_fft,
|
|
"hop_length": hop_length,
|
|
"window_fn": window_fn,
|
|
"n_mels": n_mels,
|
|
"center": True,
|
|
"pad_mode": "reflect",
|
|
"norm": None,
|
|
"mel_scale": "htk",
|
|
}
|
|
)
|
|
mfcc_tensor = transform(wav)
|
|
return mfcc_tensor
|
|
|
|
def harmonics_and_aperiodics(wav, f0, t, sample_rate):
|
|
import pyworld as pw
|
|
wav_np = wav.numpy().astype(np.float64)
|
|
sp = pw.cheaptrick(wav_np, f0, t, sample_rate, fft_size=256)
|
|
ap = pw.d4c(wav_np, f0, t, sample_rate, fft_size=256)
|
|
harmonic_tensor = torch.from_numpy(sp)
|
|
aperiodic_tensor = torch.from_numpy(ap)
|
|
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
|
|
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
|
|
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
|
|
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
|
|
return harmonic_tensor, aperiodic_tensor
|
|
|
|
if pitch or pitch_tokens or harmonics or aperiodics:
|
|
wavnp = wav.numpy().astype(np.float64)
|
|
f0_np, t = pw.dio(wavnp, sample_rate, frame_period=hop_length / sample_rate * 1000)
|
|
f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
|
|
|
|
if pitch_tokens:
|
|
wav = torch.from_numpy(wavnp)
|
|
t2 = torch.from_numpy(t)
|
|
audio_duration = len(wav) / sample_rate
|
|
T = len(labels)
|
|
tok_dur_sec = audio_duration / T
|
|
token_starts = torch.arange(T) * tok_dur_sec
|
|
token_ends = token_starts + tok_dur_sec
|
|
start_idx = torch.searchsorted(t2, token_starts, side="left")
|
|
end_idx = torch.searchsorted(t2, token_ends, side="right")
|
|
pitch_tok = torch.zeros(T, dtype=torch.float32)
|
|
for i in range(T):
|
|
lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
|
|
segment = f0_np[lo:hi]
|
|
if mode == "mean":
|
|
pitch_tok[i] = segment.mean()
|
|
elif mode == "median":
|
|
pitch_tok[i] = torch.median(segment)
|
|
else:
|
|
pitch_tok[i] = segment[-1]
|
|
pitch_tok[pitch_tok < 100.0] = 0.0
|
|
bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
|
|
pitch_tokens_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok])
|
|
pitch_tokens_tensor = torch.where(pitch_tokens_tensor == 0.0, torch.zeros_like(pitch_tokens_tensor), (pitch_tokens_tensor - 71.0) / (500.0 - 71.0))
|
|
else:
|
|
pitch_tokens_tensor = None
|
|
|
|
if phase_mod:
|
|
tframe = torch.mean(t2[1:] - t2[:-1])
|
|
phi0 = 0.0
|
|
omega = 2 * torch.pi * f0_tensor
|
|
dphi = omega * tframe
|
|
phi = torch.cumsum(dphi, dim=0) + phi0
|
|
phase = torch.remainder(phi, 2 * torch.pi)
|
|
else:
|
|
phase = None
|
|
|
|
if pitch:
|
|
p_tensor = torchaudio.functional.detect_pitch_frequency(wav, sample_rate).unsqueeze(0)
|
|
else:
|
|
p_tensor = None
|
|
|
|
if harmonics or aperiodics:
|
|
spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
|
|
apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
|
|
harmonic_tensor = torch.from_numpy(spnp)
|
|
aperiodic_tensor = torch.from_numpy(apnp)
|
|
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
|
|
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
|
|
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
|
|
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
|
|
else:
|
|
harmonic_tensor = None
|
|
aperiodic_tensor = None
|
|
|
|
if waveform:
|
|
wave_tensor = wav
|
|
else:
|
|
wave_tensor = None
|
|
|
|
if dummy:
|
|
if spectrogram_tensor is not None:
|
|
dummy_tensor = torch.ones_like(spectrogram_tensor)
|
|
elif p_tensor is not None:
|
|
dummy_tensor = torch.ones_like(p_tensor)
|
|
elif pitch_tokens_tensor is not None:
|
|
dummy_tensor = torch.ones_like(pitch_tokens_tensor)
|
|
else:
|
|
batch_size = 128
|
|
seq_len = 1024
|
|
dummy_tensor = torch.ones(batch_size, seq_len)
|
|
dummy_tensor = dummy_tensor.to(device)
|
|
else:
|
|
dummy_tensor = None
|
|
if debug:
|
|
print(f"['pitch_tokens']: {pitch_tokens_tensor.shape if pitch_tokens else None}")
|
|
print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}")
|
|
print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}")
|
|
print(f"['spectrogram']: {spectrogram_tensor.shape if spec else None}")
|
|
print(f"['waveform']: {wave_tensor.shape if waveform else None}")
|
|
print(f"['labels']: {len(labels) if labels else None}")
|
|
print(f"['phase']: {phase.shape if phase else None}")
|
|
print(f"['pitch']: {p_tensor.shape if pitch else None}")
|
|
print(f"['crepe_time']: {crepe_time.shape if crepe else None}")
|
|
print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}")
|
|
print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}")
|
|
print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}")
|
|
print(f"['dummy']: {dummy_tensor.shape if dummy else None}")
|
|
|
|
return {
|
|
"waveform": wave_tensor if waveform else None,
|
|
"spectrogram": spectrogram_tensor if spec else None,
|
|
"pitch_tokens": pitch_tokens_tensor if pitch_tokens else None,
|
|
"pitch": p_tensor if pitch else None,
|
|
"harmonic": harmonic_tensor if harmonics else None,
|
|
"aperiodic": aperiodic_tensor if aperiodics else None,
|
|
"labels": labels,
|
|
"phase": phase if phase_mod else None,
|
|
"crepe_time": crepe_time if crepe else None,
|
|
"crepe_frequency": crepe_frequency if crepe else None,
|
|
"crepe_confidence": crepe_confidence if crepe else None,
|
|
"crepe_activation": crepe_activation if crepe else None,
|
|
"dummy": dummy_tensor if dummy else None,
|
|
}
|
|
|
|
def plot_waveform(waveform, sr, title="Waveform", ax=None):
|
|
waveform = waveform.numpy()
|
|
|
|
num_channels, num_frames = waveform.shape
|
|
time_axis = torch.arange(0, num_frames) / sr
|
|
|
|
if ax is None:
|
|
_, ax = plt.subplots(num_channels, 1)
|
|
ax.plot(time_axis, waveform[0], linewidth=1)
|
|
ax.grid(True)
|
|
ax.set_xlim([0, time_axis[-1]])
|
|
ax.set_title(title)
|
|
|
|
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
|
|
import librosa
|
|
if ax is None:
|
|
_, ax = plt.subplots(1, 1)
|
|
if title is not None:
|
|
ax.set_title(title)
|
|
ax.set_ylabel(ylabel)
|
|
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
|
|
|
|
def plot_fbank(fbank, title=None):
|
|
fig, axs = plt.subplots(1, 1)
|
|
axs.set_title(title or "Filter bank")
|
|
axs.imshow(fbank, aspect="auto")
|
|
axs.set_ylabel("frequency bin")
|
|
axs.set_xlabel("mel bin")
|
|
|
|
def plot_pitch(waveform, sr, pitch):
|
|
figure, axis = plt.subplots(1, 1)
|
|
axis.set_title("Pitch Feature")
|
|
axis.grid(True)
|
|
|
|
end_time = waveform.shape[1] / sr
|
|
time_axis = torch.linspace(0, end_time, waveform.shape[1])
|
|
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
|
|
|
|
axis2 = axis.twinx()
|
|
time_axis = torch.linspace(0, end_time, pitch.shape[1])
|
|
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
|
|
|
|
axis2.legend(loc=0)
|
|
|
|
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
|
|
load_saved=False, save_dataset=False, cache_dir=None, extract_args=None, max_ctx=2048):
|
|
|
|
if extract_args is None:
|
|
extract_args = {
|
|
"waveform": False,
|
|
"spec": False,
|
|
"f0": False,
|
|
"pitch_tokens": False,
|
|
"pitch": False,
|
|
"harmonic": False,
|
|
"aperiodic": False,
|
|
"sample_rate": 16000,
|
|
"hop_length": 256,
|
|
"mode": "mean",
|
|
"debug": False,
|
|
"phase_mod": False,
|
|
"crepe": False,
|
|
"dummy": False,
|
|
}
|
|
|
|
if load_saved:
|
|
if cache_dir is None:
|
|
cache_dir = "./processed_datasets"
|
|
else:
|
|
cache_dir = cache_dir
|
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
cache_file_train = os.path.join(cache_dir, "train.arrow")
|
|
cache_file_test = os.path.join(cache_dir, "test.arrow")
|
|
|
|
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
|
|
from datasets import Dataset
|
|
train_dataset = Dataset.load_from_disk(cache_file_train)
|
|
test_dataset = Dataset.load_from_disk(cache_file_test)
|
|
return train_dataset, test_dataset
|
|
|
|
if sanity_check:
|
|
test = load_dataset(
|
|
"google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).cast_column("audio", Audio(sampling_rate=sample_rate)).take(1)
|
|
dataset = test.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=test.column_names)
|
|
train_dataset = dataset
|
|
test_dataset = dataset
|
|
return train_dataset, test_dataset
|
|
else:
|
|
def filter_func(x):
|
|
return (0 < len(x["transcription"]) < max_ctx and
|
|
len(x["audio"]["array"]) > 0 and
|
|
len(x["audio"]["array"]) < max_ctx * 160)
|
|
|
|
raw_train = load_dataset(
|
|
"google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000)
|
|
raw_test = load_dataset(
|
|
"google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).take(100)
|
|
|
|
raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
|
raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
|
|
train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_train.column_names)
|
|
test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_test.column_names)
|
|
train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
|
|
test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
|
|
return train_dataset, test_dataset
|
|
|
|
class tgate(nn.Module):
|
|
def __init__(self, dims, num_types=4):
|
|
super().__init__()
|
|
self.gates = nn.ModuleList([nn.Sequential(Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)])
|
|
self.classifier = nn.Sequential(Linear(dims, num_types), nn.Softmax(dim=-1))
|
|
def forward(self, x):
|
|
types = self.classifier(x)
|
|
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
|
|
cgate = torch.sum(gates * types.unsqueeze(2), dim=-1)
|
|
return cgate
|
|
|
|
def get_feature_encoder(feature: str, mels: int, input_dims: int, dims: int, head: int, layer: int, act=None, features=None) -> nn.Module:
|
|
if feature == "spectrogram":
|
|
return FEncoder(mels=mels, input_dims=input_dims, dims=dims, head=head)
|
|
elif feature == "waveform":
|
|
return WEncoder(input_dims, dims, head, layer, act, feature, features)
|
|
elif feature == "pitch":
|
|
return PEncoder(input_dims, dims, head, layer, act, feature, features)
|
|
else:
|
|
raise ValueError(f"Unknown feature type: {feature}")
|
|
|
|
class FEncoder(nn.Module):
|
|
def __init__(self, mels, input_dims, dims, head, layer, act, feature, features, use_rope=False, spec_shape=None, debug=[]):
|
|
super().__init__()
|
|
|
|
self.head = head
|
|
self.head_dim = dims // head
|
|
self.dropout = 0.01
|
|
self.use_rope = use_rope
|
|
self.dims = dims
|
|
self.debug = debug
|
|
self.feature = feature
|
|
self.mels = mels
|
|
self.input_dims = input_dims
|
|
act_fn = get_activation(act)
|
|
|
|
self.encoder = nn.Sequential(
|
|
Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
|
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
|
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
|
|
|
if use_rope:
|
|
if spec_shape is not None:
|
|
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
|
|
else:
|
|
self.rope = None
|
|
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
self.norm = RMSNorm(dims)
|
|
|
|
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
|
|
batch, ctx, dims = x.shape
|
|
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
|
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
|
|
x = self.rope.apply_rotary(x, freqs)
|
|
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
|
|
|
return x
|
|
|
|
def forward(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
|
|
x = self.encoder(x).permute(0, 2, 1)
|
|
if self.use_rope:
|
|
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
|
|
else:
|
|
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
|
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
|
print(f"feature encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
class WEncoder(nn.Module):
|
|
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
|
|
super().__init__()
|
|
|
|
self.head = head
|
|
self.head_dim = dims // head
|
|
self.dropout = 0.01
|
|
self.use_rope = use_rope
|
|
self.dims = dims
|
|
self.debug = debug
|
|
act_fn = get_activation(act)
|
|
self.target_length = None
|
|
self.encoder = nn.Sequential(
|
|
Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
|
|
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
|
|
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
|
|
|
|
if use_rope:
|
|
if spec_shape is not None:
|
|
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
|
|
else:
|
|
self.rope = None
|
|
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
self.norm = RMSNorm(dims)
|
|
|
|
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"):
|
|
batch, ctx, dims = x.shape
|
|
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
|
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
|
|
x = self.rope.apply_rotary(x, freqs)
|
|
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
|
return x
|
|
|
|
def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"):
|
|
x = self.encoder(x).permute(0, 2, 1)
|
|
if self.target_length and x.shape[1] != self.target_length:
|
|
x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
|
|
if self.use_rope:
|
|
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
|
|
else:
|
|
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
|
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
|
print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
|
|
return self.norm(x)
|
|
|
|
class PEncoder(nn.Module):
|
|
def __init__(self, input_dims, dims, head, layer, act, use_rope=False, debug=[], one_shot=False, spec_shape=None):
|
|
super().__init__()
|
|
|
|
self.head = head
|
|
self.head_dim = dims // head
|
|
self.dims = dims
|
|
self.dropout = 0.01
|
|
self.use_rope = use_rope
|
|
self.debug = debug
|
|
act_fn = get_activation(act)
|
|
|
|
self.attend_pitch = False
|
|
|
|
if self.attend_pitch:
|
|
self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(dims, dims),
|
|
nn.ReLU(),
|
|
nn.Linear(dims, dims),
|
|
)
|
|
else:
|
|
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
|
|
self.mlp = None
|
|
|
|
self.pitch_encoder = nn.Sequential(
|
|
Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
|
|
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
|
|
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
|
|
|
if use_rope:
|
|
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
|
|
else:
|
|
self.rope = None
|
|
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
|
self.norm = RMSNorm(dims)
|
|
|
|
def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"):
|
|
batch, ctx, dims = x.shape
|
|
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
|
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
|
|
x = self.rope.apply_rotary(x, freqs)
|
|
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
|
return x
|
|
|
|
def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"):
|
|
if x.dim() == 2:
|
|
x = x.unsqueeze(0)
|
|
x = self.pitch_encoder(x).permute(0, 2, 1)
|
|
|
|
if self.use_rope:
|
|
x = self.rope_to_feature(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
|
|
|
|
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
|
|
if self.mlp is not None:
|
|
x = self.mlp(x)
|
|
|
|
if self.attend_pitch:
|
|
if xa is not None:
|
|
q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
|
|
out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
|
|
x = x + out
|
|
|
|
x = self.norm(x)
|
|
print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
|
|
return x
|
|
|
|
@dataclass
|
|
class DataCollator:
|
|
tokenizer: Any
|
|
|
|
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
all_keys = set()
|
|
for f in features:
|
|
all_keys.update(f.keys())
|
|
batch = {}
|
|
pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
|
|
bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
|
|
eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
|
|
|
|
for key in all_keys:
|
|
if key == "labels":
|
|
labels_list = [f["labels"] for f in features]
|
|
max_len = max(len(l) for l in labels_list)
|
|
all_ids, all_labels = [], []
|
|
|
|
for label in labels_list:
|
|
label_list = label.tolist() if isinstance(label, torch.Tensor) else label
|
|
decoder_input = [bos_token_id] + label_list
|
|
label_eos = label_list + [eos_token_id]
|
|
input_len = max_len + 1 - len(decoder_input)
|
|
label_len = max_len + 1 - len(label_eos)
|
|
padded_input = decoder_input + [pad_token_id] * input_len
|
|
padded_labels = label_eos + [pad_token_id] * label_len
|
|
all_ids.append(padded_input)
|
|
all_labels.append(padded_labels)
|
|
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
|
|
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
|
|
|
|
elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "pitch_tokens", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation", "dummy"]:
|
|
|
|
items = [f[key] for f in features if key in f]
|
|
items = [item for item in items if item is not None]
|
|
if not items:
|
|
continue
|
|
items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
|
|
max_len = max(item.shape[-1] for item in items)
|
|
padded = []
|
|
for item in items:
|
|
pad_width = max_len - item.shape[-1]
|
|
if pad_width > 0:
|
|
pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
|
|
else:
|
|
pad_item = item
|
|
padded.append(pad_item)
|
|
batch[key] = torch.stack(padded)
|
|
return batch
|
|
|
|
def levenshtein(reference_words, hypothesis_words):
|
|
m, n = len(reference_words), len(hypothesis_words)
|
|
dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)]
|
|
for i in range(m+1):
|
|
dist_matrix[i][0] = i
|
|
for j in range(n+1):
|
|
dist_matrix[0][j] = j
|
|
for i in range(1, m+1):
|
|
for j in range(1, n+1):
|
|
if reference_words[i-1] == hypothesis_words[j-1]:
|
|
dist_matrix[i][j] = dist_matrix[i-1][j-1]
|
|
else:
|
|
substitution = dist_matrix[i-1][j-1] + 1
|
|
insertion = dist_matrix[i][j-1] + 1
|
|
deletion = dist_matrix[i-1][j] + 1
|
|
dist_matrix[i][j] = min(substitution, insertion, deletion)
|
|
return dist_matrix[m][n]
|
|
|
|
def wer_batch(references, hypotheses):
|
|
total_errors = 0
|
|
total_words = 0
|
|
for ref, hyp in zip(references, hypotheses):
|
|
ref_words = ref.lower().split()
|
|
errors = levenshtein(ref_words, hyp.lower().split())
|
|
total_errors += errors
|
|
total_words += len(ref_words)
|
|
return (total_errors / total_words) * 100 if total_words > 0 else 0.0
|
|
|
|
def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, logits=None):
|
|
def clean(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
|
|
|
|
if isinstance(ids, torch.Tensor):
|
|
ids = ids.tolist()
|
|
if isinstance(ids[0], (list, torch.Tensor, np.ndarray)):
|
|
return [[int(i) for i in seq if i not in (-100, pad_token_id, bos_token_id, eos_token_id)] for seq in ids]
|
|
else:
|
|
return [int(i) for i in ids if i not in (-100, pad_token_id, bos_token_id, eos_token_id)]
|
|
|
|
pred_ids = pred.predictions
|
|
label_ids = pred.label_ids
|
|
|
|
if isinstance(pred_ids, tuple):
|
|
pred_ids = pred_ids[0]
|
|
|
|
if not isinstance(pred_ids, torch.Tensor):
|
|
pred_ids = torch.tensor(pred_ids)
|
|
|
|
label_ids = clean(label_ids)
|
|
pred_ids = clean(pred_ids)
|
|
pred_str = tokenizer.batch_decode(pred_ids)
|
|
label_str = tokenizer.batch_decode(label_ids)
|
|
|
|
if print_pred:
|
|
for i in range(min(num_samples, len(pred_ids))):
|
|
|
|
print(f"Pred tokens: {pred_ids[i]}")
|
|
print(f"Label tokens: {label_ids[i]}")
|
|
print(f"Pred: '{pred_str[i]}'")
|
|
print(f"Label: '{label_str[i]}'")
|
|
print("-" * 40)
|
|
|
|
wer = wer_batch(label_str, pred_str)
|
|
if model is not None:
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
|
|
efficiency_score = (100 - wer) / trainable_params if trainable_params > 0 else 0.0
|
|
else:
|
|
trainable_params = 0.0
|
|
efficiency_score = 0.0
|
|
|
|
return {
|
|
"wer": float(wer),
|
|
"efficiency_score": float(efficiency_score),
|
|
}
|
|
|
|
def preprocess_logits_for_metrics(logits, labels):
|
|
pred_ids = torch.argmax(logits, dim=-1)
|
|
return pred_ids, labels
|
|
|
|
def hilbert_transform(x):
|
|
N = x.shape[-1]
|
|
xf = torch.fft.rfft(x)
|
|
h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
|
|
if N % 2 == 0:
|
|
h[0] = h[N//2] = 1
|
|
h[1:N//2] = 2
|
|
else:
|
|
h[0] = 1
|
|
h[1:(N+1)//2] = 2
|
|
return torch.fft.irfft(xf * h, n=N)
|
|
|
|
def analytic_signal(x):
|
|
return x + 1j * hilbert_transform(x)
|
|
|
|
def hilbert_transform_2d(x, dim=-1):
|
|
N = x.shape[dim]
|
|
if dim == -1 or dim == len(x.shape) - 1:
|
|
xf = torch.fft.rfft(x)
|
|
else:
|
|
xf = torch.fft.rfft(x, dim=dim)
|
|
h_shape = [1] * len(x.shape)
|
|
h_shape[dim] = N // 2 + 1
|
|
h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
|
|
if dim == -1 or dim == len(x.shape) - 1:
|
|
if N % 2 == 0:
|
|
h[..., 0] = h[..., -1] = 1
|
|
h[..., 1:-1] = 2
|
|
else:
|
|
h[..., 0] = 1
|
|
h[..., 1:] = 2
|
|
else:
|
|
pass
|
|
return torch.fft.irfft(xf * h, n=N, dim=dim)
|
|
|
|
def hilbert_transform_true_2d(x):
|
|
xf = torch.fft.rfft2(x)
|
|
h1, h2 = torch.meshgrid(
|
|
torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
|
|
torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
|
|
indexing='ij')
|
|
h = -1j / (math.pi * (h1 + 1j*h2))
|
|
h[0, 0] = 0
|
|
return torch.fft.irfft2(xf * h.to(x.device))
|
|
|
|
def process_spectrogram_with_hilbert(spec):
|
|
analytic = spec + 1j * hilbert_transform(spec)
|
|
envelope = torch.abs(analytic)
|
|
phase = torch.angle(analytic)
|
|
return envelope, phase
|
|
|
|
|