import pyworld as pw import os import math import warnings import logging import gzip import base64 import torch import torchaudio import torchcrepe import torch.nn.functional as F import torch.nn.init as init from torch import nn, Tensor import numpy as np from typing import Optional, Dict, Union, List, Tuple, Any from functools import partial from datetime import datetime from datasets import load_dataset, Audio from transformers.trainer_seq2seq import Seq2SeqTrainer from transformers.training_args_seq2seq import Seq2SeqTrainingArguments import transformers import evaluate from dataclasses import dataclass torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True torch.set_float32_matmul_precision('high') transformers.utils.logging.set_verbosity_error() device = torch.device(device="cuda:0") dtype = torch.float32 torch.set_default_dtype(dtype) warnings.filterwarnings("ignore") logging.basicConfig(level=logging.ERROR) tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32} extractor = None tokenizer = None optimizer = None scheduler = None model = None Residual = None MultiheadA = None @dataclass class Dimensions: vocab: int text_ctx: int text_dims: int text_head: int text_idx: int mels: int aud_ctx: int aud_dims: int aud_head: int aud_idx: int act: str debug: List[str] cross_attn: bool features: List[str] f0_rotary: bool def exists(v): return v is not None def default(v, b): return v if exists(v) else b class Conv1d(nn.Conv1d): def _conv_forward( self, x: Tensor, weight: Tensor, bias) -> Tensor: return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype)) class Conv2d(nn.Conv2d): def _conv_forward( self, x: Tensor, weight: Tensor, bias) -> Tensor: return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype)) class Linear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super(Linear, self).__init__() self.linear = nn.Linear(in_features, out_features, bias=bias) init.xavier_uniform_(self.linear.weight) if bias: init.zeros_(self.linear.bias) def forward(self, x: Tensor) -> Tensor: return self.linear(x) class RMSNorm(nn.Module): def __init__(self, dims: Union[int, Tensor, List, Tuple], eps = 1e-8, elementwise_affine = True): super(RMSNorm, self).__init__() if isinstance(dims, int): self.normalized_shape = (dims,) else: self.normalized_shape = tuple(dims) self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape)) init.ones_(self.weight) else: self.register_parameter("weight", None) def forward(self, x): return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5) -> Tensor: return F.layer_norm(x, normalized_shape, weight, bias, eps) def get_device(): return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def get_dtype(): return torch.float32 if torch.cuda.is_available() else torch.float64 def get_tox(): return {"device": get_device(), "dtype": get_dtype()} def sinusoids(length, channels, max_timescale=10000): """Returns sinusoids for positional embedding""" assert channels % 2 == 0 log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) class rotary(nn.Module): _seen = set() def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False, learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []): super().__init__() self.dims = dims device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dtype = torch.float32 self.device = device self.dtype = dtype self.debug = debug self._counter = 0 self.use_pbias = False self.max_ctx = max_ctx self.variable_radius = variable_radius self.inv_freq = nn.Parameter(1.0 / (theta ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)), requires_grad=learned_freq) self.theta = nn.Parameter(torch.tensor(float(theta)), requires_grad=learned_theta) self.pitch_scale = nn.Parameter(torch.tensor(1.0), requires_grad=learned_pitch) if variable_radius: self.radius = nn.Parameter(torch.ones(dims // 2), requires_grad=learned_radius) def get_pitch_bias(self, f0): if f0 is None: return None f0_flat = f0.squeeze().float() f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8) f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1), f0_norm.unsqueeze(1)) * self.pitch_scale) return f0_sim.unsqueeze(0).unsqueeze(0) def add_to_rotary(self): def get_sim(self, freqs): real = freqs.real.squeeze(0) imag = freqs.imag.squeeze(0) vecs = torch.cat([real.unsqueeze(-2), imag.unsqueeze(-2)], dim=-1) vecs = vecs.squeeze(-2) return F.cosine_similarity(vecs.unsqueeze(1), vecs.unsqueeze(0), dim=-1) def fwd_sim(self, x=None, f0=None): freqs = self.forward(x, f0) sim = get_sim(self, freqs) return freqs, sim rotary.get_sim = get_sim rotary.fwd_sim = fwd_sim def align_f0_to_tokens(self, f0, token_length): ratio = len(f0) / token_length indices = [int(i * ratio) for i in range(token_length)] indices = [min(i, len(f0) - 1) for i in indices] return f0[indices] def forward(self, x=None, f0=None, stage=None) -> Tensor: if isinstance(x, int): t = torch.arange(x, device=self.device).float() else: t = x.float().to(self.inv_freq.device) if f0 is not None: f0_mean = f0.mean() f0_theta = (f0_mean**2) * self.pitch_scale inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims)) else: inv_freq = self.inv_freq freqs = torch.einsum('i,j->ij', t, inv_freq) freqs = freqs.float() if self.variable_radius: if f0 is not None: f0 = f0[0] seq_len = x f0 = torch.tensor(f0, device=x.device if isinstance(x, torch.Tensor) else self.device) f0 = self.align_f0_to_tokens(f0, freqs.shape[-1]) max_f0 = torch.max(f0) if max_f0 > 0: radius = f0 / max_f0 else: radius = torch.ones_like(f0) freqs = torch.polar(radius, freqs) else: freqs = torch.polar(torch.ones_like(freqs), freqs) freqs = freqs.unsqueeze(0) # print(f"radius, {radius}") if "rotary" in self.debug: if f0 is not None: key = f"{self._counter}_{f0_theta:.2f}" if key not in rotary._seen: if not hasattr(self, '_prev_f0_theta'): self._prev_f0_theta = f0_theta print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz") elif abs(self._prev_f0_theta - f0_theta) > 0.0: print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz") self._prev_f0_theta = f0_theta rotary._seen.add(key) self._counter += 1 return freqs @staticmethod def apply_rotary(x, freqs): multihead_format = len(freqs.shape) == 4 if multihead_format: x1 = x[..., :freqs.shape[-1]*2] x2 = x[..., freqs.shape[-1]*2:] x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous() x1 = torch.view_as_complex(x1) x1 = x1 * freqs x1 = torch.view_as_real(x1).flatten(-2) return torch.cat([x1.type_as(x), x2], dim=-1) else: x1 = x[..., :freqs.shape[-1]*2] x2 = x[..., freqs.shape[-1]*2:] if x.ndim == 2: x1 = x1.unsqueeze(0) x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous() x1 = torch.view_as_complex(x1) x1 = x1 * freqs x1 = torch.view_as_real(x1).flatten(-2) x1 = x1.squeeze(0) return torch.cat([x1.type_as(x), x2], dim=-1) else: x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous() x1 = torch.view_as_complex(x1) x1 = x1 * freqs x1 = torch.view_as_real(x1).flatten(-2) return torch.cat([x1.type_as(x), x2], dim=-1) class SliceAttention(nn.Module): def __init__(self, dims, heads, dropout=0.0): super().__init__() self.dims = dims self.heads = heads self.head_dim = dims // heads self.scale = self.head_dim ** -0.5 self.q_proj = Linear(dims, dims) self.k_proj = Linear(dims, dims) self.v_proj = Linear(dims, dims) self.out_proj = Linear(dims, dims) self.dropout = nn.Dropout(dropout) assert dims % heads == 0, f"Dimensions {dims} not divisible by heads {heads}" def parallel_slice(self, q, k, v, mask=None): batch, heads, ctx, dims = q.shape head_dim = self.head_dim batch, ctx, dims = q.shape ctx_len = k.shape[1] num_heads = dims // head_dim scores = torch.zeros(batch, num_heads, ctx, ctx_len, device=q.device) for h in range(num_heads): start_idx = h * head_dim end_idx = start_idx + head_dim q_h = q[:, :, start_idx:end_idx] k_h = k[:, :, start_idx:end_idx] scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim) if mask is not None: scores = scores + mask.unsqueeze(0).unsqueeze(0) attn_weights = F.softmax(scores, dim=-1) output = torch.zeros_like(q) for h in range(num_heads): start_idx = h * head_dim end_idx = start_idx + head_dim v_h = v[:, :, start_idx:end_idx] output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h) return output def forward(self, x, context=None, mask=None): batch, ctx, _ = x.shape if context is None: context = x ctx_len = context.shape[1] q = self.q_proj(x) k = self.k_proj(context) v = self.v_proj(context) output = torch.zeros_like(q) for h in range(self.heads): start_idx = h * self.head_dim end_idx = start_idx + self.head_dim q_h = q[:, :, start_idx:end_idx] k_h = k[:, :, start_idx:end_idx] v_h = v[:, :, start_idx:end_idx] attn_scores = torch.bmm(q_h, k_h.transpose(1, 2)) * self.scale if mask is not None: attn_scores = attn_scores + mask[:ctx, :ctx_len].unsqueeze(0) attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.dropout(attn_weights) head_output = torch.bmm(attn_weights, v_h) output[:, :, start_idx:end_idx] = head_output return self.out_proj(output) def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001): batch, heads, ctx, dims = q.shape token_ids = k[:, :, :, 0] is_padding = (token_ids.float() == pad_token).unsqueeze(-2) log_scale_factor = -10.0 attn_mask = torch.zeros((batch, heads, ctx, ctx), device=q.device) if mask is not None: attn_mask = attn_mask + mask.unsqueeze(0).unsqueeze(0) attn_mask = torch.where(is_padding, torch.tensor(log_scale_factor, device=q.device), attn_mask) attn_output = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) attn_output = attn_output.permute(0, 2, 1, 3).flatten(start_dim=2) return attn_output class MultiheadA(nn.Module): _seen = set() rbf = False def __init__(self, dims: int, head: int, rotary_emb: bool = False, zero_val: float = 0.0001, minz: float = 0.0, maxz: float = 0.001, debug: List[str] = [], optim_attn=False): super(MultiheadA, self).__init__() self.debug = debug self.pad_token = 0 self.dims = dims self.head = head self.head_dim = dims // head self.rotary_emb = rotary_emb self.minz = minz self.maxz = maxz self.zero_val = zero_val self.optim_attn = optim_attn self._counter = 0 if dims % head != 0: raise ValueError(f"Dimensions {dims} must be divisible by number of heads {head}.") if zero_val < minz or zero_val > maxz: raise ValueError(f"Zero value {zero_val} must be between {minz} and {maxz}.") self.q = Linear(dims, dims) self.k = Linear(dims, dims, bias=False) self.v = Linear(dims, dims) self.o = Linear(dims, dims) self.fzero = nn.Parameter(torch.tensor(zero_val, dtype=torch.float32), requires_grad=True) if rotary_emb: self.rope = rotary( dims=self.head_dim, debug = debug, max_ctx=1500, ) else: self.rope = None def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0): scale = (self.dims // self.head) ** -0.25 dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale if rbf_ratio <= 0.0: return dot_scores q_norm = q.pow(2).sum(dim=-1, keepdim=True) k_norm = k.pow(2).sum(dim=-1, keepdim=True) qk = torch.matmul(q, k.transpose(-1, -2)) dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2)) return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, return_attn: bool = False, f0: Tensor = None) -> tuple: batch, ctx, dims = x.shape scale = (self.dims // self.head) ** -0.25 z = default(xa, x) q = self.q(x).to(x.dtype) k = self.k(z).to(x.dtype) v = self.v(z).to(x.dtype) if self.rotary_emb: if f0 is not None: qf = self.rope(q.size(1), f0=f0) kf = self.rope(k.size(1), f0=f0) else: qf = self.rope(q.size(1)) kf = self.rope(k.size(1)) q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3) k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3) v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3) q = self.rope.apply_rotary(q, qf) k = self.rope.apply_rotary(k, kf) else: q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3) k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3) v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3) batch, head, ctx, head_dim = q.shape if self.optim_attn and not return_attn: wv = optim_attn(q * scale, k * scale, v, mask=mask, pad_token=self.pad_token, fzero_val=torch.clamp(F.softplus(self.fzero), self.minz, self.maxz).item()) return self.o(wv), None if self.rbf: qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3) qk = (q * scale) @ (k * scale).transpose(-1, -2) if f0 is not None and self.rope.use_pbias: pbias = self.rope.pbias(f0) if pbias is not None: qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]] token_ids = k[:, :, :, 0] zscale = torch.ones_like(token_ids) fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz) zscale[token_ids.float() == self.pad_token] = fzero.to(q.device, q.dtype) if mask is not None: mask = mask[:q.shape[2], :q.shape[2]] qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape) qk = qk * zscale.unsqueeze(-2) if return_attn: return qk, v w = F.softmax(qk, dim=-1).to(q.dtype) wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) if "multihead" in self.debug and self._counter % 100 == 0: print(f"Step {self._counter}: Using rotary embeddings: {self.rotary_emb}") print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape}") print(f"Attention shape: {qk.shape}, wv shape: {wv.shape}") self._counter += 1 return self.o(wv), qk.detach() class FCGate(nn.Module): def __init__(self, dims, dim): super().__init__() self.proj = Linear(dim, dims // 4) self.gate = nn.Sequential( Linear(dims + dims // 4, dims // 2), nn.SiLU(), Linear(dims // 2, 1), nn.Sigmoid() ) def forward(self, x, embedding): info = self.proj(embedding) info = info.unsqueeze(1).expand(-1, x.shape[1], -1) gate_input = torch.cat([x, info], dim=-1) return self.gate(gate_input) class TTGate(nn.Module): def __init__(self, dims, num_types=4): super().__init__() self.gate_projections = nn.ModuleList([ nn.Sequential(Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)]) self.type_classifier = nn.Sequential( Linear(dims, num_types), nn.Softmax(dim=-1)) def forward(self, x): type_probs = self.type_classifier(x) gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1) combined_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1) return combined_gate class MGate(nn.Module): def __init__(self, dims, memory_size=64): super().__init__() self.mkey = nn.Parameter(torch.randn(memory_size, dims)) self.mvalue = nn.Parameter(torch.randn(memory_size, 1)) self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1)) def forward(self, x): dgate = torch.sigmoid(self.gate_proj(x)) attention = torch.matmul(x, self.mkey.transpose(0, 1)) attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1) mgate = torch.matmul(attention, self.mvalue) mgate = torch.sigmoid(mgate) return 0.5 * (dgate + mgate) class CMGate(nn.Module): def __init__(self, dims): super().__init__() self.sgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) self.wgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) self.pgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) self.integration = Linear(dims*3, dims) def forward(self, x, features): sfeat = features.get("spectrogram", x) wfeat = features.get("waveform", x) pfeat = features.get("pitch", x) spec = self.sgate(x) * sfeat wave = self.wgate(x) * wfeat pitch = self.pgate(x) * pfeat combined = torch.cat([spec, wave, pitch], dim=-1) return self.integration(combined) class Residual(nn.Module): _seen = set() def __init__(self, dims: int, head: int, ctx, act, cross_attn=True, debug: List[str] = [], fgate=False, tgate=False, mgate=False, cgate=False, memory_size=512, features=None): super().__init__() self.ctx = ctx self._counter = 0 self.dropout = 0.01 self.dims = dims self.head = head self.head_dim = dims // head self.cross_attn = cross_attn self.debug = debug self.fgate = fgate self.tgate = tgate self.mgate = mgate self.cgate = cgate self.features = features self.blend = nn.Parameter(torch.tensor(0.5)) act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} act_fn = act_map.get(act, nn.GELU()) self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug) self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None) mlp = dims * 4 self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims)) self.fgate = FCGate(dims=dims, dim=dims) if fgate else None self.tgate = TTGate(dims=dims, num_types=4) if tgate else None self.mgate = MGate(dims=dims, memory_size=memory_size) if mgate else None self.cgate = CMGate(dims=dims) if cgate else None self.lna = RMSNorm(dims) self.lnb = RMSNorm(dims) if cross_attn else None self.lnc = RMSNorm(dims) if not any([fgate, tgate, mgate, cgate]): self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) def forward(self, x, xa=None, mask=None, f0=None, mode=None): x = x + self.attna(self.lna(x), mask=mask, f0=f0)[0] if self.attnb and xa is not None: cross = self.attnb(self.lnb(x), xa, f0=f0, mask=None)[0] blend = torch.sigmoid(self.blend) x = blend * x + (1 - blend) * cross normx = self.lnc(x) mlp_out = self.mlp(normx) if self.tgate: gate = self.tgate(normx) x = x + gate * mlp_out elif self.fgate: embedding = f0.mean(dim=1) if f0 is not None else xa.mean(dim=1) gate = self.fg(normx, embedding) x = x + gate * mlp_out elif self.mgate: gate = self.mgate(normx) x = x + gate * mlp_out elif self.cgate and mode is not None: gate_output = self.cgate(normx, self.features) x = x + gate_output else: if hasattr(self, 'mlp_gate'): mlp_gate = self.mlp_gate(normx) x = x + mlp_gate * mlp_out else: x = x + mlp_out if "residual" in self.debug and self._counter % 100 == 0: print(f"Step {self._counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}") self._counter += 1 return x class PEncoder(nn.Module): def __init__(self, input_dims, dims, head, layer, kernel_size, act): super().__init__() self.head_dim = dims // head self.dropout = 0.01 act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} act_fn = act_map.get(act, nn.GELU()) self.encoder = nn.Sequential( Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn, Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn, Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn) def forward(self, x, f0=None): x = self.encoder(x).permute(0, 2, 1) x = x + self.positional(x.shape[1]).to(x.device, x.dtype) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = self.norm(x) return x class WEncoder(nn.Module): def __init__(self, input_dims, dims, head, layer, kernel_size, act): super().__init__() self.head_dim = dims // head self.dropout = 0.01 act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} act_fn = act_map.get(act, nn.GELU()) self.downsample = nn.Sequential( Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn, Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn, Conv1d(dims//4, dims, kernel_size=9, stride=5, padding=4), act_fn) self.encoder = nn.Sequential( Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn, Conv1d(dims, dims, kernel_size=1), act_fn) self.positional = lambda length: sinusoids(length, dims) self.norm = RMSNorm(dims) def forward(self, x, f0=None): x = self.downsample(x) x = self.encoder(x) x = x.permute(0, 2, 1) x = x + self.positional(x.shape[1]).to(x.device, x.dtype) x = nn.functional.dropout(x, p=self.dropout, training=self.training) return self.norm(x) class FEncoder(nn.Module): def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1): super().__init__() self.head_dim = dims // head self.dropout = 0.01 act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} act_fn = act_map.get(act, nn.GELU()) self.encoder = nn.Sequential( Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn, Conv1d(dims, dims, kernel_size=5, padding=2), act_fn, Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn) self.positional = lambda length: sinusoids(length, dims) self.norm = RMSNorm(dims) self._norm = RMSNorm(dims) def forward(self, x, f0=None): x = self.encoder(x).permute(0, 2, 1) x = x + self.positional(x.shape[1]).to(x.device, x.dtype) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = self._norm(x) return x class AudioEncoder(nn.Module): _seen = set() def __init__(self, mels: int, layer: int, dims: int, head: int, ctx: int, features: List[str], debug: List[str], f0_rotary: bool = False, act: str = "gelu"): super(AudioEncoder, self).__init__() self.debug = debug self.features = features self._counter = 0 self.dropout = 0.01 self.f0_rotary = f0_rotary self.dims = dims self.ctx = ctx self.head = head self.head_dim = dims // head self.rope = rotary( dims=self.head_dim, debug=debug,) act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()} act_fn = act_map.get(act, nn.GELU()) if features == ["spectrogram", "waveform", "pitch"]: cgate=True else: cgate = False self.blocks = nn.ModuleDict({ "spectrogram": nn.ModuleList( [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] + [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "spectrogram" in features else None ), "waveform": nn.ModuleList( [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] + [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "waveform" in features else None ), "pitch": nn.ModuleList( [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] + [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "pitch" in features else None ), "spec_envelope": nn.ModuleList( [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] + [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_envelope" in features else None ), "spec_phase": nn.ModuleList( [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] + [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None), }) def forward(self, x, f0=None): outputs = {} if self.f0_rotary: f0 = f0 if f0 is not None else x.get("pitch") else: f0 = None for y in self.features: if y in x and y in self.blocks: f = x[y] for block in self.blocks[y]: f = block(f, f0=f0) outputs[y] = f if "encoder" in self.debug and self._counter % 100 == 0: names = list(x.keys()) shapes = {k: v.shape for k, v in x.items()} print(f"Step {self._counter}: mode: {names}") print(f"shapes: {shapes}") self._counter += 1 return outputs class TextDecoder(nn.Module): def __init__(self, vocab: int, layer: int, dims: int, head: int, ctx: int, cross_attn: bool, features: List[str], debug: List[str], f0_rotary: bool = False, sequential=False): super(TextDecoder, self).__init__() self._counter = 0 self.dropout = 0.01 self.debug = debug self.sequential = sequential self.features = features self.f0_rotary = f0_rotary self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims) with torch.no_grad(): self.token.weight[0].zero_() self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True) self._blocks = nn.ModuleList([ Residual(dims=dims, head=head, ctx=ctx, act="gelu", cross_attn=cross_attn, debug=debug, features=features) for _ in range(layer)]) self.blocks = nn.ModuleDict({ f: nn.ModuleList([Residual(dims=dims, head=head, ctx=ctx, act="gelu", cross_attn=cross_attn, debug=debug, features=features) for _ in range(layer)]) for f in features}) self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features}) self.ln_dec = RMSNorm(dims) mask = torch.tril(torch.ones(ctx, ctx), diagonal=0) self.register_buffer("mask", mask, persistent=False) def forward(self, x, enc, order=None, f0=None) -> Tensor: x = x.to(device) if self.f0_rotary: f0 = f0 else: f0 = None if order is None: order = self.features mask = self.mask[:x.shape[1], :x.shape[1]] x = self.token(x) + self.positional[:x.shape[1]] x = F.dropout(x, p=self.dropout, training=self.training) for block in self._blocks: x = block(x, f0=f0, mask=mask) for f in order: if f in enc: xa = enc[f] for block in self.blocks[f]: out = block(x=x, xa=xa, f0=f0, mask=None) a = torch.sigmoid(self.blend[f]) x = a * out + (1 - a) * x x = self.ln_dec(x) return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float() class Echo(nn.Module): def __init__(self, param: Dimensions): super().__init__() self.param = param self.encoder = AudioEncoder( mels=param.mels, ctx=param.aud_ctx, dims=param.aud_dims, head=param.aud_head, layer=param.aud_idx, act=param.act, debug=param.debug, features=param.features, f0_rotary=param.f0_rotary, ) self.decoder = TextDecoder( vocab=param.vocab, ctx=param.text_ctx, dims=param.text_dims, head=param.text_head, layer=param.text_idx, cross_attn=param.cross_attn, debug=param.debug, features=param.features, f0_rotary=param.f0_rotary, ) all_head = torch.zeros(self.param.text_idx, self.param.text_head, dtype=torch.bool) all_head[self.param.text_idx // 2 :] = True self.register_buffer("alignment_head", all_head.to_sparse(), persistent=False) def set_alignment_head(self, dump: bytes): array = np.frombuffer( gzip.decompress(base64.b85decode(dump)), dtype=bool).copy() mask = torch.from_numpy(array).reshape( self.param.text_idx, self.param.text_head) self.register_buffer("alignment_head", mask.to_sparse(), persistent=False) def embed_audio(self, spectrogram: torch.Tensor): return self.encoder(spectrogram) def logits(self,input_ids: torch.Tensor, encoder_output: torch.Tensor): return self.decoder(input_ids, encoder_output) def forward(self, decoder_input_ids=None, labels=None, waveform: Optional[torch.Tensor]=None, input_ids=None, spectrogram: torch.Tensor=None, pitch: Optional[torch.Tensor]=None, f0: Optional[torch.Tensor]=None, envelope: Optional[torch.Tensor]=None, phase: Optional[torch.Tensor]=None, ) -> Dict[str, torch.Tensor]: decoder_input_ids = input_ids encoder_inputs = {} if spectrogram is not None: encoder_inputs["spectrogram"] = spectrogram if waveform is not None: encoder_inputs["waveform"] = waveform if pitch is not None: encoder_inputs["pitch"] = pitch if envelope is not None: encoder_inputs["envelope"] = envelope if phase is not None: encoder_inputs["phase"] = phase encoder_outputs = self.encoder(encoder_inputs, f0=f0) logits = self.decoder(input_ids, encoder_outputs, f0=f0) loss = None if labels is not None: loss = F.cross_entropy( logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0) return { "logits": logits, "loss": loss, "labels": labels, "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "encoder_output": encoder_outputs } def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype def _init_weights(self, module): std = 0.02 self.init_counts = { "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0, "Conv2d": 0, "SEBlock": 0, "TextDecoder": 0, "AudioEncoder": 0, "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0, "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0, "WEncoder": 0, "PEncoder": 0} for module in self.named_modules(): if isinstance(module, Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) self.init_counts["Linear"] += 1 elif isinstance(module, Conv1d): nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: nn.init.zeros_(module.bias) self.init_counts["Conv1d"] += 1 elif isinstance(module, RMSNorm): nn.init.ones_(module.weight) self.init_counts["RMSNorm"] += 1 elif isinstance(module, MultiheadA): self.init_counts["MultiheadA"] += 1 elif isinstance(module, Conv2d): nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: nn.init.zeros_(module.bias) self.init_counts["Conv2d"] += 1 elif isinstance(module, TextDecoder): self.init_counts["TextDecoder"] += 1 elif isinstance(module, AudioEncoder): self.init_counts["AudioEncoder"] += 1 elif isinstance(module, Residual): self.init_counts["Residual"] += 1 def init_weights(self): print("Initializing all weights") self.apply(self._init_weights) print("Initialization summary:") for module_type, count in self.init_counts.items(): if count > 0: print(f"{module_type}: {count}") metric = evaluate.load(path="wer") @dataclass class DataCollator: tokenizer: Any def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0 bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1 batch = {} if "spectrogram" in features[0] and features[0]["spectrogram"] is not None: spectrogram_list = [f["spectrogram"] for f in features] max_len_feat = max(f.shape[-1] for f in spectrogram_list) pad_spectrogram = [] for feat in spectrogram_list: current_len = feat.shape[-1] padding = max_len_feat - current_len if padding > 0: pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) else: pad_feat = feat pad_spectrogram.append(pad_feat) batch["spectrogram"] = torch.stack(pad_spectrogram) if "waveform" in features[0] and features[0]["waveform"] is not None: waveform_list = [f["waveform"] for f in features] max_len_wav = max(w.shape[-1] for w in waveform_list) pad_waveforms = [] for wav in waveform_list: current_len = wav.shape[-1] padding = max_len_wav - current_len if padding > 0: if wav.ndim == 1: wav = wav.unsqueeze(0) pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id) else: pad_wav = wav pad_waveforms.append(pad_wav) batch["waveform"] = torch.stack(pad_waveforms) if "label" in features[0] and features[0]["label"] is not None: labels_list = [f["label"] for f in features] max_len = max(len(l) for l in labels_list) all_ids = [] all_labels = [] for label in labels_list: label_list = label.tolist() if isinstance(label, torch.Tensor) else label decoder_input = [bos_token_id] + label_list label_eos = label_list + [pad_token_id] input_len = max_len + 1 - len(decoder_input) label_len = max_len + 1 - len(label_eos) padded_input = decoder_input + [pad_token_id] * input_len padded_labels = label_eos + [pad_token_id] * label_len all_ids.append(padded_input) all_labels.append(padded_labels) batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long) batch["labels"] = torch.tensor(all_labels, dtype=torch.long) if "pitch" in features[0] and features[0]["pitch"] is not None: pitch_list = [f["pitch"] for f in features] max_len_pitch = max(e.shape[-1] for e in pitch_list) pad_pitch = [] for pitch in pitch_list: current_len = pitch.shape[-1] padding = max_len_pitch - current_len if padding > 0: pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id) else: pad_pitch_item = pitch pad_pitch.append(pad_pitch_item) batch["pitch"] = torch.stack(pad_pitch) if "f0" in features[0] and features[0]["f0"] is not None: all_f0 = torch.cat([f["f0"] for f in features]) batch["f0"] = all_f0.unsqueeze(0) if "envelope" in features[0] and features[0]["envelope"] is not None: env_list = [f["envelope"] for f in features] max_len = max(f.shape[-1] for f in env_list) pad_env = [] for feat in env_list: current_len = feat.shape[-1] padding = max_len_feat - current_len if padding > 0: pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) else: pad_feat = feat pad_env.append(pad_feat) batch["envelope"] = torch.stack(pad_env) if "phase" in features[0] and features[0]["phase"] is not None: ph_list = [f["phase"] for f in features] max_len = max(f.shape[-1] for f in ph_list) pad_ph = [] for feat in ph_list: current_len = feat.shape[-1] padding = max_len_feat - current_len if padding > 0: pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id) else: pad_feat = feat pad_ph.append(pad_feat) batch["phase"] = torch.stack(pad_ph) return batch def hilbert_transform(x): N = x.shape[-1] xf = torch.fft.rfft(x) h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype) if N % 2 == 0: h[0] = h[N//2] = 1 h[1:N//2] = 2 else: h[0] = 1 h[1:(N+1)//2] = 2 return torch.fft.irfft(xf * h, n=N) def analytic_signal(x): return x + 1j * hilbert_transform(x) def hilbert_transform_2d(x, dim=-1): N = x.shape[dim] if dim == -1 or dim == len(x.shape) - 1: xf = torch.fft.rfft(x) else: xf = torch.fft.rfft(x, dim=dim) h_shape = [1] * len(x.shape) h_shape[dim] = N // 2 + 1 h = torch.zeros(h_shape, device=x.device, dtype=x.dtype) if dim == -1 or dim == len(x.shape) - 1: if N % 2 == 0: h[..., 0] = h[..., -1] = 1 h[..., 1:-1] = 2 else: h[..., 0] = 1 h[..., 1:] = 2 else: pass return torch.fft.irfft(xf * h, n=N, dim=dim) def hilbert_transform_true_2d(x): xf = torch.fft.rfft2(x) h1, h2 = torch.meshgrid( torch.fft.rfftfreq(x.shape[-2]) * 2 - 1, torch.fft.rfftfreq(x.shape[-1]) * 2 - 1, indexing='ij') h = -1j / (math.pi * (h1 + 1j*h2)) h[0, 0] = 0 return torch.fft.irfft2(xf * h.to(x.device)) def process_spectrogram_with_hilbert(spec): analytic = spec + 1j * hilbert_transform(spec) envelope = torch.abs(analytic) phase = torch.angle(analytic) return envelope, phase def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False, hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000, pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk", norm=None, normalized=False, downsamples=False, period=False, hilbert=False): dtype = torch.float32 device = torch.device("cuda:0") audio = batch["audio"] sampling_rate = audio["sampling_rate"] wav = torch.tensor(audio["array"]).float() sr = audio["sampling_rate"] if sr != sampling_rate: original_length = wav.shape[-1] target_length = int(original_length * (sampling_rate / sr)) resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sampling_rate) wav = resampler(wav) if abs(wav.shape[-1] - target_length) > 1: new_waveform = torch.zeros((wav.shape[0], target_length), dtype=dtype, device=device) copy_length = min(wav.shape[1], target_length) new_waveform[:, :copy_length] = wav[:, :copy_length] wav = new_waveform if spectrogram: transform = torchaudio.transforms.MelSpectrogram( f_max=fmax, f_min=fmin, n_mels=n_mels, sample_rate=sr, n_fft=n_fft, hop_length=hop_length, norm=norm, normalized=normalized, power=power, center=center, mel_scale=mel_scale, window_fn=window_fn, pad_mode=pad_mode) mel_spectrogram = transform(wav) log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10() log_mel = torch.maximum(log_mel, log_mel.max() - 8.0) spec = (log_mel + 4.0) / 4.0 spec = torch.tensor(spec) batch["spectrogram"] = spec if hilbert: envelope_list = [] phase_list = [] for ch_idx in range(spec.shape[0]): envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx]) envelope_list.append(envelope) phase_list.append(phase) batch["envelope"] = torch.stack(envelope_list) batch["phase"] = torch.stack(phase_list) wav_1d = wav.unsqueeze(0) if waveforms: batch["waveform"] = wav_1d if pitch: if period: pit, periodocity = torchcrepe.predict( wav_1d, sampling_rate, hop_length, fmin=80, fmax=800, model="tiny", decoder=torchcrepe.decode.viterbi, return_periodicity=True, device=device, pad=True ) batch["pitch"] = pit batch["period"] = periodocity else: pit = torchcrepe.predict( wav_1d, sampling_rate, hop_length, fmin=80, fmax=800, model="tiny", decoder=torchcrepe.decode.viterbi, return_periodicity=False, device=device, pad=True ) batch["pitch"] = pit if f0: wav_np = wav.numpy().astype(np.float64) f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000) f0 = pw.stonemask(wav_np, f0, t, sampling_rate) batch["f0"] = torch.from_numpy(f0).float() if spectrogram and waveforms and pitch: spec_mean = batch["spectrogram"].mean() spec_std = batch["spectrogram"].std() + 1e-6 batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std wav_mean = batch["waveform"].mean() wav_std = batch["waveform"].std() + 1e-6 batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std if batch["pitch"].max() > 1.0: pitch_min = 50.0 pitch_max = 600.0 batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min) batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False) return batch def compute_metrics(eval_pred, compute_result: bool = True, print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None): pred_logits = eval_pred.predictions label_ids = eval_pred.label_ids if hasattr(pred_logits, "cpu"): pred_logits = pred_logits.cpu() if hasattr(label_ids, "cpu"): label_ids = label_ids.cpu() if isinstance(pred_logits, tuple): pred_ids = pred_logits[0] else: pred_ids = pred_logits if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3: if not isinstance(pred_ids, torch.Tensor): pred_ids = torch.tensor(pred_ids) pred_ids = pred_ids.argmax(dim=-1) pred_ids = pred_ids.tolist() if hasattr(label_ids, "tolist"): label_ids = label_ids.tolist() label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids] pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False) label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False) if print_pred: for i in range(min(num_samples, len(pred_str))): print(f"Preds: {pred_str[i]}") print(f"Label: {label_str[i]}") print(f"preds: {pred_ids[i]}") print(f"label: {label_ids[i]}") print("--------------------------------") pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) wer = 100 * metric.compute(predictions=pred_str, references=label_str) if model is None: global global_model if 'global_model' in globals(): model = global_model if model is not None: trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000 if trainable_params > 0: efficiency_score = (100 - wer) / trainable_params else: print("Warning: Zero trainable parameters detected") efficiency_score = 0.0 else: print("Warning: Model not available for parameter counting") trainable_params = 0.0 efficiency_score = 0.0 if hasattr(wer, "item"): wer = wer.item() metrics = { "wer": float(wer), "trainable_params_M": float(trainable_params), "efficiency_score": float(efficiency_score), } print(f"Computed metrics: WER={wer:.2f}%, Params={trainable_params:.2f}M, Efficiency={efficiency_score:.4f}") return metrics logger = logging.getLogger(__name__) def create_model(param: Dimensions) -> Echo: model = Echo(param).to('cuda') trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) logger.info(f"Trainable parameters: {trainable_params:,}") logger.info(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") print(f"Total parameters: {total_params:,}") model.init_weights() return model def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"): from tokenizers import Tokenizer tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json") orig_encode = tokenizer.encode def enc(text, add_special_tokens=True): ids = orig_encode(text).ids if not add_special_tokens: sp_ids = [tokenizer.token_to_id(t) for t in ["", "", ""]] ids = [id for id in ids if id not in sp_ids] return ids def bdec(ids_list, skip_special_tokens=True): results = [] for ids in ids_list: if skip_special_tokens: ids = [id for id in ids if id not in [0, 1, 2]] results.append(tokenizer.decode(ids)) return results def save_pretrained(save_dir): os.makedirs(save_dir, exist_ok=True) tokenizer.save(f"{save_dir}/tokenizer.json") tokenizer.encode = enc tokenizer.batch_decode = bdec tokenizer.save_pretrained = save_pretrained tokenizer.pad_token_id = 0 tokenizer.bos_token_id = 1 tokenizer.eos_token_id = 2 return tokenizer def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]: if dataset_config is None: dataset_config = { "spectrogram": True, "waveforms": True, "pitch": True, "f0": True, "downsamples": True, "hop_length": 128, "fmin": 50, "fmax": 2000, "n_mels": 128, "n_fft": 1024, "sampling_rate": 16000, } dataset = load_dataset( "google/fleurs", "en_us", token=token, trust_remote_code=True, streaming=False ) dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)) if sanity_check: dataset = dataset["test"].take(10).shuffle() dataset = dataset.select_columns(["audio", "transcription"]) logger.info(f"Sanity dataset size: {dataset.num_rows}") print(f"Sanity dataset size: {dataset.num_rows}") prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config) dataset = dataset.map( function=prepare_fn, remove_columns=["audio", "transcription"] ).with_format(type="torch") train_dataset = dataset test_dataset = dataset else: def filter_func(x): return (0 < len(x["transcription"]) < 512 and len(x["audio"]["array"]) > 0 and len(x["audio"]["array"]) < 1500 * 160) dataset = dataset.filter(filter_func).shuffle() logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}") print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}") prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config) train_dataset = dataset["train"] test_dataset = dataset["test"] columns_to_remove = list(next(iter(dataset.values())).features) train_dataset = train_dataset.map( function=prepare_fn, remove_columns=columns_to_remove ).with_format(type="torch") test_dataset = test_dataset.map( function=prepare_fn, remove_columns=columns_to_remove ).with_format(type="torch") return train_dataset, test_dataset def get_training_args( log_dir: str, batch_eval_metrics: bool = False, max_steps: int = 10, save_steps: int = 1000, eval_steps: int = 100, warmup_steps: int = 0, num_train_epochs: int = 1, logging_steps: int = 10, eval_on_start: bool = False, learning_rate: float = 1e-4, weight_decay: float = 0.01, max_grad_norm: float = 1.0, ) -> Seq2SeqTrainingArguments: return Seq2SeqTrainingArguments( output_dir=log_dir, per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=1, eval_accumulation_steps=1, tf32=True, bf16=True, eval_strategy="steps", save_strategy="steps", max_steps=max_steps, save_steps=save_steps, eval_steps=eval_steps, warmup_steps=warmup_steps, num_train_epochs=num_train_epochs, logging_steps=logging_steps, logging_dir=log_dir, logging_strategy="steps", report_to=["tensorboard"], push_to_hub=False, disable_tqdm=False, save_total_limit=1, label_names=["labels"], optim="adamw_torch", lr_scheduler_type="cosine", learning_rate=learning_rate, weight_decay=weight_decay, save_safetensors=False, eval_on_start=eval_on_start, batch_eval_metrics=batch_eval_metrics, max_grad_norm=max_grad_norm, ) def main(): token = "" log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H')) os.makedirs(name=log_dir, exist_ok=True) tokenizer = setup_tokenizer(token) def sanity(sanity: bool): if sanity: training_args = get_training_args( log_dir, batch_eval_metrics = False, max_steps = 10, save_steps = 0, eval_steps = 1, warmup_steps = 0, logging_steps = 1, eval_on_start = True, learning_rate = 5e-6, weight_decay = 0.01, ) else: training_args = get_training_args( log_dir, batch_eval_metrics = False, max_steps = 10000, save_steps = 10000, eval_steps = 1000, warmup_steps = 1000, logging_steps = 100, eval_on_start = False, learning_rate = 2.5e-4, weight_decay = 0.01, ) return training_args param = Dimensions( mels=128, aud_ctx=1500, aud_head=4, aud_dims=512, aud_idx=4, vocab=40000, text_ctx=512, text_head=4, text_dims=512, text_idx=4, act="swish", debug={}, #{"encoder", "decoder", "residual", "rotary"}, debug prints for specific modules cross_attn=True, f0_rotary=True, features = ["spectrogram"], # ["spectrogram", "waveform", "pitch"] any combo and order matters ) sanity_check = False training_args = sanity(sanity_check) dataset_config = { "spectrogram": True, "waveforms": False, "pitch": False, "downsamples": False, "f0": True, # this needs to be true along with f0_rotary to pass f0 to rotary - Pitch passess as a feature "hilbert": False, "hop_length": 128, "fmin": 150, "fmax": 2000, "n_mels": 128, "n_fft": 1024, "sampling_rate": 16000, "pad_mode": "constant", "center": True, "power": 2.0, "window_fn": torch.hann_window, "mel_scale": "htk", "norm": None, "normalized": False} model = create_model(param) global global_model global_model = model metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5, tokenizer=tokenizer, model=model) print(f"{'Sanity check' if sanity_check else 'Training'} mode") train_dataset, test_dataset = prepare_datasets( tokenizer=tokenizer, token=token, sanity_check=sanity_check, dataset_config=dataset_config) trainer = Seq2SeqTrainer( args=training_args, model=model, train_dataset=train_dataset, eval_dataset=test_dataset, data_collator=DataCollator(tokenizer=tokenizer), compute_metrics=metrics_fn, ) trainer.train() if __name__ == "__main__": main()