Sin2pi commited on
Commit
75becfe
·
verified ·
1 Parent(s): 25a309a

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +309 -529
modelA.py CHANGED
@@ -33,6 +33,7 @@ warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
  def get_activation(act: str) -> nn.Module:
 
36
  act_map = {
37
  "gelu": nn.GELU(),
38
  "relu": nn.ReLU(),
@@ -50,11 +51,11 @@ def get_activation(act: str) -> nn.Module:
50
  @dataclass
51
  class Dimensions:
52
  vocab: int
 
53
  ctx: int
54
  dims: int
55
  head: int
56
  layer: int
57
- mels: int
58
  act: str
59
  debug: List[str]
60
  features: List[str]
@@ -197,6 +198,7 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
197
  return fig
198
 
199
  def valid(default_value, *items):
 
200
  for item in items:
201
  if item is not None:
202
  return item
@@ -264,6 +266,21 @@ def get_dtype():
264
  def tox():
265
  return {"device": get_device(), "dtype": get_dtype()}
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  def sinusoids(length, channels, max_tscale=10000):
268
  assert channels % 2 == 0
269
  log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
@@ -323,46 +340,39 @@ class rotary(nn.Module):
323
  idx = torch.arange(ctx, device=f0.device)
324
  idx = (idx * F).long().clamp(0, L - 1)
325
  radius = radius[idx]
326
- return torch.polar(radius.unsqueeze(-1), freqs)
 
 
327
  else:
328
- return torch.polar(torch.ones_like(freqs), freqs)
329
 
330
- def forward(self, x=None, enc=None, layer=None, feature_type="audio") -> Tensor:
331
- f0 = enc.get("f0") if enc is not None else None
332
-
333
- if isinstance(x, int):
334
- ctx = x
335
- elif isinstance(x, torch.Tensor) and x.ndim == 2:
336
- batch, ctx = x.shape
337
- elif isinstance(x, torch.Tensor) and x.ndim == 3:
338
- batch, ctx, dims = x.shape
339
  else:
340
- batch, head, ctx, head_dim = x.shape
 
 
 
 
 
341
 
 
342
  if f0 is not None:
343
  if f0.dim() == 2:
344
  f0 = f0.squeeze(0)
345
  theta = f0 + self.theta
346
  else:
347
  theta = self.theta
348
-
349
  freqs = self.theta_freqs(theta)
350
  t = torch.arange(ctx, device=device, dtype=dtype)
351
  freqs = t[:, None] * freqs
352
-
353
- if self.radii and f0 is not None:
354
- radius = f0.to(device, dtype)
355
- freqs = torch.polar(radius.unsqueeze(-1), freqs)
356
- else:
357
- radius = torch.ones_like(freqs)
358
- freqs = torch.polar(radius, freqs)
359
-
360
  if "radius" in self.debug and self.counter == 10:
361
- theta_value = theta.mean()
362
- radius_shape = radius.shape if 'radius' in locals() else "N/A"
363
- radius_mean = radius.mean() if 'radius' in locals() else 0.0
364
- print(f" [{layer}] [Radius] {radius_shape} {radius_mean:.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
365
- print(f" [{layer}] [Radius] {radius}")
366
  self.counter += 1
367
  return freqs.unsqueeze(0)
368
 
@@ -383,8 +393,7 @@ class MultiheadA(nn.Module):
383
 
384
  rbf = False
385
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
386
- zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [],
387
- optim_attn=False, use_pbias=False, use_smart_sensor=False, use_focus_bias=False):
388
  super(MultiheadA, self).__init__()
389
 
390
  self.dims = dims
@@ -417,16 +426,6 @@ class MultiheadA(nn.Module):
417
  else:
418
  self.rope = None
419
 
420
- self.use_smart_sensor = use_smart_sensor
421
- if use_smart_sensor:
422
- self.head_gate = nn.Parameter(torch.ones(head))
423
- self.guidance_strength = nn.Parameter(torch.tensor(0.3))
424
- self.lr_scale = nn.Parameter(torch.tensor(1.0))
425
-
426
- self.use_focus_bias = use_focus_bias
427
- if use_focus_bias:
428
- self.focus_bias_strength = nn.Parameter(torch.tensor(0.3))
429
-
430
  def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
431
  q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
432
  k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
@@ -448,7 +447,7 @@ class MultiheadA(nn.Module):
448
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
449
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
450
 
451
- def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, enc = None, layer = None, feature_type="audio", need_weights=True, focus_bias=None, head_weights=None, cross_guidance=None, attention_lr=None) -> tuple:
452
 
453
  x = x.to(device, dtype)
454
  if xa is not None:
@@ -476,26 +475,10 @@ class MultiheadA(nn.Module):
476
 
477
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
478
 
479
- if self.use_focus_bias and focus_bias is not None:
480
- bias_strength = torch.sigmoid(self.focus_bias_strength)
481
- qk = qk + bias_strength * focus_bias
482
-
483
- if self.use_smart_sensor and head_weights is not None:
484
- head_gate = torch.sigmoid(self.head_gate) * head_weights
485
- qk = qk * head_gate.unsqueeze(-1).unsqueeze(-1)
486
-
487
- if self.use_smart_sensor and cross_guidance is not None:
488
- guidance_strength = torch.sigmoid(self.guidance_strength)
489
- qk = qk + guidance_strength * cross_guidance
490
-
491
- if self.use_smart_sensor and attention_lr is not None:
492
- lr_scale = torch.sigmoid(self.lr_scale)
493
- self.register_buffer("predicted_lr", attention_lr * lr_scale)
494
-
495
  if self.rbf:
496
  qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
497
  if self.use_pbias:
498
- pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None)
499
  if pbias is not None:
500
  qk = qk + pbias[:,:,:q2,:q2]
501
 
@@ -504,10 +487,8 @@ class MultiheadA(nn.Module):
504
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
505
  zscale[token_ids.float() == self.pad_token] = fzero
506
 
507
- if mask is not None:
508
- mask = mask.unsqueeze(0).unsqueeze(0)
509
  qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
510
-
511
  qk = qk * zscale.unsqueeze(-2)
512
  w = F.softmax(qk, dim=-1).to(q.dtype)
513
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
@@ -517,340 +498,6 @@ class MultiheadA(nn.Module):
517
  self.counter += 1
518
  return self.o(wv), qk
519
 
520
- class FocusWindow(nn.Module):
521
-
522
- def __init__(self, dims: int, head: int, max_span: int = 512, max_dist: int = 256,
523
- feature_type: str = "waveform", debug: List[str] = [], learn_lr: bool = False, base_lr: float = 0.001):
524
- super().__init__()
525
- self.dims = dims
526
- self.head = head
527
- self.head_dim = dims // head
528
- self.max_span = max_span
529
- self.max_dist = max_dist
530
- self.feature_type = feature_type
531
- self.debug = debug
532
- self.learn_lr = learn_lr
533
- self.base_lr = base_lr
534
- self.threshold = nn.Parameter(torch.tensor(0.01))
535
- self.s_factor = nn.Parameter(torch.tensor(0.1))
536
- self.temp_scale = nn.Parameter(torch.tensor(1.0))
537
- self.sharpen = True
538
-
539
- self.q_proj = Linear(dims, dims)
540
- self.k_proj = Linear(dims, dims)
541
- self.v_proj = Linear(dims, dims)
542
-
543
- self.bias_strength = nn.Parameter(torch.tensor(0.5))
544
-
545
- self.window_sizes = {
546
- "spectrogram": 128,
547
- "waveform": 256,
548
- "pitch": 64,
549
- "envelope": 64,
550
- "phase": 64
551
- }
552
-
553
- self.span_lengths = {
554
- "spectrogram": 256,
555
- "waveform": 512,
556
- "pitch": 128,
557
- "envelope": 128,
558
- "phase": 128
559
- }
560
-
561
- self.head_router = nn.Sequential(
562
- Linear(dims, dims),
563
- nn.SiLU(),
564
- Linear(dims, head)
565
- )
566
-
567
- self.lr_predictor = nn.Sequential(
568
- Linear(dims, dims // 4),
569
- nn.SiLU(),
570
- Linear(dims // 4, 1),
571
- nn.Sigmoid()
572
- )
573
-
574
- def predict_attention_lr(self, x, feature_data=None):
575
- lr_factor = self.lr_predictor(x.mean(dim=1))
576
- return self.base_lr * lr_factor
577
-
578
- def _focus(self, q, k, v, span_scale, mask=None):
579
-
580
- q_energy = torch.norm(q, dim=-1).mean()
581
- k_energy = torch.norm(k, dim=-1).mean()
582
- content_richness = (q_energy + k_energy) / 2
583
-
584
- base_iterations = 3
585
- max_iterations = int(base_iterations + content_richness * 12)
586
- max_iterations = min(max_iterations, 20)
587
-
588
- iteration = 0
589
- prev_attn = torch.zeros_like(q)
590
- attn_out = torch.zeros_like(q)
591
- attn_weights = None
592
-
593
- threshold = self.threshold.item()
594
- s_factor = self.s_factor.item()
595
-
596
- while iteration < max_iterations:
597
- span_len = int(self.max_span * span_scale.mean().item())
598
- span_len = min(span_len, q.size(1), k.size(1), k.size(1))
599
- eff_span = min(span_len, self.max_dist)
600
-
601
- if eff_span == 0:
602
- break
603
-
604
- q_span = q[:, :eff_span, :]
605
- k_span = k[:, :eff_span, :]
606
- v_span = v[:, :eff_span, :]
607
-
608
- batch, ctx, dims = q_span.size()
609
-
610
- q_head = q_span.view(batch, ctx, self.head, -1).transpose(1, 2)
611
- k_head = k_span.view(batch, ctx, self.head, -1).transpose(1, 2)
612
- v_head = v_span.view(batch, ctx, self.head, -1).transpose(1, 2)
613
-
614
- if self.sharpen:
615
- temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
616
- else:
617
- temperature = 0.5 + self.temp_scale * span_scale.mean().item()
618
-
619
- scale = (dims // self.head) ** -0.5
620
- attn = torch.matmul(q_head, k_head.transpose(-1, -2)) * scale
621
-
622
- if mask is not None:
623
- if mask.dim() == 4:
624
- q_len, k_len = q_head.size(2), k_head.size(2)
625
- mask_q_len = min(mask.size(2), q_len)
626
- mask_k_len = min(mask.size(3), k_len)
627
-
628
- mask_part = mask[:, :, :mask_q_len, :mask_k_len]
629
- if mask_part.dtype == torch.bool:
630
- attn[:, :, :mask_q_len, :mask_k_len] = attn[:, :, :mask_q_len, :mask_k_len].masked_fill(
631
- mask_part, float("-inf")
632
- )
633
- else:
634
- attn[:, :, :mask_q_len, :mask_k_len] = attn[:, :, :mask_q_len, :mask_k_len] + mask_part
635
-
636
- attn = F.softmax(attn, dim=-1)
637
-
638
- if mask is not None and mask.dtype == torch.bool:
639
- q_len, k_len = q_head.size(2), k_head.size(2)
640
- mask_q_len = min(mask.size(2), q_len)
641
- mask_k_len = min(mask.size(3), k_len)
642
-
643
- binary_mask = (~mask[:, :, :mask_q_len, :mask_k_len]).float()
644
- attn_to_mask = attn[:, :, :mask_q_len, :mask_k_len]
645
- attn_to_mask = attn_to_mask * binary_mask
646
-
647
- attn_sum = attn_to_mask.sum(dim=-1, keepdim=True)
648
- attn_to_mask = attn_to_mask / (attn_sum + 1e-6)
649
-
650
- attn[:, :, :mask_q_len, :mask_k_len] = attn_to_mask
651
-
652
- attn_output = torch.matmul(attn, v_head)
653
- attn_out = attn_output.transpose(1, 2).contiguous().view(batch, ctx, dims)
654
-
655
- q = q.clone()
656
- q[:, :eff_span, :] = q_span + attn_out
657
-
658
- diff = torch.abs(attn_out - prev_attn).mean()
659
- dynamic_threshold = threshold + s_factor * diff
660
-
661
- if diff < dynamic_threshold:
662
- break
663
-
664
- prev_attn = attn_out
665
- iteration += 1
666
-
667
- return attn_out, attn_weights
668
-
669
- def slide_win(self, x, win_size, span_len, span_scale, mask=None):
670
- batch, ctx, dims = x.size()
671
- num_windows = (ctx + win_size - 1) // win_size
672
- output = torch.zeros_like(x)
673
-
674
- for i in range(num_windows):
675
- start_idx = i * win_size
676
- end_idx = min((i + 1) * win_size, ctx)
677
- window_size = end_idx - start_idx
678
-
679
- k_start = max(0, start_idx - span_len + win_size)
680
- k_end = min(start_idx + span_len, ctx)
681
-
682
- q = x[:, start_idx:end_idx, :]
683
- k = x[:, k_start:k_end, :]
684
- v = x[:, k_start:k_end, :]
685
-
686
- window_mask = None
687
- if mask is not None:
688
- if mask.dim() == 4:
689
- window_mask = mask[:, :, start_idx:end_idx, k_start:k_end]
690
-
691
- if window_mask.size(1) == 1:
692
- window_mask = window_mask.expand(-1, self.head, -1, -1)
693
-
694
- attn_out, _ = self._focus(q=q, k=k, v=v, span_scale=span_scale, mask=window_mask)
695
-
696
- output[:, start_idx:end_idx, :] = attn_out
697
-
698
- return output
699
-
700
- def predict_head_importance(self, x, xa=None):
701
- if xa is not None:
702
- combined = x + 0.1 * xa
703
- else:
704
- combined = x
705
- head_importance = self.head_router(combined.mean(dim=1))
706
- return head_importance
707
-
708
- def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=False, return_head_weights=False, learn_lr=False):
709
-
710
- q = self.q_proj(x)
711
- k = self.k_proj(x if xa is None else xa)
712
- v = self.v_proj(x if xa is None else xa)
713
-
714
- if xa is not None:
715
- feature_energy = torch.norm(xa, dim=-1).mean(dim=-1, keepdim=True)
716
- span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
717
- else:
718
- span_scale = torch.ones(x.size(0), 1, device=x.device)
719
-
720
- win_size = self.window_sizes.get(self.feature_type, 128)
721
- span_len = self.span_lengths.get(self.feature_type, 256)
722
-
723
- output = self.slide_win(
724
- x=q,
725
- win_size=win_size,
726
- span_len=span_len,
727
- span_scale=span_scale,
728
- mask=mask
729
- )
730
-
731
- if learn_lr:
732
- lr_factor = self.lr_predictor(output.mean(dim=1))
733
- return output, lr_factor
734
-
735
- if return_head_weights:
736
- head_weights = self.predict_head_importance(x, xa)
737
- return output, head_weights
738
-
739
- if return_bias:
740
- bias_strength = torch.sigmoid(self.bias_strength)
741
- return bias_strength * output
742
- else:
743
- return output
744
-
745
- class CrossFeatureFocusAttention(nn.Module):
746
- def __init__(self, dims: int, head: int, features: List[str] = ["spectrogram", "pitch"]):
747
- super().__init__()
748
- self.dims = dims
749
- self.head = head
750
- self.features = features
751
-
752
- self.cross_attn_layers = nn.ModuleDict({
753
- feature: nn.MultiheadAttention(dims, head, batch_first=True)
754
- for feature in features
755
- })
756
-
757
- self.feature_fusion = nn.Sequential(
758
- Linear(dims * len(features), dims),
759
- nn.SiLU(),
760
- Linear(dims, dims)
761
- )
762
-
763
- def forward(self, x, enc, mask=None):
764
- if enc is None:
765
- return None
766
-
767
- cross_features = []
768
- for feature in self.features:
769
- if feature in enc:
770
- feature_data = enc[feature]
771
- if feature_data is not None:
772
- attn_out, _ = self.cross_attn_layers[feature](
773
- x, feature_data, feature_data,
774
- attn_mask=mask
775
- )
776
- cross_features.append(attn_out)
777
-
778
- if not cross_features:
779
- return None
780
-
781
- if len(cross_features) > 1:
782
- fused = torch.cat(cross_features, dim=-1)
783
- return self.feature_fusion(fused)
784
- else:
785
- return cross_features[0]
786
-
787
- class AdaptiveAttentionLR(nn.Module):
788
- def __init__(self, dims: int, head: int):
789
- super().__init__()
790
- self.dims = dims
791
- self.head = head
792
-
793
- self.lr_predictor = nn.Sequential(
794
- Linear(dims, dims // 4),
795
- nn.SiLU(),
796
- Linear(dims // 4, 1),
797
- nn.Sigmoid()
798
- )
799
-
800
- self.quality_estimator = nn.Sequential(
801
- Linear(dims, dims // 2),
802
- nn.SiLU(),
803
- Linear(dims // 2, 1),
804
- nn.Sigmoid()
805
- )
806
-
807
- def forward(self, x, feature_data=None, mask=None):
808
- quality = self.quality_estimator(x.mean(dim=1))
809
- lr_factor = self.lr_predictor(x.mean(dim=1)
810
- adaptive_lr = quality * lr_factor
811
- return adaptive_lr, adaptive_lr
812
-
813
- class SmartSensorResidual(nn.Module):
814
- def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
815
- use_smart_sensor=True):
816
- super().__init__()
817
- self.ctx = ctx
818
- self.dims = dims
819
- self.head = head
820
- self.act = act
821
- self.debug = debug
822
-
823
- if use_smart_sensor:
824
- self.focus_attn = FocusWindow(dims, head, feature_type="waveform")
825
- self.cross_feature_guide = CrossFeatureFocusAttention(dims, head,
826
- features=["spectrogram", "pitch"])
827
- self.adaptive_lr = AdaptiveAttentionLR(dims, head)
828
-
829
- self.attna = MultiheadA(dims, head, debug=debug)
830
- self.lna = RMSNorm(dims)
831
-
832
- def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio"):
833
- if hasattr(self, 'focus_attn') and enc is not None:
834
- focus_output, head_weights = self.focus_attn(x, enc.get("waveform"), mask,
835
- return_head_weights=True)
836
-
837
- cross_guidance = self.cross_feature_guide(x, enc, mask)
838
-
839
- _, attention_lr = self.adaptive_lr(x, enc.get("waveform"), mask)
840
-
841
- x = x + self.attna(
842
- self.lna(x),
843
- xa=None,
844
- mask=mask,
845
- head_weights=head_weights,
846
- cross_guidance=cross_guidance,
847
- attention_lr=attention_lr,
848
- enc=enc,
849
- layer=layer
850
- )[0]
851
-
852
- return x
853
-
854
  class t_gate(nn.Module):
855
  def __init__(self, dims, num_types=4, enabled=True):
856
  super().__init__()
@@ -931,7 +578,7 @@ class mlp_gate(nn.Module):
931
  class Residual(nn.Module):
932
  _seen = set()
933
  def __init__(self, ctx, dims, head, act, debug: List[str] = [],
934
- tgate=True, mgate=False, cgate=False, mem_size=512, features=None, focus=True):
935
  super().__init__()
936
 
937
  self.dims = dims
@@ -945,9 +592,7 @@ class Residual(nn.Module):
945
 
946
  self.blend = nn.Parameter(torch.tensor(0.5))
947
  act_fn = get_activation(act)
948
-
949
  self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
950
- self.focus = FocusWindow(dims, head, debug=debug) if focus else None
951
 
952
  if not any([tgate, mgate, cgate]):
953
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
@@ -966,16 +611,13 @@ class Residual(nn.Module):
966
  self.lnb = RMSNorm(dims)
967
  self.lnc = RMSNorm(dims)
968
 
969
- def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
970
-
971
- focus = self.focus(x, xa=xa, mask=mask, enc=enc, layer=layer) if self.focus is not None else 0
972
 
973
  b = torch.sigmoid(self.blend)
974
- ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer)[0] + focus
975
  bx = b * ax + (1 - b) * x
976
  cx = self.lnb(bx)
977
  dx = self.mlp(cx)
978
-
979
  ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
980
  fx = x + ex + dx
981
  gx = self.lnc(fx)
@@ -1017,9 +659,9 @@ class FEncoder(nn.Module):
1017
  self.norm = RMSNorm(dims)
1018
  self._norm = RMSNorm(dims)
1019
 
1020
- def apply_rope_to_features(self, x, layer=None, feature_type="audio"):
1021
- if feature_type in ["envelope", "phase"]:
1022
- feature_type = "spectrogram"
1023
  batch, ctx, dims = x.shape
1024
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1025
  if feature_type == "spectrogram" and self.rope is not None:
@@ -1030,10 +672,10 @@ class FEncoder(nn.Module):
1030
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1031
  return x
1032
 
1033
- def forward(self, x, enc=None, layer=None, feature_type="audio"):
1034
  x = self.encoder(x).permute(0, 2, 1)
1035
  if self.use_rope:
1036
- x = self.apply_rope_to_features(x, layer=layer, feature_type=feature_type)
1037
  else:
1038
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
1039
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
@@ -1070,17 +712,17 @@ class WEncoder(nn.Module):
1070
  self.positional = lambda length: sinusoids(length, dims)
1071
  self.norm = RMSNorm(dims)
1072
 
1073
- def apply_rope_to_features(self, x, layer=None):
1074
  if not self.use_rope or self.rope is None:
1075
  return x
1076
  batch, ctx, dims = x.shape
1077
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1078
- rope_freqs = self.rope(ctx, layer=layer, input_type="waveform")
1079
  x = self.rope.apply_rotary(x, rope_freqs)
1080
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1081
  return x
1082
 
1083
- def forward(self, x, enc=None, layer=None, feature_type="waveform"):
1084
  x = self.downsample(x)
1085
  x = self.encoder(x)
1086
  x = x.permute(0, 2, 1)
@@ -1118,17 +760,17 @@ class PEncoder(nn.Module):
1118
  self.positional = lambda length: sinusoids(length, dims)
1119
  self.norm = RMSNorm(dims)
1120
 
1121
- def apply_rope_to_features(self, x, layer=None):
1122
  if not self.use_rope or self.rope is None:
1123
  return x
1124
  batch, ctx, dims = x.shape
1125
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1126
- rope_freqs = self.rope(ctx, layer=layer, input_type="pitch")
1127
  x = self.rope.apply_rotary(x, rope_freqs)
1128
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1129
  return x
1130
 
1131
- def forward(self, x, enc=None, layer=None, feature_type="pitch"):
1132
  x = self.encoder(x).permute(0, 2, 1)
1133
  if self.use_rope:
1134
  x = self.apply_rope_to_features(x, layer=layer)
@@ -1138,110 +780,148 @@ class PEncoder(nn.Module):
1138
  x = self.norm(x)
1139
  return x
1140
 
1141
- class SpeechTransformer(nn.Module):
1142
- _seen = set()
1143
- def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
1144
- super(SpeechTransformer, self).__init__()
1145
-
 
1146
  self.dims = dims
1147
  self.head = head
1148
- self.ctx = ctx
1149
  self.head_dim = dims // head
1150
  self.debug = debug
1151
  self.counter = 0
 
1152
  self.features = features
1153
- self.dropout = 0.01
1154
- self.sequential = "sequential" in debug
1155
- act_fn = get_activation(act)
1156
 
1157
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
1158
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
1159
- self.register_buffer("audio_embedding", sinusoids(ctx, dims))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1160
 
 
 
 
 
 
 
 
1161
  if features == ["spectrogram", "waveform", "pitch"]:
1162
  cgate=True
1163
  else:
1164
  cgate = False
1165
 
1166
  self.blocks = nn.ModuleDict({
1167
-
1168
  "spectrogram": nn.ModuleList(
1169
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1170
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1171
- if "spectrogram" in features else None),
1172
-
1173
  "waveform": nn.ModuleList(
1174
  [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
1175
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1176
- if "waveform" in features else None),
1177
-
1178
  "pitch": nn.ModuleList(
1179
  [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
1180
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1181
- if "pitch" in features else None),
1182
-
1183
  "envelope": nn.ModuleList(
1184
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1185
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1186
- if "envelope" in features else None),
1187
-
1188
  "phase": nn.ModuleList(
1189
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1190
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1191
- if "phase" in features else None),
1192
- })
1193
-
1194
- self.block = nn.ModuleList([
1195
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features, focus=False)
1196
- for _ in range(layer)])
1197
-
1198
- self.blend = nn.Parameter(torch.tensor(0.5))
1199
- self.ln_dec = RMSNorm(dims)
1200
-
1201
- def get_mask(text_ctx, aud_ctx):
1202
- mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device), diagonal=0)
1203
- audio_mask = torch.ones(text_ctx, aud_ctx - text_ctx, device=device)
1204
- full_mask = torch.cat([mask, audio_mask], dim=-1)
1205
- return full_mask
1206
- self.register_buffer("mask_ax", get_mask(ctx, ctx), persistent=False)
1207
-
1208
- mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
1209
- self.register_buffer("mask", mask, persistent=False)
1210
-
1211
- def forward(self, enc, layer="encoder"):
1212
  enc = dict_to(enc, device, dtype)
 
1213
 
1214
- x = enc.get("input_ids").long()
1215
  x = self.token(x) + self.positional[:x.shape[1]]
1216
  x = F.dropout(x, p=self.dropout, training=self.training)
1217
 
1218
- out = {}
1219
- out.update(enc)
1220
-
1221
- for f in self.features:
1222
- if f in enc and f in self.blocks:
1223
  xa = enc[f]
1224
  for block in self.blocks[f]:
1225
- xa = block(xa, enc=enc, layer=layer)
1226
- out[f] = xa
1227
- xa = xa + self.audio_embedding[:xa.shape[1]]
1228
 
1229
  for block in self.block:
1230
  mask = self.mask[:x.shape[1], :x.shape[1]]
1231
- x = block(x, xa=None, mask=mask, enc=None, layer=layer)
 
 
 
 
 
 
 
 
 
1232
 
1233
- for f in self.features:
1234
- if f in enc:
1235
- mask = self.mask_ax[:x.shape[1], :xa.shape[1]]
1236
- for block in self.block:
1237
- out = block(x, xa=xa, mask=mask, enc=None, layer=layer)
 
 
 
1238
  if self.sequential:
1239
  x = out
1240
  else:
1241
  a = torch.sigmoid(self.blend)
1242
  x = a * out + (1 - a) * x
1243
 
1244
- x = self.ln_dec(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1245
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1246
 
1247
  class Echo(nn.Module):
@@ -1249,17 +929,17 @@ class Echo(nn.Module):
1249
  super().__init__()
1250
  self.param = param
1251
 
1252
- self.SpeechTransformer = SpeechTransformer(
1253
  vocab=param.vocab,
1254
  mels=param.mels,
1255
  ctx=param.ctx,
1256
  dims=param.dims,
1257
  head=param.head,
1258
  layer=param.layer,
1259
- debug=param.debug,
1260
  features=param.features,
1261
  act=param.act,
1262
- )
 
1263
 
1264
  def forward(self,
1265
  labels=None,
@@ -1268,33 +948,42 @@ class Echo(nn.Module):
1268
  spectrogram: Optional[torch.Tensor]=None,
1269
  pitch: Optional[torch.Tensor]=None,
1270
  f0: Optional[torch.Tensor]=None,
1271
- envelope: Optional[torch.Tensor]=None,
1272
- phase: Optional[torch.Tensor]=None,
 
1273
  ) -> Dict[str, Optional[torch.Tensor]]:
1274
 
1275
- encoder_inputs = {}
1276
  if spectrogram is not None:
1277
- encoder_inputs["spectrogram"] = spectrogram
 
1278
  if waveform is not None:
1279
- encoder_inputs["waveform"] = waveform
 
1280
  if pitch is not None:
1281
- encoder_inputs["pitch"] = pitch
1282
- if envelope is not None:
1283
- encoder_inputs["envelope"] = envelope
1284
- if phase is not None:
1285
- encoder_inputs["phase"] = phase
1286
  if f0 is not None:
1287
- encoder_inputs["f0"] = f0
 
 
 
 
 
 
1288
  if input_ids is not None:
1289
- encoder_inputs["input_ids"] = input_ids
 
 
 
1290
 
1291
- logits = self.SpeechTransformer(encoder_inputs)
1292
 
1293
  loss = None
1294
  if labels is not None:
1295
  loss = F.cross_entropy(
1296
  logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
1297
-
1298
  return {"logits": logits, "loss": loss}
1299
 
1300
  @property
@@ -1335,8 +1024,6 @@ class Echo(nn.Module):
1335
  self.init_counts["Conv2d"] += 1
1336
  elif isinstance(module, MultiheadA):
1337
  self.init_counts["MultiheadA"] += 1
1338
- elif isinstance(module, SpeechTransformer):
1339
- self.init_counts["SpeechTransformer"] += 1
1340
  elif isinstance(module, Residual):
1341
  self.init_counts["Residual"] += 1
1342
 
@@ -1361,24 +1048,24 @@ class Echo(nn.Module):
1361
  batch_size = x.shape[0]
1362
  break
1363
  ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
1364
- encoder_inputs = {}
1365
  if spectrogram is not None:
1366
- encoder_inputs["spectrogram"] = spectrogram
1367
  if waveform is not None:
1368
- encoder_inputs["waveform"] = waveform
1369
  if pitch is not None:
1370
- encoder_inputs["pitch"] = pitch
1371
  if envelope is not None:
1372
- encoder_inputs["envelope"] = envelope
1373
  if phase is not None:
1374
- encoder_inputs["phase"] = phase
1375
  if f0 is not None:
1376
- encoder_inputs["f0"] = f0
1377
 
1378
  for i in range(max_length - 1):
1379
  with torch.no_grad():
1380
- encoder_inputs["input_ids"] = ids
1381
- logits = self.SpeechTransformer(encoder_inputs)
1382
  next_token_logits = logits[:, -1, :]
1383
  if i < min_length:
1384
  next_token_logits[:, eos_token_id] = 0
@@ -1441,6 +1128,15 @@ def setup_tokenizer(token: str):
1441
  tokenizer.eos_token_id = 2
1442
  return tokenizer
1443
 
 
 
 
 
 
 
 
 
 
1444
  def load_wave(wave_data, sample_rate):
1445
  if isinstance(wave_data, str):
1446
  waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
@@ -1452,11 +1148,17 @@ def load_wave(wave_data, sample_rate):
1452
 
1453
  return waveform
1454
 
1455
- def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **dataset_config):
 
 
 
 
 
 
 
 
1456
 
1457
- audio = batch["audio"]
1458
- sr = audio["sampling_rate"]
1459
- wav = load_wave(wave_data=audio, sample_rate=sr)
1460
 
1461
  dataset_config = {
1462
  "hop_length": 256,
@@ -1471,29 +1173,99 @@ def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **data
1471
  "window_fn": torch.hann_window,
1472
  "mel_scale": "htk",
1473
  "norm": None,
1474
- "normalized": False}
1475
-
1476
- transform = torchaudio.transforms.MelSpectrogram(
1477
- **dataset_config
1478
- )
1479
-
1480
- mel_spectrogram = transform(wav)
1481
- log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1482
- log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1483
- spec = (log_mel + 4.0) / 4.0
1484
- spec = torch.tensor(spec)
1485
-
1486
- wav_np = wav.numpy().astype(np.float64)
1487
- f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
1488
- f0 = pw.stonemask(wav_np, f0, t, sample_rate)
1489
- f0 = torch.from_numpy(f0)
1490
 
 
 
 
1491
  labels = tokenizer.encode(batch["transcription"])
1492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1493
  return {
1494
- "spectrogram": spec,
1495
  "f0": f0,
 
 
 
 
1496
  "labels": labels,
 
 
 
1497
  }
1498
 
1499
  def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
@@ -1580,9 +1352,11 @@ class DataCollator:
1580
  batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1581
  batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1582
 
1583
- elif key in ["spectrogram", "waveform", "pitch", "f0", "envelope", "phase"]:
1584
-
1585
  items = [f[key] for f in features if key in f]
 
 
 
1586
  items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
1587
  max_len = max(item.shape[-1] for item in items)
1588
  padded = []
@@ -1675,11 +1449,17 @@ def main():
1675
  os.makedirs(log_dir, exist_ok=True)
1676
  tokenizer = setup_tokenizer(token)
1677
  train_dataset, test_dataset = prepare_datasets(tokenizer, token)
 
1678
  param = Dimensions(
1679
- vocab=40000, ctx=2048, dims=512, head=4, layer=4,
1680
- mels=128, act="swish",
1681
- debug={},
1682
- features=["spectrogram"]
 
 
 
 
 
1683
  )
1684
 
1685
  model = Echo(param).to('cuda')
 
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
  def get_activation(act: str) -> nn.Module:
36
+ """Get activation function by name."""
37
  act_map = {
38
  "gelu": nn.GELU(),
39
  "relu": nn.ReLU(),
 
51
  @dataclass
52
  class Dimensions:
53
  vocab: int
54
+ mels: int
55
  ctx: int
56
  dims: int
57
  head: int
58
  layer: int
 
59
  act: str
60
  debug: List[str]
61
  features: List[str]
 
198
  return fig
199
 
200
  def valid(default_value, *items):
201
+ """Get first non-None item"""
202
  for item in items:
203
  if item is not None:
204
  return item
 
266
  def tox():
267
  return {"device": get_device(), "dtype": get_dtype()}
268
 
269
+ class sinus(nn.Module):
270
+ def __init__(self, ctx: int, dims: int):
271
+ super().__init__()
272
+
273
+ position = torch.arange(start=0, end=ctx, dtype=dtype).unsqueeze(dim=1)
274
+ div_term = torch.exp(input=torch.arange(start=0, end=dims, step=2, dtype=dtype) * -(math.log(10000.0) / dims))
275
+ features = torch.zeros(ctx, dims)
276
+ features[:, 0::2] = torch.sin(position * div_term)
277
+ features[:, 1::2] = torch.cos(position* div_term)
278
+ self.register_buffer('sinusoid', tensor=features)
279
+ self.positional_embeddings = nn.Parameter(self.sinusoid.clone())
280
+ def forward(self, positions):
281
+ position_embeddings = self.positional_embeddings[positions]
282
+ return position_embeddings
283
+
284
  def sinusoids(length, channels, max_tscale=10000):
285
  assert channels % 2 == 0
286
  log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
 
340
  idx = torch.arange(ctx, device=f0.device)
341
  idx = (idx * F).long().clamp(0, L - 1)
342
  radius = radius[idx]
343
+ return torch.polar(radius.unsqueeze(-1), freqs), radius
344
+ else:
345
+ return torch.polar(radius.unsqueeze(-1), freqs), radius
346
  else:
347
+ return torch.polar(torch.ones_like(freqs), freqs), None
348
 
349
+ def check_f0(self, f0, f0t, ctx):
350
+ if f0 is not None and f0.shape[1] == ctx:
351
+ return f0
352
+ elif f0t is not None and f0t.shape[1] == ctx:
353
+ return f0t
 
 
 
 
354
  else:
355
+ return None
356
+
357
+ def forward(self, x=None, enc=None, layer=None, feature=None) -> Tensor:
358
+ ctx=x
359
+ f0 = enc.get("f0") if enc is not None else None
360
+ f0t = enc.get("f0t") if enc is not None else None
361
 
362
+ f0 = self.check_f0(f0, f0t, ctx)
363
  if f0 is not None:
364
  if f0.dim() == 2:
365
  f0 = f0.squeeze(0)
366
  theta = f0 + self.theta
367
  else:
368
  theta = self.theta
 
369
  freqs = self.theta_freqs(theta)
370
  t = torch.arange(ctx, device=device, dtype=dtype)
371
  freqs = t[:, None] * freqs
372
+ freqs, radius = self._apply_radii(freqs, f0, ctx)
373
+
 
 
 
 
 
 
374
  if "radius" in self.debug and self.counter == 10:
375
+ print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
 
 
 
 
376
  self.counter += 1
377
  return freqs.unsqueeze(0)
378
 
 
393
 
394
  rbf = False
395
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
396
+ zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
 
397
  super(MultiheadA, self).__init__()
398
 
399
  self.dims = dims
 
426
  else:
427
  self.rope = None
428
 
 
 
 
 
 
 
 
 
 
 
429
  def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
430
  q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
431
  k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
 
447
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
448
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
449
 
450
+ def forward(self, x: Tensor, xa = None, mask = None, enc = None, layer = None, feature=None) -> tuple:
451
 
452
  x = x.to(device, dtype)
453
  if xa is not None:
 
475
 
476
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  if self.rbf:
479
  qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
480
  if self.use_pbias:
481
+ pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None)
482
  if pbias is not None:
483
  qk = qk + pbias[:,:,:q2,:q2]
484
 
 
487
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
488
  zscale[token_ids.float() == self.pad_token] = fzero
489
 
490
+ if xa is not None:
 
491
  qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
 
492
  qk = qk * zscale.unsqueeze(-2)
493
  w = F.softmax(qk, dim=-1).to(q.dtype)
494
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
 
498
  self.counter += 1
499
  return self.o(wv), qk
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  class t_gate(nn.Module):
502
  def __init__(self, dims, num_types=4, enabled=True):
503
  super().__init__()
 
578
  class Residual(nn.Module):
579
  _seen = set()
580
  def __init__(self, ctx, dims, head, act, debug: List[str] = [],
581
+ tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
582
  super().__init__()
583
 
584
  self.dims = dims
 
592
 
593
  self.blend = nn.Parameter(torch.tensor(0.5))
594
  act_fn = get_activation(act)
 
595
  self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
 
596
 
597
  if not any([tgate, mgate, cgate]):
598
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
 
611
  self.lnb = RMSNorm(dims)
612
  self.lnc = RMSNorm(dims)
613
 
614
+ def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature=None) -> Tensor:
 
 
615
 
616
  b = torch.sigmoid(self.blend)
617
+ ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer, feature=feature)[0]
618
  bx = b * ax + (1 - b) * x
619
  cx = self.lnb(bx)
620
  dx = self.mlp(cx)
 
621
  ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
622
  fx = x + ex + dx
623
  gx = self.lnc(fx)
 
659
  self.norm = RMSNorm(dims)
660
  self._norm = RMSNorm(dims)
661
 
662
+ def apply_rope_to_features(self, x, layer=None, feature=None):
663
+ if feature in ["envelope", "phase"]:
664
+ feature = "spectrogram"
665
  batch, ctx, dims = x.shape
666
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
667
  if feature_type == "spectrogram" and self.rope is not None:
 
672
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
673
  return x
674
 
675
+ def forward(self, x, enc=None, layer=None, feature=None):
676
  x = self.encoder(x).permute(0, 2, 1)
677
  if self.use_rope:
678
+ x = self.apply_rope_to_features(x, layer=layer, feature=feature)
679
  else:
680
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
681
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
 
712
  self.positional = lambda length: sinusoids(length, dims)
713
  self.norm = RMSNorm(dims)
714
 
715
+ def apply_rope_to_features(self, x, layer=None, feature=None):
716
  if not self.use_rope or self.rope is None:
717
  return x
718
  batch, ctx, dims = x.shape
719
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
720
+ rope_freqs = self.rope(ctx, layer=layer, feature=feature)
721
  x = self.rope.apply_rotary(x, rope_freqs)
722
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
723
  return x
724
 
725
+ def forward(self, x, enc=None, layer=None, feature=None):
726
  x = self.downsample(x)
727
  x = self.encoder(x)
728
  x = x.permute(0, 2, 1)
 
760
  self.positional = lambda length: sinusoids(length, dims)
761
  self.norm = RMSNorm(dims)
762
 
763
+ def apply_rope_to_features(self, x, layer=None, feature=None):
764
  if not self.use_rope or self.rope is None:
765
  return x
766
  batch, ctx, dims = x.shape
767
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
768
+ rope_freqs = self.rope(ctx, layer=layer, feature=feature)
769
  x = self.rope.apply_rotary(x, rope_freqs)
770
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
771
  return x
772
 
773
+ def forward(self, x, enc=None, layer=None, feature=None):
774
  x = self.encoder(x).permute(0, 2, 1)
775
  if self.use_rope:
776
  x = self.apply_rope_to_features(x, layer=layer)
 
780
  x = self.norm(x)
781
  return x
782
 
783
+ class theBridge(nn.Module):
784
+ def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
785
+ debug: List[str], features: List[str], act: str = "gelu"):
786
+ super(theBridge, self).__init__()
787
+
788
+ self.ctx = ctx
789
  self.dims = dims
790
  self.head = head
 
791
  self.head_dim = dims // head
792
  self.debug = debug
793
  self.counter = 0
794
+ self.dropout = 0.01
795
  self.features = features
796
+ self.do_blend = "no_blend" not in self.debug
797
+ self.sequential = "sequential" in self.debug
 
798
 
799
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
800
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
801
+ self.sinusoid = lambda length: sinusoids(length, dims)
802
+ self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
803
+ self.ln_dec = RMSNorm(dims)
804
+
805
+ with torch.no_grad():
806
+ self.token.weight[0].zero_()
807
+
808
+ self.block = nn.ModuleList([
809
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
810
+ for _ in range(layer)])
811
+
812
+ self.cross_attn = nn.ModuleList([
813
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
814
+ for _ in range(layer)])
815
+
816
+ self.cross_modal = nn.ModuleList([
817
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
818
+ for _ in range(layer)])
819
 
820
+ mask = torch.tril(torch.ones(ctx, ctx), diagonal=0).unsqueeze(0).unsqueeze(0)
821
+ self.register_buffer("mask", mask, persistent=False)
822
+ self.register_buffer("mask_win", self.window_mask(ctx, ctx), persistent=False)
823
+ self.register_buffer("mask_cat", self.modal_mask(ctx, ctx), persistent=False)
824
+ self.register_buffer("mask_cross", self.cross_mask(ctx, ctx), persistent=False)
825
+
826
+ act_fn = get_activation(act)
827
  if features == ["spectrogram", "waveform", "pitch"]:
828
  cgate=True
829
  else:
830
  cgate = False
831
 
832
  self.blocks = nn.ModuleDict({
 
833
  "spectrogram": nn.ModuleList(
834
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
835
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "spectrogram" in features else None),
 
 
836
  "waveform": nn.ModuleList(
837
  [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
838
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "waveform" in features else None),
 
 
839
  "pitch": nn.ModuleList(
840
  [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
841
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None),
 
 
842
  "envelope": nn.ModuleList(
843
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
844
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "envelope" in features else None),
 
 
845
  "phase": nn.ModuleList(
846
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
847
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None)})
848
+
849
+ def window_mask(self, text_ctx, aud_ctx):
850
+ mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device), diagonal=0)
851
+ audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device))
852
+ full_mask = torch.cat([mask, audio_mask], dim=-1)
853
+ return full_mask.unsqueeze(0).unsqueeze(0)
854
+
855
+ def modal_mask(self, text_len, audio_len):
856
+ combined_mask = torch.ones(text_len + audio_len, text_len + audio_len, device=device)
857
+ combined_mask[:text_len, :text_len] = torch.tril(torch.ones(text_len, text_len, device=device))
858
+ combined_mask[:text_len, text_len:] = torch.tril(torch.ones(text_len, audio_len, device=device))
859
+ return combined_mask.unsqueeze(0).unsqueeze(0)
860
+
861
+ def cross_mask(self, text_len, audio_len):
862
+ mask = torch.tril(torch.ones(text_len, text_len, device=device))
863
+ audio_mask = torch.tril(torch.ones(text_len, audio_len, device=device))
864
+ full_mask = torch.cat([mask, audio_mask], dim=-1)
865
+ return full_mask.unsqueeze(0).unsqueeze(0)
866
+
867
+ def forward(self, x, enc, layer='decoder', feature=None) -> Tensor:
 
868
  enc = dict_to(enc, device, dtype)
869
+ _text_len = x.shape[1]
870
 
 
871
  x = self.token(x) + self.positional[:x.shape[1]]
872
  x = F.dropout(x, p=self.dropout, training=self.training)
873
 
874
+ for f in enc:
875
+ if f in self.features:
 
 
 
876
  xa = enc[f]
877
  for block in self.blocks[f]:
878
+ xa = block(xa, enc=enc, layer=layer, feature=feature)
 
 
879
 
880
  for block in self.block:
881
  mask = self.mask[:x.shape[1], :x.shape[1]]
882
+ x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
883
+ if feature in self.features:
884
+ xa = xa + self.sinusoid(xa.shape[1])
885
+ mask = self.mask_win[:x.shape[1], :xa.shape[1]]
886
+ out = block(x, xa=xa, mask=mask, enc=enc, layer=layer)
887
+ if self.sequential:
888
+ x = out
889
+ else:
890
+ a = torch.sigmoid(self.blend)
891
+ x = a * out + (1 - a) * x
892
 
893
+ for block in self.cross_attn:
894
+ if feature in self.features:
895
+ xa = xa + self.sinusoid(xa.shape[1])
896
+ mask_x = self.cross_mask(x.shape[1], xa.shape[1])
897
+ mask_xa = self.cross_mask(xa.shape[1], x.shape[1])
898
+ x = block(x, xa=xa, mask=mask_x, enc=enc, layer=layer)
899
+ xa = block(xa, xa=x, mask=mask_xa, enc=enc, layer=layer)
900
+ out = block(x, xa=xa, mask=mask_x, enc=enc, layer=layer)
901
  if self.sequential:
902
  x = out
903
  else:
904
  a = torch.sigmoid(self.blend)
905
  x = a * out + (1 - a) * x
906
 
907
+ for block in self.cross_modal:
908
+ if feature in enc:
909
+ xa = xa + self.sinusoid(xa.shape[1])
910
+ xcat = torch.cat([x, xa], dim=1)
911
+ mask = self.mask_cat(x.shape[1], xa.shape[1])
912
+ x = block(xcat, xa=None, mask=mask, enc=enc, layer=layer)
913
+ x = x[:, :_text_len]
914
+
915
+ if self.counter < 1 and "encoder" in self.debug:
916
+ s = enc.get("spectrogram")
917
+ w = enc.get("waveform")
918
+ p = default(enc.get("pitch"), enc.get("f0"))
919
+ plot_waveform(x=s, w=w, p=p, hop_length=128)
920
+ shapes = {k: v.shape for k, v in enc.items()}
921
+ print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
922
+ self.counter += 1
923
+
924
+ x = self.ln_dec(x)
925
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
926
 
927
  class Echo(nn.Module):
 
929
  super().__init__()
930
  self.param = param
931
 
932
+ self.processor = theBridge(
933
  vocab=param.vocab,
934
  mels=param.mels,
935
  ctx=param.ctx,
936
  dims=param.dims,
937
  head=param.head,
938
  layer=param.layer,
 
939
  features=param.features,
940
  act=param.act,
941
+ debug=param.debug,
942
+ )
943
 
944
  def forward(self,
945
  labels=None,
 
948
  spectrogram: Optional[torch.Tensor]=None,
949
  pitch: Optional[torch.Tensor]=None,
950
  f0: Optional[torch.Tensor]=None,
951
+ f0t: Optional[torch.Tensor]=None,
952
+ harmonic: Optional[torch.Tensor]=None,
953
+ aperiodic: Optional[torch.Tensor]=None,
954
  ) -> Dict[str, Optional[torch.Tensor]]:
955
 
956
+ enc = {}
957
  if spectrogram is not None:
958
+ enc["spectrogram"] = spectrogram
959
+ feature = "spectrogram"
960
  if waveform is not None:
961
+ enc["waveform"] = waveform
962
+ feature = "waveform"
963
  if pitch is not None:
964
+ enc["pitch"] = pitch
965
+ feature = "pitch"
 
 
 
966
  if f0 is not None:
967
+ enc["f0"] = f0
968
+ if f0t is not None:
969
+ enc["f0t"] = f0t
970
+ if harmonic is not None:
971
+ enc["harmonic"] = harmonic
972
+ if aperiodic is not None:
973
+ enc["aperiodic"] = aperiodic
974
  if input_ids is not None:
975
+ enc["input_ids"] = input_ids
976
+ feature = "input_ids"
977
+ else:
978
+ feature = "spectrogram"
979
 
980
+ logits = self.processor(input_ids, enc, feature)
981
 
982
  loss = None
983
  if labels is not None:
984
  loss = F.cross_entropy(
985
  logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
986
+
987
  return {"logits": logits, "loss": loss}
988
 
989
  @property
 
1024
  self.init_counts["Conv2d"] += 1
1025
  elif isinstance(module, MultiheadA):
1026
  self.init_counts["MultiheadA"] += 1
 
 
1027
  elif isinstance(module, Residual):
1028
  self.init_counts["Residual"] += 1
1029
 
 
1048
  batch_size = x.shape[0]
1049
  break
1050
  ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
1051
+ feature = {}
1052
  if spectrogram is not None:
1053
+ feature["spectrogram"] = spectrogram
1054
  if waveform is not None:
1055
+ feature["waveform"] = waveform
1056
  if pitch is not None:
1057
+ feature["pitch"] = pitch
1058
  if envelope is not None:
1059
+ feature["envelope"] = envelope
1060
  if phase is not None:
1061
+ feature["phase"] = phase
1062
  if f0 is not None:
1063
+ feature["f0"] = f0
1064
 
1065
  for i in range(max_length - 1):
1066
  with torch.no_grad():
1067
+ feature["input_ids"] = ids
1068
+ logits = self.SpeechTransformer(feature)
1069
  next_token_logits = logits[:, -1, :]
1070
  if i < min_length:
1071
  next_token_logits[:, eos_token_id] = 0
 
1128
  tokenizer.eos_token_id = 2
1129
  return tokenizer
1130
 
1131
+ def tokenize_pitch(pitch_features, target_length):
1132
+ pitch_len = pitch_features.shape[-1]
1133
+ token_len = target_length
1134
+ if pitch_len > token_len:
1135
+ pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len)
1136
+ else:
1137
+ pitch_tokens = F.interpolate(pitch_features, token_len)
1138
+ return pitch_tokens
1139
+
1140
  def load_wave(wave_data, sample_rate):
1141
  if isinstance(wave_data, str):
1142
  waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
 
1148
 
1149
  return waveform
1150
 
1151
+ def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
1152
+ import librosa
1153
+ mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
1154
+ mel_basis = torch.from_numpy(mel_basis).float()
1155
+
1156
+ sp_mel = torch.matmul(sp, mel_basis.T)
1157
+ ap_mel = torch.matmul(ap, mel_basis.T)
1158
+
1159
+ return sp_mel, ap_mel
1160
 
1161
+ def extract_features(batch, tokenizer, waveform=False, spec=True, f0=False, f0t=False, sample_rate=16000, hop_length=256, mode="mean", debug=True, **dataset_config):
 
 
1162
 
1163
  dataset_config = {
1164
  "hop_length": 256,
 
1173
  "window_fn": torch.hann_window,
1174
  "mel_scale": "htk",
1175
  "norm": None,
1176
+ "normalized": False,
1177
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1178
 
1179
+ audio = batch["audio"]
1180
+ sr = audio["sampling_rate"]
1181
+ wave = load_wave(wave_data=audio, sample_rate=sr)
1182
  labels = tokenizer.encode(batch["transcription"])
1183
 
1184
+ if waveform:
1185
+ wav = load_wave(wave_data=audio, sample_rate=sr)
1186
+ else:
1187
+ wav = None
1188
+
1189
+ if spec:
1190
+ transform = torchaudio.transforms.MelSpectrogram( **dataset_config)
1191
+ mel_spectrogram = transform(wave)
1192
+ log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1193
+ log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1194
+ spec = (log_mel + 4.0) / 4.0
1195
+ spec = torch.tensor(spec)
1196
+ else:
1197
+ spec = None
1198
+
1199
+ if f0:
1200
+ wavnp = wave.numpy().astype(np.float64)
1201
+ f0_np, t = pw.dio(wavnp, sample_rate,
1202
+ frame_period = hop_length / sample_rate * 1000)
1203
+ f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
1204
+ f0 = torch.from_numpy(f0_np)
1205
+
1206
+ if f0t:
1207
+ audio_duration = len(wavnp) / sample_rate
1208
+ T = len(labels)
1209
+ tok_dur_sec = audio_duration / T
1210
+ token_starts = np.arange(T) * tok_dur_sec
1211
+ token_ends = token_starts + tok_dur_sec
1212
+ start_idx = np.searchsorted(t, token_starts, side="left")
1213
+ end_idx = np.searchsorted(t, token_ends, side="right")
1214
+ pitch_tok = np.zeros(T, dtype=np.float32)
1215
+ for i in range(T):
1216
+ lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
1217
+ segment = f0_np[lo:hi]
1218
+ pitch_tok[i] = segment.mean() if mode=="mean" else (np.median(segment) if mode=="median" else segment[-1])
1219
+ pitch_tok[pitch_tok < 100.0] = 0.0
1220
+
1221
+ bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
1222
+ f0t = torch.from_numpy(np.concatenate([[bos_pitch], pitch_tok]))
1223
+ f0t = torch.from_numpy(pitch_tok)
1224
+ f0 = torch.from_numpy(f0_np)
1225
+
1226
+ spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
1227
+ apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
1228
+ sp = torch.from_numpy(spnp)
1229
+ ap = torch.from_numpy(apnp)
1230
+ sp = sp[:, :128].contiguous().T
1231
+ ap = ap[:, :128].contiguous().T
1232
+ f0t = torch.where(f0t == 0.0, torch.zeros_like(f0t), (f0t - 71.0) / (400.0 - 71.0))
1233
+ sp = torch.where(sp == 0.0, torch.zeros_like(sp), sp / 1.0)
1234
+ ap= torch.where(ap == 0.0, torch.zeros_like(ap), ap / 1.0)
1235
+
1236
+ else:
1237
+ f0t = None
1238
+ sp = None
1239
+ ap = None
1240
+ t = None
1241
+ token_starts = None
1242
+ else:
1243
+ f0t = None
1244
+ f0 = None
1245
+ sp = None
1246
+ ap = None
1247
+ t = None
1248
+ token_starts = None
1249
+
1250
+ if debug:
1251
+ print(f"['f0']: {f0t if f0t is not None else None}")
1252
+ print(f"['f0']: {f0.shape if f0 is not None else None}")
1253
+ print(f"['f0t']: {f0t.shape if f0t is not None else None}")
1254
+ print(f"['harmonic']: {sp.shape if sp is not None else None}")
1255
+ print(f"['aperiodic']: {ap.shape if ap is not None else None}")
1256
+ print(f"['spec']: {spec.shape if spec is not None else None}")
1257
+ print(f"['wav']: {wav.shape if wav is not None else None}")
1258
+
1259
  return {
 
1260
  "f0": f0,
1261
+ "f0t": f0t,
1262
+ "pitch": f0,
1263
+ "harmonic": sp,
1264
+ "aperiodic": ap,
1265
  "labels": labels,
1266
+ "waveform": wav,
1267
+ "spectrogram": spec,
1268
+
1269
  }
1270
 
1271
  def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
 
1352
  batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1353
  batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1354
 
1355
+ elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0"]:
 
1356
  items = [f[key] for f in features if key in f]
1357
+ items = [item for item in items if item is not None]
1358
+ if not items:
1359
+ continue
1360
  items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
1361
  max_len = max(item.shape[-1] for item in items)
1362
  padded = []
 
1449
  os.makedirs(log_dir, exist_ok=True)
1450
  tokenizer = setup_tokenizer(token)
1451
  train_dataset, test_dataset = prepare_datasets(tokenizer, token)
1452
+
1453
  param = Dimensions(
1454
+ vocab=40000,
1455
+ mels=128,
1456
+ ctx=1500,
1457
+ dims=512,
1458
+ head=4,
1459
+ layer=4,
1460
+ act="swish",
1461
+ debug={"decoder", "radius"},
1462
+ features = ["spectrogram"],
1463
  )
1464
 
1465
  model = Echo(param).to('cuda')