asr-model / echoutils.py
Sin2pi's picture
Upload echoutils.py
78ab354 verified
import torch
import os
import pyworld as pw
import numpy as np
import torchaudio
import torch.nn.functional as F
from datasets import load_dataset, Audio
from dataclasses import dataclass
from typing import Any, List, Dict
import math
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.init as init
from torch import Tensor
from typing import Optional, Union, Tuple
from torch.nn.functional import scaled_dot_product_attention
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
def sinusoids(ctx, dims, max_tscale=10000):
assert dims % 2 == 0
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
positional_embedding = nn.Parameter(position, requires_grad=True)
return positional_embedding
class SLSTM(nn.Module):
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = False, bias = True, batch_first = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers, bias, batch_first)
def forward(self, x):
x = x.permute(2, 0, 1)
y, _ = self.lstm(x)
if self.skip:
y = y + x
y = y.permute(1, 2, 0)
return y
def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
qk_cosine = qk_cosine + mask
weights = F.softmax(qk_cosine, dim=-1)
out = torch.matmul(weights, v)
return out
def scaled_relu(x, sequence_length):
relu_output = torch.relu(x)
return relu_output / sequence_length
def taylor_softmax(x, order=2):
tapprox = 1.0
for i in range(1, order + 1):
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
tapprox += x**i / factorial_i
return tapprox / torch.sum(tapprox, dim=-1, keepdim=True)
def taylor_masked(x, mask, order=2):
tapprox = torch.zeros_like(x)
unmasked = x.masked_select(mask)
approx_values = 1.0 + unmasked
for i in range(1, order + 1):
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
approx_values += unmasked**i / factorial_i
tapprox.masked_scatter_(mask, approx_values)
sum_approx = torch.sum(tapprox, dim=-1, keepdim=True)
toutput = tapprox / (sum_approx + 1e-9)
toutput = toutput * mask
return toutput
def taylor_softmax2(x, mask=None, order=2):
if mask is None:
tapprox = 1.0 + x
for i in range(1, order + 1):
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
tapprox += x**i / factorial_i
return tapprox / torch.sum(tapprox, dim=-1, keepdim=True)
else:
tapprox = torch.zeros_like(x)
unmasked = x.masked_select(mask)
tapprox = 1.0 + unmasked
for i in range(1, order + 1):
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
tapprox += unmasked**i / factorial_i
tapprox_full = torch.zeros_like(x)
tapprox_full.masked_scatter_(mask, tapprox)
sum_approx = torch.sum(tapprox_full, dim=-1, keepdim=True)
toutput = tapprox_full / (sum_approx + 1e-9)
toutput = toutput * mask.float()
return toutput
def taylor_softmax_2nd_order(x):
exp_approx = 1 + x + (x**2) / 2
return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True)
def taylor_softmax_approximation(x, order=2):
if order == 0:
return torch.ones_like(x) / x.size(-1)
elif order == 1:
numerator = 1 + x
elif order == 2:
numerator = 1 + x + 0.5 * x**2
else:
raise NotImplementedError("Higher orders are not implemented yet.")
denominator = torch.sum(numerator, dim=-1, keepdim=True)
return numerator / denominator
def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0):
dot_scores = torch.matmul(q, k.transpose(-1, -2))
if rbf_ratio <= 0.0:
return dot_scores
q_norm = q.pow(2).sum(dim=-1, keepdim=True)
k_norm = k.pow(2).sum(dim=-1, keepdim=True)
qk = torch.matmul(q, k.transpose(-1, -2))
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
def sliding_window_mask(q_len, k_len, window, device):
idxs = torch.arange(q_len, device=device).unsqueeze(1)
jdxs = torch.arange(k_len, device=device).unsqueeze(0)
mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs)
return mask.float()
def mask_win(text_ctx, aud_ctx):
mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0)
audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype))
full_mask = torch.cat([mask, audio_mask], dim=-1)
return full_mask
def maskc(ctx, device):
return torch.tril(torch.ones(ctx, ctx, device=device, dtype=dtype), diagonal=0)
def create_attention_mask(batch_size, ctx, is_causal=True, padding_mask=None, device=None):
if is_causal:
mask = torch.triu(torch.ones((ctx, ctx), device=device), diagonal=0)
mask = mask.expand(batch_size, 1, ctx, ctx)
else:
mask = torch.zeros((batch_size, 1, ctx, ctx), device=device)
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).bool()
mask = (mask.bool() | (~padding_mask)).float()
return mask
def calculate_attention(q, k, v, mask=None, temp=1.0):
scaled_q = q
if temp != 1.0 and temp > 0:
scaled_q = q * (1.0 / temp)**.5
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
return out
def calculate_attentionb(q_norm, k_norm, v_iter, mask=None, temp=1.0):
d_k = q_norm.size(-1)
scores = torch.matmul(q_norm, k_norm.transpose(-2, -1)) / (torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) / temp)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, v_iter)
return output
class LocalOut(nn.Module):
def __init__(self, dims: int, head: int):
super().__init__()
self.head_dim = dims // head
self.dims = dims
self.q_module = nn.Linear(self.head_dim, self.head_dim)
self.k_module = nn.Linear(self.head_dim, self.head_dim)
self.v_module = nn.Linear(self.head_dim, self.head_dim)
self.o_proj = nn.Linear(self.head_dim, self.head_dim)
def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
batch, _, ctx, _ = attn_output.shape
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
def qkv_init(dims, head):
head_dim = dims // head
q = nn.Linear(dims, dims)
k = nn.Linear(dims, dims)
v = nn.Linear(dims, dims)
o = nn.Linear(dims, dims)
lna = nn.LayerNorm(dims)
lnb = nn.LayerNorm(dims)
lnc = nn.LayerNorm(head_dim)
lnd = nn.LayerNorm(head_dim)
return q, k, v, o, lna, lnb, lnc, lnd
def shape(dims, head, q, k, v):
batch_size = q.shape[0]
seq_len_q = q.shape[1]
seq_len_kv = k.shape[1]
head_dim = dims // head
q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
return q, k, v
def create_qkv(dims, head, q, k, v, x, xa):
head_dim = dims // head
scale = head_dim ** -0.25
q = q(x) * scale
k = k(xa) * scale
v = v(xa)
batch, ctx, dims = x.shape
def _shape(tensor):
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
return _shape(q), _shape(k), _shape(v)
class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
assert input_pos.shape[0] == k_val.shape[2]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return k_out, v_out
def mel_scale_scalar(freq: float) -> float:
return 1127.0 * math.log(1.0 + freq / 700.0)
def mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log()
def trace_x(func):
def wrapper(*args, **kwargs):
print(f"Calling {func.__name__}")
result = func(*args, **kwargs)
if isinstance(result, torch.Tensor):
print(f" {func.__name__} returned shape: {result.shape}")
return result
return wrapper
def track_x(new_x, operation=""):
""" track_x(x, "x") """
x_id = [id(new_x)]
if new_x is None:
return new_x
current_id = id(new_x)
if current_id != x_id[0]:
print(f"x FLOW: {x_id[0]}{current_id} in {operation}")
x_id[0] = current_id
else:
print(f"x REUSE: {current_id} in {operation}")
return new_x
def track_xa(new_xa, operation=""):
""" track_xa(xa, "xa - decoder") """
xa_id = [id(new_xa)] if new_xa is not None else [None]
if new_xa is None:
return new_xa
current_id = id(new_xa)
if current_id != xa_id[0]:
print(f"xa FLOW: {xa_id[0]}{current_id} in {operation}")
xa_id[0] = current_id
else:
print(f"xa REUSE: {current_id} in {operation}")
return new_xa
def get_activation(act: str) -> nn.Module:
"""Get activation function by name."""
act_map = {
"gelu": nn.GELU(),
"relu": nn.ReLU(),
"sigmoid": nn.Sigmoid(),
"tanh": nn.Tanh(),
"swish": nn.SiLU(),
"tanhshrink": nn.Tanhshrink(),
"softplus": nn.Softplus(),
"softshrink": nn.Softshrink(),
"leaky_relu": nn.LeakyReLU(),
"elu": nn.ELU()
}
return act_map.get(act, nn.GELU())
def get_generation_config(param):
return GenerationConfig(
max_length=param.text_ctx,
pad_token_id=getattr(param, "pad_token_id", 0),
bos_token_id=getattr(param, "bos_token_id", 1),
eos_token_id=getattr(param, "eos_token_id", 2),
do_sample=False,
num_beams=1,
early_stopping=False,
length_penalty=1.0,
no_repeat_ngram_size=0,
repetition_penalty=1.0,
temperature=1.0,
decoder_start_token_id=1,
is_multilingual=False,
use_cache=False,
return_timestamps=False)
class feature_encoder(nn.Module):
def __init__(self, mels, input_dims, dims, head, layer, act, features, feature=None, use_rope=False, spec_shape=None, debug=[], attend_feature=False, target_length=None):
"""
Feature encoder for audio processing.
"""
super().__init__()
self.dims = dims
self.head = head
self.head_dim = dims // head
self.dropout = 0.01
self.use_rope = use_rope
self.attend_feature = attend_feature
self.target_length = target_length
self.feature = feature
self.debug = debug
act_fn = get_activation(act)
if self.attend_feature:
self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims))
else:
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
self.mlp = None
self.spectrogram = nn.Sequential(
Conv1d(mels, dims, kernel_size=3), act_fn,
Conv1d(dims, dims, kernel_size=3), act_fn,
Conv1d(dims, dims, kernel_size=3, groups=dims), act_fn)
self.waveform = nn.Sequential(
Conv1d(1, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
self.pitch = nn.Sequential(
Conv1d(1, dims, kernel_size=7, stride=1, padding=3), act_fn,
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
if use_rope:
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def rope(self, x, xa=None, mask=None, feats=None, feature=None, layer=None):
if isinstance(x, int):
ctx = x
elif isinstance(x, torch.Tensor):
ctx = x.shape[1] if x.dim() > 1 else x.shape[0]
batch, ctx, dims = x.shape[0], ctx, x.shape[-1]
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
x = self.rope.apply_rotary(x, freqs)
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def mel_scalar(self, freq: float) -> float:
return 1127.0 * math.log(1.0 + freq / 700.0)
def forward(self, x, xa=None, mask=None, feats=None, feature=None, layer=None, max_tscale=36000):
target_length = x.shape[1] if self.target_length is None else self.target_length
if feature == "pitch":
xp = x.clone()
enc_dict = feats if feats is not None else {}
enc_dict = dict(enc_dict)
enc_dict["f0"] = xp
feats = enc_dict
if x.dim() == 2:
x = x.unsqueeze(0)
x = self.pitch(x).permute(0, 2, 1)
if feature == "phase":
if x.dim() == 2:
x = x.unsqueeze(0)
x = self.pitch(x).permute(0, 2, 1)
if feature == "waveform":
if x.dim() == 2:
x = x.unsqueeze(0)
x = self.waveform(x).permute(0, 2, 1)
if target_length and x.shape[1] != self.target_length:
x = F.adaptive_avg_pool1d(x.transpose(1, 2), target_length).transpose(1, 2)
if feature == "harmonics":
if x.dim() == 2:
x = x.unsqueeze(0)
x = self.spectrogram(x).permute(0, 2, 1)
if feature == "aperiodic":
if x.dim() == 2:
x = x.unsqueeze(0)
x = self.spectrogram(x).permute(0, 2, 1)
if feature == "spectrogram":
if x.dim() == 2:
x = x.unsqueeze(0)
x = self.spectrogram(x).permute(0, 2, 1)
if self.use_rope:
x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
x = self.rope(x=x, xa=None, mask=None, feats=feats, feature=feature, layer=layer)
else:
max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = self.norm(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = self.norm(x)
return x
class OneShot(nn.Module):
def __init__(self, dims: int, head: int, scale: float = 0.3, features: Optional[List[str]] = None):
super().__init__()
if features is None:
features = ["spectrogram", "waveform", "pitch", "aperiodic", "harmonics"]
self.head = head
self.head_dim = dims // head
self.scale = 1.0 // len(features) if features else scale
self.q = Linear(dims, dims)
self.k = Linear(dims, dims)
def forward(self, x: Tensor, xa: Tensor, feature=None) -> Tensor | None:
B, L, D = x.shape
K = xa.size(1)
q = self.q(x).view(B, L, self.head, self.head_dim).transpose(1,2)
k = self.k(xa).view(B, K, self.head, self.head_dim).transpose(1,2)
bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.head_dim)
return bias
class curiosity(nn.Module):
def __init__(self, d, h, bias=True):
super().__init__()
self.h = h
self.dh = d // h
self.qkv = nn.Linear(d, d * 3, bias=bias)
self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
self.o = nn.Linear(d, d, bias=bias)
self.g = nn.Parameter(torch.zeros(h))
def split(self, x):
b, t, _ = x.shape
return x.view(b, t, self.h, self.dh).transpose(1, 2)
def merge(self, x):
b, h, t, dh = x.shape
return x.transpose(1, 2).contiguous().view(b, t, h * dh)
def forward(self, x, xa, mask=None):
q, k, v = self.qkv(x).chunk(3, -1)
qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
q, k, v = map(self.split, (q, k, v))
qa, ka, va = map(self.split, (qa, ka, va))
dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
if mask is not None: dots = dots.masked_fill(mask, -9e15)
p = dots.softmax(-1)
pa = dots_aux.softmax(-1)
h_main = p @ v
h_aux = pa @ va
g = torch.sigmoid(self.g).view(1, -1, 1, 1)
out = self.merge(h_main * (1 - g) + h_aux * g)
return self.o(out)
class PositionalEncoding(nn.Module):
def __init__(self, dims, ctx):
super(PositionalEncoding, self).__init__()
self.dims = dims
self.ctx = ctx
self.pe = self.get_positional_encoding(max_ctx=ctx)
def get_positional_encoding(self, max_ctx):
pe = torch.zeros(max_ctx, self.dims)
position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.dims, 2, dtype=torch.float32)
* (-math.log(10000.0) / self.dims)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe.to(device)
def forward(self, x):
ctx = x.size(1)
pe = self.pe[:, :ctx, :]
x = x * math.sqrt(self.dims)
x = x + pe
return x
def valid(default_value, *items):
for item in items:
if item is not None:
return item
return default_value
def dict_to(d, device, dtype=dtype):
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
for k, v in d.items()}
def exists(v):
return v is not None
def default(v, b):
return v if exists(v) else b
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias) -> Tensor:
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
class Conv2d(nn.Conv2d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias) -> Tensor:
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(Linear, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
init.xavier_uniform_(self.linear.weight)
if bias:
init.zeros_(self.linear.bias)
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
class RMSNorm(nn.Module):
def __init__(self, dims: Union[int, Tensor, List, Tuple],
eps = 1e-8, elementwise_affine = True):
super(RMSNorm, self).__init__()
if isinstance(dims, int):
self.normalized_shape = (dims,)
else:
self.normalized_shape = tuple(dims)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
init.ones_(self.weight)
else:
self.register_parameter("weight", None)
def forward(self, x):
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
eps: float = 1e-5) -> Tensor:
return F.layer_norm(x, normalized_shape, weight, bias, eps)
def get_device():
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_dtype():
return torch.float32 if torch.cuda.is_available() else torch.float64
def tox():
return {"device": get_device(), "dtype": get_dtype()}
def sinusoids(ctx, dims, max_tscale=10000):
assert dims % 2 == 0
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1)
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32))
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1)
positional_embedding = nn.Parameter(position, requires_grad=True)
return positional_embedding
class SelfCriticalRL(nn.Module):
def __init__(self, model, tokenizer, reward_fn):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.reward_fn = reward_fn
def forward(self, input_ids, features, labels=None, max_len=128, feature_name="spectrogram"):
with torch.no_grad():
greedy_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len)
greedy_text = [self.tokenizer.decode(ids) for ids in greedy_ids]
sampled_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len, do_sample=True, top_k=5)
sampled_text = [self.tokenizer.decode(ids) for ids in sampled_ids]
rewards = []
baseline = []
for s, g, ref in zip(sampled_text, greedy_text, labels):
ref_text = self.tokenizer.decode(ref)
rewards.append(self.reward_fn(s, ref_text))
baseline.append(self.reward_fn(g, ref_text))
rewards = torch.tensor(rewards, device=device, dtype=torch.float)
baseline = torch.tensor(baseline, device=device, dtype=torch.float)
advantage = rewards - baseline
logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"]
log_probs = F.log_softmax(logits, dim=-1)
log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1)
log_probs_sum = log_probs_seq.sum(dim=1)
loss = -(advantage * log_probs_sum).mean()
return loss
class SelfTrainingModule(nn.Module):
def __init__(self, model, tokenizer, quality_fn=None, threshold=0.8):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.quality_fn = quality_fn
self.threshold = threshold
def generate_pseudo_labels(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
with torch.no_grad():
pred_ids = self.model.generate(input_ids=unlabeled_batch, **{feature_name: features}, max_length=max_len)
if self.quality_fn is not None:
quality_scores = self.quality_fn(pred_ids, self.model, features)
mask = quality_scores > self.threshold
pred_ids = pred_ids[mask]
return pred_ids
def forward(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
pseudo_labels = self.generate_pseudo_labels(unlabeled_batch, features, max_len, feature_name=feature_name)
logits = self.model(input_ids=unlabeled_batch, **{feature_name: features}, labels=pseudo_labels)["logits"]
loss = nn.functional.cross_entropy(
logits.view(-1, logits.shape[-1]), pseudo_labels.view(-1), ignore_index=0)
return loss
def confidence_indicator(pred_ids, model, features):
with torch.no_grad():
logits = model(input_ids=pred_ids, **features)["logits"]
probs = torch.softmax(logits, dim=-1)
max_probs, _ = probs.max(dim=-1)
return max_probs.mean(dim=1)
def wer_reward(hyp, ref):
hyp_words = hyp.split()
ref_words = ref.split()
d = [[0] * (len(ref_words)+1) for _ in range(len(hyp_words)+1)]
for i in range(len(hyp_words)+1):
d[i][0] = i
for j in range(len(ref_words)+1):
d[0][j] = j
for i in range(1, len(hyp_words)+1):
for j in range(1, len(ref_words)+1):
if hyp_words[i-1] == ref_words[j-1]:
d[i][j] = d[i-1][j-1]
else:
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
wer = d[-1][-1] / max(1, len(ref_words))
return -wer
def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
return [int(id) for id in ids if id != -100 and id != pad_token_id and id != bos_token_id and id != eos_token_id]
def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids]
def setup_tokenizer(dir: str):
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(f"{dir}")
orig_encode = tokenizer.encode
orig_decode = tokenizer.decode
def enc(text, add_special_tokens=True):
ids = orig_encode(text).ids
if not add_special_tokens:
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
ids = [id for id in ids if id not in sp_ids]
return ids
def bdec(ids_list, pad_token_id=0, bos_token_id=1, eos_token_id=2, skip_special_tokens=True):
results = []
if isinstance(ids_list, torch.Tensor):
ids_list = ids_list.tolist()
elif isinstance(ids_list, np.ndarray):
ids_list = ids_list.tolist()
for ids in ids_list:
ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)]
results.append(orig_decode(ids))
return results
def dec(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)]
return orig_decode(ids)
def save_pretrained(save_dir):
os.makedirs(save_dir, exist_ok=True)
tokenizer.save(f"{save_dir}/tokenizer.json")
tokenizer.encode = enc
tokenizer.batch_decode = bdec
tokenizer.decode = dec
tokenizer.save_pretrained = save_pretrained
tokenizer.pad_token_id = 0
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
return tokenizer
def tokenize_pitch(pitch_features, target_length):
pitch_len = pitch_features.shape[-1]
token_len = target_length
if pitch_len > token_len:
pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len)
else:
pitch_tokens = F.interpolate(pitch_features, token_len)
return pitch_tokens
def load_wave(wave_data, sample_rate=16000):
if isinstance(wave_data, str):
waveform, sample_rate = torchaudio.load(uri=wave_data, normalize=False)
elif isinstance(wave_data, dict):
waveform = torch.tensor(data=wave_data["array"]).float()
sample_rate = wave_data["sampling_rate"]
else:
raise TypeError("Invalid wave_data format.")
return waveform
def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
import librosa
mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
mel_basis = torch.from_numpy(mel_basis).float()
sp_mel = torch.matmul(sp, mel_basis.T)
ap_mel = torch.matmul(ap, mel_basis.T)
return sp_mel, ap_mel
def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False):
torch_windows = {
'hann': torch.hann_window,
'hamming': torch.hamming_window,
'blackman': torch.blackman_window,
'bartlett': torch.bartlett_window,
'ones': torch.ones,
None: torch.ones,
}
if dummy:
return {
"spectrogram": torch.zeros((1, 128, 100)),
"f0": torch.zeros((1, 100)),
"pitch_tokens": torch.zeros((1, 100)),
"pitch": torch.zeros((1, 100)),
"harmonics": torch.zeros((1, 128, 100)),
"aperiodics": torch.zeros((1, 128, 100)),
"crepe_time": None,
"crepe_frequency": None,
"crepe_confidence": None,
"crepe_activation": None,
}
audio = batch["audio"]
sample_rate = audio["sampling_rate"]
labels = tokenizer.encode(batch["transcription"])
wav = load_wave(wave_data=audio, sample_rate=sample_rate)
def crepe_predict(wav, sample_rate, viterbi=False):
import torchcrepe
wav = wav.numpy().astype(np.float32)
time, frequency, confidence, activation = torchcrepe.predict(
wav, sample_rate=sample_rate, viterbi=viterbi)
crepe_time = torch.from_numpy(time)
crepe_frequency = torch.from_numpy(frequency)
crepe_confidence = torch.from_numpy(confidence)
crepe_activation = torch.from_numpy(activation)
return crepe_time, crepe_frequency, crepe_confidence, crepe_activation
if crepe:
crepe_time, crepe_frequency, crepe_confidence, crepe_activation = crepe_predict(wav, sample_rate, viterbi=True)
else:
crepe_time = None
crepe_frequency = None
crepe_confidence = None
crepe_activation = None
def spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
if isinstance(window_fn, str):
window_fn = torch_windows[window_fn]
if window_fn is None:
window_fn = torch.ones(n_fft)
if isinstance(window_fn, torch.Tensor):
window_fn = window_fn.to(device)
return torchaudio.functional.spectrogram(
wav, n_fft=n_fft, hop_length=hop_length, win_length=n_fft,
window=window_fn, center=True, pad_mode="reflect", power=1.0)
def mel_spectrogram(wav, sample_rate):
spectrogram_config = {
"hop_length": 256,
"f_min": 150,
"f_max": 2000,
"n_mels": 128,
"n_fft": 1024,
"sample_rate": 16000,
"pad_mode": "constant",
"center": True,
"power": 1.0,
"window_fn": torch.hann_window,
"mel_scale": "htk",
"norm": None,
"normalized": False,
}
transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
mel_spectrogram = transform(wav)
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
spectrogram_tensor = (log_mel + 4.0) / 4.0
return spectrogram_tensor
if spec:
spectrogram_tensor = mel_spectrogram(wav, sample_rate)
def mfcc(wav, sample_rate, n_mels=128, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mels,
melkwargs={
"n_fft": n_fft,
"hop_length": hop_length,
"window_fn": window_fn,
"n_mels": n_mels,
"center": True,
"pad_mode": "reflect",
"norm": None,
"mel_scale": "htk",
}
)
mfcc_tensor = transform(wav)
return mfcc_tensor
def harmonics_and_aperiodics(wav, f0, t, sample_rate):
import pyworld as pw
wav_np = wav.numpy().astype(np.float64)
sp = pw.cheaptrick(wav_np, f0, t, sample_rate, fft_size=256)
ap = pw.d4c(wav_np, f0, t, sample_rate, fft_size=256)
harmonic_tensor = torch.from_numpy(sp)
aperiodic_tensor = torch.from_numpy(ap)
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
return harmonic_tensor, aperiodic_tensor
if pitch or pitch_tokens or harmonics or aperiodics:
wavnp = wav.numpy().astype(np.float64)
f0_np, t = pw.dio(wavnp, sample_rate, frame_period=hop_length / sample_rate * 1000)
f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
if pitch_tokens:
wav = torch.from_numpy(wavnp)
t2 = torch.from_numpy(t)
audio_duration = len(wav) / sample_rate
T = len(labels)
tok_dur_sec = audio_duration / T
token_starts = torch.arange(T) * tok_dur_sec
token_ends = token_starts + tok_dur_sec
start_idx = torch.searchsorted(t2, token_starts, side="left")
end_idx = torch.searchsorted(t2, token_ends, side="right")
pitch_tok = torch.zeros(T, dtype=torch.float32)
for i in range(T):
lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
segment = f0_np[lo:hi]
if mode == "mean":
pitch_tok[i] = segment.mean()
elif mode == "median":
pitch_tok[i] = torch.median(segment)
else:
pitch_tok[i] = segment[-1]
pitch_tok[pitch_tok < 100.0] = 0.0
bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
pitch_tokens_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok])
pitch_tokens_tensor = torch.where(pitch_tokens_tensor == 0.0, torch.zeros_like(pitch_tokens_tensor), (pitch_tokens_tensor - 71.0) / (500.0 - 71.0))
else:
pitch_tokens_tensor = None
if phase_mod:
tframe = torch.mean(t2[1:] - t2[:-1])
phi0 = 0.0
omega = 2 * torch.pi * f0_tensor
dphi = omega * tframe
phi = torch.cumsum(dphi, dim=0) + phi0
phase = torch.remainder(phi, 2 * torch.pi)
else:
phase = None
if pitch:
p_tensor = torchaudio.functional.detect_pitch_frequency(wav, sample_rate).unsqueeze(0)
else:
p_tensor = None
if harmonics or aperiodics:
spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
harmonic_tensor = torch.from_numpy(spnp)
aperiodic_tensor = torch.from_numpy(apnp)
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
else:
harmonic_tensor = None
aperiodic_tensor = None
if waveform:
wave_tensor = wav
else:
wave_tensor = None
if dummy:
if spectrogram_tensor is not None:
dummy_tensor = torch.ones_like(spectrogram_tensor)
elif p_tensor is not None:
dummy_tensor = torch.ones_like(p_tensor)
elif pitch_tokens_tensor is not None:
dummy_tensor = torch.ones_like(pitch_tokens_tensor)
else:
batch_size = 128
seq_len = 1024
dummy_tensor = torch.ones(batch_size, seq_len)
dummy_tensor = dummy_tensor.to(device)
else:
dummy_tensor = None
if debug:
print(f"['pitch_tokens']: {pitch_tokens_tensor.shape if pitch_tokens else None}")
print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}")
print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}")
print(f"['spectrogram']: {spectrogram_tensor.shape if spec else None}")
print(f"['waveform']: {wave_tensor.shape if waveform else None}")
print(f"['labels']: {len(labels) if labels else None}")
print(f"['phase']: {phase.shape if phase else None}")
print(f"['pitch']: {p_tensor.shape if pitch else None}")
print(f"['crepe_time']: {crepe_time.shape if crepe else None}")
print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}")
print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}")
print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}")
print(f"['dummy']: {dummy_tensor.shape if dummy else None}")
return {
"waveform": wave_tensor if waveform else None,
"spectrogram": spectrogram_tensor if spec else None,
"pitch_tokens": pitch_tokens_tensor if pitch_tokens else None,
"pitch": p_tensor if pitch else None,
"harmonic": harmonic_tensor if harmonics else None,
"aperiodic": aperiodic_tensor if aperiodics else None,
"labels": labels,
"phase": phase if phase_mod else None,
"crepe_time": crepe_time if crepe else None,
"crepe_frequency": crepe_frequency if crepe else None,
"crepe_confidence": crepe_confidence if crepe else None,
"crepe_activation": crepe_activation if crepe else None,
"dummy": dummy_tensor if dummy else None,
}
def plot_waveform(waveform, sr, title="Waveform", ax=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sr
if ax is None:
_, ax = plt.subplots(num_channels, 1)
ax.plot(time_axis, waveform[0], linewidth=1)
ax.grid(True)
ax.set_xlim([0, time_axis[-1]])
ax.set_title(title)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
import librosa
if ax is None:
_, ax = plt.subplots(1, 1)
if title is not None:
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
def plot_fbank(fbank, title=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Filter bank")
axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin")
def plot_pitch(waveform, sr, pitch):
figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sr
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, pitch.shape[1])
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
load_saved=False, save_dataset=False, cache_dir=None, extract_args=None, max_ctx=2048):
if extract_args is None:
extract_args = {
"waveform": False,
"spec": False,
"f0": False,
"pitch_tokens": False,
"pitch": False,
"harmonic": False,
"aperiodic": False,
"sample_rate": 16000,
"hop_length": 256,
"mode": "mean",
"debug": False,
"phase_mod": False,
"crepe": False,
"dummy": False,
}
if load_saved:
if cache_dir is None:
cache_dir = "./processed_datasets"
else:
cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
cache_file_train = os.path.join(cache_dir, "train.arrow")
cache_file_test = os.path.join(cache_dir, "test.arrow")
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
from datasets import Dataset
train_dataset = Dataset.load_from_disk(cache_file_train)
test_dataset = Dataset.load_from_disk(cache_file_test)
return train_dataset, test_dataset
if sanity_check:
test = load_dataset(
"google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).cast_column("audio", Audio(sampling_rate=sample_rate)).take(1)
dataset = test.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=test.column_names)
train_dataset = dataset
test_dataset = dataset
return train_dataset, test_dataset
else:
def filter_func(x):
return (0 < len(x["transcription"]) < max_ctx and
len(x["audio"]["array"]) > 0 and
len(x["audio"]["array"]) < max_ctx * 160)
raw_train = load_dataset(
"google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000)
raw_test = load_dataset(
"google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).take(100)
raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_train.column_names)
test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_test.column_names)
train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
return train_dataset, test_dataset
class tgate(nn.Module):
def __init__(self, dims, num_types=4):
super().__init__()
self.gates = nn.ModuleList([nn.Sequential(Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)])
self.classifier = nn.Sequential(Linear(dims, num_types), nn.Softmax(dim=-1))
def forward(self, x):
types = self.classifier(x)
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
cgate = torch.sum(gates * types.unsqueeze(2), dim=-1)
return cgate
def get_feature_encoder(feature: str, mels: int, input_dims: int, dims: int, head: int, layer: int, act=None, features=None) -> nn.Module:
if feature == "spectrogram":
return FEncoder(mels=mels, input_dims=input_dims, dims=dims, head=head)
elif feature == "waveform":
return WEncoder(input_dims, dims, head, layer, act, feature, features)
elif feature == "pitch":
return PEncoder(input_dims, dims, head, layer, act, feature, features)
else:
raise ValueError(f"Unknown feature type: {feature}")
class FEncoder(nn.Module):
def __init__(self, mels, input_dims, dims, head, layer, act, feature, features, use_rope=False, spec_shape=None, debug=[]):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dropout = 0.01
self.use_rope = use_rope
self.dims = dims
self.debug = debug
self.feature = feature
self.mels = mels
self.input_dims = input_dims
act_fn = get_activation(act)
self.encoder = nn.Sequential(
Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
if use_rope:
if spec_shape is not None:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
x = self.rope.apply_rotary(x, freqs)
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
x = self.encoder(x).permute(0, 2, 1)
if self.use_rope:
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
else:
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
print(f"feature encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
x = self.norm(x)
return x
class WEncoder(nn.Module):
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dropout = 0.01
self.use_rope = use_rope
self.dims = dims
self.debug = debug
act_fn = get_activation(act)
self.target_length = None
self.encoder = nn.Sequential(
Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
if use_rope:
if spec_shape is not None:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
x = self.rope.apply_rotary(x, freqs)
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"):
x = self.encoder(x).permute(0, 2, 1)
if self.target_length and x.shape[1] != self.target_length:
x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
if self.use_rope:
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
else:
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
return self.norm(x)
class PEncoder(nn.Module):
def __init__(self, input_dims, dims, head, layer, act, use_rope=False, debug=[], one_shot=False, spec_shape=None):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dims = dims
self.dropout = 0.01
self.use_rope = use_rope
self.debug = debug
act_fn = get_activation(act)
self.attend_pitch = False
if self.attend_pitch:
self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
self.mlp = nn.Sequential(
nn.Linear(dims, dims),
nn.ReLU(),
nn.Linear(dims, dims),
)
else:
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
self.mlp = None
self.pitch_encoder = nn.Sequential(
Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
if use_rope:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
x = self.rope.apply_rotary(x, freqs)
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"):
if x.dim() == 2:
x = x.unsqueeze(0)
x = self.pitch_encoder(x).permute(0, 2, 1)
if self.use_rope:
x = self.rope_to_feature(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
if self.mlp is not None:
x = self.mlp(x)
if self.attend_pitch:
if xa is not None:
q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
x = x + out
x = self.norm(x)
print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
return x
@dataclass
class DataCollator:
tokenizer: Any
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
all_keys = set()
for f in features:
all_keys.update(f.keys())
batch = {}
pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
for key in all_keys:
if key == "labels":
labels_list = [f["labels"] for f in features]
max_len = max(len(l) for l in labels_list)
all_ids, all_labels = [], []
for label in labels_list:
label_list = label.tolist() if isinstance(label, torch.Tensor) else label
decoder_input = [bos_token_id] + label_list
label_eos = label_list + [eos_token_id]
input_len = max_len + 1 - len(decoder_input)
label_len = max_len + 1 - len(label_eos)
padded_input = decoder_input + [pad_token_id] * input_len
padded_labels = label_eos + [pad_token_id] * label_len
all_ids.append(padded_input)
all_labels.append(padded_labels)
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "pitch_tokens", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation", "dummy"]:
items = [f[key] for f in features if key in f]
items = [item for item in items if item is not None]
if not items:
continue
items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
max_len = max(item.shape[-1] for item in items)
padded = []
for item in items:
pad_width = max_len - item.shape[-1]
if pad_width > 0:
pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
else:
pad_item = item
padded.append(pad_item)
batch[key] = torch.stack(padded)
return batch
def levenshtein(reference_words, hypothesis_words):
m, n = len(reference_words), len(hypothesis_words)
dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)]
for i in range(m+1):
dist_matrix[i][0] = i
for j in range(n+1):
dist_matrix[0][j] = j
for i in range(1, m+1):
for j in range(1, n+1):
if reference_words[i-1] == hypothesis_words[j-1]:
dist_matrix[i][j] = dist_matrix[i-1][j-1]
else:
substitution = dist_matrix[i-1][j-1] + 1
insertion = dist_matrix[i][j-1] + 1
deletion = dist_matrix[i-1][j] + 1
dist_matrix[i][j] = min(substitution, insertion, deletion)
return dist_matrix[m][n]
def wer_batch(references, hypotheses):
total_errors = 0
total_words = 0
for ref, hyp in zip(references, hypotheses):
ref_words = ref.lower().split()
errors = levenshtein(ref_words, hyp.lower().split())
total_errors += errors
total_words += len(ref_words)
return (total_errors / total_words) * 100 if total_words > 0 else 0.0
def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, logits=None):
def clean(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
if isinstance(ids[0], (list, torch.Tensor, np.ndarray)):
return [[int(i) for i in seq if i not in (-100, pad_token_id, bos_token_id, eos_token_id)] for seq in ids]
else:
return [int(i) for i in ids if i not in (-100, pad_token_id, bos_token_id, eos_token_id)]
pred_ids = pred.predictions
label_ids = pred.label_ids
if isinstance(pred_ids, tuple):
pred_ids = pred_ids[0]
if not isinstance(pred_ids, torch.Tensor):
pred_ids = torch.tensor(pred_ids)
label_ids = clean(label_ids)
pred_ids = clean(pred_ids)
pred_str = tokenizer.batch_decode(pred_ids)
label_str = tokenizer.batch_decode(label_ids)
if print_pred:
for i in range(min(num_samples, len(pred_ids))):
print(f"Pred tokens: {pred_ids[i]}")
print(f"Label tokens: {label_ids[i]}")
print(f"Pred: '{pred_str[i]}'")
print(f"Label: '{label_str[i]}'")
print("-" * 40)
wer = wer_batch(label_str, pred_str)
if model is not None:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
efficiency_score = (100 - wer) / trainable_params if trainable_params > 0 else 0.0
else:
trainable_params = 0.0
efficiency_score = 0.0
return {
"wer": float(wer),
"efficiency_score": float(efficiency_score),
}
def preprocess_logits_for_metrics(logits, labels):
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids, labels
def hilbert_transform(x):
N = x.shape[-1]
xf = torch.fft.rfft(x)
h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
if N % 2 == 0:
h[0] = h[N//2] = 1
h[1:N//2] = 2
else:
h[0] = 1
h[1:(N+1)//2] = 2
return torch.fft.irfft(xf * h, n=N)
def analytic_signal(x):
return x + 1j * hilbert_transform(x)
def hilbert_transform_2d(x, dim=-1):
N = x.shape[dim]
if dim == -1 or dim == len(x.shape) - 1:
xf = torch.fft.rfft(x)
else:
xf = torch.fft.rfft(x, dim=dim)
h_shape = [1] * len(x.shape)
h_shape[dim] = N // 2 + 1
h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
if dim == -1 or dim == len(x.shape) - 1:
if N % 2 == 0:
h[..., 0] = h[..., -1] = 1
h[..., 1:-1] = 2
else:
h[..., 0] = 1
h[..., 1:] = 2
else:
pass
return torch.fft.irfft(xf * h, n=N, dim=dim)
def hilbert_transform_true_2d(x):
xf = torch.fft.rfft2(x)
h1, h2 = torch.meshgrid(
torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
indexing='ij')
h = -1j / (math.pi * (h1 + 1j*h2))
h[0, 0] = 0
return torch.fft.irfft2(xf * h.to(x.device))
def process_spectrogram_with_hilbert(spec):
analytic = spec + 1j * hilbert_transform(spec)
envelope = torch.abs(analytic)
phase = torch.angle(analytic)
return envelope, phase