Sin2pi commited on
Commit
83198ca
·
verified ·
1 Parent(s): 09b0505

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +230 -108
modelA.py CHANGED
@@ -273,6 +273,8 @@ def get_dtype():
273
  def tox():
274
  return {"device": get_device(), "dtype": get_dtype()}
275
 
 
 
276
  class Sinusoids(nn.Module):
277
  def __init__(self, length, channels, max_tscale=10000):
278
  super().__init__()
@@ -294,24 +296,29 @@ def sinusoids(length, channels, max_tscale=10000):
294
  scaled_t = torch.arange(length)[:, None] * inv_tscales[None, :]
295
  pos1 = torch.sin(scaled_t)
296
  pos2 = torch.cos(scaled_t)
297
- positions = torch.cat([pos1, pos2], dim=1)
298
  return nn.Parameter(positions.clone())
299
 
 
 
 
 
 
 
 
300
  class rotary(nn.Module):
301
- def __init__(self, dims, head, max_ctx=1500, radii=True, debug: List[str] = [], use_pbias=False, use_2d_axial=False, spec_shape=None, use_true_2d_relative=False, freq_bins=None):
302
  super(rotary, self).__init__()
303
  self.use_pbias = use_pbias
304
  self.dims = dims
305
  self.head = head
306
  self.head_dim = dims // head
307
  self.radii = radii
308
- self.dim = self.head_dim
309
  self.debug = debug
310
  self.counter = 0
311
- self.last_theta = None
312
-
313
- self.use_2d_axial = use_2d_axial
314
- if use_2d_axial and spec_shape is not None:
315
  time_frames, freq_bins = spec_shape
316
  self.time_frames = time_frames
317
  self.freq_bins = freq_bins
@@ -327,13 +334,13 @@ class rotary(nn.Module):
327
  self.theta = nn.Parameter(theta, requires_grad=True)
328
  self.theta_values = []
329
 
330
- self.use_true_2d_relative = use_true_2d_relative
331
  self.freq_bins = freq_bins
332
  self.true2d_dim = (dims // head) // 2
333
  self.omega_t = nn.Parameter(torch.randn(self.true2d_dim))
334
  self.omega_f = nn.Parameter(torch.randn(self.true2d_dim))
335
 
336
- def axial_freqs(self, seq_len):
337
  if not self.use_2d_axial:
338
  return None
339
  time_frames = self.time_frames
@@ -362,12 +369,19 @@ class rotary(nn.Module):
362
  f0_norm.unsqueeze(1)))
363
  return f0_sim.unsqueeze(0).unsqueeze(0)
364
 
 
 
 
 
 
 
 
365
  def theta_freqs(self, theta):
366
  if theta.dim() == 0:
367
  theta = theta.unsqueeze(0)
368
  freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
369
  torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
370
- self.dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
371
  return freq
372
 
373
  def _apply_radii(self, freqs, f0, ctx):
@@ -390,7 +404,6 @@ class rotary(nn.Module):
390
  f0 = f0.squeeze(0)
391
  if f0t is not None and f0t.dim() == 2:
392
  f0t = f0t.squeeze(0)
393
-
394
  if f0 is not None and f0.shape[0] == ctx:
395
  return f0
396
  elif f0t is not None and f0t.shape[0] == ctx:
@@ -400,7 +413,7 @@ class rotary(nn.Module):
400
 
401
  def forward(self, x=None, enc=None, layer=None, feature=None) -> Tensor:
402
  ctx = x
403
- if self.use_2d_axial and feature == "spectrogram":
404
  freqs_2d = self.axial_freqs(ctx)
405
  if freqs_2d is not None:
406
  return freqs_2d.unsqueeze(0)
@@ -408,8 +421,12 @@ class rotary(nn.Module):
408
  f0 = enc.get("f0") if enc is not None else None
409
  f0t = enc.get("f0t") if enc is not None else None
410
  f0 = self.check_f0(f0, f0t, ctx)
 
411
  theta = f0 + self.theta if f0 is not None else self.theta
 
 
412
  freqs = self.theta_freqs(theta)
 
413
  t = torch.arange(ctx, device=device, dtype=dtype)
414
  freqs = t[:, None] * freqs
415
  freqs, radius = self._apply_radii(freqs, f0, ctx)
@@ -433,6 +450,7 @@ class rotary(nn.Module):
433
  x1 = x1.view(orig_shape)
434
  return torch.cat([x1.type_as(x), x2], dim=-1)
435
 
 
436
  # @staticmethod
437
  # def apply_rotary(x, freqs):
438
  # # x: [batch, head, seq, head_dim]
@@ -497,28 +515,23 @@ class rotary(nn.Module):
497
  # k_rot = to_real(k_rot)
498
  # return q_rot, k_rot
499
 
 
500
  def parallel_slice(self, q, k, v, mask=None):
501
  batch, head, ctx, dims = q.shape
502
  head_dim = self.head_dim
503
  batch, ctx, dims = q.shape
504
  ctx_len = k.shape[1]
505
  head = dims // head_dim
506
-
507
  scores = torch.zeros(batch, head, ctx, ctx_len, device=q.device)
508
-
509
  for h in range(head):
510
  start_idx = h * head_dim
511
  end_idx = start_idx + head_dim
512
  q_h = q[:, :, start_idx:end_idx]
513
  k_h = k[:, :, start_idx:end_idx]
514
-
515
  scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim)
516
-
517
  if mask is not None:
518
  scores = scores + mask.unsqueeze(0).unsqueeze(0)
519
-
520
  attn_weights = F.softmax(scores, dim=-1)
521
-
522
  output = torch.zeros_like(q)
523
  for h in range(head):
524
  start_idx = h * head_dim
@@ -527,9 +540,60 @@ def parallel_slice(self, q, k, v, mask=None):
527
  output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h)
528
  return output
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  class MultiheadA(nn.Module):
531
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
532
- zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False, use_true_2d_relative=False, freq_bins=None, radii=False, use_2d_axial=False, spec_shape=None, rbf=False):
533
 
534
  super(MultiheadA, self).__init__()
535
  self.dims = dims
@@ -538,7 +602,7 @@ class MultiheadA(nn.Module):
538
  self.debug = debug
539
  self.counter = 0
540
  self.use_pbias = use_pbias
541
- self.use_true_2d_relative = use_true_2d_relative
542
  self.freq_bins = freq_bins
543
  self.rbf = rbf
544
 
@@ -551,8 +615,7 @@ class MultiheadA(nn.Module):
551
  self.rotary_emb = rotary_emb
552
  self.minz = minz
553
  self.maxz = maxz
554
- self.zero_val = zero_val
555
- self.optim_attn = optim_attn
556
  self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
557
 
558
  if rotary_emb:
@@ -561,7 +624,7 @@ class MultiheadA(nn.Module):
561
  head=head,
562
  debug=debug,
563
  radii=radii,
564
- use_true_2d_relative=use_true_2d_relative,
565
  freq_bins=freq_bins,
566
  )
567
  else:
@@ -607,14 +670,14 @@ class MultiheadA(nn.Module):
607
  q2 = q.shape[2]
608
  k2 = k.shape[2]
609
 
610
- if self.use_true_2d_relative and feature == "spectrogram":
611
  seq_len = q2
612
  freq_bins = self.freq_bins
613
  idxs = torch.arange(seq_len, device=q.device)
614
  t_idx = idxs // freq_bins
615
  f_idx = idxs % freq_bins
616
- angle = self.rope.true2d_relative_angle(t_idx, f_idx, t_idx, f_idx)
617
- q_rot, k_rot = self.rope.true2d_apply_rotary(q, k, angle)
618
  scale = (self.dims // self.head) ** -0.25
619
  qk = (q_rot * scale * k_rot * scale).sum(-1)
620
  w = F.softmax(qk, dim=-1).to(q.dtype)
@@ -622,8 +685,8 @@ class MultiheadA(nn.Module):
622
  wv = wv.permute(0, 2, 1, 3).flatten(start_dim=2)
623
  return self.o(wv), qk
624
  else:
625
- q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
626
- k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
627
  else:
628
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
629
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -657,6 +720,8 @@ class MultiheadA(nn.Module):
657
  self.counter += 1
658
  return self.o(wv), qk
659
 
 
 
660
  class t_gate(nn.Module):
661
  def __init__(self, dims, num_types=4, enabled=True):
662
  super().__init__()
@@ -723,21 +788,26 @@ class c_gate(nn.Module):
723
  return self.integ(comb)
724
 
725
  class mlp_gate(nn.Module):
726
- def __init__(self, dims, enabled=True):
727
  super().__init__()
728
  self.enabled = enabled
729
  if enabled:
730
  self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
 
 
 
731
 
732
- def forward(self, x):
733
  if not self.enabled:
734
  return None
 
 
735
  return self.gate(x)
736
 
737
  class Residual(nn.Module):
738
  _seen = set()
739
  def __init__(self, ctx, dims, head, act, debug: List[str] = [],
740
- tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
741
  super().__init__()
742
 
743
  self.dims = dims
@@ -748,10 +818,12 @@ class Residual(nn.Module):
748
  self.debug = debug
749
  self.counter = 0
750
  self.dropout = 0.01
 
751
 
752
  self.blend = nn.Parameter(torch.tensor(0.5))
753
  act_fn = get_activation(act)
754
  self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
 
755
 
756
  if not any([tgate, mgate, cgate]):
757
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
@@ -764,16 +836,16 @@ class Residual(nn.Module):
764
  self.t_gate = t_gate(dims=dims, num_types=4*2, enabled=tgate)
765
  self.m_gate = m_gate(dims=dims, mem_size=mem_size, enabled=mgate)
766
  self.c_gate = c_gate(dims=dims, enabled=cgate)
767
- self.mlp_gate = mlp_gate(dims=dims, enabled=not any([tgate, mgate, cgate]))
768
 
769
  self.lna = RMSNorm(dims)
770
  self.lnb = RMSNorm(dims)
771
  self.lnc = RMSNorm(dims)
772
 
773
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature=None) -> Tensor:
774
-
775
  b = torch.sigmoid(self.blend)
776
- ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer, feature=feature)[0]
777
  bx = b * ax + (1 - b) * x
778
  cx = self.lnb(bx)
779
  dx = self.mlp(cx)
@@ -817,7 +889,7 @@ class FEncoder(nn.Module):
817
 
818
  self.norm = RMSNorm(dims)
819
 
820
- def apply_rope_to_features(self, x, layer=None, feature=None):
821
  batch, ctx, dims = x.shape
822
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
823
  if feature == "spectrogram" and self.rope is not None:
@@ -828,7 +900,7 @@ class FEncoder(nn.Module):
828
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
829
  return x
830
 
831
- def forward(self, x, enc=None, layer=None, feature=None):
832
  x = self.encoder(x).permute(0, 2, 1)
833
  if self.use_rope:
834
  x = self.apply_rope_to_features(x, layer=layer, feature=feature)
@@ -866,7 +938,7 @@ class WEncoder(nn.Module):
866
  self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
867
  self.norm = RMSNorm(dims)
868
 
869
- def apply_rope_to_features(self, x, layer=None, feature=None):
870
  if not self.use_rope or self.rope is None:
871
  return x
872
  batch, ctx, dims = x.shape
@@ -876,7 +948,7 @@ class WEncoder(nn.Module):
876
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
877
  return x
878
 
879
- def forward(self, x, enc=None, layer=None, feature=None):
880
  x = self.downsample(x)
881
  x = self.encoder(x)
882
  x = x.permute(0, 2, 1)
@@ -888,7 +960,7 @@ class WEncoder(nn.Module):
888
  return self.norm(x)
889
 
890
  class PEncoder(nn.Module):
891
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
892
  super().__init__()
893
 
894
  self.head = head
@@ -896,13 +968,14 @@ class PEncoder(nn.Module):
896
  self.dropout = 0.01
897
  self.use_rope = use_rope
898
  self.dims = dims
899
-
900
  act_fn = get_activation(act)
901
-
902
  self.encoder = nn.Sequential(
903
- Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
904
- Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
905
- Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2), act_fn)
 
906
 
907
  if use_rope:
908
  self.rope = rotary(
@@ -911,10 +984,10 @@ class PEncoder(nn.Module):
911
  debug=[])
912
  else:
913
  self.rope = None
914
- self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
915
  self.norm = RMSNorm(dims)
916
 
917
- def apply_rope_to_features(self, x, layer=None, feature=None):
918
  if not self.use_rope or self.rope is None:
919
  return x
920
  batch, ctx, dims = x.shape
@@ -924,20 +997,36 @@ class PEncoder(nn.Module):
924
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
925
  return x
926
 
927
- def forward(self, x, enc=None, layer=None, feature=None):
928
- x = self.encoder(x).permute(0, 2, 1)
929
  if self.use_rope:
930
- x = self.apply_rope_to_features(x, layer=layer)
931
  else:
932
- x = x + self.sinusoid_pos(x.shape[1], x.shape[-1]).to(x.device, x.dtype)
933
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
934
- return self.norm(x)
 
 
 
 
 
 
 
 
 
 
 
 
935
 
936
  class theBridge(nn.Module):
937
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
938
  debug: List[str], features: List[str], act: str = "gelu"):
939
  super(theBridge, self).__init__()
940
 
 
 
 
 
941
  self.debug = debug
942
  self.counter = 0
943
  self.dropout = 0.01
@@ -948,25 +1037,26 @@ class theBridge(nn.Module):
948
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
949
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
950
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
951
- self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
 
952
 
953
  with torch.no_grad():
954
  self.token.weight[0].zero_()
955
 
956
  self.block = nn.ModuleList([
957
- Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features)
958
  for _ in range(layer)])
959
 
960
  self.cross_attn = nn.ModuleList([
961
- Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features)
962
  for _ in range(layer)])
963
 
964
  self.cross_modal = nn.ModuleList([
965
- Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features)
966
  for _ in range(layer)])
967
 
968
- mask = torch.tril(torch.ones(ctx, ctx), diagonal=0).unsqueeze(0).unsqueeze(0)
969
- self.register_buffer("mask", mask, persistent=False)
970
 
971
  act_fn = get_activation(act)
972
  if features == ["spectrogram", "waveform", "pitch"]:
@@ -974,7 +1064,7 @@ class theBridge(nn.Module):
974
  else:
975
  cgate = False
976
 
977
- self.blocks = nn.ModuleDict({
978
  "spectrogram": nn.ModuleList(
979
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
980
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "spectrogram" in features else None),
@@ -982,7 +1072,7 @@ class theBridge(nn.Module):
982
  [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
983
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "waveform" in features else None),
984
  "pitch": nn.ModuleList(
985
- [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
986
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None),
987
  "envelope": nn.ModuleList(
988
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
@@ -991,29 +1081,30 @@ class theBridge(nn.Module):
991
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
992
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None)})
993
 
994
- self.norm = RMSNorm(dims)
995
 
996
- def forward(self, x, enc, layer='decoder', feature=None) -> Tensor:
 
 
997
  out = {}
998
  out.update(enc)
999
  enc = dict_to(enc, device, dtype)
1000
  _text_len = x.shape[1]
1001
-
1002
  x = self.token(x) + self.positional[:x.shape[1]]
1003
 
1004
  for f in enc:
1005
  if f in self.features:
1006
  xa = enc[f]
1007
- for block in self.blocks[f]:
1008
- xa = block(xa, enc=enc, layer=layer, feature=feature)
1009
- xa = xa + self.sinusoid_pos(xa.shape[1], xa.shape[-1]).to(xa.device, xa.dtype)
1010
  out[f] = xa
1011
 
1012
  for block in self.block:
1013
- x = block(x, xa=None, mask=self.mask, enc=enc, layer=layer)
 
 
1014
  if f in self.features:
1015
-
1016
- out = block(x, xa=xa, mask=self.mask, enc=enc, layer=layer)
1017
  if self.sequential:
1018
  x = out
1019
  else:
@@ -1024,9 +1115,9 @@ class theBridge(nn.Module):
1024
 
1025
  for block in self.cross_attn:
1026
  if f in self.features:
1027
- x = block(x, xa=xa, mask=self.mask, enc=enc, layer=layer)
1028
- xa = block(xa, xa=x, mask=self.mask, enc=enc, layer=layer)
1029
- out = block(x, xa=xa, mask=self.mask, enc=enc, layer=layer)
1030
  if self.sequential:
1031
  x = out
1032
  else:
@@ -1038,23 +1129,24 @@ class theBridge(nn.Module):
1038
  for block in self.cross_modal:
1039
  if f in self.features:
1040
  xcat = torch.cat([x, xa], dim=1)
1041
- x = block(xcat, xa=None, mask=self.mask, enc=enc, layer=layer)
1042
  x = x[:, :_text_len]
1043
  out[f] = x
 
1044
  if self.counter < 1 and "encoder" in self.debug:
1045
  shapes = {k: v.shape for k, v in enc.items()}
1046
  print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
1047
  self.counter += 1
1048
 
 
1049
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1050
- x = self.norm(x)
1051
  return x, out
1052
 
1053
  class Echo(nn.Module):
1054
  def __init__(self, param: Dimensions):
1055
  super().__init__()
1056
  self.param = param
1057
-
1058
  self.processor = theBridge(
1059
  vocab=param.vocab,
1060
  mels=param.mels,
@@ -1100,11 +1192,9 @@ class Echo(nn.Module):
1100
  if input_ids is not None:
1101
  enc["input_ids"] = input_ids
1102
  feature = "input_ids"
1103
- else:
1104
- feature = "spectrogram"
1105
 
1106
- out, logits = self.processor(input_ids, enc, feature)
1107
- self.out=out
1108
 
1109
  loss = None
1110
  if labels is not None:
@@ -1279,13 +1369,11 @@ def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
1279
  import librosa
1280
  mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
1281
  mel_basis = torch.from_numpy(mel_basis).float()
1282
-
1283
  sp_mel = torch.matmul(sp, mel_basis.T)
1284
  ap_mel = torch.matmul(ap, mel_basis.T)
1285
-
1286
  return sp_mel, ap_mel
1287
 
1288
- def extract_features(batch, tokenizer, waveform=False, spec=True, f0=True, f0t=True, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, **dataset_config):
1289
  dataset_config = {
1290
  "hop_length": 256,
1291
  "f_min": 150,
@@ -1307,9 +1395,9 @@ def extract_features(batch, tokenizer, waveform=False, spec=True, f0=True, f0t=T
1307
  labels = tokenizer.encode(batch["transcription"])
1308
 
1309
  wav = wavnp = f0_np = t = None
1310
- spectrogram = f0_tensor = f0t_tensor = harmonic = aperiodic = None
1311
 
1312
- if waveform or spec or f0 or f0t or harmonics:
1313
  wav = load_wave(wave_data=audio, sample_rate=sr)
1314
  wavnp = wav.numpy().astype(np.float64)
1315
 
@@ -1321,37 +1409,43 @@ def extract_features(batch, tokenizer, waveform=False, spec=True, f0=True, f0t=T
1321
  spectrogram = (log_mel + 4.0) / 4.0
1322
  spectrogram = torch.tensor(spectrogram)
1323
 
1324
- if f0 or f0t or harmonics:
1325
  f0_np, t = pw.dio(wavnp, sample_rate,
1326
- frame_period=hop_length / sample_rate * 1000, f0_ceil=500, f0_floor=71.1)
1327
  f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
 
1328
 
1329
  if f0:
1330
  f0_tensor = torch.from_numpy(f0_np)
1331
- f0_tensor = torch.where(f0_tensor == 0.0, torch.zeros_like(f0_tensor), (f0_tensor - 71.0) / (500.0 - 71.0))
1332
-
 
1333
  if f0t:
1334
  audio_duration = len(wavnp) / sample_rate
1335
  T = len(labels)
1336
  tok_dur_sec = audio_duration / T
1337
- token_starts = np.arange(T) * tok_dur_sec
1338
  token_ends = token_starts + tok_dur_sec
1339
- start_idx = np.searchsorted(t, token_starts, side="left")
1340
- end_idx = np.searchsorted(t, token_ends, side="right")
1341
- pitch_tok = np.zeros(T, dtype=np.float32)
1342
  for i in range(T):
1343
  lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
1344
  segment = f0_np[lo:hi]
1345
  if mode == "mean":
1346
  pitch_tok[i] = segment.mean()
1347
  elif mode == "median":
1348
- pitch_tok[i] = np.median(segment)
1349
  else:
1350
  pitch_tok[i] = segment[-1]
1351
  pitch_tok[pitch_tok < 100.0] = 0.0
1352
  bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
1353
  f0t_tensor = torch.from_numpy(np.concatenate([[bos_pitch], pitch_tok]))
1354
- f0t_tensor = torch.where(f0t_tensor == 0.0, torch.zeros_like(f0t_tensor), (f0t_tensor - 71.0) / (500.0 - 71.0))
 
 
 
 
1355
 
1356
  if harmonics:
1357
  spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
@@ -1364,8 +1458,8 @@ def extract_features(batch, tokenizer, waveform=False, spec=True, f0=True, f0t=T
1364
  aperiodic = torch.where(aperiodic == 0.0, torch.zeros_like(aperiodic), aperiodic / 1.0)
1365
 
1366
  if debug:
1367
- print(f"['f0']: {f0_tensor.shape if f0_tensor is not None else None}")
1368
- print(f"['f0t']: {f0t_tensor.shape if f0t_tensor is not None else None}")
1369
  print(f"['harmonic']: {harmonic.shape if harmonic is not None else None}")
1370
  print(f"['aperiodic']: {aperiodic.shape if aperiodic is not None else None}")
1371
  print(f"['spectrogram']: {spectrogram.shape if spectrogram is not None else None}")
@@ -1377,6 +1471,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=True, f0=True, f0t=T
1377
  "spectrogram": spectrogram if spec else None,
1378
  "f0": f0_tensor if f0 else None,
1379
  "f0t": f0t_tensor if f0t else None,
 
1380
  "harmonic": harmonic if harmonics else None,
1381
  "aperiodic": aperiodic if harmonics else None,
1382
  "labels": labels,
@@ -1387,10 +1482,11 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
1387
  if sanity_check:
1388
  test = load_dataset(
1389
  "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True
1390
- ).cast_column("audio", Audio(sampling_rate=sample_rate)).take(1)
1391
  dataset = test.map(
1392
  lambda x: extract_features(x, tokenizer, **dataset_config),
1393
  remove_columns=test.column_names)
 
1394
  train_dataset = dataset
1395
  test_dataset = dataset
1396
  return train_dataset, test_dataset
@@ -1412,8 +1508,10 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
1412
  len(x["audio"]["array"]) > 0 and
1413
  len(x["audio"]["array"]) < 2048 * 160)
1414
 
1415
- raw_train = load_dataset("google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming)
1416
- raw_test = load_dataset("google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming)
 
 
1417
 
1418
  raw_train = raw_train.filter(filter_func)
1419
  raw_test = raw_test.filter(filter_func)
@@ -1428,8 +1526,8 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
1428
  lambda x: extract_features(x, tokenizer, **dataset_config),
1429
  remove_columns=raw_test.column_names)
1430
 
1431
- train_dataset.save_to_disk(cache_file_train) if sanity_check or streaming is False else None
1432
- test_dataset.save_to_disk(cache_file_test) if sanity_check or streaming is False else None
1433
  return train_dataset, test_dataset
1434
 
1435
  @dataclass
@@ -1520,10 +1618,13 @@ def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
1520
  return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids]
1521
 
1522
  def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, optimizer=None, scheduler=None):
 
1523
  label_ids = pred.label_ids
1524
  pred_ids = pred.predictions[0]
1525
- label_ids = clean_batch(label_ids)
1526
- pred_ids = clean_batch(pred_ids)
 
 
1527
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1528
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1529
 
@@ -1542,12 +1643,16 @@ def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samp
1542
  else:
1543
  trainable_params = 0.0
1544
  efficiency_score = 0.0
1545
- return { "wer": float(wer), "efficiency_score": float(efficiency_score)}
 
 
 
1546
 
1547
  def preprocess_logits_for_metrics(logits, labels):
1548
  pred_ids = torch.argmax(logits, dim=-1)
1549
  labels = torch.where(labels == -100, 0, labels)
1550
  pred_ids = torch.where(pred_ids == -100, 0, pred_ids)
 
1551
  return pred_ids, labels
1552
 
1553
  def main():
@@ -1558,7 +1663,7 @@ def main():
1558
  train_dataset, test_dataset = prepare_datasets(
1559
  tokenizer,
1560
  token,
1561
- sanity_check=False,
1562
 
1563
  )
1564
 
@@ -1571,7 +1676,7 @@ def main():
1571
  layer=4,
1572
  act="swish",
1573
  debug={"radius", "encoder"},
1574
- features = ["spectrogram"],
1575
  )
1576
 
1577
  model = Echo(param).to('cuda')
@@ -1601,7 +1706,7 @@ def main():
1601
  )
1602
  from functools import partial
1603
  metrics_fn = partial(compute_metrics,
1604
- print_pred=True,
1605
  num_samples=2,
1606
  tokenizer=tokenizer, model=model)
1607
 
@@ -1620,10 +1725,27 @@ def main():
1620
  compute_metrics=metrics_fn,
1621
  optimizers=(optimizer, scheduler)
1622
  )
1623
- print(tokenizer.pad_token_id, tokenizer.bos_token_id, tokenizer.eos_token_id)
1624
  model.init_weights()
1625
  trainer.train()
1626
 
1627
  if __name__ == "__main__":
1628
  main()
1629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  def tox():
274
  return {"device": get_device(), "dtype": get_dtype()}
275
 
276
+
277
+
278
  class Sinusoids(nn.Module):
279
  def __init__(self, length, channels, max_tscale=10000):
280
  super().__init__()
 
296
  scaled_t = torch.arange(length)[:, None] * inv_tscales[None, :]
297
  pos1 = torch.sin(scaled_t)
298
  pos2 = torch.cos(scaled_t)
299
+ positions = torch.cat([pos1, pos2], dim=1)
300
  return nn.Parameter(positions.clone())
301
 
302
+ def accumulate_phase_mod(f0, t_frame, phi0=0.0):
303
+ omega = 2 * torch.pi * f0
304
+ dphi = omega * t_frame
305
+ phi = torch.cumsum(dphi, dim=0) + phi0
306
+ phi = torch.remainder(phi, 2 * torch.pi)
307
+ return phi
308
+
309
  class rotary(nn.Module):
310
+ def __init__(self, dims, head, max_ctx=1500, radii=True, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None, relative=False, freq_bins=None):
311
  super(rotary, self).__init__()
312
  self.use_pbias = use_pbias
313
  self.dims = dims
314
  self.head = head
315
  self.head_dim = dims // head
316
  self.radii = radii
 
317
  self.debug = debug
318
  self.counter = 0
319
+
320
+ self.axial = axial
321
+ if axial and spec_shape is not None:
 
322
  time_frames, freq_bins = spec_shape
323
  self.time_frames = time_frames
324
  self.freq_bins = freq_bins
 
334
  self.theta = nn.Parameter(theta, requires_grad=True)
335
  self.theta_values = []
336
 
337
+ self.relative = relative
338
  self.freq_bins = freq_bins
339
  self.true2d_dim = (dims // head) // 2
340
  self.omega_t = nn.Parameter(torch.randn(self.true2d_dim))
341
  self.omega_f = nn.Parameter(torch.randn(self.true2d_dim))
342
 
343
+ def axial(self, seq_len):
344
  if not self.use_2d_axial:
345
  return None
346
  time_frames = self.time_frames
 
369
  f0_norm.unsqueeze(1)))
370
  return f0_sim.unsqueeze(0).unsqueeze(0)
371
 
372
+ def accumulate_phase_mod(self, f0, t_frame, phi0=0.0):
373
+ omega = 2 * torch.pi * f0
374
+ dphi = omega * t_frame
375
+ phi = torch.cumsum(dphi, dim=0) + phi0
376
+ phi = torch.remainder(phi, 2 * torch.pi)
377
+ return phi
378
+
379
  def theta_freqs(self, theta):
380
  if theta.dim() == 0:
381
  theta = theta.unsqueeze(0)
382
  freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
383
  torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
384
+ self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
385
  return freq
386
 
387
  def _apply_radii(self, freqs, f0, ctx):
 
404
  f0 = f0.squeeze(0)
405
  if f0t is not None and f0t.dim() == 2:
406
  f0t = f0t.squeeze(0)
 
407
  if f0 is not None and f0.shape[0] == ctx:
408
  return f0
409
  elif f0t is not None and f0t.shape[0] == ctx:
 
413
 
414
  def forward(self, x=None, enc=None, layer=None, feature=None) -> Tensor:
415
  ctx = x
416
+ if self.axial and feature == "spectrogram":
417
  freqs_2d = self.axial_freqs(ctx)
418
  if freqs_2d is not None:
419
  return freqs_2d.unsqueeze(0)
 
421
  f0 = enc.get("f0") if enc is not None else None
422
  f0t = enc.get("f0t") if enc is not None else None
423
  f0 = self.check_f0(f0, f0t, ctx)
424
+
425
  theta = f0 + self.theta if f0 is not None else self.theta
426
+
427
+ theta = f0
428
  freqs = self.theta_freqs(theta)
429
+
430
  t = torch.arange(ctx, device=device, dtype=dtype)
431
  freqs = t[:, None] * freqs
432
  freqs, radius = self._apply_radii(freqs, f0, ctx)
 
450
  x1 = x1.view(orig_shape)
451
  return torch.cat([x1.type_as(x), x2], dim=-1)
452
 
453
+
454
  # @staticmethod
455
  # def apply_rotary(x, freqs):
456
  # # x: [batch, head, seq, head_dim]
 
515
  # k_rot = to_real(k_rot)
516
  # return q_rot, k_rot
517
 
518
+
519
  def parallel_slice(self, q, k, v, mask=None):
520
  batch, head, ctx, dims = q.shape
521
  head_dim = self.head_dim
522
  batch, ctx, dims = q.shape
523
  ctx_len = k.shape[1]
524
  head = dims // head_dim
 
525
  scores = torch.zeros(batch, head, ctx, ctx_len, device=q.device)
 
526
  for h in range(head):
527
  start_idx = h * head_dim
528
  end_idx = start_idx + head_dim
529
  q_h = q[:, :, start_idx:end_idx]
530
  k_h = k[:, :, start_idx:end_idx]
 
531
  scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim)
 
532
  if mask is not None:
533
  scores = scores + mask.unsqueeze(0).unsqueeze(0)
 
534
  attn_weights = F.softmax(scores, dim=-1)
 
535
  output = torch.zeros_like(q)
536
  for h in range(head):
537
  start_idx = h * head_dim
 
540
  output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h)
541
  return output
542
 
543
+ class curiosity(nn.Module):
544
+ def __init__(self, d, h, bias=True):
545
+ super().__init__()
546
+ self.h = h
547
+ self.dh = d // h
548
+ self.qkv = nn.Linear(d, d * 3, bias=bias)
549
+ self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
550
+ self.o = nn.Linear(d, d, bias=bias)
551
+ self.g = nn.Parameter(torch.zeros(h))
552
+
553
+ def split(self, x):
554
+ b, t, _ = x.shape
555
+ return x.view(b, t, self.h, self.dh).transpose(1, 2)
556
+
557
+ def merge(self, x):
558
+ b, h, t, dh = x.shape
559
+ return x.transpose(1, 2).contiguous().view(b, t, h * dh)
560
+
561
+ def forward(self, x, xa, mask=None):
562
+ q, k, v = self.qkv(x).chunk(3, -1)
563
+ qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
564
+ q, k, v = map(self.split, (q, k, v))
565
+ qa, ka, va = map(self.split, (qa, ka, va))
566
+ dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
567
+ dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
568
+ if mask is not None: dots = dots.masked_fill(mask, -9e15)
569
+ p = dots.softmax(-1)
570
+ pa = dots_aux.softmax(-1)
571
+ h_main = p @ v
572
+ h_aux = pa @ va
573
+ g = torch.sigmoid(self.g).view(1, -1, 1, 1)
574
+ out = self.merge(h_main * (1 - g) + h_aux * g)
575
+ return self.o(out)
576
+
577
+ class OneShot(nn.Module):
578
+ def __init__(self, dims: int, head: int, scale: float = 0.3):
579
+ super().__init__()
580
+ self.head = head
581
+ self.hdim = dims // head
582
+ self.scale = scale
583
+ self.q_proj = Linear(dims, dims)
584
+ self.k_proj = Linear(dims, dims)
585
+
586
+ def forward(self, x: Tensor, guide: Tensor) -> Tensor | None:
587
+ B, Q, _ = x.shape
588
+ K = guide.size(1)
589
+ q = self.q_proj(x ).view(B, Q, self.head, self.hdim).transpose(1,2)
590
+ k = self.k_proj(guide).view(B, K, self.head, self.hdim).transpose(1,2)
591
+ bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.hdim)
592
+ return bias
593
+
594
  class MultiheadA(nn.Module):
595
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
596
+ zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], use_pbias=False, relative=False, freq_bins=None, radii=True, axial=False, spec_shape=None, rbf=False):
597
 
598
  super(MultiheadA, self).__init__()
599
  self.dims = dims
 
602
  self.debug = debug
603
  self.counter = 0
604
  self.use_pbias = use_pbias
605
+ self.relative = relative
606
  self.freq_bins = freq_bins
607
  self.rbf = rbf
608
 
 
615
  self.rotary_emb = rotary_emb
616
  self.minz = minz
617
  self.maxz = maxz
618
+ self.zero_val = zero_val
 
619
  self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
620
 
621
  if rotary_emb:
 
624
  head=head,
625
  debug=debug,
626
  radii=radii,
627
+ relative=relative,
628
  freq_bins=freq_bins,
629
  )
630
  else:
 
670
  q2 = q.shape[2]
671
  k2 = k.shape[2]
672
 
673
+ if self.relative and feature == "spectrogram":
674
  seq_len = q2
675
  freq_bins = self.freq_bins
676
  idxs = torch.arange(seq_len, device=q.device)
677
  t_idx = idxs // freq_bins
678
  f_idx = idxs % freq_bins
679
+ angle = self.rope.relative(t_idx, f_idx, t_idx, f_idx)
680
+ q_rot, k_rot = self.rope.d2rotary(q, k, angle)
681
  scale = (self.dims // self.head) ** -0.25
682
  qk = (q_rot * scale * k_rot * scale).sum(-1)
683
  w = F.softmax(qk, dim=-1).to(q.dtype)
 
685
  wv = wv.permute(0, 2, 1, 3).flatten(start_dim=2)
686
  return self.o(wv), qk
687
  else:
688
+ q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer, feature=feature)))
689
+ k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer, feature=feature)))
690
  else:
691
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
692
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
720
  self.counter += 1
721
  return self.o(wv), qk
722
 
723
+
724
+
725
  class t_gate(nn.Module):
726
  def __init__(self, dims, num_types=4, enabled=True):
727
  super().__init__()
 
788
  return self.integ(comb)
789
 
790
  class mlp_gate(nn.Module):
791
+ def __init__(self, dims, head, enabled=True, one_shot=False):
792
  super().__init__()
793
  self.enabled = enabled
794
  if enabled:
795
  self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
796
+
797
+ if one_shot:
798
+ self.one_shot = OneShot(dims, head)
799
 
800
+ def forward(self, x, xa=None):
801
  if not self.enabled:
802
  return None
803
+ if self.one_shot:
804
+ x = self.one_shot(x, xa)
805
  return self.gate(x)
806
 
807
  class Residual(nn.Module):
808
  _seen = set()
809
  def __init__(self, ctx, dims, head, act, debug: List[str] = [],
810
+ tgate=True, mgate=False, cgate=False, mem_size=512, features=None, one_shot=False):
811
  super().__init__()
812
 
813
  self.dims = dims
 
818
  self.debug = debug
819
  self.counter = 0
820
  self.dropout = 0.01
821
+ self.one_shot = one_shot
822
 
823
  self.blend = nn.Parameter(torch.tensor(0.5))
824
  act_fn = get_activation(act)
825
  self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
826
+ self.one_shot = OneShot(dims, head) if one_shot else None
827
 
828
  if not any([tgate, mgate, cgate]):
829
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
 
836
  self.t_gate = t_gate(dims=dims, num_types=4*2, enabled=tgate)
837
  self.m_gate = m_gate(dims=dims, mem_size=mem_size, enabled=mgate)
838
  self.c_gate = c_gate(dims=dims, enabled=cgate)
839
+ self.mlp_gate = mlp_gate(dims=dims, head=head, enabled=not any([tgate, mgate, cgate]), one_shot=True)
840
 
841
  self.lna = RMSNorm(dims)
842
  self.lnb = RMSNorm(dims)
843
  self.lnc = RMSNorm(dims)
844
 
845
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature=None) -> Tensor:
846
+
847
  b = torch.sigmoid(self.blend)
848
+ ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer, feature=feature)[0]
849
  bx = b * ax + (1 - b) * x
850
  cx = self.lnb(bx)
851
  dx = self.mlp(cx)
 
889
 
890
  self.norm = RMSNorm(dims)
891
 
892
+ def apply_rope_to_features(self, x, layer="FEncoder", feature="spectrogram"):
893
  batch, ctx, dims = x.shape
894
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
895
  if feature == "spectrogram" and self.rope is not None:
 
900
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
901
  return x
902
 
903
+ def forward(self, x, enc=None, feature="spectrogram", layer="FEncoder"):
904
  x = self.encoder(x).permute(0, 2, 1)
905
  if self.use_rope:
906
  x = self.apply_rope_to_features(x, layer=layer, feature=feature)
 
938
  self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
939
  self.norm = RMSNorm(dims)
940
 
941
+ def apply_rope_to_features(self, x, layer="WEncoder", feature="waveform"):
942
  if not self.use_rope or self.rope is None:
943
  return x
944
  batch, ctx, dims = x.shape
 
948
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
949
  return x
950
 
951
+ def forward(self, x, enc=None, feature="waveform", layer="WEncoder"):
952
  x = self.downsample(x)
953
  x = self.encoder(x)
954
  x = x.permute(0, 2, 1)
 
960
  return self.norm(x)
961
 
962
  class PEncoder(nn.Module):
963
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, one_shot=False):
964
  super().__init__()
965
 
966
  self.head = head
 
968
  self.dropout = 0.01
969
  self.use_rope = use_rope
970
  self.dims = dims
971
+ self.one_shot = one_shot
972
  act_fn = get_activation(act)
973
+
974
  self.encoder = nn.Sequential(
975
+ Conv1d(input_dims, dims, kernel_size=kernel_size, stride=1, padding=kernel_size//2), act_fn,
976
+ Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
977
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
978
+
979
 
980
  if use_rope:
981
  self.rope = rotary(
 
984
  debug=[])
985
  else:
986
  self.rope = None
987
+ self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
988
  self.norm = RMSNorm(dims)
989
 
990
+ def apply_rope_to_features(self, x, layer="PEncoder", feature="pitch"):
991
  if not self.use_rope or self.rope is None:
992
  return x
993
  batch, ctx, dims = x.shape
 
997
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
998
  return x
999
 
1000
+ def forward(self, xa, enc=None, layer="PEncoder", feature="pitch"):
1001
+ xa = self.encoder(xa).permute(0, 2, 1)
1002
  if self.use_rope:
1003
+ xa = self.apply_rope_to_features(xa, layer=layer)
1004
  else:
1005
+ xa = xa + self.sinusoid_pos(xa.shape[1], xa.shape[-1], 10000).to(xa.device, xa.dtype)
1006
+ if self.one_shot:
1007
+ x = enc["input_ids"]
1008
+ xa = self.one_shot(x, xa)
1009
+ xa = nn.functional.dropout(xa, p=self.dropout, training=self.training)
1010
+ return self.norm(xa)
1011
+
1012
+ def win_mask(text_ctx, aud_ctx):
1013
+ mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device), diagonal=0)
1014
+ audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device))
1015
+ full_mask = torch.cat([mask, audio_mask], dim=-1)
1016
+ return full_mask.unsqueeze(0).unsqueeze(0)
1017
+
1018
+ def causal_mask(seq_len, device):
1019
+ return torch.tril(torch.ones(seq_len, seq_len, device=device), diagonal=0).unsqueeze(0).unsqueeze(0)
1020
 
1021
  class theBridge(nn.Module):
1022
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
1023
  debug: List[str], features: List[str], act: str = "gelu"):
1024
  super(theBridge, self).__init__()
1025
 
1026
+ self.ctx = ctx
1027
+ self.dims = dims
1028
+ self.head = head
1029
+ self.head_dim = dims // head
1030
  self.debug = debug
1031
  self.counter = 0
1032
  self.dropout = 0.01
 
1037
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
1038
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
1039
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
1040
+ self.ln_dec = RMSNorm(dims)
1041
+ self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
1042
 
1043
  with torch.no_grad():
1044
  self.token.weight[0].zero_()
1045
 
1046
  self.block = nn.ModuleList([
1047
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
1048
  for _ in range(layer)])
1049
 
1050
  self.cross_attn = nn.ModuleList([
1051
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
1052
  for _ in range(layer)])
1053
 
1054
  self.cross_modal = nn.ModuleList([
1055
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
1056
  for _ in range(layer)])
1057
 
1058
+ self.register_buffer("mask", causal_mask(ctx, device), persistent=False)
1059
+ self.register_buffer("mask_win", win_mask(ctx, ctx), persistent=False)
1060
 
1061
  act_fn = get_activation(act)
1062
  if features == ["spectrogram", "waveform", "pitch"]:
 
1064
  else:
1065
  cgate = False
1066
 
1067
+ self.blockA = nn.ModuleDict({
1068
  "spectrogram": nn.ModuleList(
1069
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1070
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "spectrogram" in features else None),
 
1072
  [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
1073
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "waveform" in features else None),
1074
  "pitch": nn.ModuleList(
1075
+ [PEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=3, act=act, one_shot=False)] +
1076
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None),
1077
  "envelope": nn.ModuleList(
1078
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
 
1081
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1082
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None)})
1083
 
 
1084
 
1085
+
1086
+ def forward(self, x, enc, feature, layer='theBridge') -> Tensor:
1087
+ f0 = enc.get("f0")
1088
  out = {}
1089
  out.update(enc)
1090
  enc = dict_to(enc, device, dtype)
1091
  _text_len = x.shape[1]
 
1092
  x = self.token(x) + self.positional[:x.shape[1]]
1093
 
1094
  for f in enc:
1095
  if f in self.features:
1096
  xa = enc[f]
1097
+ for block in self.blockA[f]:
1098
+ xa = block(xa, enc=out, feature=feature, layer="enc_self")
1099
+ xa = xa + self.sinusoid_pos(xa.shape[1], xa.shape[-1], 10000).to(xa.device, xa.dtype)
1100
  out[f] = xa
1101
 
1102
  for block in self.block:
1103
+ x = block(x, xa=None, mask=self.mask, enc=enc, feature=feature, layer="dec_self")
1104
+ out["input_ids"] = x
1105
+
1106
  if f in self.features:
1107
+ out = block(x, xa=xa, mask=self.mask, enc=enc, feature=feature, layer="dec_cross")
 
1108
  if self.sequential:
1109
  x = out
1110
  else:
 
1115
 
1116
  for block in self.cross_attn:
1117
  if f in self.features:
1118
+ x = block(x, xa=xa, mask=self.mask, enc=enc, feature=feature, layer="dec_cross")
1119
+ xa = block(xa, xa=x, mask=self.mask, enc=enc, feature=feature, layer="enc_cross")
1120
+ out = block(x, xa=xa, mask=self.mask, enc=enc, feature=feature, layer="dec_cross")
1121
  if self.sequential:
1122
  x = out
1123
  else:
 
1129
  for block in self.cross_modal:
1130
  if f in self.features:
1131
  xcat = torch.cat([x, xa], dim=1)
1132
+ x = block(xcat, xa=None, mask=self.mask, enc=enc, feature=feature, layer="cross_modal")
1133
  x = x[:, :_text_len]
1134
  out[f] = x
1135
+
1136
  if self.counter < 1 and "encoder" in self.debug:
1137
  shapes = {k: v.shape for k, v in enc.items()}
1138
  print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
1139
  self.counter += 1
1140
 
1141
+ x = self.ln_dec(x)
1142
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
 
1143
  return x, out
1144
 
1145
  class Echo(nn.Module):
1146
  def __init__(self, param: Dimensions):
1147
  super().__init__()
1148
  self.param = param
1149
+
1150
  self.processor = theBridge(
1151
  vocab=param.vocab,
1152
  mels=param.mels,
 
1192
  if input_ids is not None:
1193
  enc["input_ids"] = input_ids
1194
  feature = "input_ids"
 
 
1195
 
1196
+ logits, out = self.processor(input_ids, enc, feature)
1197
+ self.out = out
1198
 
1199
  loss = None
1200
  if labels is not None:
 
1369
  import librosa
1370
  mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
1371
  mel_basis = torch.from_numpy(mel_basis).float()
 
1372
  sp_mel = torch.matmul(sp, mel_basis.T)
1373
  ap_mel = torch.matmul(ap, mel_basis.T)
 
1374
  return sp_mel, ap_mel
1375
 
1376
+ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=True, f0t=True, pitch=True, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, **dataset_config):
1377
  dataset_config = {
1378
  "hop_length": 256,
1379
  "f_min": 150,
 
1395
  labels = tokenizer.encode(batch["transcription"])
1396
 
1397
  wav = wavnp = f0_np = t = None
1398
+ spectrogram = f0_tensor = f0t_tensor = harmonic = aperiodic = p_tensor = None
1399
 
1400
+ if waveform or spec or f0 or f0t or harmonics or pitch:
1401
  wav = load_wave(wave_data=audio, sample_rate=sr)
1402
  wavnp = wav.numpy().astype(np.float64)
1403
 
 
1409
  spectrogram = (log_mel + 4.0) / 4.0
1410
  spectrogram = torch.tensor(spectrogram)
1411
 
1412
+ if f0 or f0t or harmonics or pitch:
1413
  f0_np, t = pw.dio(wavnp, sample_rate,
1414
+ frame_period=hop_length / sample_rate * 1000)
1415
  f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
1416
+ t = torch.tensor(t)
1417
 
1418
  if f0:
1419
  f0_tensor = torch.from_numpy(f0_np)
1420
+ t_frame = torch.mean(t[1:] - t[:-1])
1421
+ f0_tensor = accumulate_phase_mod(f0_tensor, t_frame)
1422
+
1423
  if f0t:
1424
  audio_duration = len(wavnp) / sample_rate
1425
  T = len(labels)
1426
  tok_dur_sec = audio_duration / T
1427
+ token_starts = torch.arange(T) * tok_dur_sec
1428
  token_ends = token_starts + tok_dur_sec
1429
+ start_idx = torch.searchsorted(t, token_starts, side="left")
1430
+ end_idx = torch.searchsorted(t, token_ends, side="right")
1431
+ pitch_tok = torch.zeros(T, dtype=torch.float32)
1432
  for i in range(T):
1433
  lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
1434
  segment = f0_np[lo:hi]
1435
  if mode == "mean":
1436
  pitch_tok[i] = segment.mean()
1437
  elif mode == "median":
1438
+ pitch_tok[i] = torch.median(segment)
1439
  else:
1440
  pitch_tok[i] = segment[-1]
1441
  pitch_tok[pitch_tok < 100.0] = 0.0
1442
  bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
1443
  f0t_tensor = torch.from_numpy(np.concatenate([[bos_pitch], pitch_tok]))
1444
+ f0t_tensor = accumulate_phase_mod(f0t_tensor, t_frame)
1445
+
1446
+ if pitch:
1447
+ p_tensor = torch.from_numpy(f0_np)
1448
+ p_tensor = p_tensor.unsqueeze(0)
1449
 
1450
  if harmonics:
1451
  spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
 
1458
  aperiodic = torch.where(aperiodic == 0.0, torch.zeros_like(aperiodic), aperiodic / 1.0)
1459
 
1460
  if debug:
1461
+ print(f"['f0']: {f0_tensor.shape if f0 is not None else None}")
1462
+ print(f"['f0t']: {f0t_tensor.shape if f0t is not None else None}")
1463
  print(f"['harmonic']: {harmonic.shape if harmonic is not None else None}")
1464
  print(f"['aperiodic']: {aperiodic.shape if aperiodic is not None else None}")
1465
  print(f"['spectrogram']: {spectrogram.shape if spectrogram is not None else None}")
 
1471
  "spectrogram": spectrogram if spec else None,
1472
  "f0": f0_tensor if f0 else None,
1473
  "f0t": f0t_tensor if f0t else None,
1474
+ "pitch": p_tensor if pitch else None,
1475
  "harmonic": harmonic if harmonics else None,
1476
  "aperiodic": aperiodic if harmonics else None,
1477
  "labels": labels,
 
1482
  if sanity_check:
1483
  test = load_dataset(
1484
  "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True
1485
+ ).cast_column("audio", Audio(sampling_rate=sample_rate)).take(10)
1486
  dataset = test.map(
1487
  lambda x: extract_features(x, tokenizer, **dataset_config),
1488
  remove_columns=test.column_names)
1489
+
1490
  train_dataset = dataset
1491
  test_dataset = dataset
1492
  return train_dataset, test_dataset
 
1508
  len(x["audio"]["array"]) > 0 and
1509
  len(x["audio"]["array"]) < 2048 * 160)
1510
 
1511
+ raw_train = load_dataset(
1512
+ "google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000)
1513
+ raw_test = load_dataset(
1514
+ "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).take(100)
1515
 
1516
  raw_train = raw_train.filter(filter_func)
1517
  raw_test = raw_test.filter(filter_func)
 
1526
  lambda x: extract_features(x, tokenizer, **dataset_config),
1527
  remove_columns=raw_test.column_names)
1528
 
1529
+ train_dataset.save_to_disk(cache_file_train) if sanity_check is False else None
1530
+ test_dataset.save_to_disk(cache_file_test) if sanity_check is False else None
1531
  return train_dataset, test_dataset
1532
 
1533
  @dataclass
 
1618
  return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids]
1619
 
1620
  def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, optimizer=None, scheduler=None):
1621
+
1622
  label_ids = pred.label_ids
1623
  pred_ids = pred.predictions[0]
1624
+
1625
+ label_ids = clean_batch(label_ids, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
1626
+ pred_ids = clean_batch(pred_ids, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
1627
+
1628
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1629
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1630
 
 
1643
  else:
1644
  trainable_params = 0.0
1645
  efficiency_score = 0.0
1646
+ return {
1647
+ "wer": float(wer),
1648
+ "efficiency_score": float(efficiency_score),
1649
+ }
1650
 
1651
  def preprocess_logits_for_metrics(logits, labels):
1652
  pred_ids = torch.argmax(logits, dim=-1)
1653
  labels = torch.where(labels == -100, 0, labels)
1654
  pred_ids = torch.where(pred_ids == -100, 0, pred_ids)
1655
+
1656
  return pred_ids, labels
1657
 
1658
  def main():
 
1663
  train_dataset, test_dataset = prepare_datasets(
1664
  tokenizer,
1665
  token,
1666
+ sanity_check=True,
1667
 
1668
  )
1669
 
 
1676
  layer=4,
1677
  act="swish",
1678
  debug={"radius", "encoder"},
1679
+ features = ["pitch"],
1680
  )
1681
 
1682
  model = Echo(param).to('cuda')
 
1706
  )
1707
  from functools import partial
1708
  metrics_fn = partial(compute_metrics,
1709
+ print_pred=False,
1710
  num_samples=2,
1711
  tokenizer=tokenizer, model=model)
1712
 
 
1725
  compute_metrics=metrics_fn,
1726
  optimizers=(optimizer, scheduler)
1727
  )
1728
+
1729
  model.init_weights()
1730
  trainer.train()
1731
 
1732
  if __name__ == "__main__":
1733
  main()
1734
 
1735
+
1736
+
1737
+
1738
+
1739
+
1740
+
1741
+
1742
+
1743
+
1744
+
1745
+
1746
+
1747
+
1748
+
1749
+
1750
+
1751
+