Sin2pi commited on
Commit
d56f7a6
·
verified ·
1 Parent(s): 8d65545

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +221 -55
modelA.py CHANGED
@@ -32,6 +32,13 @@ dtype = torch.float32
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
 
 
 
 
 
 
 
35
  def get_activation(act: str) -> nn.Module:
36
  """Get activation function by name."""
37
  act_map = {
@@ -193,7 +200,7 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
193
  axs[0].legend(loc='upper right', fontsize='small')
194
  axs[-1].set_xlabel("t (s)")
195
  fig.suptitle(title, fontsize=16)
196
- plt.tight_layout(rect=[0, 0, 1, 0.97]) # type: ignore
197
  plt.show()
198
  return fig
199
 
@@ -245,17 +252,17 @@ class RMSNorm(nn.Module):
245
  self.eps = eps
246
  self.elementwise_affine = elementwise_affine
247
  if self.elementwise_affine:
248
- self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
249
  init.ones_(self.weight)
250
  else:
251
  self.register_parameter("weight", None)
252
  def forward(self, x):
253
- return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
254
 
255
  def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
256
  weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
257
  eps: float = 1e-5) -> Tensor:
258
- return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
259
 
260
  def get_device():
261
  return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -339,7 +346,7 @@ class rotary(nn.Module):
339
  elif isinstance(x, torch.Tensor) and x.ndim == 3:
340
  batch, ctx, dims = x.shape
341
  else:
342
- batch, head, ctx, head_dim = x.shape # type: ignore
343
 
344
  if f0 is not None:
345
  if f0.dim() == 2:
@@ -365,7 +372,6 @@ class rotary(nn.Module):
365
  radius_mean = radius.mean() if 'radius' in locals() else 0.0
366
  print(f" [{layer}] [Radius] {radius_shape} {radius_mean:.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
367
  print(f" [{layer}] [Radius] {radius}")
368
- # self.theta_values.append(theta.item())
369
  self.counter += 1
370
  return freqs.unsqueeze(0)
371
 
@@ -386,7 +392,8 @@ class MultiheadA(nn.Module):
386
 
387
  rbf = False
388
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
389
- zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
 
390
  super(MultiheadA, self).__init__()
391
 
392
  self.dims = dims
@@ -419,6 +426,16 @@ class MultiheadA(nn.Module):
419
  else:
420
  self.rope = None
421
 
 
 
 
 
 
 
 
 
 
 
422
  def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
423
  q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
424
  k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
@@ -440,7 +457,7 @@ class MultiheadA(nn.Module):
440
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
441
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
442
 
443
- def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
444
 
445
  x = x.to(device, dtype)
446
  if xa is not None:
@@ -459,8 +476,8 @@ class MultiheadA(nn.Module):
459
  q2 = q.shape[2]
460
  k2 = k.shape[2]
461
 
462
- q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer))) # type: ignore
463
- k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer))) # type: ignore
464
  else:
465
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
466
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -468,10 +485,26 @@ class MultiheadA(nn.Module):
468
 
469
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  if self.rbf:
472
  qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
473
  if self.use_pbias:
474
- pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None) # type: ignore
475
  if pbias is not None:
476
  qk = qk + pbias[:,:,:q2,:q2]
477
 
@@ -481,9 +514,6 @@ class MultiheadA(nn.Module):
481
  zscale[token_ids.float() == self.pad_token] = fzero
482
 
483
  if mask is not None:
484
- # mask = mask[:q2, :q2]#torch.tril(torch.ones(q2, q2, device=q.device))
485
- # audio_mask = torch.ones(q2, k2 - q2, device=q.device)
486
- # mask = torch.cat([mask, audio_mask], dim=-1)
487
  mask = mask.unsqueeze(0).unsqueeze(0)
488
  qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
489
 
@@ -499,7 +529,7 @@ class MultiheadA(nn.Module):
499
  class FocusWindow(nn.Module):
500
 
501
  def __init__(self, dims: int, head: int, max_span: int = 512, max_dist: int = 256,
502
- feature_type: str = "waveform", debug: List[str] = []):
503
  super().__init__()
504
  self.dims = dims
505
  self.head = head
@@ -508,22 +538,19 @@ class FocusWindow(nn.Module):
508
  self.max_dist = max_dist
509
  self.feature_type = feature_type
510
  self.debug = debug
511
-
512
- # Adaptive parameters for focus control
513
  self.threshold = nn.Parameter(torch.tensor(0.01))
514
  self.s_factor = nn.Parameter(torch.tensor(0.1))
515
  self.temp_scale = nn.Parameter(torch.tensor(1.0))
516
  self.sharpen = True
517
 
518
- # Feature-specific projections
519
  self.q_proj = Linear(dims, dims)
520
  self.k_proj = Linear(dims, dims)
521
  self.v_proj = Linear(dims, dims)
522
 
523
- # Bias strength controller
524
  self.bias_strength = nn.Parameter(torch.tensor(0.5))
525
 
526
- # Feature-specific window sizes
527
  self.window_sizes = {
528
  "spectrogram": 128,
529
  "waveform": 256,
@@ -532,7 +559,6 @@ class FocusWindow(nn.Module):
532
  "phase": 64
533
  }
534
 
535
- # Feature-specific span lengths
536
  self.span_lengths = {
537
  "spectrogram": 256,
538
  "waveform": 512,
@@ -541,16 +567,32 @@ class FocusWindow(nn.Module):
541
  "phase": 128
542
  }
543
 
544
- def _focus(self, q, k, v, span_scale, mask=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
 
 
546
  q_energy = torch.norm(q, dim=-1).mean()
547
  k_energy = torch.norm(k, dim=-1).mean()
548
  content_richness = (q_energy + k_energy) / 2
549
 
550
- # Dynamic max iterations: more interesting content = more iterations
551
  base_iterations = 3
552
  max_iterations = int(base_iterations + content_richness * 12)
553
- max_iterations = min(max_iterations, 20) # Cap at 20
554
 
555
  iteration = 0
556
  prev_attn = torch.zeros_like(q)
@@ -570,13 +612,13 @@ class FocusWindow(nn.Module):
570
 
571
  q_span = q[:, :eff_span, :]
572
  k_span = k[:, :eff_span, :]
573
- v_span = k[:, :eff_span, :]
574
 
575
  batch, ctx, dims = q_span.size()
576
 
577
- q = q_span.view(batch, ctx, self.head, -1).transpose(1, 2)
578
- k = k_span.view(batch, ctx, self.head, -1).transpose(1, 2)
579
- v = v_span.view(batch, ctx, self.head, -1).transpose(1, 2)
580
 
581
  if self.sharpen:
582
  temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
@@ -584,11 +626,11 @@ class FocusWindow(nn.Module):
584
  temperature = 0.5 + self.temp_scale * span_scale.mean().item()
585
 
586
  scale = (dims // self.head) ** -0.5
587
- attn = torch.matmul(q, k.transpose(-1, -2)) * scale
588
 
589
  if mask is not None:
590
  if mask.dim() == 4:
591
- q_len, k_len = q.size(2), k.size(2)
592
  mask_q_len = min(mask.size(2), q_len)
593
  mask_k_len = min(mask.size(3), k_len)
594
 
@@ -603,7 +645,7 @@ class FocusWindow(nn.Module):
603
  attn = F.softmax(attn, dim=-1)
604
 
605
  if mask is not None and mask.dtype == torch.bool:
606
- q_len, k_len = q.size(2), k.size(2)
607
  mask_q_len = min(mask.size(2), q_len)
608
  mask_k_len = min(mask.size(3), k_len)
609
 
@@ -616,8 +658,11 @@ class FocusWindow(nn.Module):
616
 
617
  attn[:, :, :mask_q_len, :mask_k_len] = attn_to_mask
618
 
619
- attn_output = torch.matmul(attn, v)
620
- attn_out = attn_output.transpose(1, 2).contiguous().view(batch, ctx, -1)
 
 
 
621
 
622
  diff = torch.abs(attn_out - prev_attn).mean()
623
  dynamic_threshold = threshold + s_factor * diff
@@ -626,7 +671,6 @@ class FocusWindow(nn.Module):
626
  break
627
 
628
  prev_attn = attn_out
629
- q = q + attn_out
630
  iteration += 1
631
 
632
  return attn_out, attn_weights
@@ -644,9 +688,9 @@ class FocusWindow(nn.Module):
644
  k_start = max(0, start_idx - span_len + win_size)
645
  k_end = min(start_idx + span_len, ctx)
646
 
647
- q = x[:, start_idx:end_idx, :]
648
- k = x[:, k_start:k_end, :]
649
- k = k
650
 
651
  window_mask = None
652
  if mask is not None:
@@ -656,32 +700,37 @@ class FocusWindow(nn.Module):
656
  if window_mask.size(1) == 1:
657
  window_mask = window_mask.expand(-1, self.head, -1, -1)
658
 
659
- attn_out, _ = self._focus(
660
- q=q, k=k, v=v, span_scale=span_scale, mask=window_mask
661
- )
662
 
663
  output[:, start_idx:end_idx, :] = attn_out
664
 
665
  return output
666
 
667
- def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=True):
 
 
 
 
 
 
 
 
 
 
 
668
  q = self.q_proj(x)
669
  k = self.k_proj(x if xa is None else xa)
670
  v = self.v_proj(x if xa is None else xa)
671
 
672
- # Create span scale based on feature characteristics
673
  if xa is not None:
674
- # Feature-specific span scaling
675
  feature_energy = torch.norm(xa, dim=-1).mean(dim=-1, keepdim=True)
676
  span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
677
  else:
678
  span_scale = torch.ones(x.size(0), 1, device=x.device)
679
 
680
- # Get feature-specific parameters
681
  win_size = self.window_sizes.get(self.feature_type, 128)
682
  span_len = self.span_lengths.get(self.feature_type, 256)
683
 
684
- # Apply sliding window with focus attention
685
  output = self.slide_win(
686
  x=q,
687
  win_size=win_size,
@@ -689,14 +738,133 @@ class FocusWindow(nn.Module):
689
  span_scale=span_scale,
690
  mask=mask
691
  )
692
-
 
 
 
 
 
 
 
 
693
  if return_bias:
694
- # Return as bias for main attention
695
  bias_strength = torch.sigmoid(self.bias_strength)
696
  return bias_strength * output
697
  else:
698
  return output
699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  class t_gate(nn.Module):
701
  def __init__(self, dims, num_types=4, enabled=True):
702
  super().__init__()
@@ -821,6 +989,7 @@ class Residual(nn.Module):
821
  bx = b * ax + (1 - b) * x
822
  cx = self.lnb(bx)
823
  dx = self.mlp(cx)
 
824
  ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
825
  fx = x + ex + dx
826
  gx = self.lnc(fx)
@@ -1066,7 +1235,7 @@ class SpeechTransformer(nn.Module):
1066
  for f in self.features:
1067
  if f in enc and f in self.blocks:
1068
  xa = enc[f]
1069
- for block in self.blocks[f]: # type: ignore
1070
  xa = block(xa, enc=enc, layer=layer)
1071
  out[f] = xa
1072
  xa = xa + self.audio_embedding[:xa.shape[1]]
@@ -1327,7 +1496,6 @@ def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **data
1327
  log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1328
  spec = (log_mel + 4.0) / 4.0
1329
  spec = torch.tensor(spec)
1330
- # batch["spectrogram"] = spec
1331
 
1332
  wav_np = wav.numpy().astype(np.float64)
1333
  f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
@@ -1340,8 +1508,6 @@ def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **data
1340
  "spectrogram": spec,
1341
  "f0": f0,
1342
  "labels": labels,
1343
- # "waveform": wav,
1344
- # "pitch": f0,
1345
  }
1346
 
1347
  def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
@@ -1569,11 +1735,11 @@ def main():
1569
  trainer = Seq2SeqTrainer(
1570
  args=training_args,
1571
  model=model,
1572
- train_dataset=train_dataset, # type: ignore
1573
- eval_dataset=test_dataset, # type: ignore
1574
- data_collator=DataCollator(tokenizer=tokenizer), # type: ignore
1575
  compute_metrics=metrics_fn,
1576
- optimizers=(optimizer, scheduler) # type: ignore
1577
  )
1578
  model.init_weights()
1579
  trainer.train()
 
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
+ PATH = 'E:/hf'
36
+ os.environ['HF_HOME'] = PATH
37
+ os.environ['HF_DATASETS_CACHE'] = PATH
38
+ os.environ['TORCH_HOME'] = PATH
39
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
40
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
+
42
  def get_activation(act: str) -> nn.Module:
43
  """Get activation function by name."""
44
  act_map = {
 
200
  axs[0].legend(loc='upper right', fontsize='small')
201
  axs[-1].set_xlabel("t (s)")
202
  fig.suptitle(title, fontsize=16)
203
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
204
  plt.show()
205
  return fig
206
 
 
252
  self.eps = eps
253
  self.elementwise_affine = elementwise_affine
254
  if self.elementwise_affine:
255
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape))
256
  init.ones_(self.weight)
257
  else:
258
  self.register_parameter("weight", None)
259
  def forward(self, x):
260
+ return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
261
 
262
  def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
263
  weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
264
  eps: float = 1e-5) -> Tensor:
265
+ return F.layer_norm(x, normalized_shape, weight, bias, eps)
266
 
267
  def get_device():
268
  return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
346
  elif isinstance(x, torch.Tensor) and x.ndim == 3:
347
  batch, ctx, dims = x.shape
348
  else:
349
+ batch, head, ctx, head_dim = x.shape
350
 
351
  if f0 is not None:
352
  if f0.dim() == 2:
 
372
  radius_mean = radius.mean() if 'radius' in locals() else 0.0
373
  print(f" [{layer}] [Radius] {radius_shape} {radius_mean:.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
374
  print(f" [{layer}] [Radius] {radius}")
 
375
  self.counter += 1
376
  return freqs.unsqueeze(0)
377
 
 
392
 
393
  rbf = False
394
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
395
+ zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [],
396
+ optim_attn=False, use_pbias=False, use_smart_sensor=False, use_focus_bias=False):
397
  super(MultiheadA, self).__init__()
398
 
399
  self.dims = dims
 
426
  else:
427
  self.rope = None
428
 
429
+ self.use_smart_sensor = use_smart_sensor
430
+ if use_smart_sensor:
431
+ self.head_gate = nn.Parameter(torch.ones(head))
432
+ self.guidance_strength = nn.Parameter(torch.tensor(0.3))
433
+ self.lr_scale = nn.Parameter(torch.tensor(1.0))
434
+
435
+ self.use_focus_bias = use_focus_bias
436
+ if use_focus_bias:
437
+ self.focus_bias_strength = nn.Parameter(torch.tensor(0.3))
438
+
439
  def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
440
  q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
441
  k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
 
457
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
458
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
459
 
460
+ def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, enc = None, layer = None, feature_type="audio", need_weights=True, focus_bias=None, head_weights=None, cross_guidance=None, attention_lr=None) -> tuple:
461
 
462
  x = x.to(device, dtype)
463
  if xa is not None:
 
476
  q2 = q.shape[2]
477
  k2 = k.shape[2]
478
 
479
+ q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
480
+ k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
481
  else:
482
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
483
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
485
 
486
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
487
 
488
+ if self.use_focus_bias and focus_bias is not None:
489
+ bias_strength = torch.sigmoid(self.focus_bias_strength)
490
+ qk = qk + bias_strength * focus_bias
491
+
492
+ if self.use_smart_sensor and head_weights is not None:
493
+ head_gate = torch.sigmoid(self.head_gate) * head_weights
494
+ qk = qk * head_gate.unsqueeze(-1).unsqueeze(-1)
495
+
496
+ if self.use_smart_sensor and cross_guidance is not None:
497
+ guidance_strength = torch.sigmoid(self.guidance_strength)
498
+ qk = qk + guidance_strength * cross_guidance
499
+
500
+ if self.use_smart_sensor and attention_lr is not None:
501
+ lr_scale = torch.sigmoid(self.lr_scale)
502
+ self.register_buffer("predicted_lr", attention_lr * lr_scale)
503
+
504
  if self.rbf:
505
  qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
506
  if self.use_pbias:
507
+ pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None)
508
  if pbias is not None:
509
  qk = qk + pbias[:,:,:q2,:q2]
510
 
 
514
  zscale[token_ids.float() == self.pad_token] = fzero
515
 
516
  if mask is not None:
 
 
 
517
  mask = mask.unsqueeze(0).unsqueeze(0)
518
  qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
519
 
 
529
  class FocusWindow(nn.Module):
530
 
531
  def __init__(self, dims: int, head: int, max_span: int = 512, max_dist: int = 256,
532
+ feature_type: str = "waveform", debug: List[str] = [], learn_lr: bool = False, base_lr: float = 0.001):
533
  super().__init__()
534
  self.dims = dims
535
  self.head = head
 
538
  self.max_dist = max_dist
539
  self.feature_type = feature_type
540
  self.debug = debug
541
+ self.learn_lr = learn_lr
542
+ self.base_lr = base_lr
543
  self.threshold = nn.Parameter(torch.tensor(0.01))
544
  self.s_factor = nn.Parameter(torch.tensor(0.1))
545
  self.temp_scale = nn.Parameter(torch.tensor(1.0))
546
  self.sharpen = True
547
 
 
548
  self.q_proj = Linear(dims, dims)
549
  self.k_proj = Linear(dims, dims)
550
  self.v_proj = Linear(dims, dims)
551
 
 
552
  self.bias_strength = nn.Parameter(torch.tensor(0.5))
553
 
 
554
  self.window_sizes = {
555
  "spectrogram": 128,
556
  "waveform": 256,
 
559
  "phase": 64
560
  }
561
 
 
562
  self.span_lengths = {
563
  "spectrogram": 256,
564
  "waveform": 512,
 
567
  "phase": 128
568
  }
569
 
570
+ self.head_router = nn.Sequential(
571
+ Linear(dims, dims),
572
+ nn.SiLU(),
573
+ Linear(dims, head)
574
+ )
575
+
576
+ self.lr_predictor = nn.Sequential(
577
+ Linear(dims, dims // 4),
578
+ nn.SiLU(),
579
+ Linear(dims // 4, 1),
580
+ nn.Sigmoid()
581
+ )
582
+
583
+ def predict_attention_lr(self, x, feature_data=None):
584
+ lr_factor = self.lr_predictor(x.mean(dim=1))
585
+ return self.base_lr * lr_factor
586
 
587
+ def _focus(self, q, k, v, span_scale, mask=None):
588
+
589
  q_energy = torch.norm(q, dim=-1).mean()
590
  k_energy = torch.norm(k, dim=-1).mean()
591
  content_richness = (q_energy + k_energy) / 2
592
 
 
593
  base_iterations = 3
594
  max_iterations = int(base_iterations + content_richness * 12)
595
+ max_iterations = min(max_iterations, 20)
596
 
597
  iteration = 0
598
  prev_attn = torch.zeros_like(q)
 
612
 
613
  q_span = q[:, :eff_span, :]
614
  k_span = k[:, :eff_span, :]
615
+ v_span = v[:, :eff_span, :]
616
 
617
  batch, ctx, dims = q_span.size()
618
 
619
+ q_head = q_span.view(batch, ctx, self.head, -1).transpose(1, 2)
620
+ k_head = k_span.view(batch, ctx, self.head, -1).transpose(1, 2)
621
+ v_head = v_span.view(batch, ctx, self.head, -1).transpose(1, 2)
622
 
623
  if self.sharpen:
624
  temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
 
626
  temperature = 0.5 + self.temp_scale * span_scale.mean().item()
627
 
628
  scale = (dims // self.head) ** -0.5
629
+ attn = torch.matmul(q_head, k_head.transpose(-1, -2)) * scale
630
 
631
  if mask is not None:
632
  if mask.dim() == 4:
633
+ q_len, k_len = q_head.size(2), k_head.size(2)
634
  mask_q_len = min(mask.size(2), q_len)
635
  mask_k_len = min(mask.size(3), k_len)
636
 
 
645
  attn = F.softmax(attn, dim=-1)
646
 
647
  if mask is not None and mask.dtype == torch.bool:
648
+ q_len, k_len = q_head.size(2), k_head.size(2)
649
  mask_q_len = min(mask.size(2), q_len)
650
  mask_k_len = min(mask.size(3), k_len)
651
 
 
658
 
659
  attn[:, :, :mask_q_len, :mask_k_len] = attn_to_mask
660
 
661
+ attn_output = torch.matmul(attn, v_head)
662
+ attn_out = attn_output.transpose(1, 2).contiguous().view(batch, ctx, dims)
663
+
664
+ q = q.clone()
665
+ q[:, :eff_span, :] = q_span + attn_out
666
 
667
  diff = torch.abs(attn_out - prev_attn).mean()
668
  dynamic_threshold = threshold + s_factor * diff
 
671
  break
672
 
673
  prev_attn = attn_out
 
674
  iteration += 1
675
 
676
  return attn_out, attn_weights
 
688
  k_start = max(0, start_idx - span_len + win_size)
689
  k_end = min(start_idx + span_len, ctx)
690
 
691
+ q = x[:, start_idx:end_idx, :]
692
+ k = x[:, k_start:k_end, :]
693
+ v = x[:, k_start:k_end, :]
694
 
695
  window_mask = None
696
  if mask is not None:
 
700
  if window_mask.size(1) == 1:
701
  window_mask = window_mask.expand(-1, self.head, -1, -1)
702
 
703
+ attn_out, _ = self._focus(q=q, k=k, v=v, span_scale=span_scale, mask=window_mask)
 
 
704
 
705
  output[:, start_idx:end_idx, :] = attn_out
706
 
707
  return output
708
 
709
+ def predict_head_importance(self, x, xa=None):
710
+ if xa is not None:
711
+ combined = x + 0.1 * xa
712
+ else:
713
+ combined = x
714
+ head_importance = self.head_router(combined.mean(dim=1))
715
+ return head_importance
716
+
717
+ def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=False, return_head_weights=False, learn_lr=False):
718
+
719
+ print(f"🎯 FocusWindow running! Input: {x.shape}, Feature: {xa.shape if xa is not None else None}")
720
+
721
  q = self.q_proj(x)
722
  k = self.k_proj(x if xa is None else xa)
723
  v = self.v_proj(x if xa is None else xa)
724
 
 
725
  if xa is not None:
 
726
  feature_energy = torch.norm(xa, dim=-1).mean(dim=-1, keepdim=True)
727
  span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
728
  else:
729
  span_scale = torch.ones(x.size(0), 1, device=x.device)
730
 
 
731
  win_size = self.window_sizes.get(self.feature_type, 128)
732
  span_len = self.span_lengths.get(self.feature_type, 256)
733
 
 
734
  output = self.slide_win(
735
  x=q,
736
  win_size=win_size,
 
738
  span_scale=span_scale,
739
  mask=mask
740
  )
741
+
742
+ if learn_lr:
743
+ lr_factor = self.lr_predictor(output.mean(dim=1))
744
+ return output, lr_factor
745
+
746
+ if return_head_weights:
747
+ head_weights = self.predict_head_importance(x, xa)
748
+ return output, head_weights
749
+
750
  if return_bias:
 
751
  bias_strength = torch.sigmoid(self.bias_strength)
752
  return bias_strength * output
753
  else:
754
  return output
755
 
756
+ class CrossFeatureFocusAttention(nn.Module):
757
+ def __init__(self, dims: int, head: int, features: List[str] = ["spectrogram", "pitch"]):
758
+ super().__init__()
759
+ self.dims = dims
760
+ self.head = head
761
+ self.features = features
762
+
763
+ self.cross_attn_layers = nn.ModuleDict({
764
+ feature: nn.MultiheadAttention(dims, head, batch_first=True)
765
+ for feature in features
766
+ })
767
+
768
+ self.feature_fusion = nn.Sequential(
769
+ Linear(dims * len(features), dims),
770
+ nn.SiLU(),
771
+ Linear(dims, dims)
772
+ )
773
+
774
+ def forward(self, x, enc, mask=None):
775
+ if enc is None:
776
+ return None
777
+
778
+ cross_features = []
779
+ for feature in self.features:
780
+ if feature in enc:
781
+ feature_data = enc[feature]
782
+ if feature_data is not None:
783
+ attn_out, _ = self.cross_attn_layers[feature](
784
+ x, feature_data, feature_data,
785
+ attn_mask=mask
786
+ )
787
+ cross_features.append(attn_out)
788
+
789
+ if not cross_features:
790
+ return None
791
+
792
+ if len(cross_features) > 1:
793
+ fused = torch.cat(cross_features, dim=-1)
794
+ return self.feature_fusion(fused)
795
+ else:
796
+ return cross_features[0]
797
+
798
+ class AdaptiveAttentionLR(nn.Module):
799
+ def __init__(self, dims: int, head: int):
800
+ super().__init__()
801
+ self.dims = dims
802
+ self.head = head
803
+
804
+ self.lr_predictor = nn.Sequential(
805
+ Linear(dims, dims // 4),
806
+ nn.SiLU(),
807
+ Linear(dims // 4, 1),
808
+ nn.Sigmoid()
809
+ )
810
+
811
+ self.quality_estimator = nn.Sequential(
812
+ Linear(dims, dims // 2),
813
+ nn.SiLU(),
814
+ Linear(dims // 2, 1),
815
+ nn.Sigmoid()
816
+ )
817
+
818
+ def forward(self, x, feature_data=None, mask=None):
819
+ quality = self.quality_estimator(x.mean(dim=1))
820
+
821
+ lr_factor = self.lr_predictor(x.mean(dim=1))
822
+
823
+ adaptive_lr = quality * lr_factor
824
+
825
+ return adaptive_lr, adaptive_lr
826
+
827
+ class SmartSensorResidual(nn.Module):
828
+ def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
829
+ use_smart_sensor=True):
830
+ super().__init__()
831
+ self.ctx = ctx
832
+ self.dims = dims
833
+ self.head = head
834
+ self.act = act
835
+ self.debug = debug
836
+
837
+ if use_smart_sensor:
838
+ self.focus_attn = FocusWindow(dims, head, feature_type="waveform")
839
+ self.cross_feature_guide = CrossFeatureFocusAttention(dims, head,
840
+ features=["spectrogram", "pitch"])
841
+ self.adaptive_lr = AdaptiveAttentionLR(dims, head)
842
+
843
+ self.attna = MultiheadA(dims, head, debug=debug)
844
+ self.lna = RMSNorm(dims)
845
+
846
+ def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio"):
847
+ if hasattr(self, 'focus_attn') and enc is not None:
848
+ focus_output, head_weights = self.focus_attn(x, enc.get("waveform"), mask,
849
+ return_head_weights=True)
850
+
851
+ cross_guidance = self.cross_feature_guide(x, enc, mask)
852
+
853
+ _, attention_lr = self.adaptive_lr(x, enc.get("waveform"), mask)
854
+
855
+ x = x + self.attna(
856
+ self.lna(x),
857
+ xa=None,
858
+ mask=mask,
859
+ head_weights=head_weights,
860
+ cross_guidance=cross_guidance,
861
+ attention_lr=attention_lr,
862
+ enc=enc,
863
+ layer=layer
864
+ )[0]
865
+
866
+ return x
867
+
868
  class t_gate(nn.Module):
869
  def __init__(self, dims, num_types=4, enabled=True):
870
  super().__init__()
 
989
  bx = b * ax + (1 - b) * x
990
  cx = self.lnb(bx)
991
  dx = self.mlp(cx)
992
+
993
  ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
994
  fx = x + ex + dx
995
  gx = self.lnc(fx)
 
1235
  for f in self.features:
1236
  if f in enc and f in self.blocks:
1237
  xa = enc[f]
1238
+ for block in self.blocks[f]:
1239
  xa = block(xa, enc=enc, layer=layer)
1240
  out[f] = xa
1241
  xa = xa + self.audio_embedding[:xa.shape[1]]
 
1496
  log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1497
  spec = (log_mel + 4.0) / 4.0
1498
  spec = torch.tensor(spec)
 
1499
 
1500
  wav_np = wav.numpy().astype(np.float64)
1501
  f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
 
1508
  "spectrogram": spec,
1509
  "f0": f0,
1510
  "labels": labels,
 
 
1511
  }
1512
 
1513
  def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
 
1735
  trainer = Seq2SeqTrainer(
1736
  args=training_args,
1737
  model=model,
1738
+ train_dataset=train_dataset,
1739
+ eval_dataset=test_dataset,
1740
+ data_collator=DataCollator(tokenizer=tokenizer),
1741
  compute_metrics=metrics_fn,
1742
+ optimizers=(optimizer, scheduler)
1743
  )
1744
  model.init_weights()
1745
  trainer.train()