diff --git "a/echoutils.py" "b/echoutils.py" new file mode 100644--- /dev/null +++ "b/echoutils.py" @@ -0,0 +1,3383 @@ +import torch +import os +import pyworld as pw +import numpy as np +import torchaudio +import torch.nn.functional as F +from datasets import load_dataset, Audio +from dataclasses import dataclass +from typing import Any, List, Dict +import math +import matplotlib.pyplot as plt +import torch.nn as nn +import torch.nn.init as init +from torch import Tensor +from typing import Optional, Union, Tuple +from torch.nn.functional import scaled_dot_product_attention +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +dtype = torch.float32 + +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + +def sinusoids(ctx, dims, max_tscale=10000): + assert dims % 2 == 0 + pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1) + tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32)) + scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0) + position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1) + positional_embedding = nn.Parameter(position, requires_grad=True) + return positional_embedding + +def get_activation(act: str) -> nn.Module: + act_map = { + "gelu": nn.GELU(), + "relu": nn.ReLU(), + "sigmoid": nn.Sigmoid(), + "tanh": nn.Tanh(), + "swish": nn.SiLU(), + "tanhshrink": nn.Tanhshrink(), + "softplus": nn.Softplus(), + "softshrink": nn.Softshrink(), + "leaky_relu": nn.LeakyReLU(), + "elu": nn.ELU() + } + return act_map.get(act, nn.GELU()) + +def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor: + q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12) + k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12) + qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2)) + qk_cosine = qk_cosine + mask + weights = F.softmax(qk_cosine, dim=-1) + out = torch.matmul(weights, v) + return out + +def taylor_softmax_2nd_order(x): + exp_approx = 1 + x + (x**2) / 2 + return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True) + +def taylor_softmax_approximation(x, order=2): + if order == 0: + return torch.ones_like(x) / x.size(-1) + elif order == 1: + numerator = 1 + x + elif order == 2: + numerator = 1 + x + 0.5 * x**2 + else: + raise NotImplementedError("Higher orders are not implemented yet.") + denominator = torch.sum(numerator, dim=-1, keepdim=True) + return numerator / denominator + +def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0): + dot_scores = torch.matmul(q, k.transpose(-1, -2)) + if rbf_ratio <= 0.0: + return dot_scores + q_norm = q.pow(2).sum(dim=-1, keepdim=True) + k_norm = k.pow(2).sum(dim=-1, keepdim=True) + qk = torch.matmul(q, k.transpose(-1, -2)) + dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk + rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2)) + return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores + +def sliding_window_mask(q_len, k_len, window, device): + # mask[i, j] = 1 if j in [i-window+1, i], else 0 + idxs = torch.arange(q_len, device=device).unsqueeze(1) + jdxs = torch.arange(k_len, device=device).unsqueeze(0) + mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs) + return mask.float() # shape: (q_len, k_len) + +def mask_win(text_ctx, aud_ctx): + mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0) + audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype)) + full_mask = torch.cat([mask, audio_mask], dim=-1) + return full_mask + +def maskc(ctx, device): + return torch.tril(torch.ones(ctx, ctx, device=device, dtype=dtype), diagonal=0) + +def create_attention_mask(batch_size, ctx, is_causal=True, padding_mask=None, device=None): + if is_causal: + mask = torch.triu(torch.ones((ctx, ctx), device=device), diagonal=0) + mask = mask.expand(batch_size, 1, ctx, ctx) + else: + mask = torch.zeros((batch_size, 1, ctx, ctx), device=device) + if padding_mask is not None: + padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).bool() + mask = (mask.bool() | (~padding_mask)).float() + return mask + +def calculate_attention(q, k, v, mask=None, temp=1.0): + scaled_q = q + if temp != 1.0 and temp > 0: + scaled_q = q * (1.0 / temp)**.5 + out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1) + return out + +def calculate_attentionb(q_norm, k_norm, v_iter, mask=None, temp=1.0): + d_k = q_norm.size(-1) + scores = torch.matmul(q_norm, k_norm.transpose(-2, -1)) / (torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) / temp) + if mask is not None: + scores = scores.masked_fill(mask == 0, float('-inf')) + attention_weights = F.softmax(scores, dim=-1) + output = torch.matmul(attention_weights, v_iter) + return output + +class LocalOut(nn.Module): + def __init__(self, dims: int, head: int): + super().__init__() + self.head_dim = dims // head + self.dims = dims + self.q_module = nn.Linear(self.head_dim, self.head_dim) + self.k_module = nn.Linear(self.head_dim, self.head_dim) + self.v_module = nn.Linear(self.head_dim, self.head_dim) + self.o_proj = nn.Linear(self.head_dim, self.head_dim) + + def _reshape_to_output(self, attn_output: Tensor) -> Tensor: + batch, _, ctx, _ = attn_output.shape + return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims) + +def qkv_init(dims, head): + head_dim = dims // head + q = nn.Linear(dims, dims) + k = nn.Linear(dims, dims) + v = nn.Linear(dims, dims) + o = nn.Linear(dims, dims) + lna = nn.LayerNorm(dims) + lnb = nn.LayerNorm(dims) + lnc = nn.LayerNorm(head_dim) + lnd = nn.LayerNorm(head_dim) + return q, k, v, o, lna, lnb, lnc, lnd + +def shape(dims, head, q, k, v): + batch_size = q.shape[0] + seq_len_q = q.shape[1] + seq_len_kv = k.shape[1] + head_dim = dims // head + + q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2) + k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2) + v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2) + return q, k, v + +def create_qkv(dims, head, q, k, v, x, xa): + head_dim = dims // head + scale = head_dim ** -0.25 + q = q(x) * scale + k = k(xa) * scale + v = v(xa) + batch, ctx, dims = x.shape + def _shape(tensor): + return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous() + return _shape(q), _shape(k), _shape(v) + +# def calculate_attention(q, k, v, mask=None, temp=1.0): +# scaled_q = q +# if temp != 1.0 and temp > 0: +# scaled_q = q * (1.0 / temp)**.5 +# out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1) +# return out + +# class LocalOut(nn.Module): +# def __init__(self, dims: int, head: int): +# super().__init__() +# head_dim = dims // head +# self.head_dim = head_dim +# self.query_module = nn.Linear(head_dim, head_dim) +# self.key_module = nn.Linear(head_dim, head_dim) +# self.value_module = nn.Linear(head_dim, head_dim) +# self.out_proj = nn.Linear(head_dim, head_dim) + +# def _reshape_to_output(self, x): +# return x + +# class attentiona(nn.Module): +# def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0): +# super(attentiona, self).__init__() +# self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head) +# self.dims = dims +# self.head = head +# self.head_dim = dims // head +# self.max_iter = max_iter +# self.threshold = nn.Parameter(torch.tensor(threshold)) +# self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True) +# self.factor = nn.Parameter(torch.tensor(factor)) +# self.lnc = nn.LayerNorm(self.head_dim, bias=False) +# self.lnd = nn.LayerNorm(self.head_dim, bias=False) +# self.attn_local = LocalOut(self.head_dim) + +# def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None): +# z = default(xa, x) +# q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z)) + +# iteration = 0 +# temp = self.temp.item() +# prev_out = torch.zeros_like(q) +# attn_out = torch.zeros_like(q) +# threshold = self.threshold.item() +# factor = self.factor.item() +# qcur = q + +# while iteration < self.max_iter: +# eff_span = min(qcur.shape[1], k.shape[1]) +# if xa is not None: +# eff_span = min(eff_span, xa.shape[1]) +# if eff_span == 0: +# break + +# qiter = qcur[:, :, :eff_span, :] +# kiter = k[:, :, :eff_span, :] +# viter = v[:, :, :eff_span, :] +# q = self.attn_local.query_module(qiter) +# k = self.attn_local.key_module(kiter) +# v = self.attn_local.value_module(viter) + +# iter_mask = None +# if mask is not None: +# if mask.dim() == 4: +# iter_mask = mask[:, :, :eff_span, :eff_span] +# elif mask.dim() == 2: +# iter_mask = mask[:eff_span, :eff_span] + +# attn_iter = calculate_attention( +# self.lnc(q), self.lnd(k), v, +# mask=iter_mask, temp=temp) + +# iter_out = torch.zeros_like(qcur) +# iter_out[:, :, :eff_span, :] = attn_iter +# diff = torch.abs(iter_out - prev_out).mean() +# dthresh = threshold + factor * diff +# if diff < dthresh and iteration > 0: +# attn_out = iter_out +# break + +# prev_out = iter_out.clone() +# qcur = qcur + iter_out +# attn_out = iter_out +# iteration += 1 +# temp += 0.005 + +# output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2) +# return self.o(output), None + + # def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor: + + # batch, ctx, dims = x.shape + # output = torch.zeros_like(x) + # num_win = (ctx + win_size - 1) // win_size + + # for i in range(num_win): + # qstart = i * win_size + # qend = min(qstart + win_size, ctx) + # win_qlen = qend - qstart + # if win_qlen == 0: + # continue + + # kstart = max(0, qend - span_len) + # kend = qend + # qwin = x[:, qstart:qend, :] + # kwin = x[:, kstart:kend, :] + + # win_mask = None + # if mask is not None: + # if mask.dim() == 4: + # win_mask = mask[:, :, qstart:qend, kstart:kend] + # elif mask.dim() == 2: + # win_mask = mask[qstart:qend, kstart:kend] + + # attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask) + # output[:, qstart:qend, :] = attn_out + # return output + + # def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, + # use_sliding_win: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor: + # if use_sliding_win: + # return self._slide_win_local(x, win_size, span_len, mask) + # else: + # output, _ = self._focus(x, xa, mask) + # return output + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val # pyright: ignore[reportIndexIssue] + v_out[:, :, input_pos] = v_val # pyright: ignore[reportIndexIssue] + + return k_out, v_out + +def mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + +def mel_scale(freq: Tensor) -> Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + +def trace_x(func): + def wrapper(*args, **kwargs): + print(f"Calling {func.__name__}") + result = func(*args, **kwargs) + if isinstance(result, torch.Tensor): + print(f" {func.__name__} returned shape: {result.shape}") + return result + return wrapper + +def track_x(new_x, operation=""): + """ track_x(x, "x") """ + x_id = [id(new_x)] + if new_x is None: + return new_x + current_id = id(new_x) + if current_id != x_id[0]: + print(f"x FLOW: {x_id[0]} → {current_id} in {operation}") + x_id[0] = current_id + else: + print(f"x REUSE: {current_id} in {operation}") + return new_x + +def track_xa(new_xa, operation=""): + """ track_xa(xa, "xa - decoder") """ + xa_id = [id(new_xa)] if new_xa is not None else [None] + if new_xa is None: + return new_xa + current_id = id(new_xa) + if current_id != xa_id[0]: + print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}") + xa_id[0] = current_id # pyright: ignore[reportArgumentType, reportCallIssue] + else: + print(f"xa REUSE: {current_id} in {operation}") + return new_xa + +def get_activation(act: str) -> nn.Module: + """Get activation function by name.""" + act_map = { + "gelu": nn.GELU(), + "relu": nn.ReLU(), + "sigmoid": nn.Sigmoid(), + "tanh": nn.Tanh(), + "swish": nn.SiLU(), + "tanhshrink": nn.Tanhshrink(), + "softplus": nn.Softplus(), + "softshrink": nn.Softshrink(), + "leaky_relu": nn.LeakyReLU(), + "elu": nn.ELU() + } + return act_map.get(act, nn.GELU()) + +def get_generation_config(param): + return GenerationConfig( # type: ignore + max_length=param.text_ctx, + pad_token_id=getattr(param, "pad_token_id", 0), + bos_token_id=getattr(param, "bos_token_id", 1), + eos_token_id=getattr(param, "eos_token_id", 2), + do_sample=False, + num_beams=1, + early_stopping=False, + length_penalty=1.0, + no_repeat_ngram_size=0, + repetition_penalty=1.0, + temperature=1.0, + decoder_start_token_id=1, + is_multilingual=False, + use_cache=False, + return_timestamps=False) + +class feature_encoder(nn.Module): + def __init__(self, mels, input_dims, dims, head, layer, act, features, feature=None, use_rope=False, spec_shape=None, debug=[], attend_feature=False, target_length=None): + """ + Feature encoder for audio processing. + """ + super().__init__() + + self.dims = dims + self.head = head + self.head_dim = dims // head + self.dropout = 0.01 + self.use_rope = use_rope + self.attend_feature = attend_feature + self.target_length = target_length + self.feature = feature + + self.debug = debug + act_fn = get_activation(act) + + if self.attend_feature: + # self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head) + self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims)) + else: + self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None + self.mlp = None + + self.spectrogram = nn.Sequential( + Conv1d(mels, dims, kernel_size=3), act_fn, + Conv1d(dims, dims, kernel_size=3), act_fn, + Conv1d(dims, dims, kernel_size=3, groups=dims), act_fn) + + self.waveform = nn.Sequential( + Conv1d(1, dims//4, kernel_size=15, stride=4, padding=7), act_fn, + Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn, + Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn) + + self.pitch = nn.Sequential( + Conv1d(1, dims, kernel_size=7, stride=1, padding=3), act_fn, + Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn, + Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) + + if use_rope: + # if spec_shape is not None: + self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) + self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore + else: + self.rope = None # type: ignore + self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) + self.norm = RMSNorm(dims) + + def rope(self, x, xa=None, mask=None, feats=None, feature=None, layer=None): + if isinstance(x, int): + ctx = x + elif isinstance(x, torch.Tensor): + ctx = x.shape[1] if x.dim() > 1 else x.shape[0] + batch, ctx, dims = x.shape[0], ctx, x.shape[-1] + + x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) + freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) + x = self.rope.apply_rotary(x, freqs) # pyright: ignore[reportOptionalSubscript, reportAttributeAccessIssue] + x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) + return x + + def mel_scalar(self, freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + def forward(self, x, xa=None, mask=None, feats=None, feature=None, layer=None, max_tscale=36000): + target_length = x.shape[1] if self.target_length is None else self.target_length + + if feature == "pitch": + xp = x.clone() + enc_dict = feats if feats is not None else {} + enc_dict = dict(enc_dict) + enc_dict["f0"] = xp + # xp = self.mel_scalar(xp.mean()) + # print(f"Using pitch scalar: {xp}") + # max_tscale = xp*300 + # print(f"Using max_tscale: {max_tscale}") + feats = enc_dict + if x.dim() == 2: + x = x.unsqueeze(0) + x = self.pitch(x).permute(0, 2, 1) + + if feature == "phase": + if x.dim() == 2: + x = x.unsqueeze(0) + x = self.pitch(x).permute(0, 2, 1) + + if feature == "waveform": + if x.dim() == 2: + x = x.unsqueeze(0) + x = self.waveform(x).permute(0, 2, 1) + if target_length and x.shape[1] != self.target_length: + x = F.adaptive_avg_pool1d(x.transpose(1, 2), target_length).transpose(1, 2) + + if feature == "harmonics": + if x.dim() == 2: + x = x.unsqueeze(0) + x = self.spectrogram(x).permute(0, 2, 1) + + if feature == "aperiodic": + if x.dim() == 2: + x = x.unsqueeze(0) + x = self.spectrogram(x).permute(0, 2, 1) + + if feature == "spectrogram": + if x.dim() == 2: + x = x.unsqueeze(0) + x = self.spectrogram(x).permute(0, 2, 1) + + if self.use_rope: + x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype) + x = self.rope(x=x, xa=None, mask=None, feats=feats, feature=feature, layer=layer) + else: + max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale + x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype) + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + x = self.norm(x) + + # if self.attend_feature: + # xa = feats[feature] # pyright: ignore[reportOptionalSubscript] + # if xa is not None: + # q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head) + # out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True) + # x = x + out + + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + x = self.norm(x) + return x + +class OneShot(nn.Module): + def __init__(self, dims: int, head: int, scale: float = 0.3, features: Optional[List[str]] = None): + super().__init__() + if features is None: + features = ["spectrogram", "waveform", "pitch", "aperiodic", "harmonics"] + self.head = head + self.head_dim = dims // head + self.scale = 1.0 // len(features) if features else scale + + self.q = Linear(dims, dims) + self.k = Linear(dims, dims) + + def forward(self, x: Tensor, xa: Tensor, feature=None) -> Tensor | None: + B, L, D = x.shape + K = xa.size(1) + q = self.q(x).view(B, L, self.head, self.head_dim).transpose(1,2) + k = self.k(xa).view(B, K, self.head, self.head_dim).transpose(1,2) + bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.head_dim) + return bias + +class curiosity(nn.Module): + def __init__(self, d, h, bias=True): + super().__init__() + self.h = h + self.dh = d // h + self.qkv = nn.Linear(d, d * 3, bias=bias) + self.qkv_aux = nn.Linear(d, d * 3, bias=bias) + self.o = nn.Linear(d, d, bias=bias) + self.g = nn.Parameter(torch.zeros(h)) + + def split(self, x): + b, t, _ = x.shape + return x.view(b, t, self.h, self.dh).transpose(1, 2) + + def merge(self, x): + b, h, t, dh = x.shape + return x.transpose(1, 2).contiguous().view(b, t, h * dh) + + def forward(self, x, xa, mask=None): + q, k, v = self.qkv(x).chunk(3, -1) + qa, ka, va = self.qkv_aux(xa).chunk(3, -1) + q, k, v = map(self.split, (q, k, v)) + qa, ka, va = map(self.split, (qa, ka, va)) + dots = (q @ k.transpose(-2, -1)) / self.dh**0.5 + dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5 + if mask is not None: dots = dots.masked_fill(mask, -9e15) + p = dots.softmax(-1) + pa = dots_aux.softmax(-1) + h_main = p @ v + h_aux = pa @ va + g = torch.sigmoid(self.g).view(1, -1, 1, 1) + out = self.merge(h_main * (1 - g) + h_aux * g) + return self.o(out) + +class PositionalEncoding(nn.Module): + def __init__(self, dims, ctx): + super(PositionalEncoding, self).__init__() + self.dims = dims + self.ctx = ctx + self.pe = self.get_positional_encoding(max_ctx=ctx) + + def get_positional_encoding(self, max_ctx): + pe = torch.zeros(max_ctx, self.dims) + position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.dims, 2, dtype=torch.float32) + * (-math.log(10000.0) / self.dims) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + return pe.to(device) + + def forward(self, x): + ctx = x.size(1) + pe = self.pe[:, :ctx, :] + x = x * math.sqrt(self.dims) + x = x + pe + return x + +def valid(default_value, *items): + for item in items: + if item is not None: + return item + return default_value + +def dict_to(d, device, dtype=dtype): + return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v + for k, v in d.items()} + +def exists(v): + return v is not None + +def default(v, b): + return v if exists(v) else b + +class Conv1d(nn.Conv1d): + def _conv_forward( + self, x: Tensor, weight: Tensor, bias) -> Tensor: + return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype)) + +class Conv2d(nn.Conv2d): + def _conv_forward( + self, x: Tensor, weight: Tensor, bias) -> Tensor: + return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype)) + +class Linear(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: + super(Linear, self).__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + init.xavier_uniform_(self.linear.weight) + if bias: + init.zeros_(self.linear.bias) + def forward(self, x: Tensor) -> Tensor: + return self.linear(x) + +class RMSNorm(nn.Module): + def __init__(self, dims: Union[int, Tensor, List, Tuple], + eps = 1e-8, elementwise_affine = True): + super(RMSNorm, self).__init__() + if isinstance(dims, int): + self.normalized_shape = (dims,) + else: + self.normalized_shape = tuple(dims) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore + init.ones_(self.weight) + else: + self.register_parameter("weight", None) + def forward(self, x): + return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore + +def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple], + weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, + eps: float = 1e-5) -> Tensor: + return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore + +def get_device(): + return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +def get_dtype(): + return torch.float32 if torch.cuda.is_available() else torch.float64 + +def tox(): + return {"device": get_device(), "dtype": get_dtype()} + +def sinusoids(ctx, dims, max_tscale=10000): + assert dims % 2 == 0 + pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1) + tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32)) + scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0) + position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1) + positional_embedding = nn.Parameter(position, requires_grad=True) + return positional_embedding + +class SelfCriticalRL(nn.Module): + def __init__(self, model, tokenizer, reward_fn): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.reward_fn = reward_fn + + def forward(self, input_ids, features, labels=None, max_len=128, feature_name="spectrogram"): + + with torch.no_grad(): + greedy_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len) + greedy_text = [self.tokenizer.decode(ids) for ids in greedy_ids] + sampled_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len, do_sample=True, top_k=5) + sampled_text = [self.tokenizer.decode(ids) for ids in sampled_ids] + + rewards = [] + baseline = [] + for s, g, ref in zip(sampled_text, greedy_text, labels): # type: ignore + ref_text = self.tokenizer.decode(ref) + rewards.append(self.reward_fn(s, ref_text)) + baseline.append(self.reward_fn(g, ref_text)) + rewards = torch.tensor(rewards, device=device, dtype=torch.float) + baseline = torch.tensor(baseline, device=device, dtype=torch.float) + advantage = rewards - baseline + logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"] # logits: [batch, sampled_seq_len, vocab_size] + log_probs = F.log_softmax(logits, dim=-1) + log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1) + log_probs_sum = log_probs_seq.sum(dim=1) + loss = -(advantage * log_probs_sum).mean() + return loss + +class SelfTrainingModule(nn.Module): + def __init__(self, model, tokenizer, quality_fn=None, threshold=0.8): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.quality_fn = quality_fn + self.threshold = threshold + + def generate_pseudo_labels(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"): + with torch.no_grad(): + pred_ids = self.model.generate(input_ids=unlabeled_batch, **{feature_name: features}, max_length=max_len) + + if self.quality_fn is not None: + quality_scores = self.quality_fn(pred_ids, self.model, features) + mask = quality_scores > self.threshold + pred_ids = pred_ids[mask] + return pred_ids + + def forward(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"): + pseudo_labels = self.generate_pseudo_labels(unlabeled_batch, features, max_len, feature_name=feature_name) + logits = self.model(input_ids=unlabeled_batch, **{feature_name: features}, labels=pseudo_labels)["logits"] + loss = nn.functional.cross_entropy( + logits.view(-1, logits.shape[-1]), pseudo_labels.view(-1), ignore_index=0) + return loss + +def confidence_indicator(pred_ids, model, features): + with torch.no_grad(): + logits = model(input_ids=pred_ids, **features)["logits"] + probs = torch.softmax(logits, dim=-1) + max_probs, _ = probs.max(dim=-1) + return max_probs.mean(dim=1) + +def wer_reward(hyp, ref): + hyp_words = hyp.split() + ref_words = ref.split() + d = [[0] * (len(ref_words)+1) for _ in range(len(hyp_words)+1)] + for i in range(len(hyp_words)+1): + d[i][0] = i + for j in range(len(ref_words)+1): + d[0][j] = j + for i in range(1, len(hyp_words)+1): + for j in range(1, len(ref_words)+1): + if hyp_words[i-1] == ref_words[j-1]: + d[i][j] = d[i-1][j-1] + else: + d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1]) + wer = d[-1][-1] / max(1, len(ref_words)) + return -wer # negative WER as reward + +def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): + if isinstance(ids, torch.Tensor): + ids = ids.tolist() + return [int(id) for id in ids if id != -100 and id != pad_token_id and id != bos_token_id and id != eos_token_id] + +def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): + return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids] + +def setup_tokenizer(dir: str): + from tokenizers import Tokenizer + tokenizer = Tokenizer.from_file(f"{dir}") + orig_encode = tokenizer.encode + orig_decode = tokenizer.decode + + def enc(text, add_special_tokens=True): + ids = orig_encode(text).ids + if not add_special_tokens: + sp_ids = [tokenizer.token_to_id(t) for t in ["", "", ""]] + ids = [id for id in ids if id not in sp_ids] + return ids + + def bdec(ids_list, pad_token_id=0, bos_token_id=1, eos_token_id=2, skip_special_tokens=True): + results = [] + if isinstance(ids_list, torch.Tensor): + ids_list = ids_list.tolist() + elif isinstance(ids_list, np.ndarray): + ids_list = ids_list.tolist() + for ids in ids_list: + ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)] + results.append(orig_decode(ids)) + return results + + def dec(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): + ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)] + return orig_decode(ids) + + def save_pretrained(save_dir): + os.makedirs(save_dir, exist_ok=True) + tokenizer.save(f"{save_dir}/tokenizer.json") + + tokenizer.encode = enc + tokenizer.batch_decode = bdec + tokenizer.decode = dec + tokenizer.save_pretrained = save_pretrained + tokenizer.pad_token_id = 0 + tokenizer.bos_token_id = 1 + tokenizer.eos_token_id = 2 + return tokenizer + +def tokenize_pitch(pitch_features, target_length): + pitch_len = pitch_features.shape[-1] + token_len = target_length + if pitch_len > token_len: + pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len) + else: + pitch_tokens = F.interpolate(pitch_features, token_len) + return pitch_tokens + +def load_wave(wave_data, sample_rate=16000): + + if isinstance(wave_data, str): + waveform, sample_rate = torchaudio.load(uri=wave_data, normalize=False) + elif isinstance(wave_data, dict): + waveform = torch.tensor(data=wave_data["array"]).float() + sample_rate = wave_data["sampling_rate"] # noqa: F841 + else: + raise TypeError("Invalid wave_data format.") + return waveform + +def world_to_mel(sp, ap, sample_rate=16000, n_mels=128): + import librosa + mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels) + mel_basis = torch.from_numpy(mel_basis).float() + sp_mel = torch.matmul(sp, mel_basis.T) # (frames, 128) + ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128) + return sp_mel, ap_mel + +def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False): + + # import torch + # import torchaudio + # import torchaudio.functional as F + # import torchaudio.transforms as T + + torch_windows = { + 'hann': torch.hann_window, + 'hamming': torch.hamming_window, + 'blackman': torch.blackman_window, + 'bartlett': torch.bartlett_window, + 'ones': torch.ones, + None: torch.ones, + } + if dummy: + return { + "spectrogram": torch.zeros((1, 128, 100)), + "f0": torch.zeros((1, 100)), + "pitch_tokens": torch.zeros((1, 100)), + "pitch": torch.zeros((1, 100)), + "harmonics": torch.zeros((1, 128, 100)), + "aperiodics": torch.zeros((1, 128, 100)), + "crepe_time": None, + "crepe_frequency": None, + "crepe_confidence": None, + "crepe_activation": None, + } + + audio = batch["audio"] + sample_rate = audio["sampling_rate"] + # audio_length = len(audio["array"]) / audio["sampling_rate"] + labels = tokenizer.encode(batch["transcription"]) + # sentence_length = len(batch["transcription"]) + wav = load_wave(wave_data=audio, sample_rate=sample_rate) + + def crepe_predict(wav, sample_rate, viterbi=False): + import torchcrepe + wav = wav.numpy().astype(np.float32) + time, frequency, confidence, activation = torchcrepe.predict( + wav, sample_rate=sample_rate, viterbi=viterbi) + crepe_time = torch.from_numpy(time) + crepe_frequency = torch.from_numpy(frequency) + crepe_confidence = torch.from_numpy(confidence) + crepe_activation = torch.from_numpy(activation) + return crepe_time, crepe_frequency, crepe_confidence, crepe_activation + + if crepe: + crepe_time, crepe_frequency, crepe_confidence, crepe_activation = crepe_predict(wav, sample_rate, viterbi=True) + + else: + crepe_time = None + crepe_frequency = None + crepe_confidence = None + crepe_activation = None + + def spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window): + if isinstance(window_fn, str): + window_fn = torch_windows[window_fn] + if window_fn is None: + window_fn = torch.ones(n_fft) + if isinstance(window_fn, torch.Tensor): + window_fn = window_fn.to(device) + return torchaudio.functional.spectrogram( + wav, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, + window=window_fn, center=True, pad_mode="reflect", power=1.0) + + def mel_spectrogram(wav, sample_rate): + + spectrogram_config = { + "hop_length": 256, + "f_min": 150, + "f_max": 2000, + "n_mels": 128, + "n_fft": 1024, + "sample_rate": 16000, + "pad_mode": "constant", + "center": True, + "power": 1.0, + "window_fn": torch.hann_window, + "mel_scale": "htk", + "norm": None, + "normalized": False, + } + transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config) + mel_spectrogram = transform(wav) + log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10() + log_mel = torch.maximum(log_mel, log_mel.max() - 8.0) + spectrogram_tensor = (log_mel + 4.0) / 4.0 + return spectrogram_tensor + + if spec: + spectrogram_tensor = mel_spectrogram(wav, sample_rate) + + def mfcc(wav, sample_rate, n_mels=128, n_fft=1024, hop_length=256, window_fn=torch.hann_window): + transform = torchaudio.transforms.MFCC( + sample_rate=sample_rate, + n_mfcc=n_mels, + melkwargs={ + "n_fft": n_fft, + "hop_length": hop_length, + "window_fn": window_fn, + "n_mels": n_mels, + "center": True, + "pad_mode": "reflect", + "norm": None, + "mel_scale": "htk", + } + ) + mfcc_tensor = transform(wav) + return mfcc_tensor + + # def compute_pitch(wav, sample_rate, hop_length=256): + # # pitch = F.detect_pitch_frequency(wav, sample_rate) + # # f0 = pitch + # import pyworld as pw + # wav_np = wav.numpy().astype(np.float64) + # f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length / sample_rate * 1000) + # f0 = pw.stonemask(wav_np, f0, t, sample_rate) + # return f0, t + + def harmonics_and_aperiodics(wav, f0, t, sample_rate): + import pyworld as pw + wav_np = wav.numpy().astype(np.float64) + sp = pw.cheaptrick(wav_np, f0, t, sample_rate, fft_size=256) + ap = pw.d4c(wav_np, f0, t, sample_rate, fft_size=256) + harmonic_tensor = torch.from_numpy(sp) + aperiodic_tensor = torch.from_numpy(ap) + harmonic_tensor = harmonic_tensor[:, :128].contiguous().T + aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T + harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0) + aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0) + return harmonic_tensor, aperiodic_tensor + + if pitch or pitch_tokens or harmonics or aperiodics: + wavnp = wav.numpy().astype(np.float64) + f0_np, t = pw.dio(wavnp, sample_rate, frame_period=hop_length / sample_rate * 1000) + f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate) + + if pitch_tokens: + wav = torch.from_numpy(wavnp) + t2 = torch.from_numpy(t) + audio_duration = len(wav) / sample_rate + T = len(labels) + tok_dur_sec = audio_duration / T + token_starts = torch.arange(T) * tok_dur_sec + token_ends = token_starts + tok_dur_sec + start_idx = torch.searchsorted(t2, token_starts, side="left") + end_idx = torch.searchsorted(t2, token_ends, side="right") + pitch_tok = torch.zeros(T, dtype=torch.float32) + for i in range(T): + lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i]) # type: ignore + segment = f0_np[lo:hi] + if mode == "mean": + pitch_tok[i] = segment.mean() + elif mode == "median": + pitch_tok[i] = torch.median(segment) + else: + pitch_tok[i] = segment[-1] + pitch_tok[pitch_tok < 100.0] = 0.0 + bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0 + pitch_tokens_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok]) + pitch_tokens_tensor = torch.where(pitch_tokens_tensor == 0.0, torch.zeros_like(pitch_tokens_tensor), (pitch_tokens_tensor - 71.0) / (500.0 - 71.0)) + else: + pitch_tokens_tensor = None + + if phase_mod: + tframe = torch.mean(t2[1:] - t2[:-1]) + phi0 = 0.0 + omega = 2 * torch.pi * f0_tensor # type: ignore + dphi = omega * tframe + phi = torch.cumsum(dphi, dim=0) + phi0 + phase = torch.remainder(phi, 2 * torch.pi) + else: + phase = None + + if pitch: + p_tensor = torchaudio.functional.detect_pitch_frequency(wav, sample_rate) + # p_tensor = torch.from_numpy(f0_np) + # p_tensor = p_tensor.unsqueeze(0) + else: + p_tensor = None + + if harmonics or aperiodics: + spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256) + apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256) + harmonic_tensor = torch.from_numpy(spnp) + aperiodic_tensor = torch.from_numpy(apnp) + harmonic_tensor = harmonic_tensor[:, :128].contiguous().T + aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T + harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0) + aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0) + else: + harmonic_tensor = None + aperiodic_tensor = None + + if waveform: + wave_tensor = wav + else: + wave_tensor = None + + if dummy: + if spectrogram_tensor is not None: + dummy_tensor = torch.ones_like(spectrogram_tensor) + elif p_tensor is not None: + dummy_tensor = torch.ones_like(p_tensor) + elif pitch_tokens_tensor is not None: + dummy_tensor = torch.ones_like(pitch_tokens_tensor) + else: + batch_size = 128 + seq_len = 1024 + dummy_tensor = torch.ones(batch_size, seq_len) + dummy_tensor = dummy_tensor.to(device) + else: + dummy_tensor = None + if debug: + print(f"['pitch_tokens']: {pitch_tokens_tensor.shape if pitch_tokens else None}") + print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}") + print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}") + print(f"['spectrogram']: {spectrogram_tensor.shape if spec else None}") + print(f"['waveform']: {wave_tensor.shape if waveform else None}") + print(f"['labels']: {len(labels) if labels else None}") + print(f"['phase']: {phase.shape if phase else None}") + print(f"['pitch']: {p_tensor.shape if pitch else None}") + print(f"['crepe_time']: {crepe_time.shape if crepe else None}") + print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}") + print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}") + print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}") + print(f"['dummy']: {dummy_tensor.shape if dummy else None}") + + return { + "waveform": wave_tensor if waveform else None, + "spectrogram": spectrogram_tensor if spec else None, + "pitch_tokens": pitch_tokens_tensor if pitch_tokens else None, + "pitch": p_tensor if pitch else None, + "harmonic": harmonic_tensor if harmonics else None, + "aperiodic": aperiodic_tensor if aperiodics else None, + "labels": labels, + "phase": phase if phase_mod else None, + "crepe_time": crepe_time if crepe else None, + "crepe_frequency": crepe_frequency if crepe else None, + "crepe_confidence": crepe_confidence if crepe else None, + "crepe_activation": crepe_activation if crepe else None, + "dummy": dummy_tensor if dummy else None, + } + +# class PEncoder(nn.Module): # pitch encoder +# def __init__(self, dims: int, head: int, layer: int, kernel_size: int, act: str, +# max_seq_len: int, input_dims: int = 1, use_rope=False): +# super().__init__() + +# self.head = head +# self.head_dim = dims // head +# self.dims = dims +# self.use_rope=use_rope +# self.dropout_rate = 0.01 +# act_fn = get_activation(act) + +# self.positional_encoding = nn.Parameter(torch.randn(1, max_seq_len, dims)) + +# if use_rope: +# self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore +# else: +# self.rope = None +# self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) + +# self.attend_pitch = False +# if self.attend_pitch: +# self.mlp = nn.Sequential( +# nn.Linear(dims, dims), +# nn.ReLU(), +# nn.Linear(dims, dims), +# ) +# else: +# self.mlp = None + +# self.pitch_encoder = nn.Sequential( +# nn.Conv1d(input_dims, dims, kernel_size=kernel_size, stride=1, padding=kernel_size // 2), act_fn, +# nn.Conv1d(dims, dims, kernel_size=kernel_size - 2, stride=1, padding=(kernel_size - 2) // 2), act_fn, +# nn.Conv1d(dims, dims, kernel_size=kernel_size - 4, stride=1, padding=(kernel_size - 4) // 2, groups=dims), act_fn +# ) + +# def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"): +# batch, ctx, dims = x.shape +# x = x.view(batch, ctx, head, self.head_dim).permute(0, 2, 1, 3) +# freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore +# x = self.rope.apply_rotary(x, freqs)# type: ignore +# x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) +# return x + +# self.norm = nn.LayerNorm(dims) + +# def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, +# feats: Optional[Any] = None, feature: str = "pitch", layer: str = "PEncoder", +# audio_duration: Optional[float] = None, sample_rate: Optional[int] = None, +# labels_len: Optional[int] = None, f0_np: Optional[np.ndarray] = None, +# t2_np: Optional[np.ndarray] = None, mode: str = "mean") -> Tensor: + +# if x.dim() == 2 and feature == "pitch": +# x_processed = x.unsqueeze(1) # Input to pitch_encoder: (batch_size, 1, num_pitch_tokens) +# x_processed = self.pitch_encoder(x_processed) # Output: (batch_size, dims, num_pitch_tokens) +# x = x_processed.permute(0, 2, 1) # Reassign to x for consistency + +# if self.use_rope: +# pass # Placeholder for RoPE application + +# seq_len = x.shape[1] +# # x = x + self.positional_encoding[:, :seq_len, :] +# x = x + sinusoids(x.shape[1], x.shape[-1], 36000).to(device, dtype) + +# if self.mlp is not None: +# x = self.mlp(x) + +# x = nn.functional.dropout(x, p=self.dropout_rate, training=self.training) +# x = self.norm(x) + +# return x + +def plot_waveform(waveform, sr, title="Waveform", ax=None): + waveform = waveform.numpy() + + num_channels, num_frames = waveform.shape + time_axis = torch.arange(0, num_frames) / sr + + if ax is None: + _, ax = plt.subplots(num_channels, 1) + ax.plot(time_axis, waveform[0], linewidth=1) + ax.grid(True) + ax.set_xlim([0, time_axis[-1]]) + ax.set_title(title) + +def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): + import librosa + if ax is None: + _, ax = plt.subplots(1, 1) + if title is not None: + ax.set_title(title) + ax.set_ylabel(ylabel) + ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest") + +def plot_fbank(fbank, title=None): + fig, axs = plt.subplots(1, 1) + axs.set_title(title or "Filter bank") + axs.imshow(fbank, aspect="auto") + axs.set_ylabel("frequency bin") + axs.set_xlabel("mel bin") + +def plot_pitch(waveform, sr, pitch): + figure, axis = plt.subplots(1, 1) + axis.set_title("Pitch Feature") + axis.grid(True) + + end_time = waveform.shape[1] / sr + time_axis = torch.linspace(0, end_time, waveform.shape[1]) + axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3) + + axis2 = axis.twinx() + time_axis = torch.linspace(0, end_time, pitch.shape[1]) + axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green") + + axis2.legend(loc=0) + +def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False, + load_saved=False, save_dataset=False, cache_dir=None, extract_args=None, max_ctx=2048): + + if extract_args is None: + extract_args = { + "waveform": False, + "spec": False, + "f0": False, + "pitch_tokens": False, + "pitch": False, + "harmonic": False, + "aperiodic": False, + "sample_rate": 16000, + "hop_length": 256, + "mode": "mean", + "debug": False, + "phase_mod": False, + "crepe": False, + "dummy": False, + } + + if load_saved: + if cache_dir is None: + cache_dir = "./processed_datasets" + else: + cache_dir = cache_dir + + os.makedirs(cache_dir, exist_ok=True) + cache_file_train = os.path.join(cache_dir, "train.arrow") + cache_file_test = os.path.join(cache_dir, "test.arrow") + + if os.path.exists(cache_file_train) and os.path.exists(cache_file_test): + from datasets import Dataset + train_dataset = Dataset.load_from_disk(cache_file_train) + test_dataset = Dataset.load_from_disk(cache_file_test) + return train_dataset, test_dataset + + if sanity_check: + test = load_dataset( + "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).cast_column("audio", Audio(sampling_rate=sample_rate)).take(1) + dataset = test.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=test.column_names) + train_dataset = dataset + test_dataset = dataset + return train_dataset, test_dataset + else: + def filter_func(x): + return (0 < len(x["transcription"]) < max_ctx and + len(x["audio"]["array"]) > 0 and + len(x["audio"]["array"]) < max_ctx * 160) + + # raw_train = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="train", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription") + # raw_test = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="test", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription").take(1000) + + raw_train = load_dataset( + "google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000) + raw_test = load_dataset( + "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).take(100) + + raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate)) + raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate)) + train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_train.column_names) + test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_test.column_names) + train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None + test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None + return train_dataset, test_dataset + +class tgate(nn.Module): + def __init__(self, dims, num_types=4): + super().__init__() + self.gates = nn.ModuleList([nn.Sequential(Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)]) + self.classifier = nn.Sequential(Linear(dims, num_types), nn.Softmax(dim=-1)) + def forward(self, x): + types = self.classifier(x) + gates = torch.stack([gate(x) for gate in self.gates], dim=-1) + cgate = torch.sum(gates * types.unsqueeze(2), dim=-1) + return cgate + +def get_feature_encoder(feature: str, mels: int, input_dims: int, dims: int, head: int, layer: int, act=None, features=None) -> nn.Module: + if feature == "spectrogram": + return FEncoder(mels=mels, input_dims=input_dims, dims=dims, head=head) + elif feature == "waveform": + return WEncoder(input_dims, dims, head, layer, act, feature, features) + elif feature == "pitch": + return PEncoder(input_dims, dims, head, layer, act, feature, features) + else: + raise ValueError(f"Unknown feature type: {feature}") + +class FEncoder(nn.Module): + def __init__(self, mels, input_dims, dims, head, layer, act, feature, features, use_rope=False, spec_shape=None, debug=[]): + super().__init__() + + self.head = head + self.head_dim = dims // head + self.dropout = 0.01 + self.use_rope = use_rope + self.dims = dims + self.debug = debug + self.feature = feature + self.mels = mels + self.input_dims = input_dims + act_fn = get_activation(act) + + self.encoder = nn.Sequential( + Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn, + Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn, + Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) + + if use_rope: + if spec_shape is not None: + self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore + else: + self.rope = None + self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) + self.norm = RMSNorm(dims) + + def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"): + batch, ctx, dims = x.shape + x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) + freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore + x = self.rope.apply_rotary(x, freqs)# type: ignore + x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) + + return x + + def forward(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"): + x = self.encoder(x).permute(0, 2, 1) + if self.use_rope: + x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer) + else: + x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype) + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + print(f"feature encoder: {x.shape} {feature}") if "fencoder" in self.debug else None + x = self.norm(x) + return x + +class WEncoder(nn.Module): # waveform encoder + def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None): + super().__init__() + + self.head = head + self.head_dim = dims // head + self.dropout = 0.01 + self.use_rope = use_rope + self.dims = dims + self.debug = debug + act_fn = get_activation(act) + self.target_length = None + self.encoder = nn.Sequential( + Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn, + Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn, + Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn) + + if use_rope: + if spec_shape is not None: + self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore + else: + self.rope = None + self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) + self.norm = RMSNorm(dims) + + def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"): + batch, ctx, dims = x.shape + x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) + freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore + x = self.rope.apply_rotary(x, freqs)# type: ignore + x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) + return x + + def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"): + x = self.encoder(x).permute(0, 2, 1) # (batch, time, dims) + if self.target_length and x.shape[1] != self.target_length: + x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2) + if self.use_rope: + x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer) + else: + x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype) + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None + return self.norm(x) + +class PEncoder(nn.Module): # pitch encoder + def __init__(self, input_dims, dims, head, layer, act, use_rope=False, debug=[], one_shot=False, spec_shape=None): + super().__init__() + + self.head = head + self.head_dim = dims // head + self.dims = dims + self.dropout = 0.01 + self.use_rope = use_rope + self.debug = debug + act_fn = get_activation(act) + + self.attend_pitch = False + + if self.attend_pitch: + self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head) + self.mlp = nn.Sequential( + nn.Linear(dims, dims), + nn.ReLU(), + nn.Linear(dims, dims), + ) + else: + self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None + self.mlp = None + + self.pitch_encoder = nn.Sequential( + Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn, + Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn, + Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) + + if use_rope: + self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore + else: + self.rope = None + self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) + self.norm = RMSNorm(dims) + + def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"): + batch, ctx, dims = x.shape + x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) + freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore + x = self.rope.apply_rotary(x, freqs)# type: ignore + x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) + return x + + def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"): + # f0=x + # freqs = self.rope(f0.shape[1], feats=feats, feature=feature, layer=layer) + if x.dim() == 2: + x = x.unsqueeze(0) + # if feature == "pitch": + x = self.pitch_encoder(x).permute(0, 2, 1) + + if self.use_rope: + x = self.rope_to_feature(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer) + + x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype) + if self.mlp is not None: + x = self.mlp(x) + + if self.attend_pitch: + if xa is not None: + q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head) + out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True) + x = x + out + + # x = nn.functional.dropout(x, p=self.dropout, training=self.training) + x = self.norm(x) + print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None + return x + +@dataclass +class DataCollator: + tokenizer: Any + + def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + all_keys = set() + for f in features: + all_keys.update(f.keys()) + batch = {} + pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0) + bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1) + eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2) + + for key in all_keys: + if key == "labels": + labels_list = [f["labels"] for f in features] + max_len = max(len(l) for l in labels_list) # noqa: E741 + all_ids, all_labels = [], [] + + for label in labels_list: + label_list = label.tolist() if isinstance(label, torch.Tensor) else label + decoder_input = [bos_token_id] + label_list + label_eos = label_list + [eos_token_id] + input_len = max_len + 1 - len(decoder_input) + label_len = max_len + 1 - len(label_eos) + padded_input = decoder_input + [pad_token_id] * input_len + padded_labels = label_eos + [pad_token_id] * label_len + all_ids.append(padded_input) + all_labels.append(padded_labels) + batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long) + batch["labels"] = torch.tensor(all_labels, dtype=torch.long) + + elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "pitch_tokens", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation", "dummy"]: + + items = [f[key] for f in features if key in f] + items = [item for item in items if item is not None] + if not items: + continue + items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items] + max_len = max(item.shape[-1] for item in items) + padded = [] + for item in items: + pad_width = max_len - item.shape[-1] + if pad_width > 0: + pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id) + else: + pad_item = item + padded.append(pad_item) + batch[key] = torch.stack(padded) + # if key == "spectrogram": + # batch["spectrogram"] = batch[key] + return batch + +# import tiktoken +# import torch +# from torch.utils.data import Dataset, DataLoader + +# class tokenize(Dataset): +# def __init__(self, txt, tokenizer, max_length, stride): + +# self.input_ids = [] +# self.labels = [] + +# token_ids = tokenizer.encode(txt, allowed_special={""}) + +# for i in range(0, len(token_ids) - max_length, stride): + +# input_chunk = token_ids[i:i + max_length] +# target_chunk = token_ids[i + 1: i + max_length + 1] +# self.input_ids.append(torch.tensor(input_chunk)) +# self.labels.append(torch.tensor(target_chunk)) + +# def __len__(self): +# return len(self.input_ids) + +# def __getitem__(self, idx): +# return self.input_ids[idx], self.labels[idx] + +# def create_dataloader_v1(txt, batch_size, max_length, stride, shuffle=True, drop_last=True, num_workers=0): +# tokenizer = tiktoken.get_encoding("gpt2") +# dataset = tokenize(txt, tokenizer, max_length, stride) +# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers) +# return dataloader + +# def custom_collate_fn(batch, tokenizer_pad_token_id): +# max_len_in_batch = max(len(seq) for seq in batch) + +# padded_input_ids = [] +# attention_masks = [] +# for seq in batch: +# padded_seq = F.pad(seq, (0, max_len_in_batch - len(seq)), value=tokenizer_pad_token_id) +# attention_mask = torch.ones(max_len_in_batch) +# attention_mask[len(seq):] = 0 +# padded_input_ids.append(padded_seq) +# attention_masks.append(attention_mask) + +# input_ids_tensor = torch.stack(padded_input_ids) +# attention_mask_tensor = torch.stack(attention_masks) +# labels_tensor = input_ids_tensor.clone() + +# return { +# 'input_ids': input_ids_tensor, +# 'attention_mask': attention_mask_tensor, +# 'labels': labels_tensor +# } + +# with open("the-verdict.txt", "r", encoding="utf-8") as f: +# raw_text = f.read() + +# vocab_size = 50257 +# output_dim = 256 +# context_length = 1024 + +# token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim) +# pos_embedding_layer = torch.nn.Embedding(context_length, output_dim) + +# batch_size = 8 +# max_length = 4 +# dataloader = create_dataloader_v1( +# raw_text, +# batch_size=batch_size, +# max_length=max_length, +# stride=max_length +# ) + +def levenshtein(reference_words, hypothesis_words): + m, n = len(reference_words), len(hypothesis_words) + dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)] + for i in range(m+1): + dist_matrix[i][0] = i + for j in range(n+1): + dist_matrix[0][j] = j + for i in range(1, m+1): + for j in range(1, n+1): + if reference_words[i-1] == hypothesis_words[j-1]: + dist_matrix[i][j] = dist_matrix[i-1][j-1] + else: + substitution = dist_matrix[i-1][j-1] + 1 + insertion = dist_matrix[i][j-1] + 1 + deletion = dist_matrix[i-1][j] + 1 + dist_matrix[i][j] = min(substitution, insertion, deletion) + return dist_matrix[m][n] + +def wer_batch(references, hypotheses): + total_errors = 0 + total_words = 0 + for ref, hyp in zip(references, hypotheses): + ref_words = ref.lower().split() + errors = levenshtein(ref_words, hyp.lower().split()) + total_errors += errors + total_words += len(ref_words) + return (total_errors / total_words) * 100 if total_words > 0 else 0.0 + +def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, logits=None): + def clean(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): + + if isinstance(ids, torch.Tensor): + ids = ids.tolist() + if isinstance(ids[0], (list, torch.Tensor, np.ndarray)): + return [[int(i) for i in seq if i not in (-100, pad_token_id, bos_token_id, eos_token_id)] for seq in ids] + else: + return [int(i) for i in ids if i not in (-100, pad_token_id, bos_token_id, eos_token_id)] + + pred_ids = pred.predictions + label_ids = pred.label_ids + + if isinstance(pred_ids, tuple): + pred_ids = pred_ids[0] + + if not isinstance(pred_ids, torch.Tensor): + pred_ids = torch.tensor(pred_ids) + + label_ids = clean(label_ids) + pred_ids = clean(pred_ids) + pred_str = tokenizer.batch_decode(pred_ids) + label_str = tokenizer.batch_decode(label_ids) + + if print_pred: + for i in range(min(num_samples, len(pred_ids))): + + print(f"Pred tokens: {pred_ids[i]}") + print(f"Label tokens: {label_ids[i]}") + print(f"Pred: '{pred_str[i]}'") + print(f"Label: '{label_str[i]}'") + print("-" * 40) + + wer = wer_batch(label_str, pred_str) + if model is not None: + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000 + efficiency_score = (100 - wer) / trainable_params if trainable_params > 0 else 0.0 + else: + trainable_params = 0.0 + efficiency_score = 0.0 + + return { + "wer": float(wer), + "efficiency_score": float(efficiency_score), + } + +def preprocess_logits_for_metrics(logits, labels): + pred_ids = torch.argmax(logits, dim=-1) + return pred_ids, labels + +def hilbert_transform(x): + N = x.shape[-1] + xf = torch.fft.rfft(x) + h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype) + if N % 2 == 0: + h[0] = h[N//2] = 1 + h[1:N//2] = 2 + else: + h[0] = 1 + h[1:(N+1)//2] = 2 + return torch.fft.irfft(xf * h, n=N) + +def analytic_signal(x): + return x + 1j * hilbert_transform(x) + +def hilbert_transform_2d(x, dim=-1): + N = x.shape[dim] + if dim == -1 or dim == len(x.shape) - 1: + xf = torch.fft.rfft(x) + else: + xf = torch.fft.rfft(x, dim=dim) + h_shape = [1] * len(x.shape) + h_shape[dim] = N // 2 + 1 + h = torch.zeros(h_shape, device=x.device, dtype=x.dtype) + if dim == -1 or dim == len(x.shape) - 1: + if N % 2 == 0: + h[..., 0] = h[..., -1] = 1 + h[..., 1:-1] = 2 + else: + h[..., 0] = 1 + h[..., 1:] = 2 + else: + pass + return torch.fft.irfft(xf * h, n=N, dim=dim) + +def hilbert_transform_true_2d(x): + xf = torch.fft.rfft2(x) + h1, h2 = torch.meshgrid( + torch.fft.rfftfreq(x.shape[-2]) * 2 - 1, + torch.fft.rfftfreq(x.shape[-1]) * 2 - 1, + indexing='ij') + h = -1j / (math.pi * (h1 + 1j*h2)) + h[0, 0] = 0 + return torch.fft.irfft2(xf * h.to(x.device)) + +def process_spectrogram_with_hilbert(spec): + analytic = spec + 1j * hilbert_transform(spec) + envelope = torch.abs(analytic) + phase = torch.angle(analytic) + return envelope, phase + +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from torch import Tensor +# from typing import Optional, Tuple +# import numpy as np +# from torch.nn.functional import scaled_dot_product_attention +# from torch.cuda.amp import autocast +# from torch.nn import LayerNorm, Linear +# import logging + +# logging.basicConfig(level=logging.WARNING) +# log = logging.getLogger(__name__) + +# class ProjectionModule(nn.Module): +# """ +# Projects input embeddings into query, key, or value representations +# for multi-head attention, handling scaling for Q/K. +# """ +# def __init__(self, dims: int, head: int, proj_type: str = "query", use_bias: bool = True): +# """ +# Args: +# dims: Input and output dimension. +# head: Number of attention heads. +# proj_type: Type of projection ("query", "key", "value"). +# use_bias: Whether to use bias in the linear layer. +# """ +# super().__init__() +# assert dims % head == 0, f"dims ({dims}) must be divisible by head ({head})" +# assert proj_type in ["query", "key", "value"], \ +# f"proj_type must be 'query', 'key', or 'value', got {proj_type}" + +# self.dims = dims +# self.head = head +# self.head_dim = dims // head +# self.proj_type = proj_type +# self.scale = self.head_dim ** -0.5 if proj_type != "value" else 1.0 +# self.proj = Linear(in_features=dims, out_features=dims, bias=use_bias) +# self.init_weights() + +# def init_weights(self): +# """Initialize projection weights.""" +# nn.init.normal_(tensor=self.proj.weight, std=0.02) +# if self.proj.bias is not None: +# nn.init.zeros_(tensor=self.proj.bias) + +# def forward(self, x: Tensor) -> Tensor: +# """ +# Applies projection, scaling (for Q/K), and reshapes for multi-head attention. + +# Args: +# x: Input tensor of shape (batch, seq_len, dims). + +# Returns: +# Projected tensor of shape (batch, head, seq_len, head_dim). +# """ +# batch, seq_len, _ = x.shape +# proj = self.proj(x) + +# proj = proj.view(batch, seq_len, self.head, self.head_dim).permute(0, 2, 1, 3) + +# if self.proj_type in ["query", "key"]: +# proj = proj * self.scale +# return proj + +# def calculate_attention( +# q: Tensor, +# k: Tensor, +# v: Tensor, +# mask: Optional[Tensor] = None, +# temperature: float = 1.0, +# use_sdpa: bool = True, +# is_causal: bool = False, +# dropout_p: float = 0.0 +# ) -> Tuple[Tensor, Optional[Tensor]]: +# """ +# Calculates scaled dot-product attention. + +# Uses torch.nn.functional.scaled_dot_product_attention if use_sdpa is True +# and inputs are compatible, otherwise falls back to manual implementation. + +# Args: +# q: Query tensor (Batch, Heads, SeqLen_Q, HeadDim). Already scaled if needed. +# k: Key tensor (Batch, Heads, SeqLen_K, HeadDim). Already scaled if needed. +# v: Value tensor (Batch, Heads, SeqLen_K, HeadDim). +# mask: Attention mask. Can be boolean (True means ignore) or float (-inf means ignore). +# Shape should be broadcastable to (Batch, Heads, SeqLen_Q, SeqLen_K). +# temperature: Softmax temperature scaling. Applied *before* softmax. +# use_sdpa: Flag to attempt using PyTorch's optimized SDPA implementation. +# is_causal: If True, applies a causal mask (for decoder self-attention). +# Used only if mask is None and use_sdpa is True. +# dropout_p: Dropout probability for attention weights. + +# Returns: +# A tuple containing: +# - attn_output: Attention output tensor (Batch, Heads, SeqLen_Q, HeadDim). +# - attn_weights: Attention weights tensor (Batch, Heads, SeqLen_Q, SeqLen_K), +# or None if SDPA implementation doesn't return them or if fallback used. +# *Note: SDPA's default doesn't return weights, requires specific backend support.* +# *Manual path always returns weights.* +# """ +# batch_size, num_heads, q_len, head_dim = q.shape +# k_len = k.size(2) + +# temp_scale = 1.0 / temperature if temperature > 0 else 1.0 + +# attn_output, attn_weights = None, None + +# if use_sdpa: +# try: +# if temperature != 1.0: +# raise NotImplementedError("SDPA does not directly support temperature scaling. Use manual path or scale Q.") + +# attn_output = scaled_dot_product_attention( +# q, k, v, +# attn_mask=mask, +# dropout_p=dropout_p, +# is_causal=is_causal and mask is None +# ) +# attn_weights = None +# return attn_output, attn_weights +# except (RuntimeError, NotImplementedError) as e: +# log.warning(f"SDPA failed or not used ({e}), falling back to manual attention.") +# attn_scores = torch.matmul(q, k.transpose(-2, -1)) * temp_scale + +# if mask is not None: +# if mask.dim() == 2: +# mask = mask.unsqueeze(0).unsqueeze(0) +# elif mask.dim() == 3: +# mask = mask.unsqueeze(1) + +# expected_mask_shape = (batch_size, num_heads, q_len, k_len) +# if mask.shape != expected_mask_shape: +# try: +# mask = mask.expand(expected_mask_shape) +# except RuntimeError: +# raise ValueError(f"Mask shape {mask.shape} is not compatible with attention scores shape {expected_mask_shape}") + +# if mask.dtype == torch.bool: +# attn_scores = attn_scores.masked_fill(mask, float("-inf")) +# else: +# attn_scores = attn_scores + mask + +# attn_weights = F.softmax(attn_scores, dim=-1) + +# if dropout_p > 0.0: +# attn_weights = F.dropout(attn_weights, p=dropout_p) + +# attn_output = torch.matmul(attn_weights, v) + +# return attn_output, attn_weights + +# class BaseAttention(nn.Module): +# """Base class for attention mechanisms with common functionality.""" +# use_sdpa = True + +# def __init__(self, dims: int, head: int, max_dist: int = 512, dropout: float = 0.0): +# """ +# Args: +# dims: Embedding dimension. +# head: Number of attention heads. +# max_dist: Maximum attention distance (used by some subclasses). +# dropout: Dropout probability for attention weights. +# """ +# super().__init__() +# assert dims % head == 0, f"dims ({dims}) must be divisible by head ({head})" +# self.dims = dims +# self.head = head +# self.head_dim = dims // head +# self.max_dist = max_dist +# self.dropout = dropout + +# def _shape(self, tensor: torch.Tensor) -> torch.Tensor: +# """ +# Reshape tensor from (batch, seq_len, dims) to +# (batch, head, seq_len, head_dim) for multi-head attention. +# """ +# batch, seq_len, _ = tensor.shape +# return tensor.view(batch, seq_len, self.head, self.head_dim).transpose(1, 2).contiguous() + +# def _reshape_to_output(self, attn_output: Tensor) -> Tensor: +# """ +# Reshape attention output from (batch, head, seq_len, head_dim) +# back to (batch, seq_len, dims). +# """ +# batch, _, seq_len, _ = attn_output.shape +# return attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.dims) + +# class AttentionCombiner(BaseAttention): +# """ +# Computes attention given Q, K, V projections and applies an output projection. +# Assumes Q, K, V inputs are already projected and appropriately shaped/scaled. +# """ +# def __init__(self, dims: int, head: int, use_bias: bool = True, dropout: float = 0.0): +# """ +# Args: +# dims: Embedding dimension. +# head: Number of attention heads. +# use_bias: Whether to use bias in the output projection. +# dropout: Dropout probability for attention weights. +# """ +# super().__init__(dims, head, dropout=dropout) +# self.out = Linear(in_features=dims, out_features=dims, bias=use_bias) +# self._init_weights() + +# def _init_weights(self): +# """Initialize output projection weights.""" +# nn.init.normal_(tensor=self.out.weight, std=0.02) +# if self.out.bias is not None: +# nn.init.zeros_(tensor=self.out.bias) + +# # @autocast('cuda', enabled=torch.cuda.is_available()) +# def forward(self, q: Tensor, k: Tensor, v: Tensor, +# mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor: +# """ +# Processes Q, K, V through attention and output projection. + +# Args: +# q: Query tensor (Batch, Heads, SeqLen_Q, HeadDim). Assumed scaled. +# k: Key tensor (Batch, Heads, SeqLen_K, HeadDim). Assumed scaled. +# v: Value tensor (Batch, Heads, SeqLen_K, HeadDim). +# mask: Attention mask. +# is_causal: Whether to apply causal masking (if mask is None). + +# Returns: +# Output tensor (Batch, SeqLen_Q, Dims). +# """ +# attn_output, _ = calculate_attention( +# q, k, v, mask=mask, +# temperature=1.0, +# use_sdpa=BaseAttention.use_sdpa, +# is_causal=is_causal, +# dropout_p = self.dropout +# ) + +# output = self._reshape_to_output(attn_output) +# return self.out(output) + +# class AdaptiveUpdateAttention(BaseAttention): +# """ +# Attention implementation where Key and Value representations are cached +# and only updated based on content-dependent predictors. Suitable for +# encoder layers or cross-attention where K/V context changes less frequently. + +# Note: Current implementation focuses on conditional update based on *current* +# input, not standard auto-regressive KV caching for generation. +# """ +# def __init__(self, dims: int, head: int, max_dist: int = 512, update_threshold: float = 0.5, dropout: float = 0.0): +# """ +# Args: +# dims: Embedding dimension. +# head: Number of attention heads. +# max_dist: Maximum attention distance (inherited, may not be directly used here). +# update_threshold: Threshold for sigmoid output of predictors to trigger update. +# dropout: Dropout probability for attention weights. +# """ +# super().__init__(dims, head, max_dist, dropout=dropout) + +# self.query_module = ProjectionModule(dims, head, "query") +# self.key_module = ProjectionModule(dims, head, "key") +# self.value_module = ProjectionModule(dims, head, "value") +# self.combiner = AttentionCombiner(dims, head, dropout=dropout) + +# self.key_update_predictor = nn.Sequential( +# Linear(dims, dims // 4), nn.ReLU(), Linear(dims // 4, 1), nn.Sigmoid()) +# self.value_update_predictor = nn.Sequential( +# Linear(dims, dims // 4), nn.ReLU(), Linear(dims // 4, 1), nn.Sigmoid()) + +# self.update_threshold = update_threshold +# self.stored_key_cache: Optional[Tensor] = None +# self.stored_value_cache: Optional[Tensor] = None +# self.reset_cache_on_forward = True + +# def _should_update(self, x: torch.Tensor, predictor: nn.Module) -> torch.Tensor: +# """Predict whether K or V should be updated based on content.""" +# avg_rep = x.mean(dim=1) +# update_prob = predictor(avg_rep) +# return update_prob > self.update_threshold + +# def forward(self, x: Tensor, xa: Optional[Tensor] = None, +# mask: Optional[Tensor] = None, +# is_causal: bool = False) -> Tensor: +# """ +# Process inputs with adaptive K/V update mechanism. + +# Args: +# x: Input tensor for queries (Batch, SeqLen_Q, Dims). +# xa: Optional input tensor for keys/values (for cross-attention). +# If None, uses x for self-attention (Batch, SeqLen_KV, Dims). +# mask: Attention mask. +# is_causal: Whether attention should be causal. + +# Returns: +# Output tensor (Batch, SeqLen_Q, Dims). +# """ +# if self.reset_cache_on_forward: +# self.stored_key_cache = None +# self.stored_value_cache = None + +# batch, ctx_q, _ = x.shape +# q = self.query_module(x) + +# kv_input = xa if xa is not None else x +# ctx_kv = kv_input.size(1) + +# update_k_batch = self._should_update(kv_input, self.key_update_predictor) +# update_v_batch = self._should_update(kv_input, self.value_update_predictor) + +# if self.stored_key_cache is None or self.stored_key_cache.shape[2] != ctx_kv or self.stored_key_cache.shape[0] != batch: +# k = self.key_module(kv_input) +# self.stored_key_cache = k +# elif update_k_batch.any(): +# new_k = self.key_module(kv_input) +# update_mask_k = update_k_batch.view(-1, 1, 1, 1).expand_as(self.stored_key_cache) +# k = torch.where(update_mask_k, new_k, self.stored_key_cache) +# self.stored_key_cache = k +# else: +# k = self.stored_key_cache + +# if self.stored_value_cache is None or self.stored_value_cache.shape[2] != ctx_kv or self.stored_value_cache.shape[0] != batch: +# v = self.value_module(kv_input) +# self.stored_value_cache = v +# elif update_v_batch.any(): +# new_v = self.value_module(kv_input) +# update_mask_v = update_v_batch.view(-1, 1, 1, 1).expand_as(self.stored_value_cache) +# v = torch.where(update_mask_v, new_v, self.stored_value_cache) +# self.stored_value_cache = v +# else: +# v = self.stored_value_cache + +# output = self.combiner(q, k, v, mask=mask, is_causal=is_causal) +# return output + +# class Refiner: +# """ +# Q-learning based agent to refine parameters (e.g., attention span). +# Operates outside the standard backpropagation loop. +# """ +# def __init__(self, states: int, actions: int, alpha: float = 0.1, gamma: float = 0.9, epsilon: float = 0.1): +# self.states = states +# self.actions = actions +# self.R = {} +# self.alpha = alpha +# self.gamma = gamma +# self.epsilon = epsilon +# self.default_value = 0.0 + +# def get_value(self, state: int, action: int) -> float: +# """Get Q-value for state-action pair.""" +# return self.R.get((state, action), self.default_value) + +# def set_value(self, state: int, action: int, value: float): +# """Set Q-value for state-action pair.""" +# self.R[(state, action)] = value + +# def choose_action(self, state: int) -> int: +# """Choose action using epsilon-greedy strategy.""" +# if np.random.random() < self.epsilon: +# return np.random.randint(self.actions) +# else: +# action_values = [self.get_value(state, a) for a in range(self.actions)] +# return np.argmax(action_values).item() + +# def update(self, state: int, action: int, reward: float, next_state: int): +# """Update Q-value using the Q-learning rule.""" +# next_values = [self.get_value(next_state, a) for a in range(self.actions)] +# best_next_value = max(next_values) if next_values else self.default_value + +# old_value = self.get_value(state, action) +# td_target = reward + self.gamma * best_next_value +# td_error = td_target - old_value +# new_value = old_value + self.alpha * td_error +# self.set_value(state, action, new_value) + +# class Predictor(nn.Module): +# """Neural predictor for estimating a scale value (e.g., for adaptive span).""" +# def __init__(self, dims: int): +# super().__init__() +# self.linear = Linear(in_features=dims, out_features=1) +# self._init_weights() + +# def _init_weights(self): +# """Initialize predictor weights.""" +# nn.init.xavier_normal_(self.linear.weight) +# if self.linear.bias is not None: +# nn.init.zeros_(self.linear.bias) + +# def forward(self, x: Tensor) -> Tensor: +# """ +# Predicts a scale factor (0-1) from input features. + +# Args: +# x: Input tensor (Batch, SeqLen, Dims) or (Batch, Dims). + +# Returns: +# Scale tensor (Batch, 1). +# """ +# if x.dim() > 2: +# x = x.mean(dim=1) +# scale = torch.sigmoid(self.linear(x)) +# return scale + +# class AdaptiveSpanAttention(BaseAttention): +# """ +# Attention mechanism where the span is dynamically adjusted based on a +# learnable parameter or predicted scale. This version focuses on slicing +# the input sequence to the effective span. + +# Note: This implementation attends only to the *first* `eff_span` tokens. +# For attending to a *relative* window, different logic (e.g., sliding window +# or masking) would be needed in `calculate_attention`. +# """ +# def __init__(self, dims: int, head: int, max_dist: int = 512, +# initial_span_scale: float = 1.0, learnable_scale: bool = True, +# sharpen: bool = True, temp_scale: float = 0.01, dropout: float = 0.0): +# """ +# Args: +# dims, head, max_dist, dropout: Standard BaseAttention params. +# initial_span_scale: Initial value for the span scale. +# learnable_scale: If True, span_scale is an nn.Parameter. +# sharpen, temp_scale: Parameters for dynamic temperature adjustment. +# """ +# super().__init__(dims, head, max_dist, dropout=dropout) +# self.sharpen = sharpen +# self.temp_scale = temp_scale +# if learnable_scale: +# self.span_scale = nn.Parameter(torch.tensor(initial_span_scale)) +# else: +# self.register_buffer("span_scale", torch.tensor(initial_span_scale)) + +# self.query_module = ProjectionModule(dims, head, "query") +# self.key_module = ProjectionModule(dims, head, "key") +# self.value_module = ProjectionModule(dims, head, "value") +# self.out_proj = Linear(dims, dims) + +# @autocast('cuda', enabled=torch.cuda.is_available()) +# def forward(self, x: Tensor, xa: Optional[Tensor] = None, +# mask: Optional[Tensor] = None, +# span_scale_override: Optional[Tensor] = None, +# is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]: +# """ +# Computes attention over an adaptively determined span. + +# Args: +# x: Input tensor for Q (Batch, SeqLen_Q, Dims). +# xa: Optional input for K/V (Batch, SeqLen_KV, Dims). If None, use x. +# mask: External attention mask. +# span_scale_override: Optional tensor (Batch, 1) or scalar to override internal span_scale. +# is_causal: Whether to apply causal masking. + +# Returns: +# Tuple of (output tensor (Batch, SeqLen_Q, Dims), attention weights (optional)). +# """ +# kv_input = xa if xa is not None else x +# batch, ctx_q, _ = x.shape +# ctx_kv = kv_input.size(1) + +# current_span_scale = span_scale_override if span_scale_override is not None else self.span_scale +# if isinstance(current_span_scale, nn.Parameter): +# span_scale_val = current_span_scale.sigmoid() +# elif current_span_scale.numel() == 1: +# span_scale_val = current_span_scale.expand(batch, 1) +# else: +# span_scale_val = current_span_scale + +# span_mean = span_scale_val.mean().item() +# max_span_len = ctx_kv +# target_span_len = max(1, int(max_span_len * span_mean)) + +# eff_span = min(target_span_len, self.max_dist, ctx_q, ctx_kv) + +# if eff_span == 0: +# return (torch.zeros_like(x), None) + +# q_span = x[:, :eff_span, :] +# k_span = kv_input[:, :eff_span, :] +# v_span = kv_input[:, :eff_span, :] + +# q_proj = self.query_module(q_span) +# k_proj = self.key_module(k_span) +# v_proj = self.value_module(v_span) + +# temperature = (1.0 + self.temp_scale * (1.0 - span_mean) +# if self.sharpen +# else 0.5 + self.temp_scale * span_mean) +# temperature = max(temperature, 1e-3) + +# span_mask = None +# if mask is not None: +# if mask.dim() == 4: +# span_mask = mask[:, :, :eff_span, :eff_span] +# elif mask.dim() == 2: +# span_mask = mask[:eff_span, :eff_span] +# attn_output_span, attn_weights = calculate_attention( +# q_proj, k_proj, v_proj, +# mask=span_mask, +# temperature=temperature, +# use_sdpa=BaseAttention.use_sdpa, +# is_causal=is_causal, +# dropout_p=self.dropout +# ) + +# output_span = self._reshape_to_output(attn_output_span) +# projected_output_span = self.out_proj(output_span) + +# output = torch.zeros_like(x) +# output[:, :eff_span, :] = projected_output_span + +# return output, attn_weights + +# class MyelinatedLayer(BaseAttention): +# """ +# A complex Transformer layer featuring: +# - Integrated local/global attention (via IntegratedAttention). +# - Optional adapters within sub-layers. +# - Node importance prediction for sparsity. +# - MLP block. +# - Working memory component. +# - Potential layer skipping ("jumping") based on a learned policy. + +# (This version assumes IntegratedAttention is the core attention mechanism). +# """ +# def __init__(self, dims: int, head: int, num_layers: int = 3, +# sparsity_threshold: float = 0.1, max_dist: int = 512, +# dropout: float = 0.1, mlp_ratio: int = 4): +# super().__init__(dims, head, max_dist, dropout) +# self.num_layers = num_layers +# self.sparsity_threshold = sparsity_threshold + +# self.attention = IntegratedAttention(dims, head, max_dist=max_dist, dropout=dropout) + +# self.sub_layers = nn.ModuleList() +# self.node_predictors = nn.ModuleList([ +# nn.Sequential(LayerNorm(dims), Linear(dims, 1), nn.Sigmoid()) +# for _ in range(num_layers)]) + +# for i in range(num_layers): +# self.sub_layers.append(nn.ModuleDict({ +# 'ln': LayerNorm(dims), +# 'gate': nn.Sequential(Linear(dims, 1), nn.Sigmoid()), +# 'adapter': Linear(dims, dims) if i % 2 == 0 else None +# })) + +# self.policy_net = nn.Sequential(Linear(dims, 128), nn.ReLU(), Linear(128, num_layers)) +# self.jump_weights = nn.Parameter(torch.tensor([0.1, 0.05, 0.01])) + +# n_mlp = dims * mlp_ratio +# self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) +# self.mlp = nn.Sequential(Linear(dims, n_mlp), nn.GELU(), Linear(n_mlp, dims), nn.Dropout(dropout)) +# self.mlp_ln = LayerNorm(dims) + +# self.working_memory = nn.Parameter(torch.zeros(1, 1, dims)) +# self.memory_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) + +# self.last_memory_gate_values: Optional[Tensor] = None + +# def predict_node_importance(self, x: Tensor, layer_idx: int) -> Tensor: +# """Predict token importance mask (0.0 or 1.0) for sparsity.""" +# importance = self.node_predictors[layer_idx](x) +# is_important = (importance > self.sparsity_threshold).float() +# return is_important + +# def forward(self, x: Tensor, xa: Optional[Tensor] = None, +# mask: Optional[Tensor] = None, kv_cache: Optional[Tensor] = None, +# is_causal: bool = True) -> Tensor: +# batch, ctx, _ = x.shape +# working_memory = self.working_memory.expand(batch, 1, -1).to(x.device) +# original_x = x + +# pooled_representation = x.mean(dim=1) +# policy_logits = self.policy_net(pooled_representation) +# policy = F.softmax(policy_logits, dim=-1) + +# jump_history = [] +# i = 0 +# last_processed_output = x + +# while i < self.num_layers: +# layer = self.sub_layers[i] + +# node_importance_mask = self.predict_node_importance(x, i) + +# if node_importance_mask.mean() < 0.2 and i > 0: +# i += 1 +# jump_history.append(f"skip_low_imp->{i}") +# continue + +# norm_x = layer['ln'](x) + +# current_attn_mask = node_importance_mask.permute(0, 2, 1) +# if mask is not None: +# pass + +# attn_output = self.attention( +# norm_x * node_importance_mask, +# xa=xa, +# mask=mask, +# kv_cache=kv_cache, +# is_causal=is_causal +# ) + +# if layer['adapter'] is not None: +# attn_output = layer['adapter'](attn_output) + +# gate_value = layer['gate'](norm_x) +# x = x + gate_value * attn_output * node_importance_mask +# last_processed_output = x + +# memory_gate = self.memory_gate(x.mean(dim=1, keepdim=True)) +# current_mean_x = x.mean(dim=1, keepdim=True) +# working_memory = memory_gate * working_memory + (1 - memory_gate) * current_mean_x + +# if i < self.num_layers - 1: +# jump_prob_dist = policy[:, 1:] +# jump_prob = jump_prob_dist.sum(dim=-1) + +# should_jump_batch = torch.rand_like(jump_prob) < jump_prob + +# if should_jump_batch.any(): +# jump_len_probs = policy[should_jump_batch, :self.num_layers-i] +# sampled_jump_len = torch.multinomial(jump_len_probs, 1)[:, 0] + 1 + +# jump_length = sampled_jump_len.max().item() +# i_next = min(i + jump_length, self.num_layers) + +# skip_weight_idx = min(jump_length - 1, len(self.jump_weights) - 1) +# skip_weight = self.jump_weights[skip_weight_idx] + +# x = skip_weight * original_x + (1 - skip_weight) * working_memory.expand_as(x) + x * (1-skip_weight) +# jump_history.append(f"jump_{jump_length} S:{skip_weight.item():.2f} ->{i_next}") +# i = i_next +# continue + +# i += 1 + +# mlp_input = last_processed_output +# norm_mlp_input = self.mlp_ln(mlp_input) +# mlp_output = self.mlp(norm_mlp_input) +# mlp_gate_value = self.mlp_gate(norm_mlp_input) +# final_output = mlp_input + mlp_gate_value * mlp_output + +# if 'memory_gate' in locals(): +# self.last_memory_gate_values = memory_gate.detach().clone() + +# return final_output + +# class IntegratedAttention(BaseAttention): +# """ +# Integrates multiple attention strategies: +# - Local attention (sliding window or adaptive span via AdaptiveSpanAttention). +# - Global attention (potentially with adaptive updates via AdaptiveUpdateAttention). +# - Cross-attention capability. +# - RL-based refinement (`Refiner`) of the local attention span. +# - Iterative refinement (`_focus`) within local attention. +# """ +# def __init__(self, dims: int, head: int, max_dist: int = 512, +# win_size: int = 256, max_span: int = 384, temp_scale: float = 0.01, +# dropout: float = 0.1, +# rl_states: int = 10000, rl_actions: int = 10, rl_alpha: float = 0.1, +# rl_gamma: float = 0.9, rl_epsilon: float = 0.1): +# super().__init__(dims, head, max_dist, dropout=dropout) +# self.max_span = max_span +# self.sliding_window = win_size +# self.temp_scale = temp_scale +# self.sharpen = True + +# self.refiner = Refiner( +# states=rl_states, actions=rl_actions, alpha=rl_alpha, +# gamma=rl_gamma, epsilon=rl_epsilon) +# self.span_pred = Predictor(dims=dims) + +# self.attn_local = AdaptiveSpanAttention( +# dims=dims, head=head, max_dist=max_dist, sharpen=self.sharpen, +# temp_scale=temp_scale, learnable_scale=False, +# dropout=dropout) + +# self.attn_global = AdaptiveUpdateAttention( +# dims=dims, head=head, max_dist=max_dist, dropout=dropout) + +# self.cross_attn = AttentionCombiner(dims=dims, head=head, dropout=dropout) + +# self.self_projection = Linear(in_features=2 * dims, out_features=dims) +# self.global_cross_projection = Linear(in_features=dims, out_features=dims) + +# self.ln_local_in = LayerNorm(normalized_shape=dims) +# self.ln_global_in = LayerNorm(normalized_shape=dims) +# self.ln_cross_in = LayerNorm(normalized_shape=dims) + +# self.register_buffer("threshold", torch.tensor(1e-4), persistent=False) +# self.register_buffer("s_factor", torch.tensor(0.1), persistent=False) + +# def forward(self, x: Tensor, xa: Optional[Tensor] = None, +# mask: Optional[Tensor] = None, kv_cache: Optional[Tensor] = None, +# is_causal: bool = True) -> Tensor: +# """ +# Main forward pass distributing to cross or self-attention pathways. + +# Args: +# x: Primary input tensor (Batch, SeqLen_Q, Dims). +# xa: Context tensor for cross-attention (Batch, SeqLen_KV, Dims). +# mask: Attention mask (padding or causal). +# kv_cache: Key/Value cache for generation (specific usage depends on sub-modules). +# is_causal: Flag for causal masking in self-attention. + +# Returns: +# Output tensor (Batch, SeqLen_Q, Dims). +# """ +# batch, ctx_q, _ = x.shape + +# if xa is not None: +# q_norm = self.ln_cross_in(x) +# k_cross = self.attn_global.key_module(xa) +# v_cross = self.attn_global.value_module(xa) +# q_cross = self.attn_global.query_module(q_norm) + +# cross_out = self.cross_attn(q=q_cross, k=k_cross, v=v_cross, mask=mask, is_causal=False) +# return self.global_cross_projection(cross_out) + +# local_input = self.ln_local_in(x) +# global_input = self.ln_global_in(x) + +# globe_out_raw = self.attn_global( +# global_input, +# xa=None, +# mask=mask, +# is_causal=is_causal +# ) +# globe_out = self.global_cross_projection(globe_out_raw) + +# base_freq_scale = self.span_pred(globe_out) + +# state = self._extract_rl_state(local_input) +# action = self.refiner.choose_action(state=state) +# refinement_scale = self._action_to_scale(action=action) +# final_span_scale = torch.clamp(base_freq_scale * refinement_scale.expand_as(base_freq_scale), min=0.0, max=1.0) + +# span_mean = final_span_scale.mean().item() +# with torch.no_grad(): +# current_win_size = max(1, int(self.sliding_window * span_mean)) +# current_span_len = max(1, int(self.max_span * span_mean)) +# local_out_raw = self._slide_win_local( +# x=local_input, +# win_size=current_win_size, +# span_len=current_span_len, +# span_scale=final_span_scale, +# mask=mask, +# is_causal=is_causal +# ) +# with torch.no_grad(): +# reward = self._calculate_rl_reward(output=local_out_raw) +# next_state = self._extract_rl_state(local_out_raw) +# self.refiner.update(state=state, action=action, reward=reward, next_state=next_state) + +# combined = torch.cat([local_out_raw, globe_out], dim=-1) +# output = self.self_projection(combined) + +# return output + +# def _calculate_rl_reward(self, output: Tensor) -> float: +# """Calculate quality metric (reward) for reinforcement learning.""" +# with torch.no_grad(): +# output_probs = torch.softmax(output, dim=-1) +# safe_probs = torch.clamp(output_probs, min=1e-10) +# entropy = -(safe_probs * torch.log(safe_probs)).sum(-1).mean() +# coverage = (output.abs() > 0.01).float().mean() +# reward = float(coverage - 0.1 * entropy) +# return reward + +# def _extract_rl_state(self, x: Tensor) -> int: +# """Extract discrete state features for RL agent from tensor.""" +# with torch.no_grad(): +# pooled = x.mean(dim=1) +# mean_state = pooled[0].mean() +# var_state = pooled[0].var(unbiased=False) +# state_features = torch.stack([mean_state, var_state]).cpu().numpy() +# state_id = self._discretize_state(state_features) +# return state_id + +# def _discretize_state(self, state: np.ndarray) -> int: +# """Convert continuous state numpy array to a discrete state ID.""" +# bins = np.linspace(-1, 1, num=10) +# state_discrete = np.digitize(state, bins) +# state_hash = sum(val * (10**i) for i, val in enumerate(state_discrete)) +# state_id = int(state_hash % self.refiner.states) +# return state_id + +# def _action_to_scale(self, action: int) -> Tensor: +# """Convert discrete RL action index to a continuous scale factor [0, 1].""" +# span_value = action / (self.refiner.actions - 1) +# scale_tensor = torch.tensor([span_value], device=self.span_pred.linear.weight.device, dtype=torch.float) +# return scale_tensor + +# def _focus(self, query: Tensor, key: Tensor, value: Tensor, +# span_scale: Tensor, mask: Optional[Tensor] = None, +# is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]: +# """ +# Iterative attention refinement. Applies attention multiple times, +# adding the output back to the query. Uses manual attention calculation. + +# Args: +# query, key, value: Input tensors (B, SeqLen_Window, D). +# span_scale: Scale factor (scalar or B, 1) influencing effective span. +# mask: Attention mask for the window. +# is_causal: Apply causal masking within the window. + +# Returns: +# Tuple (refined_output (B, SeqLen_Window, D), attention_weights (optional, None here)). +# """ +# max_iterations = 5 +# iteration = 0 +# prev_attn_out = torch.zeros_like(query) +# attn_out = torch.zeros_like(query) +# threshold = self.threshold.item() +# s_factor = self.s_factor.item() + +# q_current = query + +# while iteration < max_iterations: +# span_mean = span_scale.mean().item() +# target_span_len = max(1, int(self.max_span * span_mean)) +# eff_span = min(target_span_len, self.max_dist, q_current.size(1), key.size(1)) + +# if eff_span == 0: break + +# q_iter = q_current[:, :eff_span, :] +# k_iter = key[:, :eff_span, :] +# v_iter = value[:, :eff_span, :] + +# q_proj = self.attn_local.query_module(q_iter) +# k_proj = self.attn_local.key_module(k_iter) +# v_proj = self.attn_local.value_module(v_iter) + +# temperature = (1.0 + self.temp_scale * (1.0 - span_mean) +# if self.sharpen +# else 0.5 + self.temp_scale * span_mean) +# temperature = max(temperature, 1e-3) + +# iter_mask = None +# if mask is not None: +# if mask.dim() == 4: iter_mask = mask[:, :, :eff_span, :eff_span] +# elif mask.dim() == 2: iter_mask = mask[:eff_span, :eff_span] +# attn_output_iter, _ = calculate_attention( +# q_proj, k_proj, v_proj, +# mask=iter_mask, +# temperature=temperature, +# use_sdpa=False, +# is_causal=is_causal, +# dropout_p=self.dropout +# ) + +# attn_out_span = self.attn_local._reshape_to_output(attn_output_iter) +# projected_attn_out_span = self.attn_local.out_proj(attn_out_span) + +# current_iter_out = torch.zeros_like(q_current) +# current_iter_out[:, :eff_span, :] = projected_attn_out_span + +# diff = torch.abs(current_iter_out - prev_attn_out).mean() +# dynamic_threshold = threshold + s_factor * diff + +# if diff < dynamic_threshold and iteration > 0: +# attn_out = current_iter_out +# break + +# prev_attn_out = current_iter_out.clone() +# q_current = q_current + current_iter_out +# attn_out = current_iter_out + +# iteration += 1 + +# return attn_out, None + +# @autocast('cuda', enabled=torch.cuda.is_available()) +# def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, +# span_scale: Tensor, mask: Optional[Tensor] = None, +# is_causal: bool = False) -> Tensor: +# """ +# Process input with sliding window attention, using `_focus` for each window. + +# Args: +# x: Input tensor (Batch, SeqLen, Dims). +# win_size: Size of the attention window for queries. +# span_len: Max length of keys/values relative to query window start. +# span_scale: Span scale tensor (Batch, 1 or scalar) passed to _focus. +# mask: Full attention mask. +# is_causal: Apply causal masking within windows. + +# Returns: +# Output tensor (Batch, SeqLen, Dims). +# """ +# batch, ctx, dims = x.size() +# output = torch.zeros_like(x) + +# num_windows = (ctx + win_size - 1) // win_size + +# for i in range(num_windows): +# q_start = i * win_size +# q_end = min(q_start + win_size, ctx) +# current_window_q_len = q_end - q_start +# if current_window_q_len == 0: continue + +# kv_start = max(0, q_end - span_len) +# kv_end = q_end +# query_win = x[:, q_start:q_end, :] +# key_win = x[:, kv_start:kv_end, :] +# value_win = x[:, kv_start:kv_end, :] + +# window_mask = None +# if mask is not None: +# if mask.dim() == 4: +# window_mask = mask[:, :, q_start:q_end, kv_start:kv_end] +# elif mask.dim() == 2: +# window_mask = mask[q_start:q_end, kv_start:kv_end] +# attn_out_win, _ = self._focus( +# query=query_win, +# key=key_win, +# value=value_win, +# span_scale=span_scale, +# mask=window_mask, +# is_causal=is_causal +# ) + +# output[:, q_start:q_end, :] = attn_out_win + +# return output + +# class CTCDecoder(nn.Module): +# def __init__(self, input_dim: int, vocab_size: int, dims: int = 256, num_layers: int = 2, dropout: float = 0.1): +# super().__init__() +# self.input_dim = input_dim +# self.vocab_size = vocab_size +# self.dims = dims + +# self.projection = nn.Linear(input_dim, dims) +# self.lstm = nn.LSTM(dims, dims, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True, bidirectional=True) +# self.output = nn.Linear(dims * 2, vocab_size + 1) # +1 for CTC blank token +# self.dropout = nn.Dropout(dropout) + +# def forward(self, x: Tensor) -> Tensor: +# x = self.projection(x) # (batch, seq_len, dims) +# x = self.dropout(x) +# x, _ = self.lstm(x) # (batch, seq_len, dims * 2) +# x = self.dropout(x) +# logits = self.output(x) # (batch, seq_len, vocab_size + 1) +# return logits + +# class CTCWrapper(nn.Module): +# def __init__(self, model: Model, vocab_size: int, dims: int = 256, num_layers: int = 2): +# super().__init__() +# self.model = model +# self.ctc_decoder = CTCDecoder( +# input_dim=model.param.dims, +# vocab_size=vocab_size, +# dims=dims, +# num_layers=num_layers +# ) + +# def forward(self, input_ids=None, pitch=None, labels=None, input_lengths=None, label_lengths=None): +# outputs = self.model(input_ids=input_ids, pitch=pitch) +# x = outputs["logits"] # (batch, seq_len, vocab_size) +# ctc_logits = self.ctc_decoder(x) # (batch, seq_len, vocab_size + 1) +# loss = None +# if labels is not None and input_lengths is not None and label_lengths is not None: +# log_probs = torch.log_softmax(ctc_logits, dim=-1) +# log_probs = log_probs.transpose(0, 1) + +# loss = torch.nn.functional.ctc_loss( +# log_probs, +# labels, +# input_lengths, +# label_lengths, +# blank=0, +# reduction='mean' +# ) + +# return { +# "logits": ctc_logits, +# "loss": loss, +# "out": x +# } + +# def decode(self, logits: Tensor, input_lengths: Optional[Tensor] = None) -> List[List[int]]: +# probs = torch.softmax(logits, dim=-1) # (batch, seq_len, vocab_size + 1) +# predictions = torch.argmax(probs, dim=-1) # (batch, seq_len) + +# decoded_sequences = [] +# for i, pred in enumerate(predictions): +# seq = [] +# prev_token = None +# for j, token in enumerate(pred): +# if input_lengths is not None and j >= input_lengths[i]: +# break +# if token != 0 and token != prev_token: +# seq.append(token.item()) +# prev_token = token +# decoded_sequences.append(seq) +# return decoded_sequences + +# # ctc_model = CTCWrapper(model, vocab_size=40000, dims=256, num_layers=2) + +# # outputs = ctc_model( +# # input_ids=input_ids, +# # pitch=pitch, +# # labels=labels, +# # input_lengths=input_lengths, # Length of each audio sequence +# # label_lengths=label_lengths # Length of each text sequence +# # ) + +# # loss = outputs["loss"] + +# # outputs = ctc_model(input_ids=input_ids, pitch=pitch) +# # logits = outputs["logits"] + +# # # Decode to text +# # decoded_sequences = ctc_model.decode(logits, input_lengths=input_lengths) +# # ctc_model = CTCWrapper(model, vocab_size=param.vocab, dims=256, num_layers=2).to('cuda') + +# # print(f"CTC model parameters: {sum(p.numel() for p in ctc_model.parameters() if p.requires_grad):,}") + +# # from tensorboard import program +# # log_dir = "D:/newmodel/output/logs" +# # tb = program.TensorBoard() +# # tb.configure(argv=[None, '--logdir', log_dir]) +# # url = tb.launch() +# # print(f"TensorBoard started at {url}") + +# def compute_metricsB(pred, tokenizer): +# pred_ids = pred["predictions"] +# label_ids = pred["label_ids"] +# if isinstance(pred_ids, tuple): +# pred_ids = pred_ids[0] +# else: +# pred_ids = pred_ids +# if pred_ids.ndim == 3: +# pred_ids = np.argmax(pred_ids, axis=-1) +# label_ids[label_ids == -100] = tokenizer.pad_token_id +# pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) +# label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) +# metrics = evaluate.load(path="wer") +# wer = metrics.compute(predictions=pred_str, references=label_str) +# return {"wer": wer} + +# def train_and_evaluate( +# model, +# tokenizer, +# train_loader, +# eval_loader, +# optimizer, +# scheduler, +# loss_fn, +# max_steps=10000, +# device="cuda", +# accumulation_steps=1, +# clear_cache=True, +# log_interval=10, +# eval_interval=100, +# save_interval=1000, +# checkpoint_dir="checkpoint_dir", +# log_dir="log_dir", +# ): +# model.to(device) +# global_step = 0 +# scaler = torch.GradScaler() +# writer = SummaryWriter(log_dir=log_dir) +# train_iterator = iter(train_loader) +# total_loss = 0 +# step_in_report = 0 +# dataset_epochs = 0 + +# progress_bar = tqdm( +# total=max_steps, desc="Training Progress", leave=True, colour="green" +# ) + +# model.train() +# optimizer.zero_grad() + +# while global_step < max_steps: +# try: +# batch = next(train_iterator) +# except StopIteration: +# train_iterator = iter(train_loader) +# batch = next(train_iterator) +# dataset_epochs += 1 +# print(f"Starting dataset epoch {dataset_epochs}") + +# if step_in_report > 0: +# avg_loss = total_loss / step_in_report +# logging.info( +# f"Dataset iteration complete - Steps: {global_step}, Avg Loss: {avg_loss:.4f}" +# ) +# total_loss = 0 +# step_in_report = 0 + +# start_time = time.time() + +# input_features = batch["input_features"].to(device) +# input_ids = batch["input_ids"].to(device) +# labels = batch["labels"].long().to(device) + +# with torch.autocast(device_type="cuda"): +# input_features_encoded = model.encoder(input_features) +# decoder_output = model.decoder(input_ids, input_features_encoded) +# logits = decoder_output.view(-1, decoder_output.size(-1)) +# active_logits = logits.view(-1, decoder_output.size(-1)) +# active_labels = labels.view(-1) +# active_mask = active_labels != -100 +# active_logits = active_logits[active_mask] +# active_labels = active_labels[active_mask] +# loss = loss_fn(active_logits, active_labels) +# # model.adjust_freq(loss=loss.item()) +# total_loss += loss.item() +# loss = loss / accumulation_steps + +# scaler.scale(loss).backward() + +# if (global_step + 1) % accumulation_steps == 0: +# scaler.unscale_(optimizer) +# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) +# scaler.step(optimizer) +# scaler.update() +# optimizer.zero_grad() + +# if clear_cache: +# torch.cuda.empty_cache() + +# end_time = time.time() +# samples_per_sec = len(batch["input_features"]) / (end_time - start_time) + +# if global_step % log_interval == 0: +# writer.add_scalar( +# tag="Loss/train", +# scalar_value=total_loss / (global_step + 1), +# global_step=global_step, +# ) +# lr = scheduler.get_last_lr()[0] +# writer.add_scalar( +# tag="LearningRate", scalar_value=lr, global_step=global_step +# ) +# writer.add_scalar( +# tag="SamplesPerSec", +# scalar_value=samples_per_sec, +# global_step=global_step, +# ) + +# if global_step % eval_interval == 0: +# model.eval() +# eval_start_time = time.time() +# eval_loss = 0 +# all_predictions = [] +# all_labels = [] +# batch_count = 0 +# total_samples = 0 + +# with torch.no_grad(): +# for eval_batch in eval_loader: +# # for eval_batch in tqdm(eval_loader, desc=f"Evaluating (Step {global_step})", leave=True, colour='green'): +# input_features = eval_batch["input_features"].to(device) +# input_ids = eval_batch["input_ids"].to(device) +# labels = eval_batch["labels"].long().to(device) + +# batch = input_features.size(0) +# total_samples += batch + +# input_features_encoded = model.encoder(input_features) +# decoder_output = model.decoder(input_ids, input_features_encoded) +# logits = decoder_output.view(-1, decoder_output.size(-1)) +# loss = loss_fn(logits, labels.view(-1)) +# eval_loss += loss.item() +# all_predictions.extend( +# torch.argmax(decoder_output, dim=-1).cpu().numpy().tolist() +# ) +# all_labels.extend(labels.cpu().numpy().tolist()) +# batch_count += 1 + +# eval_time = time.time() - eval_start_time +# loss_avg = eval_loss / batch_count if batch_count > 0 else 0 +# predictions = { +# "predictions": np.array(all_predictions, dtype=object), +# "label_ids": np.array(all_labels, dtype=object), +# } +# metrics = compute_metrics(pred=predictions, tokenizer=tokenizer) + +# writer.add_scalar("Loss/eval", loss_avg, global_step) +# writer.add_scalar("WER", metrics["wer"], global_step) +# writer.add_scalar("EvalSamples", total_samples, global_step) +# writer.add_scalar("EvalTimeSeconds", eval_time, global_step) +# lr = scheduler.get_last_lr()[0] + +# print( +# f"• STEP:{global_step} • samp:{samples_per_sec:.1f} • WER:{metrics['wer']:.2f}% • Loss:{loss_avg:.4f} • LR:{lr:.8f}" +# ) + +# logging.info( +# f"EVALUATION STEP {global_step} - WER: {metrics['wer']:.2f}%, Loss: {loss_avg:.4f}, LR: {lr:.8f}" +# ) +# # scheduler.step() +# model.train() + +# if global_step % save_interval == 0: +# checkpoint_path = os.path.join( +# checkpoint_dir, f"checkpoint_step_{global_step}.pt" +# ) +# torch.save(model.state_dict(), checkpoint_path) +# # print(f"Model saved at step {global_step} to {checkpoint_path}") +# logging.info(f"Model saved at step {global_step} to {checkpoint_path}") + +# lr = scheduler.get_last_lr()[0] +# scheduler.step() +# global_step += 1 +# step_in_report += 1 + +# avg_loss = total_loss / (global_step + 1) +# postfix_dict = { +# "loss": f"{avg_loss:.4f}", +# "lr": f"{lr:.6f}", +# "samp": f"{samples_per_sec:.1f}", +# } +# progress_bar.set_postfix(postfix_dict, refresh=True) +# progress_bar.update(1) + +# final_model_path = os.path.join(checkpoint_dir, "final_model.pt") +# torch.save(model.state_dict(), final_model_path) +# print( +# f"Training completed after {global_step} steps. Final model saved to {final_model_path}" +# ) +# writer.close() +# progress_bar.close() + +# def mainB(): + +# checkpoint_dir = "./output/checkpoints" +# os.makedirs(checkpoint_dir, exist_ok=True) +# log_dir = os.path.join("./output/logs", datetime.now().strftime(format="%m-%d_%H")) +# os.makedirs(name=log_dir, exist_ok=True) + +# logging.basicConfig( +# filename=os.path.join(log_dir, "training.log"), +# filemode="w", +# format="%(asctime)s - %(levelname)s - %(message)s", +# level=logging.INFO, +# ) + +# token = "" +# dataset = IterableDatasetDict() +# dataset["train"] = load_dataset( +# path="google/fleurs", +# name="en_us", +# split="train", +# streaming=True, +# token=token, +# trust_remote_code=True, +# ).select_columns(column_names=["audio", "transcription"]) + +# dataset["test"] = load_dataset( +# path="google/fleurs", +# name="en_us", +# split="test", +# streaming=True, +# token=token, +# trust_remote_code=True, +# ).select_columns(column_names=["audio", "transcription"]) + +# debug = None + +# param = Dimensions( +# mels=128, +# audio_ctx=1500, +# audio_head=4, +# encoder_idx=4, +# audio_dims=512, +# vocab=51865, +# text_ctx=512, +# text_head=4, +# decoder_idx=4, +# text_dims=512, +# decoder_start_token_id=0, +# pad_token_id=0, +# eos_token_id=0, +# act="gelu", +# ) + +# model = model + +# Collator = DataCollatorB( +# tokenizer=tokenizer, +# audio_ctx=param.audio_ctx, +# text_ctx=param.text_ctx, +# mels=param.mels, +# ) + +# train_dataloader = DataLoader( +# dataset=dataset["train"], batch_size=1, collate_fn=Collator, num_workers=0 +# ) + +# eval_dataloader = DataLoader( +# dataset=dataset["test"], batch_size=1, collate_fn=Collator, num_workers=0 +# ) + +# optimizer = torch.optim.AdamW( +# model.parameters(), lr=5e-4, weight_decay=0.01, eps=1e-6, betas=(0.9, 0.98) +# ) +# scheduler = torch.optim.lr_scheduler.LinearLR( +# optimizer, start_factor=0.25, total_iters=10000, last_epoch=-1 +# ) + +# loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100) + +# train_and_evaluate( +# model=model, +# tokenizer=tokenizer, +# train_loader=train_dataloader, +# eval_loader=eval_dataloader, +# optimizer=optimizer, +# scheduler=scheduler, +# loss_fn=loss_fn, +# max_steps=10000, +# device="cuda", +# accumulation_steps=1, +# clear_cache=False, +# log_interval=10, +# eval_interval=500, +# save_interval=10000, +# checkpoint_dir=checkpoint_dir, +# log_dir=log_dir, +# ) + +# def train_and_evaluate( +# model, tokenizer, train_loader, eval_loader, optimizer, scheduler, loss_fn, +# max_steps=10000, device='cuda', accumulation_steps=1, clear_cache=True, +# log_interval=10, eval_interval=100, save_interval=1000, +# checkpoint_dir="checkpoint_dir", log_dir="log_dir" +# ): +# model.to(device) +# global_step = 0 +# scaler = torch.GradScaler() +# writer = SummaryWriter(log_dir=log_dir) +# train_iterator = iter(train_loader) +# total_loss = 0 +# step_in_report = 0 +# dataset_epochs = 0 + +# progress_bar = tqdm(total=max_steps, desc="Training Progress", leave=True, colour='green') + +# model.train() +# optimizer.zero_grad() + +# while global_step < max_steps: +# try: +# batch = next(train_iterator) +# except StopIteration: +# train_iterator = iter(train_loader) +# batch = next(train_iterator) +# dataset_epochs += 1 +# print(f"Starting dataset epoch {dataset_epochs}") + +# if step_in_report > 0: +# avg_loss = total_loss / step_in_report +# logging.info(f"Dataset iteration complete - Steps: {global_step}, Avg Loss: {avg_loss:.4f}") +# total_loss = 0 +# step_in_report = 0 + +# start_time = time.time() + +# batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + +# with torch.autocast(device_type="cuda"): +# output = model(**batch) if hasattr(model, '__call__') else model.forward(**batch) +# logits = output["logits"] if isinstance(output, dict) and "logits" in output else output +# labels = batch["labels"] +# active_logits = logits.view(-1, logits.size(-1)) +# active_labels = labels.view(-1) +# active_mask = active_labels != 0 +# active_logits = active_logits[active_mask] +# active_labels = active_labels[active_mask] +# loss = loss_fn(active_logits, active_labels) +# total_loss += loss.item() +# loss = loss / accumulation_steps + +# scaler.scale(loss).backward() + +# if (global_step + 1) % accumulation_steps == 0: +# scaler.unscale_(optimizer) +# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) +# scaler.step(optimizer) +# scaler.update() +# optimizer.zero_grad() +# if clear_cache: +# torch.cuda.empty_cache() + +# end_time = time.time() +# samples_per_sec = batch["spectrogram"].size(0) / (end_time - start_time) + +# if global_step % log_interval == 0: +# writer.add_scalar(tag='Loss/train', scalar_value=total_loss / (global_step + 1), global_step=global_step) +# lr = scheduler.get_last_lr()[0] +# writer.add_scalar(tag='LearningRate', scalar_value=lr, global_step=global_step) +# writer.add_scalar(tag='SamplesPerSec', scalar_value=samples_per_sec, global_step=global_step) + +# if global_step % eval_interval == 0: +# model.eval() +# eval_start_time = time.time() +# eval_loss = 0 +# all_predictions = [] +# all_labels = [] +# batch_count = 0 +# total_samples = 0 + +# with torch.no_grad(): +# for eval_batch in eval_loader: +# eval_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in eval_batch.items()} +# output = model(**eval_batch) if hasattr(model, '__call__') else model.forward(**eval_batch) +# logits = output["logits"] if isinstance(output, dict) and "logits" in output else output +# labels = eval_batch["labels"] +# batch_size = logits.size(0) +# total_samples += batch_size +# loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) +# eval_loss += loss.item() +# all_predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy().tolist()) +# all_labels.extend(labels.cpu().numpy().tolist()) +# batch_count += 1 + +# eval_time = time.time() - eval_start_time +# loss_avg = eval_loss / batch_count if batch_count > 0 else 0 +# predictions = {"predictions": np.array(all_predictions, dtype=object), "label_ids": np.array(all_labels, dtype=object)} +# metrics = compute_metrics(pred=predictions, tokenizer=tokenizer) + +# writer.add_scalar('Loss/eval', loss_avg, global_step) +# writer.add_scalar('WER', metrics['wer'], global_step) +# writer.add_scalar('EvalSamples', total_samples, global_step) +# writer.add_scalar('EvalTimeSeconds', eval_time, global_step) + +# lr = scheduler.get_last_lr()[0] +# print(f"• STEP:{global_step} • samp:{samples_per_sec:.1f} • WER:{metrics['wer']:.2f}% • Loss:{loss_avg:.4f} • LR:{lr:.8f}") +# logging.info(f"EVALUATION STEP {global_step} - WER: {metrics['wer']:.2f}%, Loss: {loss_avg:.4f}, LR: {lr:.8f}") +# model.train() + +# if global_step % save_interval == 0: +# checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{global_step}.pt') +# torch.save(model.state_dict(), checkpoint_path) +# logging.info(f"Model saved at step {global_step} to {checkpoint_path}") + +# lr = scheduler.get_last_lr()[0] +# scheduler.step() +# global_step += 1 +# step_in_report += 1 + +# avg_loss = total_loss / (global_step + 1) +# postfix_dict = { +# 'loss': f'{avg_loss:.4f}', +# 'lr': f'{lr:.6f}', +# 'samp': f'{samples_per_sec:.1f}' +# } +# progress_bar.set_postfix(postfix_dict, refresh=True) +# progress_bar.update(1) + +# final_model_path = os.path.join(checkpoint_dir, 'final_model.pt') +# torch.save(model.state_dict(), final_model_path) +# print(f"Training completed after {global_step} steps. Final model saved to {final_model_path}") +# writer.close() +# progress_bar.close() + +# def get_optimizer(model, lr=5e-4, weight_decay=0.01): +# return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-6, betas=(0.9, 0.98)) + +# def get_scheduler(optimizer, total_steps=10000): +# return torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.25, total_iters=total_steps, last_epoch=-1) + +# def get_loss_fn(): +# return torch.nn.CrossEntropyLoss(ignore_index=0) + +# def mainc(): +# token = "" +# log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S')) +# os.makedirs(log_dir, exist_ok=True) +# tokenizer = setup_tokenizer(token) + +# param = Dimensions( +# mels=128, aud_ctx=1500, aud_head=4, aud_dims=512, aud_idx=4, +# vocab=40000, text_ctx=512, text_head=4, text_dims=512, text_idx=4, +# act="swish", debug={}, cross_attn=True, features=["spectrogram"] +# ) + +# dataset_config = { +# "spectrogram": True, "waveforms": False, "pitch": False, "downsamples": False, +# "frequency": True, "hilbert": False, "hop_length": 128, "fmin": 150, "fmax": 2000, +# "n_mels": 128, "n_fft": 1024, "sampling_rate": 16000, "pad_mode": "constant", +# "center": True, "power": 2.0, "window_fn": torch.hann_window, "mel_scale": "htk", +# "norm": None, "normalized": False +# } + +# model = create_model(param) +# train_dataset, test_dataset = prepare_datasets( +# tokenizer=tokenizer, token=token, sanity_check=False, dataset_config=dataset_config +# ) + +# collator = DataCollator(tokenizer=tokenizer) +# train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collator, num_workers=0) +# eval_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collator, num_workers=0) + +# optimizer = get_optimizer(model) +# scheduler = get_scheduler(optimizer) +# loss_fn = get_loss_fn() + +# train_and_evaluate( +# model=model, +# tokenizer=tokenizer, +# train_loader=train_loader, +# eval_loader=eval_loader, +# optimizer=optimizer, +# scheduler=scheduler, +# loss_fn=loss_fn, +# max_steps=10000, +# device='cuda', +# accumulation_steps=1, +# clear_cache=False, +# log_interval=10, +# eval_interval=500, +# save_interval=10000, +# checkpoint_dir="./checkpoints", +# log_dir=log_dir +# ) + +# class attention(nn.Module): +# def __init__(self, dims: int, head: int): +# super(attention, self).__init__() +# self.dims = dims +# self.head = head +# self.head_dim = dims // head +# self.q = nn.Linear(dims, dims) +# self.k = nn.Linear(dims, dims, bias=False) +# self.v = nn.Linear(dims, dims) +# self.o = nn.Linear(dims, dims) + +# self.lna = nn.LayerNorm(dims, bias = False) +# self.lnb = nn.LayerNorm(dims, bias = False) +# self.lnc = nn.LayerNorm(self.head_dim, bias = False) +# self.lnd = nn.LayerNorm(self.head_dim, bias = False) + +# def _forward(self, x: Tensor, xa = None, mask = None): +# q = self.q(self.lna(x)) +# k = self.k(self.lnb(x if xa is None else xa)) +# v = self.v(self.lnb(x if xa is None else xa)) +# query = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3) +# key = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3) +# value = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3) + +# max_iterations = 5 +# iteration = 0 +# prev_attn_out = torch.zeros_like(query) +# attn_out = torch.zeros_like(query) +# threshold = self.threshold.item() +# s_factor = self.s_factor.item() + +# q_current = query + +# while iteration < max_iterations: + +# eff_span = min(x.shape[1], xa.shape[1], q_current.size(1), key.size(1)) + +# if eff_span == 0: break + +# q_iter = q_current[:, :eff_span, :] +# k_iter = key[:, :eff_span, :] +# v_iter = value[:, :eff_span, :] + +# q_proj = self.attn_local.query_module(q_iter) +# k_proj = self.attn_local.key_module(k_iter) +# v_proj = self.attn_local.value_module(v_iter) + +# temperature = (1.0 + self.temp_scale * (1.0 - xa.mean()) +# if self.sharpen +# else 0.5 + self.temp_scale * xa.mean()) +# temperature = max(temperature, 1e-3) + +# iter_mask = None +# if mask is not None: +# if mask.dim() == 4: iter_mask = mask[:, :, :eff_span, :eff_span] +# elif mask.dim() == 2: iter_mask = mask[:eff_span, :eff_span] + +# attn_output_iter, _ = calculate_attention( +# q_proj, k_proj, v_proj, +# mask=iter_mask, +# temperature=temperature, +# use_sdpa=False, +# dropout_p=self.dropout +# ) + +# attn_out_span = self.attn_local._reshape_to_output(attn_output_iter) +# projected_attn_out_span = self.attn_local.out_proj(attn_out_span) + +# current_iter_out = torch.zeros_like(q_current) +# current_iter_out[:, :eff_span, :] = projected_attn_out_span + +# diff = torch.abs(current_iter_out - prev_attn_out).mean() +# dynamic_threshold = threshold + s_factor * diff + +# if diff < dynamic_threshold and iteration > 0: +# attn_out = current_iter_out +# break + +# prev_attn_out = current_iter_out.clone() +# q_current = q_current + current_iter_out +# attn_out = current_iter_out + +# iteration += 1 + +# return attn_out, None + +# def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, +# span_scale: Tensor, mask: Optional[Tensor] = None, +# is_causal: bool = False) -> Tensor: +# """ +# Process input with sliding window attention, using `_focus` for each window. + +# Args: +# x: Input tensor (Batch, SeqLen, Dims). +# win_size: Size of the attention window for queries. +# span_len: Max length of keys/values relative to query window start. +# span_scale: Span scale tensor (Batch, 1 or scalar) passed to _focus. +# mask: Full attention mask. +# is_causal: Apply causal masking within windows. + +# Returns: +# Output tensor (Batch, SeqLen, Dims). +# """ +# batch, ctx, dims = x.size() +# output = torch.zeros_like(x) + +# num_windows = (ctx + win_size - 1) // win_size + +# for i in range(num_windows): +# q_start = i * win_size +# q_end = min(q_start + win_size, ctx) +# current_window_q_len = q_end - q_start +# if current_window_q_len == 0: continue + +# kv_start = max(0, q_end - span_len) +# kv_end = q_end +# query_win = x[:, q_start:q_end, :] +# key_win = x[:, kv_start:kv_end, :] +# value_win = x[:, kv_start:kv_end, :] + +# window_mask = None +# if mask is not None: +# if mask.dim() == 4: +# window_mask = mask[:, :, q_start:q_end, kv_start:kv_end] +# elif mask.dim() == 2: +# window_mask = mask[q_start:q_end, kv_start:kv_end] + +# attn_out_win, _ = self._focus( +# query=query_win, +# key=key_win, +# value=value_win, +# span_scale=span_scale, +# mask=window_mask, +# is_causal=is_causal +# ) + +# output[:, q_start:q_end, :] = attn_out_win + +# return output \ No newline at end of file