|
|
|
import pyworld as pw |
|
import os |
|
import math |
|
import warnings |
|
import logging |
|
import gzip |
|
import base64 |
|
import torch |
|
import torchaudio |
|
import torchcrepe |
|
import torch.nn.functional as F |
|
import torch.nn.init as init |
|
from torch import nn, Tensor |
|
import numpy as np |
|
from typing import Optional, Dict, Union, List, Tuple, Any |
|
from functools import partial |
|
from datetime import datetime |
|
from datasets import load_dataset, Audio |
|
from transformers.trainer_seq2seq import Seq2SeqTrainer |
|
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments |
|
import transformers |
|
import evaluate |
|
from dataclasses import dataclass |
|
|
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.set_float32_matmul_precision('high') |
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
device = torch.device(device="cuda:0") |
|
dtype = torch.float32 |
|
|
|
torch.set_default_dtype(dtype) |
|
warnings.filterwarnings("ignore") |
|
logging.basicConfig(level=logging.ERROR) |
|
tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32} |
|
|
|
extractor = None |
|
tokenizer = None |
|
optimizer = None |
|
scheduler = None |
|
model = None |
|
Residual = None |
|
MultiheadA = None |
|
|
|
@dataclass |
|
class Dimensions: |
|
vocab: int |
|
text_ctx: int |
|
text_dims: int |
|
text_head: int |
|
text_idx: int |
|
mels: int |
|
aud_ctx: int |
|
aud_dims: int |
|
aud_head: int |
|
aud_idx: int |
|
act: str |
|
debug: List[str] |
|
cross_attn: bool |
|
features: List[str] |
|
f0_rotary: bool |
|
|
|
def exists(v): |
|
return v is not None |
|
|
|
def default(v, b): |
|
return v if exists(v) else b |
|
|
|
class Conv1d(nn.Conv1d): |
|
def _conv_forward( |
|
self, x: Tensor, weight: Tensor, bias) -> Tensor: |
|
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype)) |
|
|
|
class Conv2d(nn.Conv2d): |
|
def _conv_forward( |
|
self, x: Tensor, weight: Tensor, bias) -> Tensor: |
|
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype)) |
|
|
|
class Linear(nn.Module): |
|
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: |
|
super(Linear, self).__init__() |
|
self.linear = nn.Linear(in_features, out_features, bias=bias) |
|
init.xavier_uniform_(self.linear.weight) |
|
if bias: |
|
init.zeros_(self.linear.bias) |
|
def forward(self, x: Tensor) -> Tensor: |
|
return self.linear(x) |
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dims: Union[int, Tensor, List, Tuple], |
|
eps = 1e-8, elementwise_affine = True): |
|
super(RMSNorm, self).__init__() |
|
if isinstance(dims, int): |
|
self.normalized_shape = (dims,) |
|
else: |
|
self.normalized_shape = tuple(dims) |
|
self.eps = eps |
|
self.elementwise_affine = elementwise_affine |
|
if self.elementwise_affine: |
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape)) |
|
init.ones_(self.weight) |
|
else: |
|
self.register_parameter("weight", None) |
|
def forward(self, x): |
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) |
|
|
|
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple], |
|
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, |
|
eps: float = 1e-5) -> Tensor: |
|
return F.layer_norm(x, normalized_shape, weight, bias, eps) |
|
|
|
def get_device(): |
|
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
def get_dtype(): |
|
return torch.float32 if torch.cuda.is_available() else torch.float64 |
|
|
|
def get_tox(): |
|
return {"device": get_device(), "dtype": get_dtype()} |
|
|
|
def sinusoids(length, channels, max_timescale=10000): |
|
"""Returns sinusoids for positional embedding""" |
|
assert channels % 2 == 0 |
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
|
|
|
class rotary(nn.Module): |
|
_seen = set() |
|
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False, |
|
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []): |
|
super().__init__() |
|
|
|
self.dims = dims |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
dtype = torch.float32 |
|
self.device = device |
|
self.dtype = dtype |
|
self.debug = debug |
|
self._counter = 0 |
|
|
|
self.use_pbias = False |
|
self.max_ctx = max_ctx |
|
self.variable_radius = variable_radius |
|
|
|
self.inv_freq = nn.Parameter(1.0 / (theta ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)), |
|
requires_grad=learned_freq) |
|
|
|
self.theta = nn.Parameter(torch.tensor(float(theta)), |
|
requires_grad=learned_theta) |
|
|
|
self.pitch_scale = nn.Parameter(torch.tensor(1.0), requires_grad=learned_pitch) |
|
|
|
if variable_radius: |
|
self.radius = nn.Parameter(torch.ones(dims // 2), requires_grad=learned_radius) |
|
|
|
def get_pitch_bias(self, f0): |
|
if f0 is None: |
|
return None |
|
|
|
f0_flat = f0.squeeze().float() |
|
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8) |
|
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1), |
|
f0_norm.unsqueeze(1)) * self.pitch_scale) |
|
return f0_sim.unsqueeze(0).unsqueeze(0) |
|
|
|
def add_to_rotary(self): |
|
def get_sim(self, freqs): |
|
real = freqs.real.squeeze(0) |
|
imag = freqs.imag.squeeze(0) |
|
vecs = torch.cat([real.unsqueeze(-2), imag.unsqueeze(-2)], dim=-1) |
|
vecs = vecs.squeeze(-2) |
|
return F.cosine_similarity(vecs.unsqueeze(1), vecs.unsqueeze(0), dim=-1) |
|
|
|
def fwd_sim(self, x=None, f0=None): |
|
freqs = self.forward(x, f0) |
|
sim = get_sim(self, freqs) |
|
return freqs, sim |
|
|
|
rotary.get_sim = get_sim |
|
rotary.fwd_sim = fwd_sim |
|
|
|
def align_f0_to_tokens(self, f0, token_length): |
|
ratio = len(f0) / token_length |
|
indices = [int(i * ratio) for i in range(token_length)] |
|
indices = [min(i, len(f0) - 1) for i in indices] |
|
return f0[indices] |
|
|
|
def forward(self, x=None, f0=None, stage=None) -> Tensor: |
|
if isinstance(x, int): |
|
t = torch.arange(x, device=self.device).float() |
|
else: |
|
t = x.float().to(self.inv_freq.device) |
|
if f0 is not None: |
|
f0_mean = f0.mean() |
|
f0_theta = (f0_mean**2) * self.pitch_scale |
|
|
|
inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims)) |
|
else: |
|
inv_freq = self.inv_freq |
|
freqs = torch.einsum('i,j->ij', t, inv_freq) |
|
freqs = freqs.float() |
|
if self.variable_radius: |
|
if f0 is not None: |
|
f0 = f0[0] |
|
seq_len = x |
|
f0 = torch.tensor(f0, device=device if isinstance(x, torch.Tensor) else device) |
|
f0 = self.align_f0_to_tokens(f0, freqs.shape[-1]) |
|
radius = 1.0 / (f0 + 1) |
|
freqs = torch.polar(radius, freqs) |
|
else: |
|
freqs = torch.polar(torch.ones_like(freqs), freqs) |
|
freqs = freqs.unsqueeze(0) |
|
|
|
if "rotary" in self.debug: |
|
if f0 is not None: |
|
key = f"{self._counter}_{f0_theta:.2f}" |
|
if key not in rotary._seen: |
|
if not hasattr(self, '_prev_f0_theta'): |
|
self._prev_f0_theta = f0_theta |
|
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz") |
|
elif abs(self._prev_f0_theta - f0_theta) > 0.0: |
|
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz") |
|
self._prev_f0_theta = f0_theta |
|
rotary._seen.add(key) |
|
self._counter += 1 |
|
return freqs |
|
|
|
@staticmethod |
|
def apply_rotary(x, freqs): |
|
multihead_format = len(freqs.shape) == 4 |
|
if multihead_format: |
|
x1 = x[..., :freqs.shape[-1]*2] |
|
x2 = x[..., freqs.shape[-1]*2:] |
|
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous() |
|
x1 = torch.view_as_complex(x1) |
|
x1 = x1 * freqs |
|
x1 = torch.view_as_real(x1).flatten(-2) |
|
return torch.cat([x1.type_as(x), x2], dim=-1) |
|
else: |
|
x1 = x[..., :freqs.shape[-1]*2] |
|
x2 = x[..., freqs.shape[-1]*2:] |
|
|
|
if x.ndim == 2: |
|
x1 = x1.unsqueeze(0) |
|
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous() |
|
x1 = torch.view_as_complex(x1) |
|
x1 = x1 * freqs |
|
x1 = torch.view_as_real(x1).flatten(-2) |
|
x1 = x1.squeeze(0) |
|
return torch.cat([x1.type_as(x), x2], dim=-1) |
|
else: |
|
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous() |
|
x1 = torch.view_as_complex(x1) |
|
x1 = x1 * freqs |
|
x1 = torch.view_as_real(x1).flatten(-2) |
|
return torch.cat([x1.type_as(x), x2], dim=-1) |
|
|
|
class SliceAttention(nn.Module): |
|
def __init__(self, dims, heads, dropout=0.0): |
|
super().__init__() |
|
self.dims = dims |
|
self.heads = heads |
|
self.head_dim = dims // heads |
|
self.scale = self.head_dim ** -0.5 |
|
|
|
self.q_proj = Linear(dims, dims) |
|
self.k_proj = Linear(dims, dims) |
|
self.v_proj = Linear(dims, dims) |
|
self.out_proj = Linear(dims, dims) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
assert dims % heads == 0, f"Dimensions {dims} not divisible by heads {heads}" |
|
|
|
def parallel_slice(self, q, k, v, mask=None): |
|
batch, heads, ctx, dims = q.shape |
|
head_dim = self.head_dim |
|
batch, ctx, dims = q.shape |
|
ctx_len = k.shape[1] |
|
num_heads = dims // head_dim |
|
|
|
scores = torch.zeros(batch, num_heads, ctx, ctx_len, device=q.device) |
|
|
|
for h in range(num_heads): |
|
start_idx = h * head_dim |
|
end_idx = start_idx + head_dim |
|
q_h = q[:, :, start_idx:end_idx] |
|
k_h = k[:, :, start_idx:end_idx] |
|
|
|
scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim) |
|
|
|
if mask is not None: |
|
scores = scores + mask.unsqueeze(0).unsqueeze(0) |
|
|
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
|
output = torch.zeros_like(q) |
|
for h in range(num_heads): |
|
start_idx = h * head_dim |
|
end_idx = start_idx + head_dim |
|
v_h = v[:, :, start_idx:end_idx] |
|
output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h) |
|
return output |
|
|
|
def forward(self, x, context=None, mask=None): |
|
batch, ctx, _ = x.shape |
|
if context is None: |
|
context = x |
|
|
|
ctx_len = context.shape[1] |
|
q = self.q_proj(x) |
|
k = self.k_proj(context) |
|
v = self.v_proj(context) |
|
output = torch.zeros_like(q) |
|
|
|
for h in range(self.heads): |
|
start_idx = h * self.head_dim |
|
end_idx = start_idx + self.head_dim |
|
|
|
q_h = q[:, :, start_idx:end_idx] |
|
k_h = k[:, :, start_idx:end_idx] |
|
v_h = v[:, :, start_idx:end_idx] |
|
|
|
attn_scores = torch.bmm(q_h, k_h.transpose(1, 2)) * self.scale |
|
if mask is not None: |
|
attn_scores = attn_scores + mask[:ctx, :ctx_len].unsqueeze(0) |
|
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
attn_weights = self.dropout(attn_weights) |
|
head_output = torch.bmm(attn_weights, v_h) |
|
output[:, :, start_idx:end_idx] = head_output |
|
return self.out_proj(output) |
|
|
|
def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001): |
|
|
|
batch, heads, ctx, dims = q.shape |
|
token_ids = k[:, :, :, 0] |
|
is_padding = (token_ids.float() == pad_token).unsqueeze(-2) |
|
log_scale_factor = -10.0 |
|
attn_mask = torch.zeros((batch, heads, ctx, ctx), device=q.device) |
|
|
|
if mask is not None: |
|
attn_mask = attn_mask + mask.unsqueeze(0).unsqueeze(0) |
|
attn_mask = torch.where(is_padding, |
|
torch.tensor(log_scale_factor, device=q.device), |
|
attn_mask) |
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=attn_mask, |
|
dropout_p=0.0, is_causal=False) |
|
attn_output = attn_output.permute(0, 2, 1, 3).flatten(start_dim=2) |
|
return attn_output |
|
|
|
class MultiheadA(nn.Module): |
|
_seen = set() |
|
rbf = False |
|
def __init__(self, dims: int, head: int, rotary_emb: bool = False, |
|
zero_val: float = 0.0001, minz: float = 0.0, maxz: float = 0.001, debug: List[str] = [], optim_attn=False): |
|
|
|
super(MultiheadA, self).__init__() |
|
|
|
self.debug = debug |
|
self.pad_token = 0 |
|
self.dims = dims |
|
self.head = head |
|
self.head_dim = dims // head |
|
self.rotary_emb = rotary_emb |
|
self.minz = minz |
|
self.maxz = maxz |
|
self.zero_val = zero_val |
|
self.optim_attn = optim_attn |
|
self._counter = 0 |
|
if dims % head != 0: |
|
raise ValueError(f"Dimensions {dims} must be divisible by number of heads {head}.") |
|
if zero_val < minz or zero_val > maxz: |
|
raise ValueError(f"Zero value {zero_val} must be between {minz} and {maxz}.") |
|
|
|
self.q = Linear(dims, dims) |
|
self.k = Linear(dims, dims, bias=False) |
|
self.v = Linear(dims, dims) |
|
self.o = Linear(dims, dims) |
|
self.fzero = nn.Parameter(torch.tensor(zero_val, dtype=torch.float32), requires_grad=True) |
|
|
|
if rotary_emb: |
|
self.rope = rotary( |
|
dims=self.head_dim, |
|
debug = debug, |
|
max_ctx=1500, |
|
) |
|
else: |
|
self.rope = None |
|
|
|
def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0): |
|
scale = (self.dims // self.head) ** -0.25 |
|
dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale |
|
if rbf_ratio <= 0.0: |
|
return dot_scores |
|
q_norm = q.pow(2).sum(dim=-1, keepdim=True) |
|
k_norm = k.pow(2).sum(dim=-1, keepdim=True) |
|
qk = torch.matmul(q, k.transpose(-1, -2)) |
|
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk |
|
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2)) |
|
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores |
|
|
|
def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, |
|
return_attn: bool = False, f0: Tensor = None) -> tuple: |
|
|
|
batch, ctx, dims = x.shape |
|
scale = (self.dims // self.head) ** -0.25 |
|
|
|
z = default(xa, x) |
|
q = self.q(x).to(x.dtype) |
|
k = self.k(z).to(x.dtype) |
|
v = self.v(z).to(x.dtype) |
|
|
|
if self.rotary_emb: |
|
if f0 is not None: |
|
qf = self.rope(q.size(1), f0=f0) |
|
kf = self.rope(k.size(1), f0=f0) |
|
else: |
|
qf = self.rope(q.size(1)) |
|
kf = self.rope(k.size(1)) |
|
|
|
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3) |
|
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3) |
|
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3) |
|
|
|
q = self.rope.apply_rotary(q, qf) |
|
k = self.rope.apply_rotary(k, kf) |
|
|
|
else: |
|
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3) |
|
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3) |
|
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3) |
|
batch, head, ctx, head_dim = q.shape |
|
|
|
if self.optim_attn and not return_attn: |
|
wv = optim_attn(q * scale, k * scale, v, mask=mask, |
|
pad_token=self.pad_token, fzero_val=torch.clamp(F.softplus(self.fzero), self.minz, self.maxz).item()) |
|
return self.o(wv), None |
|
|
|
if self.rbf: |
|
qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3) |
|
|
|
qk = (q * scale) @ (k * scale).transpose(-1, -2) |
|
if f0 is not None and self.rope.use_pbias: |
|
pbias = self.rope.pbias(f0) |
|
if pbias is not None: |
|
qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]] |
|
token_ids = k[:, :, :, 0] |
|
zscale = torch.ones_like(token_ids) |
|
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz) |
|
zscale[token_ids.float() == self.pad_token] = fzero.to(q.device, q.dtype) |
|
|
|
if mask is not None: |
|
mask = mask[:q.shape[2], :q.shape[2]] |
|
qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape) |
|
qk = qk * zscale.unsqueeze(-2) |
|
if return_attn: |
|
return qk, v |
|
w = F.softmax(qk, dim=-1).to(q.dtype) |
|
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) |
|
|
|
if "multihead" in self.debug and self._counter % 100 == 0: |
|
print(f"Step {self._counter}: Using rotary embeddings: {self.rotary_emb}") |
|
print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape}") |
|
print(f"Attention shape: {qk.shape}, wv shape: {wv.shape}") |
|
self._counter += 1 |
|
return self.o(wv), qk.detach() |
|
|
|
class FCGate(nn.Module): |
|
def __init__(self, dims, dim): |
|
super().__init__() |
|
self.proj = Linear(dim, dims // 4) |
|
self.gate = nn.Sequential( |
|
Linear(dims + dims // 4, dims // 2), |
|
nn.SiLU(), |
|
Linear(dims // 2, 1), |
|
nn.Sigmoid() |
|
) |
|
def forward(self, x, embedding): |
|
info = self.proj(embedding) |
|
info = info.unsqueeze(1).expand(-1, x.shape[1], -1) |
|
gate_input = torch.cat([x, info], dim=-1) |
|
return self.gate(gate_input) |
|
|
|
class TTGate(nn.Module): |
|
def __init__(self, dims, num_types=4): |
|
super().__init__() |
|
self.gate_projections = nn.ModuleList([ |
|
nn.Sequential(Linear(dims, 1), nn.Sigmoid()) |
|
for _ in range(num_types)]) |
|
self.type_classifier = nn.Sequential( |
|
Linear(dims, num_types), |
|
nn.Softmax(dim=-1)) |
|
def forward(self, x): |
|
type_probs = self.type_classifier(x) |
|
gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1) |
|
combined_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1) |
|
return combined_gate |
|
|
|
class MGate(nn.Module): |
|
def __init__(self, dims, memory_size=64): |
|
super().__init__() |
|
self.mkey = nn.Parameter(torch.randn(memory_size, dims)) |
|
self.mvalue = nn.Parameter(torch.randn(memory_size, 1)) |
|
self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1)) |
|
|
|
def forward(self, x): |
|
dgate = torch.sigmoid(self.gate_proj(x)) |
|
attention = torch.matmul(x, self.mkey.transpose(0, 1)) |
|
attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1) |
|
mgate = torch.matmul(attention, self.mvalue) |
|
mgate = torch.sigmoid(mgate) |
|
return 0.5 * (dgate + mgate) |
|
|
|
class CMGate(nn.Module): |
|
def __init__(self, dims): |
|
super().__init__() |
|
self.sgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) |
|
self.wgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) |
|
self.pgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) |
|
self.integration = Linear(dims*3, dims) |
|
|
|
def forward(self, x, features): |
|
sfeat = features.get("spectrogram", x) |
|
wfeat = features.get("waveform", x) |
|
pfeat = features.get("pitch", x) |
|
spec = self.sgate(x) * sfeat |
|
wave = self.wgate(x) * wfeat |
|
pitch = self.pgate(x) * pfeat |
|
|
|
combined = torch.cat([spec, wave, pitch], dim=-1) |
|
return self.integration(combined) |
|
|
|
class Residual(nn.Module): |
|
_seen = set() |
|
def __init__(self, dims: int, head: int, ctx, act, cross_attn=True, debug: List[str] = [], |
|
fgate=False, tgate=False, mgate=False, cgate=False, |
|
memory_size=512, features=None): |
|
super().__init__() |
|
self.ctx = ctx |
|
self._counter = 0 |
|
self.dropout = 0.01 |
|
self.dims = dims |
|
self.head = head |
|
self.head_dim = dims // head |
|
self.cross_attn = cross_attn |
|
self.debug = debug |
|
self.fgate = fgate |
|
self.tgate = tgate |
|
self.mgate = mgate |
|
self.cgate = cgate |
|
self.features = features |
|
self.blend = nn.Parameter(torch.tensor(0.5)) |
|
|
|
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), |
|
"tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), |
|
"softplus": nn.Softplus(), "softshrink": nn.Softshrink(), |
|
"leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} |
|
act_fn = act_map.get(act, nn.GELU()) |
|
|
|
self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug) |
|
self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None) |
|
|
|
mlp = dims * 4 |
|
self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims)) |
|
|
|
self.fgate = FCGate(dims=dims, dim=dims) if fgate else None |
|
self.tgate = TTGate(dims=dims, num_types=4) if tgate else None |
|
self.mgate = MGate(dims=dims, memory_size=memory_size) if mgate else None |
|
self.cgate = CMGate(dims=dims) if cgate else None |
|
|
|
self.lna = RMSNorm(dims) |
|
self.lnb = RMSNorm(dims) if cross_attn else None |
|
self.lnc = RMSNorm(dims) |
|
|
|
if not any([fgate, tgate, mgate, cgate]): |
|
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) |
|
|
|
def forward(self, x, xa=None, mask=None, f0=None, mode=None): |
|
x = x + self.attna(self.lna(x), mask=mask, f0=f0)[0] |
|
|
|
if self.attnb and xa is not None: |
|
cross = self.attnb(self.lnb(x), xa, f0=f0, mask=None)[0] |
|
blend = torch.sigmoid(self.blend) |
|
x = blend * x + (1 - blend) * cross |
|
|
|
normx = self.lnc(x) |
|
mlp_out = self.mlp(normx) |
|
|
|
if self.tgate: |
|
gate = self.tgate(normx) |
|
x = x + gate * mlp_out |
|
|
|
elif self.fgate: |
|
embedding = f0.mean(dim=1) if f0 is not None else xa.mean(dim=1) |
|
gate = self.fg(normx, embedding) |
|
x = x + gate * mlp_out |
|
|
|
elif self.mgate: |
|
gate = self.mgate(normx) |
|
x = x + gate * mlp_out |
|
|
|
elif self.cgate and mode is not None: |
|
gate_output = self.cgate(normx, self.features) |
|
x = x + gate_output |
|
|
|
else: |
|
if hasattr(self, 'mlp_gate'): |
|
mlp_gate = self.mlp_gate(normx) |
|
x = x + mlp_gate * mlp_out |
|
else: |
|
x = x + mlp_out |
|
if "residual" in self.debug and self._counter % 100 == 0: |
|
print(f"Step {self._counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}") |
|
self._counter += 1 |
|
return x |
|
|
|
class PEncoder(nn.Module): |
|
def __init__(self, input_dims, dims, head, layer, kernel_size, act): |
|
super().__init__() |
|
|
|
self.head_dim = dims // head |
|
self.dropout = 0.01 |
|
|
|
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} |
|
act_fn = act_map.get(act, nn.GELU()) |
|
|
|
self.encoder = nn.Sequential( |
|
Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn, |
|
Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn, |
|
Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn) |
|
|
|
def forward(self, x, f0=None): |
|
x = self.encoder(x).permute(0, 2, 1) |
|
x = x + self.positional(x.shape[1]).to(x.device, x.dtype) |
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training) |
|
x = self.norm(x) |
|
return x |
|
|
|
class WEncoder(nn.Module): |
|
def __init__(self, input_dims, dims, head, layer, kernel_size, act): |
|
super().__init__() |
|
|
|
self.head_dim = dims // head |
|
self.dropout = 0.01 |
|
|
|
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} |
|
act_fn = act_map.get(act, nn.GELU()) |
|
|
|
self.downsample = nn.Sequential( |
|
Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn, |
|
Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn, |
|
Conv1d(dims//4, dims, kernel_size=9, stride=5, padding=4), act_fn) |
|
|
|
self.encoder = nn.Sequential( |
|
Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn, |
|
Conv1d(dims, dims, kernel_size=1), act_fn) |
|
|
|
self.positional = lambda length: sinusoids(length, dims) |
|
self.norm = RMSNorm(dims) |
|
|
|
def forward(self, x, f0=None): |
|
x = self.downsample(x) |
|
x = self.encoder(x) |
|
x = x.permute(0, 2, 1) |
|
x = x + self.positional(x.shape[1]).to(x.device, x.dtype) |
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training) |
|
return self.norm(x) |
|
|
|
class FEncoder(nn.Module): |
|
def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1): |
|
super().__init__() |
|
|
|
self.head_dim = dims // head |
|
self.dropout = 0.01 |
|
|
|
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} |
|
act_fn = act_map.get(act, nn.GELU()) |
|
|
|
self.encoder = nn.Sequential( |
|
Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn, |
|
Conv1d(dims, dims, kernel_size=5, padding=2), act_fn, |
|
Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn) |
|
|
|
self.positional = lambda length: sinusoids(length, dims) |
|
self.norm = RMSNorm(dims) |
|
self._norm = RMSNorm(dims) |
|
|
|
def forward(self, x, f0=None): |
|
x = self.encoder(x).permute(0, 2, 1) |
|
x = x + self.positional(x.shape[1]).to(x.device, x.dtype) |
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training) |
|
x = self._norm(x) |
|
return x |
|
|
|
class AudioEncoder(nn.Module): |
|
_seen = set() |
|
def __init__(self, mels: int, layer: int, dims: int, head: int, ctx: int, features: List[str], |
|
debug: List[str], f0_rotary: bool = False, act: str = "gelu"): |
|
super(AudioEncoder, self).__init__() |
|
|
|
self.debug = debug |
|
self.features = features |
|
self._counter = 0 |
|
self.dropout = 0.01 |
|
self.f0_rotary = f0_rotary |
|
self.dims = dims |
|
self.ctx = ctx |
|
self.head = head |
|
self.head_dim = dims // head |
|
|
|
self.rope = rotary( |
|
dims=self.head_dim, |
|
debug=debug,) |
|
|
|
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} |
|
act_fn = act_map.get(act, nn.GELU()) |
|
|
|
if features == ["spectrogram", "waveform", "pitch"]: |
|
cgate=True |
|
else: |
|
cgate = False |
|
|
|
self.blocks = nn.ModuleDict({ |
|
"spectrogram": nn.ModuleList( |
|
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] + |
|
[Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "spectrogram" in features else None |
|
), |
|
"waveform": nn.ModuleList( |
|
[WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] + |
|
[Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "waveform" in features else None |
|
), |
|
"pitch": nn.ModuleList( |
|
[FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] + |
|
[Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "pitch" in features else None |
|
), |
|
"spec_envelope": nn.ModuleList( |
|
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] + |
|
[Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_envelope" in features else None |
|
), |
|
"spec_phase": nn.ModuleList( |
|
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] + |
|
[Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None), |
|
}) |
|
|
|
def forward(self, x, f0=None): |
|
outputs = {} |
|
if self.f0_rotary: |
|
f0 = f0 if f0 is not None else x.get("pitch") |
|
else: |
|
f0 = None |
|
for y in self.features: |
|
if y in x and y in self.blocks: |
|
f = x[y] |
|
for block in self.blocks[y]: |
|
f = block(f, f0=f0) |
|
outputs[y] = f |
|
|
|
if "encoder" in self.debug and self._counter % 100 == 0: |
|
names = list(x.keys()) |
|
shapes = {k: v.shape for k, v in x.items()} |
|
print(f"Step {self._counter}: mode: {names}") |
|
print(f"shapes: {shapes}") |
|
self._counter += 1 |
|
return outputs |
|
|
|
class TextDecoder(nn.Module): |
|
def __init__(self, vocab: int, layer: int, dims: int, head: int, ctx: int, cross_attn: bool, |
|
features: List[str], debug: List[str], f0_rotary: bool = False, sequential=False): |
|
super(TextDecoder, self).__init__() |
|
|
|
self._counter = 0 |
|
self.dropout = 0.01 |
|
self.debug = debug |
|
self.sequential = sequential |
|
self.features = features |
|
self.f0_rotary = f0_rotary |
|
|
|
self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims) |
|
with torch.no_grad(): |
|
self.token.weight[0].zero_() |
|
self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True) |
|
|
|
self._blocks = nn.ModuleList([ |
|
Residual(dims=dims, head=head, ctx=ctx, act="gelu", cross_attn=cross_attn, debug=debug, features=features) |
|
for _ in range(layer)]) |
|
|
|
self.blocks = nn.ModuleDict({ |
|
f: nn.ModuleList([Residual(dims=dims, head=head, ctx=ctx, act="gelu", cross_attn=cross_attn, debug=debug, features=features) |
|
for _ in range(layer)]) for f in features}) |
|
|
|
self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features}) |
|
|
|
self.ln_dec = RMSNorm(dims) |
|
|
|
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0) |
|
self.register_buffer("mask", mask, persistent=False) |
|
|
|
def forward(self, x, enc, order=None, f0=None) -> Tensor: |
|
x = x.to(device) |
|
if self.f0_rotary: |
|
f0 = f0 |
|
else: |
|
f0 = None |
|
if order is None: |
|
order = self.features |
|
mask = self.mask[:x.shape[1], :x.shape[1]] |
|
x = self.token(x) + self.positional[:x.shape[1]] |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
for block in self._blocks: |
|
x = block(x, f0=f0, mask=mask) |
|
for f in order: |
|
if f in enc: |
|
xa = enc[f] |
|
for block in self.blocks[f]: |
|
out = block(x=x, xa=xa, f0=f0, mask=None) |
|
a = torch.sigmoid(self.blend[f]) |
|
x = a * out + (1 - a) * x |
|
x = self.ln_dec(x) |
|
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float() |
|
|
|
class Echo(nn.Module): |
|
def __init__(self, param: Dimensions): |
|
super().__init__() |
|
self.param = param |
|
|
|
self.encoder = AudioEncoder( |
|
mels=param.mels, |
|
ctx=param.aud_ctx, |
|
dims=param.aud_dims, |
|
head=param.aud_head, |
|
layer=param.aud_idx, |
|
act=param.act, |
|
debug=param.debug, |
|
features=param.features, |
|
f0_rotary=param.f0_rotary, |
|
) |
|
|
|
self.decoder = TextDecoder( |
|
vocab=param.vocab, |
|
ctx=param.text_ctx, |
|
dims=param.text_dims, |
|
head=param.text_head, |
|
layer=param.text_idx, |
|
cross_attn=param.cross_attn, |
|
debug=param.debug, |
|
features=param.features, |
|
f0_rotary=param.f0_rotary, |
|
) |
|
|
|
all_head = torch.zeros(self.param.text_idx, self.param.text_head, dtype=torch.bool) |
|
all_head[self.param.text_idx // 2 :] = True |
|
self.register_buffer("alignment_head", all_head.to_sparse(), persistent=False) |
|
|
|
def set_alignment_head(self, dump: bytes): |
|
array = np.frombuffer( |
|
gzip.decompress(base64.b85decode(dump)), dtype=bool).copy() |
|
mask = torch.from_numpy(array).reshape( |
|
self.param.text_idx, self.param.text_head) |
|
self.register_buffer("alignment_head", mask.to_sparse(), persistent=False) |
|
|
|
def embed_audio(self, spectrogram: torch.Tensor): |
|
return self.encoder(spectrogram) |
|
|
|
def logits(self,input_ids: torch.Tensor, encoder_output: torch.Tensor): |
|
return self.decoder(input_ids, encoder_output) |
|
|
|
def forward(self, |
|
decoder_input_ids=None, |
|
labels=None, |
|
waveform: Optional[torch.Tensor]=None, |
|
input_ids=None, |
|
spectrogram: torch.Tensor=None, |
|
pitch: Optional[torch.Tensor]=None, |
|
f0: Optional[torch.Tensor]=None, |
|
envelope: Optional[torch.Tensor]=None, |
|
phase: Optional[torch.Tensor]=None, |
|
) -> Dict[str, torch.Tensor]: |
|
|
|
decoder_input_ids = input_ids |
|
encoder_inputs = {} |
|
if spectrogram is not None: |
|
encoder_inputs["spectrogram"] = spectrogram |
|
if waveform is not None: |
|
encoder_inputs["waveform"] = waveform |
|
if pitch is not None: |
|
encoder_inputs["pitch"] = pitch |
|
if envelope is not None: |
|
encoder_inputs["envelope"] = envelope |
|
if phase is not None: |
|
encoder_inputs["phase"] = phase |
|
|
|
encoder_outputs = self.encoder(encoder_inputs, f0=f0) |
|
logits = self.decoder(input_ids, encoder_outputs, f0=f0) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = F.cross_entropy( |
|
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0) |
|
|
|
return { |
|
"logits": logits, |
|
"loss": loss, |
|
"labels": labels, |
|
"input_ids": input_ids, |
|
"decoder_input_ids": decoder_input_ids, |
|
"encoder_output": encoder_outputs |
|
} |
|
|
|
def device(self): |
|
return next(self.parameters()).device |
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
def _init_weights(self, module): |
|
std = 0.02 |
|
self.init_counts = { |
|
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0, |
|
"Conv2d": 0, "SEBlock": 0, "TextDecoder": 0, "AudioEncoder": 0, |
|
"Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0, |
|
"MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0, |
|
"WEncoder": 0, "PEncoder": 0} |
|
|
|
for module in self.named_modules(): |
|
if isinstance(module, Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
self.init_counts["Linear"] += 1 |
|
elif isinstance(module, Conv1d): |
|
nn.init.normal_(module.weight, mean=0.0, std=std) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
self.init_counts["Conv1d"] += 1 |
|
|
|
elif isinstance(module, RMSNorm): |
|
nn.init.ones_(module.weight) |
|
self.init_counts["RMSNorm"] += 1 |
|
elif isinstance(module, MultiheadA): |
|
self.init_counts["MultiheadA"] += 1 |
|
elif isinstance(module, Conv2d): |
|
nn.init.normal_(module.weight, mean=0.0, std=std) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
self.init_counts["Conv2d"] += 1 |
|
|
|
elif isinstance(module, TextDecoder): |
|
self.init_counts["TextDecoder"] += 1 |
|
elif isinstance(module, AudioEncoder): |
|
self.init_counts["AudioEncoder"] += 1 |
|
elif isinstance(module, Residual): |
|
self.init_counts["Residual"] += 1 |
|
|
|
def init_weights(self): |
|
print("Initializing all weights") |
|
self.apply(self._init_weights) |
|
print("Initialization summary:") |
|
for module_type, count in self.init_counts.items(): |
|
if count > 0: |
|
print(f"{module_type}: {count}") |
|
|
|
metric = evaluate.load(path="wer") |
|
|
|
@dataclass |
|
class DataCollator: |
|
tokenizer: Any |
|
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
|
pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0 |
|
bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1 |
|
|
|
batch = {} |
|
|
|
if "spectrogram" in features[0] and features[0]["spectrogram"] is not None: |
|
spectrogram_list = [f["spectrogram"] for f in features] |
|
max_len_feat = max(f.shape[-1] for f in spectrogram_list) |
|
pad_spectrogram = [] |
|
for feat in spectrogram_list: |
|
current_len = feat.shape[-1] |
|
padding = max_len_feat - current_len |
|
if padding > 0: |
|
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_feat = feat |
|
pad_spectrogram.append(pad_feat) |
|
batch["spectrogram"] = torch.stack(pad_spectrogram) |
|
|
|
if "waveform" in features[0] and features[0]["waveform"] is not None: |
|
waveform_list = [f["waveform"] for f in features] |
|
max_len_wav = max(w.shape[-1] for w in waveform_list) |
|
pad_waveforms = [] |
|
for wav in waveform_list: |
|
current_len = wav.shape[-1] |
|
padding = max_len_wav - current_len |
|
if padding > 0: |
|
if wav.ndim == 1: |
|
wav = wav.unsqueeze(0) |
|
pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_wav = wav |
|
pad_waveforms.append(pad_wav) |
|
batch["waveform"] = torch.stack(pad_waveforms) |
|
|
|
if "label" in features[0] and features[0]["label"] is not None: |
|
labels_list = [f["label"] for f in features] |
|
max_len = max(len(l) for l in labels_list) |
|
all_ids = [] |
|
all_labels = [] |
|
|
|
for label in labels_list: |
|
label_list = label.tolist() if isinstance(label, torch.Tensor) else label |
|
decoder_input = [bos_token_id] + label_list |
|
label_eos = label_list + [pad_token_id] |
|
input_len = max_len + 1 - len(decoder_input) |
|
label_len = max_len + 1 - len(label_eos) |
|
padded_input = decoder_input + [pad_token_id] * input_len |
|
padded_labels = label_eos + [pad_token_id] * label_len |
|
all_ids.append(padded_input) |
|
all_labels.append(padded_labels) |
|
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long) |
|
batch["labels"] = torch.tensor(all_labels, dtype=torch.long) |
|
|
|
if "pitch" in features[0] and features[0]["pitch"] is not None: |
|
pitch_list = [f["pitch"] for f in features] |
|
max_len_pitch = max(e.shape[-1] for e in pitch_list) |
|
pad_pitch = [] |
|
for pitch in pitch_list: |
|
current_len = pitch.shape[-1] |
|
padding = max_len_pitch - current_len |
|
if padding > 0: |
|
pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_pitch_item = pitch |
|
pad_pitch.append(pad_pitch_item) |
|
batch["pitch"] = torch.stack(pad_pitch) |
|
|
|
if "f0" in features[0] and features[0]["f0"] is not None: |
|
all_f0 = torch.cat([f["f0"] for f in features]) |
|
batch["f0"] = all_f0.unsqueeze(0) |
|
|
|
if "envelope" in features[0] and features[0]["envelope"] is not None: |
|
env_list = [f["envelope"] for f in features] |
|
max_len = max(f.shape[-1] for f in env_list) |
|
pad_env = [] |
|
for feat in env_list: |
|
current_len = feat.shape[-1] |
|
padding = max_len_feat - current_len |
|
if padding > 0: |
|
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_feat = feat |
|
pad_env.append(pad_feat) |
|
batch["envelope"] = torch.stack(pad_env) |
|
|
|
if "phase" in features[0] and features[0]["phase"] is not None: |
|
ph_list = [f["phase"] for f in features] |
|
max_len = max(f.shape[-1] for f in ph_list) |
|
pad_ph = [] |
|
for feat in ph_list: |
|
current_len = feat.shape[-1] |
|
padding = max_len_feat - current_len |
|
if padding > 0: |
|
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) |
|
else: |
|
pad_feat = feat |
|
pad_ph.append(pad_feat) |
|
batch["phase"] = torch.stack(pad_ph) |
|
return batch |
|
|
|
def hilbert_transform(x): |
|
N = x.shape[-1] |
|
xf = torch.fft.rfft(x) |
|
h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype) |
|
if N % 2 == 0: |
|
h[0] = h[N//2] = 1 |
|
h[1:N//2] = 2 |
|
else: |
|
h[0] = 1 |
|
h[1:(N+1)//2] = 2 |
|
return torch.fft.irfft(xf * h, n=N) |
|
|
|
def analytic_signal(x): |
|
return x + 1j * hilbert_transform(x) |
|
|
|
def hilbert_transform_2d(x, dim=-1): |
|
N = x.shape[dim] |
|
if dim == -1 or dim == len(x.shape) - 1: |
|
xf = torch.fft.rfft(x) |
|
else: |
|
xf = torch.fft.rfft(x, dim=dim) |
|
h_shape = [1] * len(x.shape) |
|
h_shape[dim] = N // 2 + 1 |
|
h = torch.zeros(h_shape, device=x.device, dtype=x.dtype) |
|
if dim == -1 or dim == len(x.shape) - 1: |
|
if N % 2 == 0: |
|
h[..., 0] = h[..., -1] = 1 |
|
h[..., 1:-1] = 2 |
|
else: |
|
h[..., 0] = 1 |
|
h[..., 1:] = 2 |
|
else: |
|
pass |
|
return torch.fft.irfft(xf * h, n=N, dim=dim) |
|
|
|
def hilbert_transform_true_2d(x): |
|
xf = torch.fft.rfft2(x) |
|
h1, h2 = torch.meshgrid( |
|
torch.fft.rfftfreq(x.shape[-2]) * 2 - 1, |
|
torch.fft.rfftfreq(x.shape[-1]) * 2 - 1, |
|
indexing='ij') |
|
h = -1j / (math.pi * (h1 + 1j*h2)) |
|
h[0, 0] = 0 |
|
return torch.fft.irfft2(xf * h.to(x.device)) |
|
|
|
def process_spectrogram_with_hilbert(spec): |
|
analytic = spec + 1j * hilbert_transform(spec) |
|
envelope = torch.abs(analytic) |
|
phase = torch.angle(analytic) |
|
return envelope, phase |
|
|
|
def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False, |
|
hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000, |
|
pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk", |
|
norm=None, normalized=False, downsamples=False, period=False, hilbert=False): |
|
|
|
dtype = torch.float32 |
|
device = torch.device("cuda:0") |
|
audio = batch["audio"] |
|
sampling_rate = audio["sampling_rate"] |
|
|
|
wav = torch.tensor(audio["array"]).float() |
|
sr = audio["sampling_rate"] |
|
|
|
if sr != sampling_rate: |
|
original_length = wav.shape[-1] |
|
target_length = int(original_length * (sampling_rate / sr)) |
|
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sampling_rate) |
|
wav = resampler(wav) |
|
|
|
if abs(wav.shape[-1] - target_length) > 1: |
|
new_waveform = torch.zeros((wav.shape[0], target_length), dtype=dtype, device=device) |
|
copy_length = min(wav.shape[1], target_length) |
|
new_waveform[:, :copy_length] = wav[:, :copy_length] |
|
wav = new_waveform |
|
|
|
if spectrogram: |
|
transform = torchaudio.transforms.MelSpectrogram( |
|
f_max=fmax, |
|
f_min=fmin, |
|
n_mels=n_mels, |
|
sample_rate=sr, |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
norm=norm, |
|
normalized=normalized, |
|
power=power, |
|
center=center, |
|
mel_scale=mel_scale, |
|
window_fn=window_fn, |
|
pad_mode=pad_mode) |
|
|
|
mel_spectrogram = transform(wav) |
|
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10() |
|
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0) |
|
spec = (log_mel + 4.0) / 4.0 |
|
spec = torch.tensor(spec) |
|
batch["spectrogram"] = spec |
|
|
|
if hilbert: |
|
envelope_list = [] |
|
phase_list = [] |
|
|
|
for ch_idx in range(spec.shape[0]): |
|
envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx]) |
|
envelope_list.append(envelope) |
|
phase_list.append(phase) |
|
|
|
batch["envelope"] = torch.stack(envelope_list) |
|
batch["phase"] = torch.stack(phase_list) |
|
|
|
wav_1d = wav.unsqueeze(0) |
|
|
|
if waveforms: |
|
batch["waveform"] = wav_1d |
|
|
|
if pitch: |
|
if period: |
|
pit, periodocity = torchcrepe.predict( |
|
wav_1d, |
|
sampling_rate, |
|
hop_length, |
|
fmin=80, |
|
fmax=800, |
|
model="tiny", |
|
decoder=torchcrepe.decode.viterbi, |
|
return_periodicity=True, |
|
device=device, |
|
pad=True |
|
) |
|
batch["pitch"] = pit |
|
batch["period"] = periodocity |
|
else: |
|
pit = torchcrepe.predict( |
|
wav_1d, |
|
sampling_rate, |
|
hop_length, |
|
fmin=80, |
|
fmax=800, |
|
model="tiny", |
|
decoder=torchcrepe.decode.viterbi, |
|
return_periodicity=False, |
|
device=device, |
|
pad=True |
|
) |
|
batch["pitch"] = pit |
|
|
|
if f0: |
|
wav_np = wav.numpy().astype(np.float64) |
|
f0, t = pw.dio(wav_np, sampling_rate, |
|
frame_period=hop_length/sampling_rate*1000) |
|
f0 = pw.stonemask(wav_np, f0, t, sampling_rate) |
|
batch["f0"] = torch.from_numpy(f0).float() |
|
|
|
if spectrogram and waveforms and pitch: |
|
spec_mean = batch["spectrogram"].mean() |
|
spec_std = batch["spectrogram"].std() + 1e-6 |
|
batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std |
|
|
|
wav_mean = batch["waveform"].mean() |
|
wav_std = batch["waveform"].std() + 1e-6 |
|
batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std |
|
|
|
if batch["pitch"].max() > 1.0: |
|
pitch_min = 50.0 |
|
pitch_max = 600.0 |
|
batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min) |
|
|
|
batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False) |
|
return batch |
|
|
|
def compute_metrics(eval_pred, compute_result: bool = True, |
|
print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None): |
|
|
|
pred_logits = eval_pred.predictions |
|
label_ids = eval_pred.label_ids |
|
|
|
if hasattr(pred_logits, "cpu"): |
|
pred_logits = pred_logits.cpu() |
|
if hasattr(label_ids, "cpu"): |
|
label_ids = label_ids.cpu() |
|
if isinstance(pred_logits, tuple): |
|
pred_ids = pred_logits[0] |
|
else: |
|
pred_ids = pred_logits |
|
if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3: |
|
if not isinstance(pred_ids, torch.Tensor): |
|
pred_ids = torch.tensor(pred_ids) |
|
pred_ids = pred_ids.argmax(dim=-1) |
|
pred_ids = pred_ids.tolist() |
|
|
|
if hasattr(label_ids, "tolist"): |
|
label_ids = label_ids.tolist() |
|
|
|
label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids] |
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False) |
|
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False) |
|
|
|
if print_pred: |
|
for i in range(min(num_samples, len(pred_str))): |
|
print(f"Preds: {pred_str[i]}") |
|
print(f"Label: {label_str[i]}") |
|
print(f"preds: {pred_ids[i]}") |
|
print(f"label: {label_ids[i]}") |
|
print("--------------------------------") |
|
|
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
wer = 100 * metric.compute(predictions=pred_str, references=label_str) |
|
|
|
if model is None: |
|
global global_model |
|
if 'global_model' in globals(): |
|
model = global_model |
|
|
|
if model is not None: |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000 |
|
if trainable_params > 0: |
|
efficiency_score = (100 - wer) / trainable_params |
|
else: |
|
print("Warning: Zero trainable parameters detected") |
|
efficiency_score = 0.0 |
|
else: |
|
print("Warning: Model not available for parameter counting") |
|
trainable_params = 0.0 |
|
efficiency_score = 0.0 |
|
|
|
if hasattr(wer, "item"): |
|
wer = wer.item() |
|
|
|
metrics = { |
|
"wer": float(wer), |
|
"trainable_params_M": float(trainable_params), |
|
"efficiency_score": float(efficiency_score), |
|
} |
|
|
|
print(f"Computed metrics: WER={wer:.2f}%, Params={trainable_params:.2f}M, Efficiency={efficiency_score:.4f}") |
|
return metrics |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def create_model(param: Dimensions) -> Echo: |
|
model = Echo(param).to('cuda') |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
logger.info(f"Trainable parameters: {trainable_params:,}") |
|
logger.info(f"Total parameters: {total_params:,}") |
|
print(f"Trainable parameters: {trainable_params:,}") |
|
print(f"Total parameters: {total_params:,}") |
|
model.init_weights() |
|
return model |
|
|
|
def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"): |
|
from tokenizers import Tokenizer |
|
tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json") |
|
orig_encode = tokenizer.encode |
|
def enc(text, add_special_tokens=True): |
|
ids = orig_encode(text).ids |
|
if not add_special_tokens: |
|
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]] |
|
ids = [id for id in ids if id not in sp_ids] |
|
return ids |
|
def bdec(ids_list, skip_special_tokens=True): |
|
results = [] |
|
for ids in ids_list: |
|
if skip_special_tokens: |
|
ids = [id for id in ids if id not in [0, 1, 2]] |
|
results.append(tokenizer.decode(ids)) |
|
return results |
|
def save_pretrained(save_dir): |
|
os.makedirs(save_dir, exist_ok=True) |
|
tokenizer.save(f"{save_dir}/tokenizer.json") |
|
tokenizer.encode = enc |
|
tokenizer.batch_decode = bdec |
|
tokenizer.save_pretrained = save_pretrained |
|
tokenizer.pad_token_id = 0 |
|
tokenizer.bos_token_id = 1 |
|
tokenizer.eos_token_id = 2 |
|
return tokenizer |
|
|
|
def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]: |
|
if dataset_config is None: |
|
dataset_config = { |
|
"spectrogram": True, |
|
"waveforms": True, |
|
"pitch": True, |
|
"f0": True, |
|
"downsamples": True, |
|
"hop_length": 128, |
|
"fmin": 50, |
|
"fmax": 2000, |
|
"n_mels": 128, |
|
"n_fft": 1024, |
|
"sampling_rate": 16000, |
|
} |
|
|
|
dataset = load_dataset( |
|
"google/fleurs", |
|
"en_us", |
|
token=token, |
|
trust_remote_code=True, |
|
streaming=False |
|
) |
|
dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)) |
|
|
|
if sanity_check: |
|
dataset = dataset["test"].take(10).shuffle() |
|
dataset = dataset.select_columns(["audio", "transcription"]) |
|
logger.info(f"Sanity dataset size: {dataset.num_rows}") |
|
print(f"Sanity dataset size: {dataset.num_rows}") |
|
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config) |
|
|
|
dataset = dataset.map( |
|
function=prepare_fn, |
|
remove_columns=["audio", "transcription"] |
|
).with_format(type="torch") |
|
train_dataset = dataset |
|
test_dataset = dataset |
|
else: |
|
def filter_func(x): |
|
return (0 < len(x["transcription"]) < 512 and |
|
len(x["audio"]["array"]) > 0 and |
|
len(x["audio"]["array"]) < 1500 * 160) |
|
|
|
dataset = dataset.filter(filter_func).shuffle() |
|
logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}") |
|
print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}") |
|
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config) |
|
train_dataset = dataset["train"] |
|
test_dataset = dataset["test"] |
|
columns_to_remove = list(next(iter(dataset.values())).features) |
|
|
|
train_dataset = train_dataset.map( |
|
function=prepare_fn, |
|
remove_columns=columns_to_remove |
|
).with_format(type="torch") |
|
|
|
test_dataset = test_dataset.map( |
|
function=prepare_fn, |
|
remove_columns=columns_to_remove |
|
).with_format(type="torch") |
|
|
|
return train_dataset, test_dataset |
|
|
|
def get_training_args( |
|
log_dir: str, |
|
batch_eval_metrics: bool = False, |
|
max_steps: int = 10, |
|
save_steps: int = 1000, |
|
eval_steps: int = 100, |
|
warmup_steps: int = 0, |
|
num_train_epochs: int = 1, |
|
logging_steps: int = 10, |
|
eval_on_start: bool = False, |
|
learning_rate: float = 1e-4, |
|
weight_decay: float = 0.01, |
|
max_grad_norm: float = 1.0, |
|
) -> Seq2SeqTrainingArguments: |
|
|
|
return Seq2SeqTrainingArguments( |
|
output_dir=log_dir, |
|
per_device_train_batch_size=1, |
|
per_device_eval_batch_size=1, |
|
gradient_accumulation_steps=1, |
|
eval_accumulation_steps=1, |
|
tf32=True, |
|
bf16=True, |
|
eval_strategy="steps", |
|
save_strategy="steps", |
|
max_steps=max_steps, |
|
save_steps=save_steps, |
|
eval_steps=eval_steps, |
|
warmup_steps=warmup_steps, |
|
num_train_epochs=num_train_epochs, |
|
logging_steps=logging_steps, |
|
logging_dir=log_dir, |
|
logging_strategy="steps", |
|
report_to=["tensorboard"], |
|
push_to_hub=False, |
|
disable_tqdm=False, |
|
save_total_limit=1, |
|
label_names=["labels"], |
|
optim="adamw_torch", |
|
lr_scheduler_type="cosine", |
|
learning_rate=learning_rate, |
|
weight_decay=weight_decay, |
|
save_safetensors=False, |
|
eval_on_start=eval_on_start, |
|
batch_eval_metrics=batch_eval_metrics, |
|
max_grad_norm=max_grad_norm, |
|
|
|
) |
|
|
|
def main(): |
|
|
|
token = "" |
|
log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H')) |
|
os.makedirs(name=log_dir, exist_ok=True) |
|
tokenizer = setup_tokenizer(token) |
|
|
|
def sanity(sanity: bool): |
|
|
|
if sanity: |
|
training_args = get_training_args( |
|
log_dir, |
|
batch_eval_metrics = False, |
|
max_steps = 10, |
|
save_steps = 0, |
|
eval_steps = 1, |
|
warmup_steps = 0, |
|
logging_steps = 1, |
|
eval_on_start = True, |
|
learning_rate = 5e-6, |
|
weight_decay = 0.01, |
|
) |
|
else: |
|
training_args = get_training_args( |
|
log_dir, |
|
batch_eval_metrics = False, |
|
max_steps = 10000, |
|
save_steps = 10000, |
|
eval_steps = 1000, |
|
warmup_steps = 1000, |
|
logging_steps = 100, |
|
eval_on_start = False, |
|
learning_rate = 2.5e-4, |
|
weight_decay = 0.01, |
|
) |
|
|
|
return training_args |
|
|
|
param = Dimensions( |
|
mels=128, |
|
aud_ctx=1500, |
|
aud_head=4, |
|
aud_dims=512, |
|
aud_idx=4, |
|
vocab=40000, |
|
text_ctx=512, |
|
text_head=4, |
|
text_dims=512, |
|
text_idx=4, |
|
act="swish", |
|
debug={}, |
|
cross_attn=True, |
|
f0_rotary=True, |
|
features = ["spectrogram"], |
|
) |
|
|
|
sanity_check = False |
|
|
|
training_args = sanity(sanity_check) |
|
|
|
dataset_config = { |
|
"spectrogram": True, |
|
"waveforms": False, |
|
"pitch": False, |
|
"downsamples": False, |
|
"f0": True, |
|
"hilbert": False, |
|
"hop_length": 128, |
|
"fmin": 150, |
|
"fmax": 2000, |
|
"n_mels": 128, |
|
"n_fft": 1024, |
|
"sampling_rate": 16000, |
|
"pad_mode": "constant", |
|
"center": True, |
|
"power": 2.0, |
|
"window_fn": torch.hann_window, |
|
"mel_scale": "htk", |
|
"norm": None, |
|
"normalized": False} |
|
|
|
model = create_model(param) |
|
|
|
global global_model |
|
global_model = model |
|
|
|
metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5, |
|
tokenizer=tokenizer, model=model) |
|
|
|
print(f"{'Sanity check' if sanity_check else 'Training'} mode") |
|
train_dataset, test_dataset = prepare_datasets( |
|
tokenizer=tokenizer, |
|
token=token, |
|
sanity_check=sanity_check, |
|
dataset_config=dataset_config) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
args=training_args, |
|
model=model, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
data_collator=DataCollator(tokenizer=tokenizer), |
|
compute_metrics=metrics_fn, |
|
) |
|
|
|
trainer.train() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|