Sin2pi commited on
Commit
87411f1
·
verified ·
1 Parent(s): 034d1a9

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +212 -4
modelA.py CHANGED
@@ -32,7 +32,15 @@ dtype = torch.float32
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
 
 
 
 
 
 
 
35
  def get_activation(act: str) -> nn.Module:
 
36
  act_map = {
37
  "gelu": nn.GELU(),
38
  "relu": nn.ReLU(),
@@ -496,6 +504,207 @@ class MultiheadA(nn.Module):
496
  self.counter += 1
497
  return self.o(wv), qk
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  class t_gate(nn.Module):
500
  def __init__(self, dims, num_types=4, enabled=True):
501
  super().__init__()
@@ -949,7 +1158,7 @@ class Echo(nn.Module):
949
  self.init_counts = {
950
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
951
  "Conv2d": 0, "SEBlock": 0, "SpeechTransformer": 0,
952
- "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
953
  "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
954
  "WEncoder": 0, "PEncoder": 0}
955
 
@@ -1166,9 +1375,9 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **
1166
  return train_dataset, test_dataset
1167
 
1168
  def filter_func(x):
1169
- return (0 < len(x["transcription"]) < 512 and
1170
  len(x["audio"]["array"]) > 0 and
1171
- len(x["audio"]["array"]) < 1500 * 160)
1172
 
1173
  raw_train = load_dataset(
1174
  "google/fleurs", "en_us", token=token, split="train[:1000]", trust_remote_code=True)
@@ -1322,7 +1531,6 @@ def main():
1322
  vocab=40000, ctx=2048, dims=512, head=4, layer=4,
1323
  mels=128, act="swish",
1324
  debug={},
1325
- cross_attn=True,
1326
  features=["spectrogram"]
1327
  )
1328
 
 
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
+ PATH = 'E:/hf'
36
+ os.environ['HF_HOME'] = PATH
37
+ os.environ['HF_DATASETS_CACHE'] = PATH
38
+ os.environ['TORCH_HOME'] = PATH
39
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
40
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
+
42
  def get_activation(act: str) -> nn.Module:
43
+ """Get activation function by name."""
44
  act_map = {
45
  "gelu": nn.GELU(),
46
  "relu": nn.ReLU(),
 
504
  self.counter += 1
505
  return self.o(wv), qk
506
 
507
+ class FocusWindow(nn.Module):
508
+
509
+ def __init__(self, dims: int, head: int, max_span: int = 512, max_dist: int = 256,
510
+ feature_type: str = "waveform", debug: List[str] = []):
511
+ super().__init__()
512
+ self.dims = dims
513
+ self.head = head
514
+ self.head_dim = dims // head
515
+ self.max_span = max_span
516
+ self.max_dist = max_dist
517
+ self.feature_type = feature_type
518
+ self.debug = debug
519
+
520
+ # Adaptive parameters for focus control
521
+ self.threshold = nn.Parameter(torch.tensor(0.01))
522
+ self.s_factor = nn.Parameter(torch.tensor(0.1))
523
+ self.temp_scale = nn.Parameter(torch.tensor(1.0))
524
+ self.sharpen = True
525
+
526
+ # Feature-specific projections
527
+ self.q_proj = Linear(dims, dims)
528
+ self.k_proj = Linear(dims, dims)
529
+ self.v_proj = Linear(dims, dims)
530
+
531
+ # Bias strength controller
532
+ self.bias_strength = nn.Parameter(torch.tensor(0.5))
533
+
534
+ # Feature-specific window sizes
535
+ self.window_sizes = {
536
+ "spectrogram": 128,
537
+ "waveform": 256,
538
+ "pitch": 64,
539
+ "envelope": 64,
540
+ "phase": 64
541
+ }
542
+
543
+ # Feature-specific span lengths
544
+ self.span_lengths = {
545
+ "spectrogram": 256,
546
+ "waveform": 512,
547
+ "pitch": 128,
548
+ "envelope": 128,
549
+ "phase": 128
550
+ }
551
+
552
+ def _focus(self, q, k, v, span_scale, mask=None):
553
+
554
+ q_energy = torch.norm(q, dim=-1).mean()
555
+ k_energy = torch.norm(k, dim=-1).mean()
556
+ content_richness = (q_energy + k_energy) / 2
557
+
558
+ # Dynamic max iterations: more interesting content = more iterations
559
+ base_iterations = 3
560
+ max_iterations = int(base_iterations + content_richness * 12)
561
+ max_iterations = min(max_iterations, 20) # Cap at 20
562
+
563
+ iteration = 0
564
+ prev_attn = torch.zeros_like(q)
565
+ attn_out = torch.zeros_like(q)
566
+ attn_weights = None
567
+
568
+ threshold = self.threshold.item()
569
+ s_factor = self.s_factor.item()
570
+
571
+ while iteration < max_iterations:
572
+ span_len = int(self.max_span * span_scale.mean().item())
573
+ span_len = min(span_len, q.size(1), k.size(1), k.size(1))
574
+ eff_span = min(span_len, self.max_dist)
575
+
576
+ if eff_span == 0:
577
+ break
578
+
579
+ q_span = q[:, :eff_span, :]
580
+ k_span = k[:, :eff_span, :]
581
+ v_span = k[:, :eff_span, :]
582
+
583
+ batch, ctx, dims = q_span.size()
584
+
585
+ q = q_span.view(batch, ctx, self.head, -1).transpose(1, 2)
586
+ k = k_span.view(batch, ctx, self.head, -1).transpose(1, 2)
587
+ v = v_span.view(batch, ctx, self.head, -1).transpose(1, 2)
588
+
589
+ if self.sharpen:
590
+ temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
591
+ else:
592
+ temperature = 0.5 + self.temp_scale * span_scale.mean().item()
593
+
594
+ scale = (dims // self.head) ** -0.5
595
+ attn = torch.matmul(q, k.transpose(-1, -2)) * scale
596
+
597
+ if mask is not None:
598
+ if mask.dim() == 4:
599
+ q_len, k_len = q.size(2), k.size(2)
600
+ mask_q_len = min(mask.size(2), q_len)
601
+ mask_k_len = min(mask.size(3), k_len)
602
+
603
+ mask_part = mask[:, :, :mask_q_len, :mask_k_len]
604
+ if mask_part.dtype == torch.bool:
605
+ attn[:, :, :mask_q_len, :mask_k_len] = attn[:, :, :mask_q_len, :mask_k_len].masked_fill(
606
+ mask_part, float("-inf")
607
+ )
608
+ else:
609
+ attn[:, :, :mask_q_len, :mask_k_len] = attn[:, :, :mask_q_len, :mask_k_len] + mask_part
610
+
611
+ attn = F.softmax(attn, dim=-1)
612
+
613
+ if mask is not None and mask.dtype == torch.bool:
614
+ q_len, k_len = q.size(2), k.size(2)
615
+ mask_q_len = min(mask.size(2), q_len)
616
+ mask_k_len = min(mask.size(3), k_len)
617
+
618
+ binary_mask = (~mask[:, :, :mask_q_len, :mask_k_len]).float()
619
+ attn_to_mask = attn[:, :, :mask_q_len, :mask_k_len]
620
+ attn_to_mask = attn_to_mask * binary_mask
621
+
622
+ attn_sum = attn_to_mask.sum(dim=-1, keepdim=True)
623
+ attn_to_mask = attn_to_mask / (attn_sum + 1e-6)
624
+
625
+ attn[:, :, :mask_q_len, :mask_k_len] = attn_to_mask
626
+
627
+ attn_output = torch.matmul(attn, v)
628
+ attn_out = attn_output.transpose(1, 2).contiguous().view(batch, ctx, -1)
629
+
630
+ diff = torch.abs(attn_out - prev_attn).mean()
631
+ dynamic_threshold = threshold + s_factor * diff
632
+
633
+ if diff < dynamic_threshold:
634
+ break
635
+
636
+ prev_attn = attn_out
637
+ q = q + attn_out
638
+ iteration += 1
639
+
640
+ return attn_out, attn_weights
641
+
642
+ def slide_win(self, x, win_size, span_len, span_scale, mask=None):
643
+ batch, ctx, dims = x.size()
644
+ num_windows = (ctx + win_size - 1) // win_size
645
+ output = torch.zeros_like(x)
646
+
647
+ for i in range(num_windows):
648
+ start_idx = i * win_size
649
+ end_idx = min((i + 1) * win_size, ctx)
650
+ window_size = end_idx - start_idx
651
+
652
+ k_start = max(0, start_idx - span_len + win_size)
653
+ k_end = min(start_idx + span_len, ctx)
654
+
655
+ q = x[:, start_idx:end_idx, :]
656
+ k = x[:, k_start:k_end, :]
657
+ k = k
658
+
659
+ window_mask = None
660
+ if mask is not None:
661
+ if mask.dim() == 4:
662
+ window_mask = mask[:, :, start_idx:end_idx, k_start:k_end]
663
+
664
+ if window_mask.size(1) == 1:
665
+ window_mask = window_mask.expand(-1, self.head, -1, -1)
666
+
667
+ attn_out, _ = self._focus(
668
+ q=q, k=k, v=v, span_scale=span_scale, mask=window_mask
669
+ )
670
+
671
+ output[:, start_idx:end_idx, :] = attn_out
672
+
673
+ return output
674
+
675
+ def forward(self, x, feature_data=None, mask=None, return_bias=True):
676
+ q = self.q_proj(x)
677
+ k = self.k_proj(x if feature_data is None else feature_data)
678
+ v = self.v_proj(x if feature_data is None else feature_data)
679
+
680
+ # Create span scale based on feature characteristics
681
+ if feature_data is not None:
682
+ # Feature-specific span scaling
683
+ feature_energy = torch.norm(feature_data, dim=-1).mean(dim=-1, keepdim=True)
684
+ span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
685
+ else:
686
+ span_scale = torch.ones(x.size(0), 1, device=x.device)
687
+
688
+ # Get feature-specific parameters
689
+ win_size = self.window_sizes.get(self.feature_type, 128)
690
+ span_len = self.span_lengths.get(self.feature_type, 256)
691
+
692
+ # Apply sliding window with focus attention
693
+ output = self.slide_win(
694
+ x=q,
695
+ win_size=win_size,
696
+ span_len=span_len,
697
+ span_scale=span_scale,
698
+ mask=mask
699
+ )
700
+
701
+ if return_bias:
702
+ # Return as bias for main attention
703
+ bias_strength = torch.sigmoid(self.bias_strength)
704
+ return bias_strength * output
705
+ else:
706
+ return output
707
+
708
  class t_gate(nn.Module):
709
  def __init__(self, dims, num_types=4, enabled=True):
710
  super().__init__()
 
1158
  self.init_counts = {
1159
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
1160
  "Conv2d": 0, "SEBlock": 0, "SpeechTransformer": 0,
1161
+ "Residual": 0, "MultiheadA": 0,
1162
  "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
1163
  "WEncoder": 0, "PEncoder": 0}
1164
 
 
1375
  return train_dataset, test_dataset
1376
 
1377
  def filter_func(x):
1378
+ return (0 < len(x["transcription"]) < 2048 and
1379
  len(x["audio"]["array"]) > 0 and
1380
+ len(x["audio"]["array"]) < 2048 * 160)
1381
 
1382
  raw_train = load_dataset(
1383
  "google/fleurs", "en_us", token=token, split="train[:1000]", trust_remote_code=True)
 
1531
  vocab=40000, ctx=2048, dims=512, head=4, layer=4,
1532
  mels=128, act="swish",
1533
  debug={},
 
1534
  features=["spectrogram"]
1535
  )
1536