diff --git "a/echoutils.py" "b/echoutils.py" --- "a/echoutils.py" +++ "b/echoutils.py" @@ -49,20 +49,19 @@ def sinusoids(ctx, dims, max_tscale=10000): positional_embedding = nn.Parameter(position, requires_grad=True) return positional_embedding -def get_activation(act: str) -> nn.Module: - act_map = { - "gelu": nn.GELU(), - "relu": nn.ReLU(), - "sigmoid": nn.Sigmoid(), - "tanh": nn.Tanh(), - "swish": nn.SiLU(), - "tanhshrink": nn.Tanhshrink(), - "softplus": nn.Softplus(), - "softshrink": nn.Softshrink(), - "leaky_relu": nn.LeakyReLU(), - "elu": nn.ELU() - } - return act_map.get(act, nn.GELU()) +class SLSTM(nn.Module): + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = False, bias = True, batch_first = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers, bias, batch_first) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor: q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12) @@ -73,6 +72,55 @@ def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor: out = torch.matmul(weights, v) return out +def scaled_relu(x, sequence_length): + relu_output = torch.relu(x) + return relu_output / sequence_length + +def taylor_softmax(x, order=2): + tapprox = 1.0 + for i in range(1, order + 1): + factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32))) + tapprox += x**i / factorial_i + return tapprox / torch.sum(tapprox, dim=-1, keepdim=True) + +def taylor_masked(x, mask, order=2): + tapprox = torch.zeros_like(x) + unmasked = x.masked_select(mask) + approx_values = 1.0 + unmasked + for i in range(1, order + 1): + factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32))) + approx_values += unmasked**i / factorial_i + tapprox.masked_scatter_(mask, approx_values) + sum_approx = torch.sum(tapprox, dim=-1, keepdim=True) + toutput = tapprox / (sum_approx + 1e-9) + toutput = toutput * mask + return toutput + +def taylor_softmax2(x, mask=None, order=2): + if mask is None: + tapprox = 1.0 + x + for i in range(1, order + 1): + factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32))) + tapprox += x**i / factorial_i + return tapprox / torch.sum(tapprox, dim=-1, keepdim=True) + + else: + tapprox = torch.zeros_like(x) + unmasked = x.masked_select(mask) + tapprox = 1.0 + unmasked + for i in range(1, order + 1): + factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32))) + tapprox += unmasked**i / factorial_i + + tapprox_full = torch.zeros_like(x) + tapprox_full.masked_scatter_(mask, tapprox) + + sum_approx = torch.sum(tapprox_full, dim=-1, keepdim=True) + toutput = tapprox_full / (sum_approx + 1e-9) + + toutput = toutput * mask.float() + return toutput + def taylor_softmax_2nd_order(x): exp_approx = 1 + x + (x**2) / 2 return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True) @@ -101,11 +149,10 @@ def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0): return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores def sliding_window_mask(q_len, k_len, window, device): - # mask[i, j] = 1 if j in [i-window+1, i], else 0 idxs = torch.arange(q_len, device=device).unsqueeze(1) jdxs = torch.arange(k_len, device=device).unsqueeze(0) mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs) - return mask.float() # shape: (q_len, k_len) + return mask.float() def mask_win(text_ctx, aud_ctx): mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0) @@ -191,132 +238,6 @@ def create_qkv(dims, head, q, k, v, x, xa): return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous() return _shape(q), _shape(k), _shape(v) -# def calculate_attention(q, k, v, mask=None, temp=1.0): -# scaled_q = q -# if temp != 1.0 and temp > 0: -# scaled_q = q * (1.0 / temp)**.5 -# out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1) -# return out - -# class LocalOut(nn.Module): -# def __init__(self, dims: int, head: int): -# super().__init__() -# head_dim = dims // head -# self.head_dim = head_dim -# self.query_module = nn.Linear(head_dim, head_dim) -# self.key_module = nn.Linear(head_dim, head_dim) -# self.value_module = nn.Linear(head_dim, head_dim) -# self.out_proj = nn.Linear(head_dim, head_dim) - -# def _reshape_to_output(self, x): -# return x - -# class attentiona(nn.Module): -# def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0): -# super(attentiona, self).__init__() -# self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head) -# self.dims = dims -# self.head = head -# self.head_dim = dims // head -# self.max_iter = max_iter -# self.threshold = nn.Parameter(torch.tensor(threshold)) -# self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True) -# self.factor = nn.Parameter(torch.tensor(factor)) -# self.lnc = nn.LayerNorm(self.head_dim, bias=False) -# self.lnd = nn.LayerNorm(self.head_dim, bias=False) -# self.attn_local = LocalOut(self.head_dim) - -# def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None): -# z = default(xa, x) -# q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z)) - -# iteration = 0 -# temp = self.temp.item() -# prev_out = torch.zeros_like(q) -# attn_out = torch.zeros_like(q) -# threshold = self.threshold.item() -# factor = self.factor.item() -# qcur = q - -# while iteration < self.max_iter: -# eff_span = min(qcur.shape[1], k.shape[1]) -# if xa is not None: -# eff_span = min(eff_span, xa.shape[1]) -# if eff_span == 0: -# break - -# qiter = qcur[:, :, :eff_span, :] -# kiter = k[:, :, :eff_span, :] -# viter = v[:, :, :eff_span, :] -# q = self.attn_local.query_module(qiter) -# k = self.attn_local.key_module(kiter) -# v = self.attn_local.value_module(viter) - -# iter_mask = None -# if mask is not None: -# if mask.dim() == 4: -# iter_mask = mask[:, :, :eff_span, :eff_span] -# elif mask.dim() == 2: -# iter_mask = mask[:eff_span, :eff_span] - -# attn_iter = calculate_attention( -# self.lnc(q), self.lnd(k), v, -# mask=iter_mask, temp=temp) - -# iter_out = torch.zeros_like(qcur) -# iter_out[:, :, :eff_span, :] = attn_iter -# diff = torch.abs(iter_out - prev_out).mean() -# dthresh = threshold + factor * diff -# if diff < dthresh and iteration > 0: -# attn_out = iter_out -# break - -# prev_out = iter_out.clone() -# qcur = qcur + iter_out -# attn_out = iter_out -# iteration += 1 -# temp += 0.005 - -# output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2) -# return self.o(output), None - - # def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor: - - # batch, ctx, dims = x.shape - # output = torch.zeros_like(x) - # num_win = (ctx + win_size - 1) // win_size - - # for i in range(num_win): - # qstart = i * win_size - # qend = min(qstart + win_size, ctx) - # win_qlen = qend - qstart - # if win_qlen == 0: - # continue - - # kstart = max(0, qend - span_len) - # kend = qend - # qwin = x[:, qstart:qend, :] - # kwin = x[:, kstart:kend, :] - - # win_mask = None - # if mask is not None: - # if mask.dim() == 4: - # win_mask = mask[:, :, qstart:qend, kstart:kend] - # elif mask.dim() == 2: - # win_mask = mask[qstart:qend, kstart:kend] - - # attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask) - # output[:, qstart:qend, :] = attn_out - # return output - - # def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, - # use_sliding_win: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor: - # if use_sliding_win: - # return self._slide_win_local(x, win_size, span_len, mask) - # else: - # output, _ = self._focus(x, xa, mask) - # return output - class KVCache(nn.Module): def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): super().__init__() @@ -325,13 +246,12 @@ class KVCache(nn.Module): self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) def update(self, input_pos, k_val, v_val): - # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] k_out = self.k_cache v_out = self.v_cache - k_out[:, :, input_pos] = k_val # pyright: ignore[reportIndexIssue] - v_out[:, :, input_pos] = v_val # pyright: ignore[reportIndexIssue] + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val return k_out, v_out @@ -371,7 +291,7 @@ def track_xa(new_xa, operation=""): current_id = id(new_xa) if current_id != xa_id[0]: print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}") - xa_id[0] = current_id # pyright: ignore[reportArgumentType, reportCallIssue] + xa_id[0] = current_id else: print(f"xa REUSE: {current_id} in {operation}") return new_xa @@ -393,7 +313,7 @@ def get_activation(act: str) -> nn.Module: return act_map.get(act, nn.GELU()) def get_generation_config(param): - return GenerationConfig( # type: ignore + return GenerationConfig( max_length=param.text_ctx, pad_token_id=getattr(param, "pad_token_id", 0), bos_token_id=getattr(param, "bos_token_id", 1), @@ -430,7 +350,6 @@ class feature_encoder(nn.Module): act_fn = get_activation(act) if self.attend_feature: - # self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head) self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims)) else: self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None @@ -452,11 +371,10 @@ class feature_encoder(nn.Module): Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) if use_rope: - # if spec_shape is not None: self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) - self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore + self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) else: - self.rope = None # type: ignore + self.rope = None self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) self.norm = RMSNorm(dims) @@ -469,7 +387,7 @@ class feature_encoder(nn.Module): x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) - x = self.rope.apply_rotary(x, freqs) # pyright: ignore[reportOptionalSubscript, reportAttributeAccessIssue] + x = self.rope.apply_rotary(x, freqs) x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) return x @@ -484,10 +402,6 @@ class feature_encoder(nn.Module): enc_dict = feats if feats is not None else {} enc_dict = dict(enc_dict) enc_dict["f0"] = xp - # xp = self.mel_scalar(xp.mean()) - # print(f"Using pitch scalar: {xp}") - # max_tscale = xp*300 - # print(f"Using max_tscale: {max_tscale}") feats = enc_dict if x.dim() == 2: x = x.unsqueeze(0) @@ -529,13 +443,6 @@ class feature_encoder(nn.Module): x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = self.norm(x) - # if self.attend_feature: - # xa = feats[feature] # pyright: ignore[reportOptionalSubscript] - # if xa is not None: - # q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head) - # out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True) - # x = x + out - x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = self.norm(x) return x @@ -579,6 +486,7 @@ class curiosity(nn.Module): return x.transpose(1, 2).contiguous().view(b, t, h * dh) def forward(self, x, xa, mask=None): + q, k, v = self.qkv(x).chunk(3, -1) qa, ka, va = self.qkv_aux(xa).chunk(3, -1) q, k, v = map(self.split, (q, k, v)) @@ -667,17 +575,17 @@ class RMSNorm(nn.Module): self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore + self.weight = nn.Parameter(torch.empty(self.normalized_shape)) init.ones_(self.weight) else: self.register_parameter("weight", None) def forward(self, x): - return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore + return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5) -> Tensor: - return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore + return F.layer_norm(x, normalized_shape, weight, bias, eps) def get_device(): return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -714,14 +622,14 @@ class SelfCriticalRL(nn.Module): rewards = [] baseline = [] - for s, g, ref in zip(sampled_text, greedy_text, labels): # type: ignore + for s, g, ref in zip(sampled_text, greedy_text, labels): ref_text = self.tokenizer.decode(ref) rewards.append(self.reward_fn(s, ref_text)) baseline.append(self.reward_fn(g, ref_text)) rewards = torch.tensor(rewards, device=device, dtype=torch.float) baseline = torch.tensor(baseline, device=device, dtype=torch.float) advantage = rewards - baseline - logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"] # logits: [batch, sampled_seq_len, vocab_size] + logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"] log_probs = F.log_softmax(logits, dim=-1) log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1) log_probs_sum = log_probs_seq.sum(dim=1) @@ -775,7 +683,7 @@ def wer_reward(hyp, ref): else: d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1]) wer = d[-1][-1] / max(1, len(ref_words)) - return -wer # negative WER as reward + return -wer def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): if isinstance(ids, torch.Tensor): @@ -841,7 +749,7 @@ def load_wave(wave_data, sample_rate=16000): waveform, sample_rate = torchaudio.load(uri=wave_data, normalize=False) elif isinstance(wave_data, dict): waveform = torch.tensor(data=wave_data["array"]).float() - sample_rate = wave_data["sampling_rate"] # noqa: F841 + sample_rate = wave_data["sampling_rate"] else: raise TypeError("Invalid wave_data format.") return waveform @@ -850,17 +758,12 @@ def world_to_mel(sp, ap, sample_rate=16000, n_mels=128): import librosa mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels) mel_basis = torch.from_numpy(mel_basis).float() - sp_mel = torch.matmul(sp, mel_basis.T) # (frames, 128) - ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128) + sp_mel = torch.matmul(sp, mel_basis.T) + ap_mel = torch.matmul(ap, mel_basis.T) return sp_mel, ap_mel def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False): - # import torch - # import torchaudio - # import torchaudio.functional as F - # import torchaudio.transforms as T - torch_windows = { 'hann': torch.hann_window, 'hamming': torch.hamming_window, @@ -885,9 +788,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens= audio = batch["audio"] sample_rate = audio["sampling_rate"] - # audio_length = len(audio["array"]) / audio["sampling_rate"] labels = tokenizer.encode(batch["transcription"]) - # sentence_length = len(batch["transcription"]) wav = load_wave(wave_data=audio, sample_rate=sample_rate) def crepe_predict(wav, sample_rate, viterbi=False): @@ -966,15 +867,6 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens= mfcc_tensor = transform(wav) return mfcc_tensor - # def compute_pitch(wav, sample_rate, hop_length=256): - # # pitch = F.detect_pitch_frequency(wav, sample_rate) - # # f0 = pitch - # import pyworld as pw - # wav_np = wav.numpy().astype(np.float64) - # f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length / sample_rate * 1000) - # f0 = pw.stonemask(wav_np, f0, t, sample_rate) - # return f0, t - def harmonics_and_aperiodics(wav, f0, t, sample_rate): import pyworld as pw wav_np = wav.numpy().astype(np.float64) @@ -1005,7 +897,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens= end_idx = torch.searchsorted(t2, token_ends, side="right") pitch_tok = torch.zeros(T, dtype=torch.float32) for i in range(T): - lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i]) # type: ignore + lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i]) segment = f0_np[lo:hi] if mode == "mean": pitch_tok[i] = segment.mean() @@ -1023,7 +915,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens= if phase_mod: tframe = torch.mean(t2[1:] - t2[:-1]) phi0 = 0.0 - omega = 2 * torch.pi * f0_tensor # type: ignore + omega = 2 * torch.pi * f0_tensor dphi = omega * tframe phi = torch.cumsum(dphi, dim=0) + phi0 phase = torch.remainder(phi, 2 * torch.pi) @@ -1031,9 +923,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens= phase = None if pitch: - p_tensor = torchaudio.functional.detect_pitch_frequency(wav, sample_rate) - # p_tensor = torch.from_numpy(f0_np) - # p_tensor = p_tensor.unsqueeze(0) + p_tensor = torchaudio.functional.detect_pitch_frequency(wav, sample_rate).unsqueeze(0) else: p_tensor = None @@ -1100,78 +990,6 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens= "dummy": dummy_tensor if dummy else None, } -# class PEncoder(nn.Module): # pitch encoder -# def __init__(self, dims: int, head: int, layer: int, kernel_size: int, act: str, -# max_seq_len: int, input_dims: int = 1, use_rope=False): -# super().__init__() - -# self.head = head -# self.head_dim = dims // head -# self.dims = dims -# self.use_rope=use_rope -# self.dropout_rate = 0.01 -# act_fn = get_activation(act) - -# self.positional_encoding = nn.Parameter(torch.randn(1, max_seq_len, dims)) - -# if use_rope: -# self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore -# else: -# self.rope = None -# self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) - -# self.attend_pitch = False -# if self.attend_pitch: -# self.mlp = nn.Sequential( -# nn.Linear(dims, dims), -# nn.ReLU(), -# nn.Linear(dims, dims), -# ) -# else: -# self.mlp = None - -# self.pitch_encoder = nn.Sequential( -# nn.Conv1d(input_dims, dims, kernel_size=kernel_size, stride=1, padding=kernel_size // 2), act_fn, -# nn.Conv1d(dims, dims, kernel_size=kernel_size - 2, stride=1, padding=(kernel_size - 2) // 2), act_fn, -# nn.Conv1d(dims, dims, kernel_size=kernel_size - 4, stride=1, padding=(kernel_size - 4) // 2, groups=dims), act_fn -# ) - -# def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"): -# batch, ctx, dims = x.shape -# x = x.view(batch, ctx, head, self.head_dim).permute(0, 2, 1, 3) -# freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore -# x = self.rope.apply_rotary(x, freqs)# type: ignore -# x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) -# return x - -# self.norm = nn.LayerNorm(dims) - -# def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, -# feats: Optional[Any] = None, feature: str = "pitch", layer: str = "PEncoder", -# audio_duration: Optional[float] = None, sample_rate: Optional[int] = None, -# labels_len: Optional[int] = None, f0_np: Optional[np.ndarray] = None, -# t2_np: Optional[np.ndarray] = None, mode: str = "mean") -> Tensor: - -# if x.dim() == 2 and feature == "pitch": -# x_processed = x.unsqueeze(1) # Input to pitch_encoder: (batch_size, 1, num_pitch_tokens) -# x_processed = self.pitch_encoder(x_processed) # Output: (batch_size, dims, num_pitch_tokens) -# x = x_processed.permute(0, 2, 1) # Reassign to x for consistency - -# if self.use_rope: -# pass # Placeholder for RoPE application - -# seq_len = x.shape[1] -# # x = x + self.positional_encoding[:, :seq_len, :] -# x = x + sinusoids(x.shape[1], x.shape[-1], 36000).to(device, dtype) - -# if self.mlp is not None: -# x = self.mlp(x) - -# x = nn.functional.dropout(x, p=self.dropout_rate, training=self.training) -# x = self.norm(x) - -# return x - def plot_waveform(waveform, sr, title="Waveform", ax=None): waveform = waveform.numpy() @@ -1266,9 +1084,6 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st len(x["audio"]["array"]) > 0 and len(x["audio"]["array"]) < max_ctx * 160) - # raw_train = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="train", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription") - # raw_test = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="test", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription").take(1000) - raw_train = load_dataset( "google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000) raw_test = load_dataset( @@ -1325,7 +1140,7 @@ class FEncoder(nn.Module): if use_rope: if spec_shape is not None: - self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore + self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) else: self.rope = None self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) @@ -1334,8 +1149,8 @@ class FEncoder(nn.Module): def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"): batch, ctx, dims = x.shape x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) - freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore - x = self.rope.apply_rotary(x, freqs)# type: ignore + freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) + x = self.rope.apply_rotary(x, freqs) x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) return x @@ -1351,7 +1166,7 @@ class FEncoder(nn.Module): x = self.norm(x) return x -class WEncoder(nn.Module): # waveform encoder +class WEncoder(nn.Module): def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None): super().__init__() @@ -1370,7 +1185,7 @@ class WEncoder(nn.Module): # waveform encoder if use_rope: if spec_shape is not None: - self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore + self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) else: self.rope = None self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) @@ -1379,13 +1194,13 @@ class WEncoder(nn.Module): # waveform encoder def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"): batch, ctx, dims = x.shape x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) - freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore - x = self.rope.apply_rotary(x, freqs)# type: ignore + freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) + x = self.rope.apply_rotary(x, freqs) x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) return x def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"): - x = self.encoder(x).permute(0, 2, 1) # (batch, time, dims) + x = self.encoder(x).permute(0, 2, 1) if self.target_length and x.shape[1] != self.target_length: x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2) if self.use_rope: @@ -1396,7 +1211,7 @@ class WEncoder(nn.Module): # waveform encoder print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None return self.norm(x) -class PEncoder(nn.Module): # pitch encoder +class PEncoder(nn.Module): def __init__(self, input_dims, dims, head, layer, act, use_rope=False, debug=[], one_shot=False, spec_shape=None): super().__init__() @@ -1427,7 +1242,7 @@ class PEncoder(nn.Module): # pitch encoder Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) if use_rope: - self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore + self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) else: self.rope = None self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) @@ -1436,17 +1251,14 @@ class PEncoder(nn.Module): # pitch encoder def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"): batch, ctx, dims = x.shape x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) - freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore - x = self.rope.apply_rotary(x, freqs)# type: ignore + freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) + x = self.rope.apply_rotary(x, freqs) x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) return x def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"): - # f0=x - # freqs = self.rope(f0.shape[1], feats=feats, feature=feature, layer=layer) if x.dim() == 2: x = x.unsqueeze(0) - # if feature == "pitch": x = self.pitch_encoder(x).permute(0, 2, 1) if self.use_rope: @@ -1462,7 +1274,6 @@ class PEncoder(nn.Module): # pitch encoder out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True) x = x + out - # x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = self.norm(x) print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None return x @@ -1483,7 +1294,7 @@ class DataCollator: for key in all_keys: if key == "labels": labels_list = [f["labels"] for f in features] - max_len = max(len(l) for l in labels_list) # noqa: E741 + max_len = max(len(l) for l in labels_list) all_ids, all_labels = [], [] for label in labels_list: @@ -1516,82 +1327,8 @@ class DataCollator: pad_item = item padded.append(pad_item) batch[key] = torch.stack(padded) - # if key == "spectrogram": - # batch["spectrogram"] = batch[key] return batch -# import tiktoken -# import torch -# from torch.utils.data import Dataset, DataLoader - -# class tokenize(Dataset): -# def __init__(self, txt, tokenizer, max_length, stride): - -# self.input_ids = [] -# self.labels = [] - -# token_ids = tokenizer.encode(txt, allowed_special={""}) - -# for i in range(0, len(token_ids) - max_length, stride): - -# input_chunk = token_ids[i:i + max_length] -# target_chunk = token_ids[i + 1: i + max_length + 1] -# self.input_ids.append(torch.tensor(input_chunk)) -# self.labels.append(torch.tensor(target_chunk)) - -# def __len__(self): -# return len(self.input_ids) - -# def __getitem__(self, idx): -# return self.input_ids[idx], self.labels[idx] - -# def create_dataloader_v1(txt, batch_size, max_length, stride, shuffle=True, drop_last=True, num_workers=0): -# tokenizer = tiktoken.get_encoding("gpt2") -# dataset = tokenize(txt, tokenizer, max_length, stride) -# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers) -# return dataloader - -# def custom_collate_fn(batch, tokenizer_pad_token_id): -# max_len_in_batch = max(len(seq) for seq in batch) - -# padded_input_ids = [] -# attention_masks = [] -# for seq in batch: -# padded_seq = F.pad(seq, (0, max_len_in_batch - len(seq)), value=tokenizer_pad_token_id) -# attention_mask = torch.ones(max_len_in_batch) -# attention_mask[len(seq):] = 0 -# padded_input_ids.append(padded_seq) -# attention_masks.append(attention_mask) - -# input_ids_tensor = torch.stack(padded_input_ids) -# attention_mask_tensor = torch.stack(attention_masks) -# labels_tensor = input_ids_tensor.clone() - -# return { -# 'input_ids': input_ids_tensor, -# 'attention_mask': attention_mask_tensor, -# 'labels': labels_tensor -# } - -# with open("the-verdict.txt", "r", encoding="utf-8") as f: -# raw_text = f.read() - -# vocab_size = 50257 -# output_dim = 256 -# context_length = 1024 - -# token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim) -# pos_embedding_layer = torch.nn.Embedding(context_length, output_dim) - -# batch_size = 8 -# max_length = 4 -# dataloader = create_dataloader_v1( -# raw_text, -# batch_size=batch_size, -# max_length=max_length, -# stride=max_length -# ) - def levenshtein(reference_words, hypothesis_words): m, n = len(reference_words), len(hypothesis_words) dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)] @@ -1721,1663 +1458,3 @@ def process_spectrogram_with_hilbert(spec): phase = torch.angle(analytic) return envelope, phase -# import torch -# import torch.nn as nn -# import torch.nn.functional as F -# from torch import Tensor -# from typing import Optional, Tuple -# import numpy as np -# from torch.nn.functional import scaled_dot_product_attention -# from torch.cuda.amp import autocast -# from torch.nn import LayerNorm, Linear -# import logging - -# logging.basicConfig(level=logging.WARNING) -# log = logging.getLogger(__name__) - -# class ProjectionModule(nn.Module): -# """ -# Projects input embeddings into query, key, or value representations -# for multi-head attention, handling scaling for Q/K. -# """ -# def __init__(self, dims: int, head: int, proj_type: str = "query", use_bias: bool = True): -# """ -# Args: -# dims: Input and output dimension. -# head: Number of attention heads. -# proj_type: Type of projection ("query", "key", "value"). -# use_bias: Whether to use bias in the linear layer. -# """ -# super().__init__() -# assert dims % head == 0, f"dims ({dims}) must be divisible by head ({head})" -# assert proj_type in ["query", "key", "value"], \ -# f"proj_type must be 'query', 'key', or 'value', got {proj_type}" - -# self.dims = dims -# self.head = head -# self.head_dim = dims // head -# self.proj_type = proj_type -# self.scale = self.head_dim ** -0.5 if proj_type != "value" else 1.0 -# self.proj = Linear(in_features=dims, out_features=dims, bias=use_bias) -# self.init_weights() - -# def init_weights(self): -# """Initialize projection weights.""" -# nn.init.normal_(tensor=self.proj.weight, std=0.02) -# if self.proj.bias is not None: -# nn.init.zeros_(tensor=self.proj.bias) - -# def forward(self, x: Tensor) -> Tensor: -# """ -# Applies projection, scaling (for Q/K), and reshapes for multi-head attention. - -# Args: -# x: Input tensor of shape (batch, seq_len, dims). - -# Returns: -# Projected tensor of shape (batch, head, seq_len, head_dim). -# """ -# batch, seq_len, _ = x.shape -# proj = self.proj(x) - -# proj = proj.view(batch, seq_len, self.head, self.head_dim).permute(0, 2, 1, 3) - -# if self.proj_type in ["query", "key"]: -# proj = proj * self.scale -# return proj - -# def calculate_attention( -# q: Tensor, -# k: Tensor, -# v: Tensor, -# mask: Optional[Tensor] = None, -# temperature: float = 1.0, -# use_sdpa: bool = True, -# is_causal: bool = False, -# dropout_p: float = 0.0 -# ) -> Tuple[Tensor, Optional[Tensor]]: -# """ -# Calculates scaled dot-product attention. - -# Uses torch.nn.functional.scaled_dot_product_attention if use_sdpa is True -# and inputs are compatible, otherwise falls back to manual implementation. - -# Args: -# q: Query tensor (Batch, Heads, SeqLen_Q, HeadDim). Already scaled if needed. -# k: Key tensor (Batch, Heads, SeqLen_K, HeadDim). Already scaled if needed. -# v: Value tensor (Batch, Heads, SeqLen_K, HeadDim). -# mask: Attention mask. Can be boolean (True means ignore) or float (-inf means ignore). -# Shape should be broadcastable to (Batch, Heads, SeqLen_Q, SeqLen_K). -# temperature: Softmax temperature scaling. Applied *before* softmax. -# use_sdpa: Flag to attempt using PyTorch's optimized SDPA implementation. -# is_causal: If True, applies a causal mask (for decoder self-attention). -# Used only if mask is None and use_sdpa is True. -# dropout_p: Dropout probability for attention weights. - -# Returns: -# A tuple containing: -# - attn_output: Attention output tensor (Batch, Heads, SeqLen_Q, HeadDim). -# - attn_weights: Attention weights tensor (Batch, Heads, SeqLen_Q, SeqLen_K), -# or None if SDPA implementation doesn't return them or if fallback used. -# *Note: SDPA's default doesn't return weights, requires specific backend support.* -# *Manual path always returns weights.* -# """ -# batch_size, num_heads, q_len, head_dim = q.shape -# k_len = k.size(2) - -# temp_scale = 1.0 / temperature if temperature > 0 else 1.0 - -# attn_output, attn_weights = None, None - -# if use_sdpa: -# try: -# if temperature != 1.0: -# raise NotImplementedError("SDPA does not directly support temperature scaling. Use manual path or scale Q.") - -# attn_output = scaled_dot_product_attention( -# q, k, v, -# attn_mask=mask, -# dropout_p=dropout_p, -# is_causal=is_causal and mask is None -# ) -# attn_weights = None -# return attn_output, attn_weights -# except (RuntimeError, NotImplementedError) as e: -# log.warning(f"SDPA failed or not used ({e}), falling back to manual attention.") -# attn_scores = torch.matmul(q, k.transpose(-2, -1)) * temp_scale - -# if mask is not None: -# if mask.dim() == 2: -# mask = mask.unsqueeze(0).unsqueeze(0) -# elif mask.dim() == 3: -# mask = mask.unsqueeze(1) - -# expected_mask_shape = (batch_size, num_heads, q_len, k_len) -# if mask.shape != expected_mask_shape: -# try: -# mask = mask.expand(expected_mask_shape) -# except RuntimeError: -# raise ValueError(f"Mask shape {mask.shape} is not compatible with attention scores shape {expected_mask_shape}") - -# if mask.dtype == torch.bool: -# attn_scores = attn_scores.masked_fill(mask, float("-inf")) -# else: -# attn_scores = attn_scores + mask - -# attn_weights = F.softmax(attn_scores, dim=-1) - -# if dropout_p > 0.0: -# attn_weights = F.dropout(attn_weights, p=dropout_p) - -# attn_output = torch.matmul(attn_weights, v) - -# return attn_output, attn_weights - -# class BaseAttention(nn.Module): -# """Base class for attention mechanisms with common functionality.""" -# use_sdpa = True - -# def __init__(self, dims: int, head: int, max_dist: int = 512, dropout: float = 0.0): -# """ -# Args: -# dims: Embedding dimension. -# head: Number of attention heads. -# max_dist: Maximum attention distance (used by some subclasses). -# dropout: Dropout probability for attention weights. -# """ -# super().__init__() -# assert dims % head == 0, f"dims ({dims}) must be divisible by head ({head})" -# self.dims = dims -# self.head = head -# self.head_dim = dims // head -# self.max_dist = max_dist -# self.dropout = dropout - -# def _shape(self, tensor: torch.Tensor) -> torch.Tensor: -# """ -# Reshape tensor from (batch, seq_len, dims) to -# (batch, head, seq_len, head_dim) for multi-head attention. -# """ -# batch, seq_len, _ = tensor.shape -# return tensor.view(batch, seq_len, self.head, self.head_dim).transpose(1, 2).contiguous() - -# def _reshape_to_output(self, attn_output: Tensor) -> Tensor: -# """ -# Reshape attention output from (batch, head, seq_len, head_dim) -# back to (batch, seq_len, dims). -# """ -# batch, _, seq_len, _ = attn_output.shape -# return attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.dims) - -# class AttentionCombiner(BaseAttention): -# """ -# Computes attention given Q, K, V projections and applies an output projection. -# Assumes Q, K, V inputs are already projected and appropriately shaped/scaled. -# """ -# def __init__(self, dims: int, head: int, use_bias: bool = True, dropout: float = 0.0): -# """ -# Args: -# dims: Embedding dimension. -# head: Number of attention heads. -# use_bias: Whether to use bias in the output projection. -# dropout: Dropout probability for attention weights. -# """ -# super().__init__(dims, head, dropout=dropout) -# self.out = Linear(in_features=dims, out_features=dims, bias=use_bias) -# self._init_weights() - -# def _init_weights(self): -# """Initialize output projection weights.""" -# nn.init.normal_(tensor=self.out.weight, std=0.02) -# if self.out.bias is not None: -# nn.init.zeros_(tensor=self.out.bias) - -# # @autocast('cuda', enabled=torch.cuda.is_available()) -# def forward(self, q: Tensor, k: Tensor, v: Tensor, -# mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor: -# """ -# Processes Q, K, V through attention and output projection. - -# Args: -# q: Query tensor (Batch, Heads, SeqLen_Q, HeadDim). Assumed scaled. -# k: Key tensor (Batch, Heads, SeqLen_K, HeadDim). Assumed scaled. -# v: Value tensor (Batch, Heads, SeqLen_K, HeadDim). -# mask: Attention mask. -# is_causal: Whether to apply causal masking (if mask is None). - -# Returns: -# Output tensor (Batch, SeqLen_Q, Dims). -# """ -# attn_output, _ = calculate_attention( -# q, k, v, mask=mask, -# temperature=1.0, -# use_sdpa=BaseAttention.use_sdpa, -# is_causal=is_causal, -# dropout_p = self.dropout -# ) - -# output = self._reshape_to_output(attn_output) -# return self.out(output) - -# class AdaptiveUpdateAttention(BaseAttention): -# """ -# Attention implementation where Key and Value representations are cached -# and only updated based on content-dependent predictors. Suitable for -# encoder layers or cross-attention where K/V context changes less frequently. - -# Note: Current implementation focuses on conditional update based on *current* -# input, not standard auto-regressive KV caching for generation. -# """ -# def __init__(self, dims: int, head: int, max_dist: int = 512, update_threshold: float = 0.5, dropout: float = 0.0): -# """ -# Args: -# dims: Embedding dimension. -# head: Number of attention heads. -# max_dist: Maximum attention distance (inherited, may not be directly used here). -# update_threshold: Threshold for sigmoid output of predictors to trigger update. -# dropout: Dropout probability for attention weights. -# """ -# super().__init__(dims, head, max_dist, dropout=dropout) - -# self.query_module = ProjectionModule(dims, head, "query") -# self.key_module = ProjectionModule(dims, head, "key") -# self.value_module = ProjectionModule(dims, head, "value") -# self.combiner = AttentionCombiner(dims, head, dropout=dropout) - -# self.key_update_predictor = nn.Sequential( -# Linear(dims, dims // 4), nn.ReLU(), Linear(dims // 4, 1), nn.Sigmoid()) -# self.value_update_predictor = nn.Sequential( -# Linear(dims, dims // 4), nn.ReLU(), Linear(dims // 4, 1), nn.Sigmoid()) - -# self.update_threshold = update_threshold -# self.stored_key_cache: Optional[Tensor] = None -# self.stored_value_cache: Optional[Tensor] = None -# self.reset_cache_on_forward = True - -# def _should_update(self, x: torch.Tensor, predictor: nn.Module) -> torch.Tensor: -# """Predict whether K or V should be updated based on content.""" -# avg_rep = x.mean(dim=1) -# update_prob = predictor(avg_rep) -# return update_prob > self.update_threshold - -# def forward(self, x: Tensor, xa: Optional[Tensor] = None, -# mask: Optional[Tensor] = None, -# is_causal: bool = False) -> Tensor: -# """ -# Process inputs with adaptive K/V update mechanism. - -# Args: -# x: Input tensor for queries (Batch, SeqLen_Q, Dims). -# xa: Optional input tensor for keys/values (for cross-attention). -# If None, uses x for self-attention (Batch, SeqLen_KV, Dims). -# mask: Attention mask. -# is_causal: Whether attention should be causal. - -# Returns: -# Output tensor (Batch, SeqLen_Q, Dims). -# """ -# if self.reset_cache_on_forward: -# self.stored_key_cache = None -# self.stored_value_cache = None - -# batch, ctx_q, _ = x.shape -# q = self.query_module(x) - -# kv_input = xa if xa is not None else x -# ctx_kv = kv_input.size(1) - -# update_k_batch = self._should_update(kv_input, self.key_update_predictor) -# update_v_batch = self._should_update(kv_input, self.value_update_predictor) - -# if self.stored_key_cache is None or self.stored_key_cache.shape[2] != ctx_kv or self.stored_key_cache.shape[0] != batch: -# k = self.key_module(kv_input) -# self.stored_key_cache = k -# elif update_k_batch.any(): -# new_k = self.key_module(kv_input) -# update_mask_k = update_k_batch.view(-1, 1, 1, 1).expand_as(self.stored_key_cache) -# k = torch.where(update_mask_k, new_k, self.stored_key_cache) -# self.stored_key_cache = k -# else: -# k = self.stored_key_cache - -# if self.stored_value_cache is None or self.stored_value_cache.shape[2] != ctx_kv or self.stored_value_cache.shape[0] != batch: -# v = self.value_module(kv_input) -# self.stored_value_cache = v -# elif update_v_batch.any(): -# new_v = self.value_module(kv_input) -# update_mask_v = update_v_batch.view(-1, 1, 1, 1).expand_as(self.stored_value_cache) -# v = torch.where(update_mask_v, new_v, self.stored_value_cache) -# self.stored_value_cache = v -# else: -# v = self.stored_value_cache - -# output = self.combiner(q, k, v, mask=mask, is_causal=is_causal) -# return output - -# class Refiner: -# """ -# Q-learning based agent to refine parameters (e.g., attention span). -# Operates outside the standard backpropagation loop. -# """ -# def __init__(self, states: int, actions: int, alpha: float = 0.1, gamma: float = 0.9, epsilon: float = 0.1): -# self.states = states -# self.actions = actions -# self.R = {} -# self.alpha = alpha -# self.gamma = gamma -# self.epsilon = epsilon -# self.default_value = 0.0 - -# def get_value(self, state: int, action: int) -> float: -# """Get Q-value for state-action pair.""" -# return self.R.get((state, action), self.default_value) - -# def set_value(self, state: int, action: int, value: float): -# """Set Q-value for state-action pair.""" -# self.R[(state, action)] = value - -# def choose_action(self, state: int) -> int: -# """Choose action using epsilon-greedy strategy.""" -# if np.random.random() < self.epsilon: -# return np.random.randint(self.actions) -# else: -# action_values = [self.get_value(state, a) for a in range(self.actions)] -# return np.argmax(action_values).item() - -# def update(self, state: int, action: int, reward: float, next_state: int): -# """Update Q-value using the Q-learning rule.""" -# next_values = [self.get_value(next_state, a) for a in range(self.actions)] -# best_next_value = max(next_values) if next_values else self.default_value - -# old_value = self.get_value(state, action) -# td_target = reward + self.gamma * best_next_value -# td_error = td_target - old_value -# new_value = old_value + self.alpha * td_error -# self.set_value(state, action, new_value) - -# class Predictor(nn.Module): -# """Neural predictor for estimating a scale value (e.g., for adaptive span).""" -# def __init__(self, dims: int): -# super().__init__() -# self.linear = Linear(in_features=dims, out_features=1) -# self._init_weights() - -# def _init_weights(self): -# """Initialize predictor weights.""" -# nn.init.xavier_normal_(self.linear.weight) -# if self.linear.bias is not None: -# nn.init.zeros_(self.linear.bias) - -# def forward(self, x: Tensor) -> Tensor: -# """ -# Predicts a scale factor (0-1) from input features. - -# Args: -# x: Input tensor (Batch, SeqLen, Dims) or (Batch, Dims). - -# Returns: -# Scale tensor (Batch, 1). -# """ -# if x.dim() > 2: -# x = x.mean(dim=1) -# scale = torch.sigmoid(self.linear(x)) -# return scale - -# class AdaptiveSpanAttention(BaseAttention): -# """ -# Attention mechanism where the span is dynamically adjusted based on a -# learnable parameter or predicted scale. This version focuses on slicing -# the input sequence to the effective span. - -# Note: This implementation attends only to the *first* `eff_span` tokens. -# For attending to a *relative* window, different logic (e.g., sliding window -# or masking) would be needed in `calculate_attention`. -# """ -# def __init__(self, dims: int, head: int, max_dist: int = 512, -# initial_span_scale: float = 1.0, learnable_scale: bool = True, -# sharpen: bool = True, temp_scale: float = 0.01, dropout: float = 0.0): -# """ -# Args: -# dims, head, max_dist, dropout: Standard BaseAttention params. -# initial_span_scale: Initial value for the span scale. -# learnable_scale: If True, span_scale is an nn.Parameter. -# sharpen, temp_scale: Parameters for dynamic temperature adjustment. -# """ -# super().__init__(dims, head, max_dist, dropout=dropout) -# self.sharpen = sharpen -# self.temp_scale = temp_scale -# if learnable_scale: -# self.span_scale = nn.Parameter(torch.tensor(initial_span_scale)) -# else: -# self.register_buffer("span_scale", torch.tensor(initial_span_scale)) - -# self.query_module = ProjectionModule(dims, head, "query") -# self.key_module = ProjectionModule(dims, head, "key") -# self.value_module = ProjectionModule(dims, head, "value") -# self.out_proj = Linear(dims, dims) - -# @autocast('cuda', enabled=torch.cuda.is_available()) -# def forward(self, x: Tensor, xa: Optional[Tensor] = None, -# mask: Optional[Tensor] = None, -# span_scale_override: Optional[Tensor] = None, -# is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]: -# """ -# Computes attention over an adaptively determined span. - -# Args: -# x: Input tensor for Q (Batch, SeqLen_Q, Dims). -# xa: Optional input for K/V (Batch, SeqLen_KV, Dims). If None, use x. -# mask: External attention mask. -# span_scale_override: Optional tensor (Batch, 1) or scalar to override internal span_scale. -# is_causal: Whether to apply causal masking. - -# Returns: -# Tuple of (output tensor (Batch, SeqLen_Q, Dims), attention weights (optional)). -# """ -# kv_input = xa if xa is not None else x -# batch, ctx_q, _ = x.shape -# ctx_kv = kv_input.size(1) - -# current_span_scale = span_scale_override if span_scale_override is not None else self.span_scale -# if isinstance(current_span_scale, nn.Parameter): -# span_scale_val = current_span_scale.sigmoid() -# elif current_span_scale.numel() == 1: -# span_scale_val = current_span_scale.expand(batch, 1) -# else: -# span_scale_val = current_span_scale - -# span_mean = span_scale_val.mean().item() -# max_span_len = ctx_kv -# target_span_len = max(1, int(max_span_len * span_mean)) - -# eff_span = min(target_span_len, self.max_dist, ctx_q, ctx_kv) - -# if eff_span == 0: -# return (torch.zeros_like(x), None) - -# q_span = x[:, :eff_span, :] -# k_span = kv_input[:, :eff_span, :] -# v_span = kv_input[:, :eff_span, :] - -# q_proj = self.query_module(q_span) -# k_proj = self.key_module(k_span) -# v_proj = self.value_module(v_span) - -# temperature = (1.0 + self.temp_scale * (1.0 - span_mean) -# if self.sharpen -# else 0.5 + self.temp_scale * span_mean) -# temperature = max(temperature, 1e-3) - -# span_mask = None -# if mask is not None: -# if mask.dim() == 4: -# span_mask = mask[:, :, :eff_span, :eff_span] -# elif mask.dim() == 2: -# span_mask = mask[:eff_span, :eff_span] -# attn_output_span, attn_weights = calculate_attention( -# q_proj, k_proj, v_proj, -# mask=span_mask, -# temperature=temperature, -# use_sdpa=BaseAttention.use_sdpa, -# is_causal=is_causal, -# dropout_p=self.dropout -# ) - -# output_span = self._reshape_to_output(attn_output_span) -# projected_output_span = self.out_proj(output_span) - -# output = torch.zeros_like(x) -# output[:, :eff_span, :] = projected_output_span - -# return output, attn_weights - -# class MyelinatedLayer(BaseAttention): -# """ -# A complex Transformer layer featuring: -# - Integrated local/global attention (via IntegratedAttention). -# - Optional adapters within sub-layers. -# - Node importance prediction for sparsity. -# - MLP block. -# - Working memory component. -# - Potential layer skipping ("jumping") based on a learned policy. - -# (This version assumes IntegratedAttention is the core attention mechanism). -# """ -# def __init__(self, dims: int, head: int, num_layers: int = 3, -# sparsity_threshold: float = 0.1, max_dist: int = 512, -# dropout: float = 0.1, mlp_ratio: int = 4): -# super().__init__(dims, head, max_dist, dropout) -# self.num_layers = num_layers -# self.sparsity_threshold = sparsity_threshold - -# self.attention = IntegratedAttention(dims, head, max_dist=max_dist, dropout=dropout) - -# self.sub_layers = nn.ModuleList() -# self.node_predictors = nn.ModuleList([ -# nn.Sequential(LayerNorm(dims), Linear(dims, 1), nn.Sigmoid()) -# for _ in range(num_layers)]) - -# for i in range(num_layers): -# self.sub_layers.append(nn.ModuleDict({ -# 'ln': LayerNorm(dims), -# 'gate': nn.Sequential(Linear(dims, 1), nn.Sigmoid()), -# 'adapter': Linear(dims, dims) if i % 2 == 0 else None -# })) - -# self.policy_net = nn.Sequential(Linear(dims, 128), nn.ReLU(), Linear(128, num_layers)) -# self.jump_weights = nn.Parameter(torch.tensor([0.1, 0.05, 0.01])) - -# n_mlp = dims * mlp_ratio -# self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) -# self.mlp = nn.Sequential(Linear(dims, n_mlp), nn.GELU(), Linear(n_mlp, dims), nn.Dropout(dropout)) -# self.mlp_ln = LayerNorm(dims) - -# self.working_memory = nn.Parameter(torch.zeros(1, 1, dims)) -# self.memory_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) - -# self.last_memory_gate_values: Optional[Tensor] = None - -# def predict_node_importance(self, x: Tensor, layer_idx: int) -> Tensor: -# """Predict token importance mask (0.0 or 1.0) for sparsity.""" -# importance = self.node_predictors[layer_idx](x) -# is_important = (importance > self.sparsity_threshold).float() -# return is_important - -# def forward(self, x: Tensor, xa: Optional[Tensor] = None, -# mask: Optional[Tensor] = None, kv_cache: Optional[Tensor] = None, -# is_causal: bool = True) -> Tensor: -# batch, ctx, _ = x.shape -# working_memory = self.working_memory.expand(batch, 1, -1).to(x.device) -# original_x = x - -# pooled_representation = x.mean(dim=1) -# policy_logits = self.policy_net(pooled_representation) -# policy = F.softmax(policy_logits, dim=-1) - -# jump_history = [] -# i = 0 -# last_processed_output = x - -# while i < self.num_layers: -# layer = self.sub_layers[i] - -# node_importance_mask = self.predict_node_importance(x, i) - -# if node_importance_mask.mean() < 0.2 and i > 0: -# i += 1 -# jump_history.append(f"skip_low_imp->{i}") -# continue - -# norm_x = layer['ln'](x) - -# current_attn_mask = node_importance_mask.permute(0, 2, 1) -# if mask is not None: -# pass - -# attn_output = self.attention( -# norm_x * node_importance_mask, -# xa=xa, -# mask=mask, -# kv_cache=kv_cache, -# is_causal=is_causal -# ) - -# if layer['adapter'] is not None: -# attn_output = layer['adapter'](attn_output) - -# gate_value = layer['gate'](norm_x) -# x = x + gate_value * attn_output * node_importance_mask -# last_processed_output = x - -# memory_gate = self.memory_gate(x.mean(dim=1, keepdim=True)) -# current_mean_x = x.mean(dim=1, keepdim=True) -# working_memory = memory_gate * working_memory + (1 - memory_gate) * current_mean_x - -# if i < self.num_layers - 1: -# jump_prob_dist = policy[:, 1:] -# jump_prob = jump_prob_dist.sum(dim=-1) - -# should_jump_batch = torch.rand_like(jump_prob) < jump_prob - -# if should_jump_batch.any(): -# jump_len_probs = policy[should_jump_batch, :self.num_layers-i] -# sampled_jump_len = torch.multinomial(jump_len_probs, 1)[:, 0] + 1 - -# jump_length = sampled_jump_len.max().item() -# i_next = min(i + jump_length, self.num_layers) - -# skip_weight_idx = min(jump_length - 1, len(self.jump_weights) - 1) -# skip_weight = self.jump_weights[skip_weight_idx] - -# x = skip_weight * original_x + (1 - skip_weight) * working_memory.expand_as(x) + x * (1-skip_weight) -# jump_history.append(f"jump_{jump_length} S:{skip_weight.item():.2f} ->{i_next}") -# i = i_next -# continue - -# i += 1 - -# mlp_input = last_processed_output -# norm_mlp_input = self.mlp_ln(mlp_input) -# mlp_output = self.mlp(norm_mlp_input) -# mlp_gate_value = self.mlp_gate(norm_mlp_input) -# final_output = mlp_input + mlp_gate_value * mlp_output - -# if 'memory_gate' in locals(): -# self.last_memory_gate_values = memory_gate.detach().clone() - -# return final_output - -# class IntegratedAttention(BaseAttention): -# """ -# Integrates multiple attention strategies: -# - Local attention (sliding window or adaptive span via AdaptiveSpanAttention). -# - Global attention (potentially with adaptive updates via AdaptiveUpdateAttention). -# - Cross-attention capability. -# - RL-based refinement (`Refiner`) of the local attention span. -# - Iterative refinement (`_focus`) within local attention. -# """ -# def __init__(self, dims: int, head: int, max_dist: int = 512, -# win_size: int = 256, max_span: int = 384, temp_scale: float = 0.01, -# dropout: float = 0.1, -# rl_states: int = 10000, rl_actions: int = 10, rl_alpha: float = 0.1, -# rl_gamma: float = 0.9, rl_epsilon: float = 0.1): -# super().__init__(dims, head, max_dist, dropout=dropout) -# self.max_span = max_span -# self.sliding_window = win_size -# self.temp_scale = temp_scale -# self.sharpen = True - -# self.refiner = Refiner( -# states=rl_states, actions=rl_actions, alpha=rl_alpha, -# gamma=rl_gamma, epsilon=rl_epsilon) -# self.span_pred = Predictor(dims=dims) - -# self.attn_local = AdaptiveSpanAttention( -# dims=dims, head=head, max_dist=max_dist, sharpen=self.sharpen, -# temp_scale=temp_scale, learnable_scale=False, -# dropout=dropout) - -# self.attn_global = AdaptiveUpdateAttention( -# dims=dims, head=head, max_dist=max_dist, dropout=dropout) - -# self.cross_attn = AttentionCombiner(dims=dims, head=head, dropout=dropout) - -# self.self_projection = Linear(in_features=2 * dims, out_features=dims) -# self.global_cross_projection = Linear(in_features=dims, out_features=dims) - -# self.ln_local_in = LayerNorm(normalized_shape=dims) -# self.ln_global_in = LayerNorm(normalized_shape=dims) -# self.ln_cross_in = LayerNorm(normalized_shape=dims) - -# self.register_buffer("threshold", torch.tensor(1e-4), persistent=False) -# self.register_buffer("s_factor", torch.tensor(0.1), persistent=False) - -# def forward(self, x: Tensor, xa: Optional[Tensor] = None, -# mask: Optional[Tensor] = None, kv_cache: Optional[Tensor] = None, -# is_causal: bool = True) -> Tensor: -# """ -# Main forward pass distributing to cross or self-attention pathways. - -# Args: -# x: Primary input tensor (Batch, SeqLen_Q, Dims). -# xa: Context tensor for cross-attention (Batch, SeqLen_KV, Dims). -# mask: Attention mask (padding or causal). -# kv_cache: Key/Value cache for generation (specific usage depends on sub-modules). -# is_causal: Flag for causal masking in self-attention. - -# Returns: -# Output tensor (Batch, SeqLen_Q, Dims). -# """ -# batch, ctx_q, _ = x.shape - -# if xa is not None: -# q_norm = self.ln_cross_in(x) -# k_cross = self.attn_global.key_module(xa) -# v_cross = self.attn_global.value_module(xa) -# q_cross = self.attn_global.query_module(q_norm) - -# cross_out = self.cross_attn(q=q_cross, k=k_cross, v=v_cross, mask=mask, is_causal=False) -# return self.global_cross_projection(cross_out) - -# local_input = self.ln_local_in(x) -# global_input = self.ln_global_in(x) - -# globe_out_raw = self.attn_global( -# global_input, -# xa=None, -# mask=mask, -# is_causal=is_causal -# ) -# globe_out = self.global_cross_projection(globe_out_raw) - -# base_freq_scale = self.span_pred(globe_out) - -# state = self._extract_rl_state(local_input) -# action = self.refiner.choose_action(state=state) -# refinement_scale = self._action_to_scale(action=action) -# final_span_scale = torch.clamp(base_freq_scale * refinement_scale.expand_as(base_freq_scale), min=0.0, max=1.0) - -# span_mean = final_span_scale.mean().item() -# with torch.no_grad(): -# current_win_size = max(1, int(self.sliding_window * span_mean)) -# current_span_len = max(1, int(self.max_span * span_mean)) -# local_out_raw = self._slide_win_local( -# x=local_input, -# win_size=current_win_size, -# span_len=current_span_len, -# span_scale=final_span_scale, -# mask=mask, -# is_causal=is_causal -# ) -# with torch.no_grad(): -# reward = self._calculate_rl_reward(output=local_out_raw) -# next_state = self._extract_rl_state(local_out_raw) -# self.refiner.update(state=state, action=action, reward=reward, next_state=next_state) - -# combined = torch.cat([local_out_raw, globe_out], dim=-1) -# output = self.self_projection(combined) - -# return output - -# def _calculate_rl_reward(self, output: Tensor) -> float: -# """Calculate quality metric (reward) for reinforcement learning.""" -# with torch.no_grad(): -# output_probs = torch.softmax(output, dim=-1) -# safe_probs = torch.clamp(output_probs, min=1e-10) -# entropy = -(safe_probs * torch.log(safe_probs)).sum(-1).mean() -# coverage = (output.abs() > 0.01).float().mean() -# reward = float(coverage - 0.1 * entropy) -# return reward - -# def _extract_rl_state(self, x: Tensor) -> int: -# """Extract discrete state features for RL agent from tensor.""" -# with torch.no_grad(): -# pooled = x.mean(dim=1) -# mean_state = pooled[0].mean() -# var_state = pooled[0].var(unbiased=False) -# state_features = torch.stack([mean_state, var_state]).cpu().numpy() -# state_id = self._discretize_state(state_features) -# return state_id - -# def _discretize_state(self, state: np.ndarray) -> int: -# """Convert continuous state numpy array to a discrete state ID.""" -# bins = np.linspace(-1, 1, num=10) -# state_discrete = np.digitize(state, bins) -# state_hash = sum(val * (10**i) for i, val in enumerate(state_discrete)) -# state_id = int(state_hash % self.refiner.states) -# return state_id - -# def _action_to_scale(self, action: int) -> Tensor: -# """Convert discrete RL action index to a continuous scale factor [0, 1].""" -# span_value = action / (self.refiner.actions - 1) -# scale_tensor = torch.tensor([span_value], device=self.span_pred.linear.weight.device, dtype=torch.float) -# return scale_tensor - -# def _focus(self, query: Tensor, key: Tensor, value: Tensor, -# span_scale: Tensor, mask: Optional[Tensor] = None, -# is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]: -# """ -# Iterative attention refinement. Applies attention multiple times, -# adding the output back to the query. Uses manual attention calculation. - -# Args: -# query, key, value: Input tensors (B, SeqLen_Window, D). -# span_scale: Scale factor (scalar or B, 1) influencing effective span. -# mask: Attention mask for the window. -# is_causal: Apply causal masking within the window. - -# Returns: -# Tuple (refined_output (B, SeqLen_Window, D), attention_weights (optional, None here)). -# """ -# max_iterations = 5 -# iteration = 0 -# prev_attn_out = torch.zeros_like(query) -# attn_out = torch.zeros_like(query) -# threshold = self.threshold.item() -# s_factor = self.s_factor.item() - -# q_current = query - -# while iteration < max_iterations: -# span_mean = span_scale.mean().item() -# target_span_len = max(1, int(self.max_span * span_mean)) -# eff_span = min(target_span_len, self.max_dist, q_current.size(1), key.size(1)) - -# if eff_span == 0: break - -# q_iter = q_current[:, :eff_span, :] -# k_iter = key[:, :eff_span, :] -# v_iter = value[:, :eff_span, :] - -# q_proj = self.attn_local.query_module(q_iter) -# k_proj = self.attn_local.key_module(k_iter) -# v_proj = self.attn_local.value_module(v_iter) - -# temperature = (1.0 + self.temp_scale * (1.0 - span_mean) -# if self.sharpen -# else 0.5 + self.temp_scale * span_mean) -# temperature = max(temperature, 1e-3) - -# iter_mask = None -# if mask is not None: -# if mask.dim() == 4: iter_mask = mask[:, :, :eff_span, :eff_span] -# elif mask.dim() == 2: iter_mask = mask[:eff_span, :eff_span] -# attn_output_iter, _ = calculate_attention( -# q_proj, k_proj, v_proj, -# mask=iter_mask, -# temperature=temperature, -# use_sdpa=False, -# is_causal=is_causal, -# dropout_p=self.dropout -# ) - -# attn_out_span = self.attn_local._reshape_to_output(attn_output_iter) -# projected_attn_out_span = self.attn_local.out_proj(attn_out_span) - -# current_iter_out = torch.zeros_like(q_current) -# current_iter_out[:, :eff_span, :] = projected_attn_out_span - -# diff = torch.abs(current_iter_out - prev_attn_out).mean() -# dynamic_threshold = threshold + s_factor * diff - -# if diff < dynamic_threshold and iteration > 0: -# attn_out = current_iter_out -# break - -# prev_attn_out = current_iter_out.clone() -# q_current = q_current + current_iter_out -# attn_out = current_iter_out - -# iteration += 1 - -# return attn_out, None - -# @autocast('cuda', enabled=torch.cuda.is_available()) -# def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, -# span_scale: Tensor, mask: Optional[Tensor] = None, -# is_causal: bool = False) -> Tensor: -# """ -# Process input with sliding window attention, using `_focus` for each window. - -# Args: -# x: Input tensor (Batch, SeqLen, Dims). -# win_size: Size of the attention window for queries. -# span_len: Max length of keys/values relative to query window start. -# span_scale: Span scale tensor (Batch, 1 or scalar) passed to _focus. -# mask: Full attention mask. -# is_causal: Apply causal masking within windows. - -# Returns: -# Output tensor (Batch, SeqLen, Dims). -# """ -# batch, ctx, dims = x.size() -# output = torch.zeros_like(x) - -# num_windows = (ctx + win_size - 1) // win_size - -# for i in range(num_windows): -# q_start = i * win_size -# q_end = min(q_start + win_size, ctx) -# current_window_q_len = q_end - q_start -# if current_window_q_len == 0: continue - -# kv_start = max(0, q_end - span_len) -# kv_end = q_end -# query_win = x[:, q_start:q_end, :] -# key_win = x[:, kv_start:kv_end, :] -# value_win = x[:, kv_start:kv_end, :] - -# window_mask = None -# if mask is not None: -# if mask.dim() == 4: -# window_mask = mask[:, :, q_start:q_end, kv_start:kv_end] -# elif mask.dim() == 2: -# window_mask = mask[q_start:q_end, kv_start:kv_end] -# attn_out_win, _ = self._focus( -# query=query_win, -# key=key_win, -# value=value_win, -# span_scale=span_scale, -# mask=window_mask, -# is_causal=is_causal -# ) - -# output[:, q_start:q_end, :] = attn_out_win - -# return output - -# class CTCDecoder(nn.Module): -# def __init__(self, input_dim: int, vocab_size: int, dims: int = 256, num_layers: int = 2, dropout: float = 0.1): -# super().__init__() -# self.input_dim = input_dim -# self.vocab_size = vocab_size -# self.dims = dims - -# self.projection = nn.Linear(input_dim, dims) -# self.lstm = nn.LSTM(dims, dims, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True, bidirectional=True) -# self.output = nn.Linear(dims * 2, vocab_size + 1) # +1 for CTC blank token -# self.dropout = nn.Dropout(dropout) - -# def forward(self, x: Tensor) -> Tensor: -# x = self.projection(x) # (batch, seq_len, dims) -# x = self.dropout(x) -# x, _ = self.lstm(x) # (batch, seq_len, dims * 2) -# x = self.dropout(x) -# logits = self.output(x) # (batch, seq_len, vocab_size + 1) -# return logits - -# class CTCWrapper(nn.Module): -# def __init__(self, model: Model, vocab_size: int, dims: int = 256, num_layers: int = 2): -# super().__init__() -# self.model = model -# self.ctc_decoder = CTCDecoder( -# input_dim=model.param.dims, -# vocab_size=vocab_size, -# dims=dims, -# num_layers=num_layers -# ) - -# def forward(self, input_ids=None, pitch=None, labels=None, input_lengths=None, label_lengths=None): -# outputs = self.model(input_ids=input_ids, pitch=pitch) -# x = outputs["logits"] # (batch, seq_len, vocab_size) -# ctc_logits = self.ctc_decoder(x) # (batch, seq_len, vocab_size + 1) -# loss = None -# if labels is not None and input_lengths is not None and label_lengths is not None: -# log_probs = torch.log_softmax(ctc_logits, dim=-1) -# log_probs = log_probs.transpose(0, 1) - -# loss = torch.nn.functional.ctc_loss( -# log_probs, -# labels, -# input_lengths, -# label_lengths, -# blank=0, -# reduction='mean' -# ) - -# return { -# "logits": ctc_logits, -# "loss": loss, -# "out": x -# } - -# def decode(self, logits: Tensor, input_lengths: Optional[Tensor] = None) -> List[List[int]]: -# probs = torch.softmax(logits, dim=-1) # (batch, seq_len, vocab_size + 1) -# predictions = torch.argmax(probs, dim=-1) # (batch, seq_len) - -# decoded_sequences = [] -# for i, pred in enumerate(predictions): -# seq = [] -# prev_token = None -# for j, token in enumerate(pred): -# if input_lengths is not None and j >= input_lengths[i]: -# break -# if token != 0 and token != prev_token: -# seq.append(token.item()) -# prev_token = token -# decoded_sequences.append(seq) -# return decoded_sequences - -# # ctc_model = CTCWrapper(model, vocab_size=40000, dims=256, num_layers=2) - -# # outputs = ctc_model( -# # input_ids=input_ids, -# # pitch=pitch, -# # labels=labels, -# # input_lengths=input_lengths, # Length of each audio sequence -# # label_lengths=label_lengths # Length of each text sequence -# # ) - -# # loss = outputs["loss"] - -# # outputs = ctc_model(input_ids=input_ids, pitch=pitch) -# # logits = outputs["logits"] - -# # # Decode to text -# # decoded_sequences = ctc_model.decode(logits, input_lengths=input_lengths) -# # ctc_model = CTCWrapper(model, vocab_size=param.vocab, dims=256, num_layers=2).to('cuda') - -# # print(f"CTC model parameters: {sum(p.numel() for p in ctc_model.parameters() if p.requires_grad):,}") - -# # from tensorboard import program -# # log_dir = "D:/newmodel/output/logs" -# # tb = program.TensorBoard() -# # tb.configure(argv=[None, '--logdir', log_dir]) -# # url = tb.launch() -# # print(f"TensorBoard started at {url}") - -# def compute_metricsB(pred, tokenizer): -# pred_ids = pred["predictions"] -# label_ids = pred["label_ids"] -# if isinstance(pred_ids, tuple): -# pred_ids = pred_ids[0] -# else: -# pred_ids = pred_ids -# if pred_ids.ndim == 3: -# pred_ids = np.argmax(pred_ids, axis=-1) -# label_ids[label_ids == -100] = tokenizer.pad_token_id -# pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) -# label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) -# metrics = evaluate.load(path="wer") -# wer = metrics.compute(predictions=pred_str, references=label_str) -# return {"wer": wer} - -# def train_and_evaluate( -# model, -# tokenizer, -# train_loader, -# eval_loader, -# optimizer, -# scheduler, -# loss_fn, -# max_steps=10000, -# device="cuda", -# accumulation_steps=1, -# clear_cache=True, -# log_interval=10, -# eval_interval=100, -# save_interval=1000, -# checkpoint_dir="checkpoint_dir", -# log_dir="log_dir", -# ): -# model.to(device) -# global_step = 0 -# scaler = torch.GradScaler() -# writer = SummaryWriter(log_dir=log_dir) -# train_iterator = iter(train_loader) -# total_loss = 0 -# step_in_report = 0 -# dataset_epochs = 0 - -# progress_bar = tqdm( -# total=max_steps, desc="Training Progress", leave=True, colour="green" -# ) - -# model.train() -# optimizer.zero_grad() - -# while global_step < max_steps: -# try: -# batch = next(train_iterator) -# except StopIteration: -# train_iterator = iter(train_loader) -# batch = next(train_iterator) -# dataset_epochs += 1 -# print(f"Starting dataset epoch {dataset_epochs}") - -# if step_in_report > 0: -# avg_loss = total_loss / step_in_report -# logging.info( -# f"Dataset iteration complete - Steps: {global_step}, Avg Loss: {avg_loss:.4f}" -# ) -# total_loss = 0 -# step_in_report = 0 - -# start_time = time.time() - -# input_features = batch["input_features"].to(device) -# input_ids = batch["input_ids"].to(device) -# labels = batch["labels"].long().to(device) - -# with torch.autocast(device_type="cuda"): -# input_features_encoded = model.encoder(input_features) -# decoder_output = model.decoder(input_ids, input_features_encoded) -# logits = decoder_output.view(-1, decoder_output.size(-1)) -# active_logits = logits.view(-1, decoder_output.size(-1)) -# active_labels = labels.view(-1) -# active_mask = active_labels != -100 -# active_logits = active_logits[active_mask] -# active_labels = active_labels[active_mask] -# loss = loss_fn(active_logits, active_labels) -# # model.adjust_freq(loss=loss.item()) -# total_loss += loss.item() -# loss = loss / accumulation_steps - -# scaler.scale(loss).backward() - -# if (global_step + 1) % accumulation_steps == 0: -# scaler.unscale_(optimizer) -# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) -# scaler.step(optimizer) -# scaler.update() -# optimizer.zero_grad() - -# if clear_cache: -# torch.cuda.empty_cache() - -# end_time = time.time() -# samples_per_sec = len(batch["input_features"]) / (end_time - start_time) - -# if global_step % log_interval == 0: -# writer.add_scalar( -# tag="Loss/train", -# scalar_value=total_loss / (global_step + 1), -# global_step=global_step, -# ) -# lr = scheduler.get_last_lr()[0] -# writer.add_scalar( -# tag="LearningRate", scalar_value=lr, global_step=global_step -# ) -# writer.add_scalar( -# tag="SamplesPerSec", -# scalar_value=samples_per_sec, -# global_step=global_step, -# ) - -# if global_step % eval_interval == 0: -# model.eval() -# eval_start_time = time.time() -# eval_loss = 0 -# all_predictions = [] -# all_labels = [] -# batch_count = 0 -# total_samples = 0 - -# with torch.no_grad(): -# for eval_batch in eval_loader: -# # for eval_batch in tqdm(eval_loader, desc=f"Evaluating (Step {global_step})", leave=True, colour='green'): -# input_features = eval_batch["input_features"].to(device) -# input_ids = eval_batch["input_ids"].to(device) -# labels = eval_batch["labels"].long().to(device) - -# batch = input_features.size(0) -# total_samples += batch - -# input_features_encoded = model.encoder(input_features) -# decoder_output = model.decoder(input_ids, input_features_encoded) -# logits = decoder_output.view(-1, decoder_output.size(-1)) -# loss = loss_fn(logits, labels.view(-1)) -# eval_loss += loss.item() -# all_predictions.extend( -# torch.argmax(decoder_output, dim=-1).cpu().numpy().tolist() -# ) -# all_labels.extend(labels.cpu().numpy().tolist()) -# batch_count += 1 - -# eval_time = time.time() - eval_start_time -# loss_avg = eval_loss / batch_count if batch_count > 0 else 0 -# predictions = { -# "predictions": np.array(all_predictions, dtype=object), -# "label_ids": np.array(all_labels, dtype=object), -# } -# metrics = compute_metrics(pred=predictions, tokenizer=tokenizer) - -# writer.add_scalar("Loss/eval", loss_avg, global_step) -# writer.add_scalar("WER", metrics["wer"], global_step) -# writer.add_scalar("EvalSamples", total_samples, global_step) -# writer.add_scalar("EvalTimeSeconds", eval_time, global_step) -# lr = scheduler.get_last_lr()[0] - -# print( -# f"• STEP:{global_step} • samp:{samples_per_sec:.1f} • WER:{metrics['wer']:.2f}% • Loss:{loss_avg:.4f} • LR:{lr:.8f}" -# ) - -# logging.info( -# f"EVALUATION STEP {global_step} - WER: {metrics['wer']:.2f}%, Loss: {loss_avg:.4f}, LR: {lr:.8f}" -# ) -# # scheduler.step() -# model.train() - -# if global_step % save_interval == 0: -# checkpoint_path = os.path.join( -# checkpoint_dir, f"checkpoint_step_{global_step}.pt" -# ) -# torch.save(model.state_dict(), checkpoint_path) -# # print(f"Model saved at step {global_step} to {checkpoint_path}") -# logging.info(f"Model saved at step {global_step} to {checkpoint_path}") - -# lr = scheduler.get_last_lr()[0] -# scheduler.step() -# global_step += 1 -# step_in_report += 1 - -# avg_loss = total_loss / (global_step + 1) -# postfix_dict = { -# "loss": f"{avg_loss:.4f}", -# "lr": f"{lr:.6f}", -# "samp": f"{samples_per_sec:.1f}", -# } -# progress_bar.set_postfix(postfix_dict, refresh=True) -# progress_bar.update(1) - -# final_model_path = os.path.join(checkpoint_dir, "final_model.pt") -# torch.save(model.state_dict(), final_model_path) -# print( -# f"Training completed after {global_step} steps. Final model saved to {final_model_path}" -# ) -# writer.close() -# progress_bar.close() - -# def mainB(): - -# checkpoint_dir = "./output/checkpoints" -# os.makedirs(checkpoint_dir, exist_ok=True) -# log_dir = os.path.join("./output/logs", datetime.now().strftime(format="%m-%d_%H")) -# os.makedirs(name=log_dir, exist_ok=True) - -# logging.basicConfig( -# filename=os.path.join(log_dir, "training.log"), -# filemode="w", -# format="%(asctime)s - %(levelname)s - %(message)s", -# level=logging.INFO, -# ) - -# token = "" -# dataset = IterableDatasetDict() -# dataset["train"] = load_dataset( -# path="google/fleurs", -# name="en_us", -# split="train", -# streaming=True, -# token=token, -# trust_remote_code=True, -# ).select_columns(column_names=["audio", "transcription"]) - -# dataset["test"] = load_dataset( -# path="google/fleurs", -# name="en_us", -# split="test", -# streaming=True, -# token=token, -# trust_remote_code=True, -# ).select_columns(column_names=["audio", "transcription"]) - -# debug = None - -# param = Dimensions( -# mels=128, -# audio_ctx=1500, -# audio_head=4, -# encoder_idx=4, -# audio_dims=512, -# vocab=51865, -# text_ctx=512, -# text_head=4, -# decoder_idx=4, -# text_dims=512, -# decoder_start_token_id=0, -# pad_token_id=0, -# eos_token_id=0, -# act="gelu", -# ) - -# model = model - -# Collator = DataCollatorB( -# tokenizer=tokenizer, -# audio_ctx=param.audio_ctx, -# text_ctx=param.text_ctx, -# mels=param.mels, -# ) - -# train_dataloader = DataLoader( -# dataset=dataset["train"], batch_size=1, collate_fn=Collator, num_workers=0 -# ) - -# eval_dataloader = DataLoader( -# dataset=dataset["test"], batch_size=1, collate_fn=Collator, num_workers=0 -# ) - -# optimizer = torch.optim.AdamW( -# model.parameters(), lr=5e-4, weight_decay=0.01, eps=1e-6, betas=(0.9, 0.98) -# ) -# scheduler = torch.optim.lr_scheduler.LinearLR( -# optimizer, start_factor=0.25, total_iters=10000, last_epoch=-1 -# ) - -# loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100) - -# train_and_evaluate( -# model=model, -# tokenizer=tokenizer, -# train_loader=train_dataloader, -# eval_loader=eval_dataloader, -# optimizer=optimizer, -# scheduler=scheduler, -# loss_fn=loss_fn, -# max_steps=10000, -# device="cuda", -# accumulation_steps=1, -# clear_cache=False, -# log_interval=10, -# eval_interval=500, -# save_interval=10000, -# checkpoint_dir=checkpoint_dir, -# log_dir=log_dir, -# ) - -# def train_and_evaluate( -# model, tokenizer, train_loader, eval_loader, optimizer, scheduler, loss_fn, -# max_steps=10000, device='cuda', accumulation_steps=1, clear_cache=True, -# log_interval=10, eval_interval=100, save_interval=1000, -# checkpoint_dir="checkpoint_dir", log_dir="log_dir" -# ): -# model.to(device) -# global_step = 0 -# scaler = torch.GradScaler() -# writer = SummaryWriter(log_dir=log_dir) -# train_iterator = iter(train_loader) -# total_loss = 0 -# step_in_report = 0 -# dataset_epochs = 0 - -# progress_bar = tqdm(total=max_steps, desc="Training Progress", leave=True, colour='green') - -# model.train() -# optimizer.zero_grad() - -# while global_step < max_steps: -# try: -# batch = next(train_iterator) -# except StopIteration: -# train_iterator = iter(train_loader) -# batch = next(train_iterator) -# dataset_epochs += 1 -# print(f"Starting dataset epoch {dataset_epochs}") - -# if step_in_report > 0: -# avg_loss = total_loss / step_in_report -# logging.info(f"Dataset iteration complete - Steps: {global_step}, Avg Loss: {avg_loss:.4f}") -# total_loss = 0 -# step_in_report = 0 - -# start_time = time.time() - -# batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} - -# with torch.autocast(device_type="cuda"): -# output = model(**batch) if hasattr(model, '__call__') else model.forward(**batch) -# logits = output["logits"] if isinstance(output, dict) and "logits" in output else output -# labels = batch["labels"] -# active_logits = logits.view(-1, logits.size(-1)) -# active_labels = labels.view(-1) -# active_mask = active_labels != 0 -# active_logits = active_logits[active_mask] -# active_labels = active_labels[active_mask] -# loss = loss_fn(active_logits, active_labels) -# total_loss += loss.item() -# loss = loss / accumulation_steps - -# scaler.scale(loss).backward() - -# if (global_step + 1) % accumulation_steps == 0: -# scaler.unscale_(optimizer) -# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) -# scaler.step(optimizer) -# scaler.update() -# optimizer.zero_grad() -# if clear_cache: -# torch.cuda.empty_cache() - -# end_time = time.time() -# samples_per_sec = batch["spectrogram"].size(0) / (end_time - start_time) - -# if global_step % log_interval == 0: -# writer.add_scalar(tag='Loss/train', scalar_value=total_loss / (global_step + 1), global_step=global_step) -# lr = scheduler.get_last_lr()[0] -# writer.add_scalar(tag='LearningRate', scalar_value=lr, global_step=global_step) -# writer.add_scalar(tag='SamplesPerSec', scalar_value=samples_per_sec, global_step=global_step) - -# if global_step % eval_interval == 0: -# model.eval() -# eval_start_time = time.time() -# eval_loss = 0 -# all_predictions = [] -# all_labels = [] -# batch_count = 0 -# total_samples = 0 - -# with torch.no_grad(): -# for eval_batch in eval_loader: -# eval_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in eval_batch.items()} -# output = model(**eval_batch) if hasattr(model, '__call__') else model.forward(**eval_batch) -# logits = output["logits"] if isinstance(output, dict) and "logits" in output else output -# labels = eval_batch["labels"] -# batch_size = logits.size(0) -# total_samples += batch_size -# loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) -# eval_loss += loss.item() -# all_predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy().tolist()) -# all_labels.extend(labels.cpu().numpy().tolist()) -# batch_count += 1 - -# eval_time = time.time() - eval_start_time -# loss_avg = eval_loss / batch_count if batch_count > 0 else 0 -# predictions = {"predictions": np.array(all_predictions, dtype=object), "label_ids": np.array(all_labels, dtype=object)} -# metrics = compute_metrics(pred=predictions, tokenizer=tokenizer) - -# writer.add_scalar('Loss/eval', loss_avg, global_step) -# writer.add_scalar('WER', metrics['wer'], global_step) -# writer.add_scalar('EvalSamples', total_samples, global_step) -# writer.add_scalar('EvalTimeSeconds', eval_time, global_step) - -# lr = scheduler.get_last_lr()[0] -# print(f"• STEP:{global_step} • samp:{samples_per_sec:.1f} • WER:{metrics['wer']:.2f}% • Loss:{loss_avg:.4f} • LR:{lr:.8f}") -# logging.info(f"EVALUATION STEP {global_step} - WER: {metrics['wer']:.2f}%, Loss: {loss_avg:.4f}, LR: {lr:.8f}") -# model.train() - -# if global_step % save_interval == 0: -# checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{global_step}.pt') -# torch.save(model.state_dict(), checkpoint_path) -# logging.info(f"Model saved at step {global_step} to {checkpoint_path}") - -# lr = scheduler.get_last_lr()[0] -# scheduler.step() -# global_step += 1 -# step_in_report += 1 - -# avg_loss = total_loss / (global_step + 1) -# postfix_dict = { -# 'loss': f'{avg_loss:.4f}', -# 'lr': f'{lr:.6f}', -# 'samp': f'{samples_per_sec:.1f}' -# } -# progress_bar.set_postfix(postfix_dict, refresh=True) -# progress_bar.update(1) - -# final_model_path = os.path.join(checkpoint_dir, 'final_model.pt') -# torch.save(model.state_dict(), final_model_path) -# print(f"Training completed after {global_step} steps. Final model saved to {final_model_path}") -# writer.close() -# progress_bar.close() - -# def get_optimizer(model, lr=5e-4, weight_decay=0.01): -# return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-6, betas=(0.9, 0.98)) - -# def get_scheduler(optimizer, total_steps=10000): -# return torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.25, total_iters=total_steps, last_epoch=-1) - -# def get_loss_fn(): -# return torch.nn.CrossEntropyLoss(ignore_index=0) - -# def mainc(): -# token = "" -# log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S')) -# os.makedirs(log_dir, exist_ok=True) -# tokenizer = setup_tokenizer(token) - -# param = Dimensions( -# mels=128, aud_ctx=1500, aud_head=4, aud_dims=512, aud_idx=4, -# vocab=40000, text_ctx=512, text_head=4, text_dims=512, text_idx=4, -# act="swish", debug={}, cross_attn=True, features=["spectrogram"] -# ) - -# dataset_config = { -# "spectrogram": True, "waveforms": False, "pitch": False, "downsamples": False, -# "frequency": True, "hilbert": False, "hop_length": 128, "fmin": 150, "fmax": 2000, -# "n_mels": 128, "n_fft": 1024, "sampling_rate": 16000, "pad_mode": "constant", -# "center": True, "power": 2.0, "window_fn": torch.hann_window, "mel_scale": "htk", -# "norm": None, "normalized": False -# } - -# model = create_model(param) -# train_dataset, test_dataset = prepare_datasets( -# tokenizer=tokenizer, token=token, sanity_check=False, dataset_config=dataset_config -# ) - -# collator = DataCollator(tokenizer=tokenizer) -# train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collator, num_workers=0) -# eval_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collator, num_workers=0) - -# optimizer = get_optimizer(model) -# scheduler = get_scheduler(optimizer) -# loss_fn = get_loss_fn() - -# train_and_evaluate( -# model=model, -# tokenizer=tokenizer, -# train_loader=train_loader, -# eval_loader=eval_loader, -# optimizer=optimizer, -# scheduler=scheduler, -# loss_fn=loss_fn, -# max_steps=10000, -# device='cuda', -# accumulation_steps=1, -# clear_cache=False, -# log_interval=10, -# eval_interval=500, -# save_interval=10000, -# checkpoint_dir="./checkpoints", -# log_dir=log_dir -# ) - -# class attention(nn.Module): -# def __init__(self, dims: int, head: int): -# super(attention, self).__init__() -# self.dims = dims -# self.head = head -# self.head_dim = dims // head -# self.q = nn.Linear(dims, dims) -# self.k = nn.Linear(dims, dims, bias=False) -# self.v = nn.Linear(dims, dims) -# self.o = nn.Linear(dims, dims) - -# self.lna = nn.LayerNorm(dims, bias = False) -# self.lnb = nn.LayerNorm(dims, bias = False) -# self.lnc = nn.LayerNorm(self.head_dim, bias = False) -# self.lnd = nn.LayerNorm(self.head_dim, bias = False) - -# def _forward(self, x: Tensor, xa = None, mask = None): -# q = self.q(self.lna(x)) -# k = self.k(self.lnb(x if xa is None else xa)) -# v = self.v(self.lnb(x if xa is None else xa)) -# query = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3) -# key = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3) -# value = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3) - -# max_iterations = 5 -# iteration = 0 -# prev_attn_out = torch.zeros_like(query) -# attn_out = torch.zeros_like(query) -# threshold = self.threshold.item() -# s_factor = self.s_factor.item() - -# q_current = query - -# while iteration < max_iterations: - -# eff_span = min(x.shape[1], xa.shape[1], q_current.size(1), key.size(1)) - -# if eff_span == 0: break - -# q_iter = q_current[:, :eff_span, :] -# k_iter = key[:, :eff_span, :] -# v_iter = value[:, :eff_span, :] - -# q_proj = self.attn_local.query_module(q_iter) -# k_proj = self.attn_local.key_module(k_iter) -# v_proj = self.attn_local.value_module(v_iter) - -# temperature = (1.0 + self.temp_scale * (1.0 - xa.mean()) -# if self.sharpen -# else 0.5 + self.temp_scale * xa.mean()) -# temperature = max(temperature, 1e-3) - -# iter_mask = None -# if mask is not None: -# if mask.dim() == 4: iter_mask = mask[:, :, :eff_span, :eff_span] -# elif mask.dim() == 2: iter_mask = mask[:eff_span, :eff_span] - -# attn_output_iter, _ = calculate_attention( -# q_proj, k_proj, v_proj, -# mask=iter_mask, -# temperature=temperature, -# use_sdpa=False, -# dropout_p=self.dropout -# ) - -# attn_out_span = self.attn_local._reshape_to_output(attn_output_iter) -# projected_attn_out_span = self.attn_local.out_proj(attn_out_span) - -# current_iter_out = torch.zeros_like(q_current) -# current_iter_out[:, :eff_span, :] = projected_attn_out_span - -# diff = torch.abs(current_iter_out - prev_attn_out).mean() -# dynamic_threshold = threshold + s_factor * diff - -# if diff < dynamic_threshold and iteration > 0: -# attn_out = current_iter_out -# break - -# prev_attn_out = current_iter_out.clone() -# q_current = q_current + current_iter_out -# attn_out = current_iter_out - -# iteration += 1 - -# return attn_out, None - -# def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, -# span_scale: Tensor, mask: Optional[Tensor] = None, -# is_causal: bool = False) -> Tensor: -# """ -# Process input with sliding window attention, using `_focus` for each window. - -# Args: -# x: Input tensor (Batch, SeqLen, Dims). -# win_size: Size of the attention window for queries. -# span_len: Max length of keys/values relative to query window start. -# span_scale: Span scale tensor (Batch, 1 or scalar) passed to _focus. -# mask: Full attention mask. -# is_causal: Apply causal masking within windows. - -# Returns: -# Output tensor (Batch, SeqLen, Dims). -# """ -# batch, ctx, dims = x.size() -# output = torch.zeros_like(x) - -# num_windows = (ctx + win_size - 1) // win_size - -# for i in range(num_windows): -# q_start = i * win_size -# q_end = min(q_start + win_size, ctx) -# current_window_q_len = q_end - q_start -# if current_window_q_len == 0: continue - -# kv_start = max(0, q_end - span_len) -# kv_end = q_end -# query_win = x[:, q_start:q_end, :] -# key_win = x[:, kv_start:kv_end, :] -# value_win = x[:, kv_start:kv_end, :] - -# window_mask = None -# if mask is not None: -# if mask.dim() == 4: -# window_mask = mask[:, :, q_start:q_end, kv_start:kv_end] -# elif mask.dim() == 2: -# window_mask = mask[q_start:q_end, kv_start:kv_end] - -# attn_out_win, _ = self._focus( -# query=query_win, -# key=key_win, -# value=value_win, -# span_scale=span_scale, -# mask=window_mask, -# is_causal=is_causal -# ) - -# output[:, q_start:q_end, :] = attn_out_win - -# return output \ No newline at end of file