Sin2pi commited on
Commit
09b0505
·
verified ·
1 Parent(s): 45a9c69

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +341 -222
modelA.py CHANGED
@@ -32,6 +32,13 @@ dtype = torch.float32
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
 
 
 
 
 
 
 
35
  def get_activation(act: str) -> nn.Module:
36
  """Get activation function by name."""
37
  act_map = {
@@ -266,32 +273,33 @@ def get_dtype():
266
  def tox():
267
  return {"device": get_device(), "dtype": get_dtype()}
268
 
269
- class sinus(nn.Module):
270
- def __init__(self, ctx: int, dims: int):
271
  super().__init__()
272
-
273
- position = torch.arange(start=0, end=ctx, dtype=dtype).unsqueeze(dim=1)
274
- div_term = torch.exp(input=torch.arange(start=0, end=dims, step=2, dtype=dtype) * -(math.log(10000.0) / dims))
275
- features = torch.zeros(ctx, dims)
276
- features[:, 0::2] = torch.sin(position * div_term)
277
- features[:, 1::2] = torch.cos(position* div_term)
278
- self.register_buffer('sinusoid', tensor=features)
279
- self.positional_embeddings = nn.Parameter(self.sinusoid.clone())
280
  def forward(self, positions):
281
- position_embeddings = self.positional_embeddings[positions]
282
- return position_embeddings
283
 
284
  def sinusoids(length, channels, max_tscale=10000):
285
  assert channels % 2 == 0
286
  log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
287
  inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
288
- scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
289
- return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
 
 
 
290
 
291
  class rotary(nn.Module):
292
- def __init__(self, dims, head, max_ctx=1500, radii=True, debug: List[str] = [], use_pbias=False):
293
  super(rotary, self).__init__()
294
-
295
  self.use_pbias = use_pbias
296
  self.dims = dims
297
  self.head = head
@@ -302,11 +310,43 @@ class rotary(nn.Module):
302
  self.counter = 0
303
  self.last_theta = None
304
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
306
  theta = (torch.tensor(10000, device=device, dtype=dtype))
307
  self.theta = nn.Parameter(theta, requires_grad=True)
308
  self.theta_values = []
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  def mel_scale_scalar(self, freq: float) -> float:
311
  return 1127.0 * math.log(1.0 + freq / 700.0)
312
 
@@ -328,7 +368,6 @@ class rotary(nn.Module):
328
  freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
329
  torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
330
  self.dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
331
-
332
  return freq
333
 
334
  def _apply_radii(self, freqs, f0, ctx):
@@ -347,32 +386,37 @@ class rotary(nn.Module):
347
  return torch.polar(torch.ones_like(freqs), freqs), None
348
 
349
  def check_f0(self, f0, f0t, ctx):
350
- if f0 is not None and f0.shape[1] == ctx:
 
 
 
 
 
351
  return f0
352
- elif f0t is not None and f0t.shape[1] == ctx:
353
  return f0t
354
  else:
355
  return None
356
 
357
  def forward(self, x=None, enc=None, layer=None, feature=None) -> Tensor:
358
- ctx=x
 
 
 
 
 
359
  f0 = enc.get("f0") if enc is not None else None
360
  f0t = enc.get("f0t") if enc is not None else None
361
-
362
  f0 = self.check_f0(f0, f0t, ctx)
363
- if f0 is not None:
364
- if f0.dim() == 2:
365
- f0 = f0.squeeze(0)
366
- theta = f0 + self.theta
367
- else:
368
- theta = self.theta
369
  freqs = self.theta_freqs(theta)
370
  t = torch.arange(ctx, device=device, dtype=dtype)
371
  freqs = t[:, None] * freqs
372
  freqs, radius = self._apply_radii(freqs, f0, ctx)
373
-
374
  if "radius" in self.debug and self.counter == 10:
375
- print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
 
376
  self.counter += 1
377
  return freqs.unsqueeze(0)
378
 
@@ -389,19 +433,114 @@ class rotary(nn.Module):
389
  x1 = x1.view(orig_shape)
390
  return torch.cat([x1.type_as(x), x2], dim=-1)
391
 
392
- class MultiheadA(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
- rbf = False
395
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
396
- zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
 
397
  super(MultiheadA, self).__init__()
398
-
399
  self.dims = dims
400
  self.head = head
401
  self.head_dim = dims // head
402
  self.debug = debug
403
  self.counter = 0
404
  self.use_pbias = use_pbias
 
 
 
405
 
406
  self.q = nn.Linear(dims, dims).to(device, dtype)
407
  self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
@@ -421,8 +560,10 @@ class MultiheadA(nn.Module):
421
  dims=dims,
422
  head=head,
423
  debug=debug,
424
- radii=True,
425
- )
 
 
426
  else:
427
  self.rope = None
428
 
@@ -466,8 +607,23 @@ class MultiheadA(nn.Module):
466
  q2 = q.shape[2]
467
  k2 = k.shape[2]
468
 
469
- q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
470
- k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  else:
472
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
473
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -482,6 +638,9 @@ class MultiheadA(nn.Module):
482
  if pbias is not None:
483
  qk = qk + pbias[:,:,:q2,:q2]
484
 
 
 
 
485
  token_ids = k[:, :, :, 0]
486
  zscale = torch.ones_like(token_ids)
487
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
@@ -643,31 +802,28 @@ class FEncoder(nn.Module):
643
  if use_rope:
644
  if spec_shape is not None:
645
  self.rope = rotary(
646
- dims=self.head_dim,
647
- head=self.head,
648
  use_2d_axial=True,
649
  spec_shape=spec_shape, debug=[])
650
  else:
651
  self.rope = rotary(
652
- dims=self.head_dim,
653
- head=self.head,
654
  use_2d_axial=False, debug=[])
655
  else:
656
  self.rope = None
657
- self.positional = lambda length: sinusoids(length, dims)
658
 
659
  self.norm = RMSNorm(dims)
660
- self._norm = RMSNorm(dims)
661
 
662
  def apply_rope_to_features(self, x, layer=None, feature=None):
663
- if feature in ["envelope", "phase"]:
664
- feature = "spectrogram"
665
  batch, ctx, dims = x.shape
666
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
667
- if feature_type == "spectrogram" and self.rope is not None:
668
- rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
669
  else:
670
- rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
671
  x = self.rope.apply_rotary(x, rope_freqs)
672
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
673
  return x
@@ -677,10 +833,9 @@ class FEncoder(nn.Module):
677
  if self.use_rope:
678
  x = self.apply_rope_to_features(x, layer=layer, feature=feature)
679
  else:
680
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
681
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
682
- x = self._norm(x)
683
- return x
684
 
685
  class WEncoder(nn.Module):
686
  def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
@@ -693,7 +848,6 @@ class WEncoder(nn.Module):
693
  self.dims = dims
694
 
695
  act_fn = get_activation(act)
696
-
697
  self.downsample = nn.Sequential(
698
  Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
699
  Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn,
@@ -709,7 +863,7 @@ class WEncoder(nn.Module):
709
  debug=[])
710
  else:
711
  self.rope = None
712
- self.positional = lambda length: sinusoids(length, dims)
713
  self.norm = RMSNorm(dims)
714
 
715
  def apply_rope_to_features(self, x, layer=None, feature=None):
@@ -729,7 +883,7 @@ class WEncoder(nn.Module):
729
  if self.use_rope:
730
  x = self.apply_rope_to_features(x, layer=layer)
731
  else:
732
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
733
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
734
  return self.norm(x)
735
 
@@ -757,7 +911,7 @@ class PEncoder(nn.Module):
757
  debug=[])
758
  else:
759
  self.rope = None
760
- self.positional = lambda length: sinusoids(length, dims)
761
  self.norm = RMSNorm(dims)
762
 
763
  def apply_rope_to_features(self, x, layer=None, feature=None):
@@ -775,20 +929,15 @@ class PEncoder(nn.Module):
775
  if self.use_rope:
776
  x = self.apply_rope_to_features(x, layer=layer)
777
  else:
778
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
779
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
780
- x = self.norm(x)
781
- return x
782
 
783
  class theBridge(nn.Module):
784
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
785
  debug: List[str], features: List[str], act: str = "gelu"):
786
  super(theBridge, self).__init__()
787
 
788
- self.ctx = ctx
789
- self.dims = dims
790
- self.head = head
791
- self.head_dim = dims // head
792
  self.debug = debug
793
  self.counter = 0
794
  self.dropout = 0.01
@@ -798,30 +947,26 @@ class theBridge(nn.Module):
798
 
799
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
800
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
801
- self.sinusoid = lambda length: sinusoids(length, dims)
802
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
803
- self.ln_dec = RMSNorm(dims)
804
 
805
  with torch.no_grad():
806
  self.token.weight[0].zero_()
807
 
808
  self.block = nn.ModuleList([
809
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
810
  for _ in range(layer)])
811
 
812
  self.cross_attn = nn.ModuleList([
813
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
814
  for _ in range(layer)])
815
 
816
  self.cross_modal = nn.ModuleList([
817
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
818
  for _ in range(layer)])
819
 
820
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0).unsqueeze(0).unsqueeze(0)
821
  self.register_buffer("mask", mask, persistent=False)
822
- self.register_buffer("mask_win", self.window_mask(ctx, ctx), persistent=False)
823
- self.register_buffer("mask_cat", self.modal_mask(ctx, ctx), persistent=False)
824
- self.register_buffer("mask_cross", self.cross_mask(ctx, ctx), persistent=False)
825
 
826
  act_fn = get_activation(act)
827
  if features == ["spectrogram", "waveform", "pitch"]:
@@ -846,83 +991,64 @@ class theBridge(nn.Module):
846
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
847
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None)})
848
 
849
- def window_mask(self, text_ctx, aud_ctx):
850
- mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device), diagonal=0)
851
- audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device))
852
- full_mask = torch.cat([mask, audio_mask], dim=-1)
853
- return full_mask.unsqueeze(0).unsqueeze(0)
854
-
855
- def modal_mask(self, text_len, audio_len):
856
- combined_mask = torch.ones(text_len + audio_len, text_len + audio_len, device=device)
857
- combined_mask[:text_len, :text_len] = torch.tril(torch.ones(text_len, text_len, device=device))
858
- combined_mask[:text_len, text_len:] = torch.tril(torch.ones(text_len, audio_len, device=device))
859
- return combined_mask.unsqueeze(0).unsqueeze(0)
860
-
861
- def cross_mask(self, text_len, audio_len):
862
- mask = torch.tril(torch.ones(text_len, text_len, device=device))
863
- audio_mask = torch.tril(torch.ones(text_len, audio_len, device=device))
864
- full_mask = torch.cat([mask, audio_mask], dim=-1)
865
- return full_mask.unsqueeze(0).unsqueeze(0)
866
 
867
  def forward(self, x, enc, layer='decoder', feature=None) -> Tensor:
 
 
868
  enc = dict_to(enc, device, dtype)
869
  _text_len = x.shape[1]
870
 
871
  x = self.token(x) + self.positional[:x.shape[1]]
872
- x = F.dropout(x, p=self.dropout, training=self.training)
873
-
874
  for f in enc:
875
  if f in self.features:
876
  xa = enc[f]
877
  for block in self.blocks[f]:
878
  xa = block(xa, enc=enc, layer=layer, feature=feature)
 
 
879
 
880
- for block in self.block:
881
- mask = self.mask[:x.shape[1], :x.shape[1]]
882
- x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
883
- if feature in self.features:
884
- xa = xa + self.sinusoid(xa.shape[1])
885
- mask = self.mask_win[:x.shape[1], :xa.shape[1]]
886
- out = block(x, xa=xa, mask=mask, enc=enc, layer=layer)
887
  if self.sequential:
888
  x = out
889
  else:
890
  a = torch.sigmoid(self.blend)
891
  x = a * out + (1 - a) * x
 
 
892
 
893
  for block in self.cross_attn:
894
- if feature in self.features:
895
- xa = xa + self.sinusoid(xa.shape[1])
896
- mask_x = self.cross_mask(x.shape[1], xa.shape[1])
897
- mask_xa = self.cross_mask(xa.shape[1], x.shape[1])
898
- x = block(x, xa=xa, mask=mask_x, enc=enc, layer=layer)
899
- xa = block(xa, xa=x, mask=mask_xa, enc=enc, layer=layer)
900
- out = block(x, xa=xa, mask=mask_x, enc=enc, layer=layer)
901
  if self.sequential:
902
  x = out
903
  else:
904
  a = torch.sigmoid(self.blend)
905
  x = a * out + (1 - a) * x
 
 
906
 
907
  for block in self.cross_modal:
908
- if feature in enc:
909
- xa = xa + self.sinusoid(xa.shape[1])
910
  xcat = torch.cat([x, xa], dim=1)
911
- mask = self.mask_cat(x.shape[1], xa.shape[1])
912
- x = block(xcat, xa=None, mask=mask, enc=enc, layer=layer)
913
  x = x[:, :_text_len]
914
-
915
  if self.counter < 1 and "encoder" in self.debug:
916
- s = enc.get("spectrogram")
917
- w = enc.get("waveform")
918
- p = default(enc.get("pitch"), enc.get("f0"))
919
- plot_waveform(x=s, w=w, p=p, hop_length=128)
920
  shapes = {k: v.shape for k, v in enc.items()}
921
  print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
922
  self.counter += 1
923
-
924
- x = self.ln_dec(x)
925
- return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
 
926
 
927
  class Echo(nn.Module):
928
  def __init__(self, param: Dimensions):
@@ -977,7 +1103,8 @@ class Echo(nn.Module):
977
  else:
978
  feature = "spectrogram"
979
 
980
- logits = self.processor(input_ids, enc, feature)
 
981
 
982
  loss = None
983
  if labels is not None:
@@ -1089,7 +1216,7 @@ class Echo(nn.Module):
1089
  "eos_token_id": self.eos_token_id,
1090
  })
1091
  return Config()
1092
-
1093
  def setup_tokenizer(token: str):
1094
  from tokenizers import Tokenizer
1095
  tokenizer = Tokenizer.from_file("./tokenizer.json")
@@ -1117,7 +1244,6 @@ def setup_tokenizer(token: str):
1117
  results.append(tokenizer.decode(ids))
1118
  return results
1119
 
1120
-
1121
  def save_pretrained(save_dir):
1122
  os.makedirs(save_dir, exist_ok=True)
1123
  tokenizer.save(f"{save_dir}/tokenizer.json")
@@ -1159,8 +1285,7 @@ def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
1159
 
1160
  return sp_mel, ap_mel
1161
 
1162
- def extract_features(batch, tokenizer, waveform=False, spec=True, f0=False, f0t=False, sample_rate=16000, hop_length=256, mode="mean", debug=True, **dataset_config):
1163
-
1164
  dataset_config = {
1165
  "hop_length": 256,
1166
  "f_min": 150,
@@ -1175,113 +1300,100 @@ def extract_features(batch, tokenizer, waveform=False, spec=True, f0=False, f0t=
1175
  "mel_scale": "htk",
1176
  "norm": None,
1177
  "normalized": False,
1178
- }
1179
 
1180
  audio = batch["audio"]
1181
  sr = audio["sampling_rate"]
1182
- wave = load_wave(wave_data=audio, sample_rate=sr)
1183
  labels = tokenizer.encode(batch["transcription"])
1184
 
1185
- if waveform:
 
 
 
1186
  wav = load_wave(wave_data=audio, sample_rate=sr)
1187
- else:
1188
- wav = None
1189
 
1190
  if spec:
1191
- transform = torchaudio.transforms.MelSpectrogram( **dataset_config)
1192
- mel_spectrogram = transform(wave)
1193
  log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1194
  log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1195
- spec = (log_mel + 4.0) / 4.0
1196
- spec = torch.tensor(spec)
1197
- else:
1198
- spec = None
1199
 
1200
- if f0:
1201
- wavnp = wave.numpy().astype(np.float64)
1202
  f0_np, t = pw.dio(wavnp, sample_rate,
1203
- frame_period = hop_length / sample_rate * 1000)
1204
  f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
1205
- f0 = torch.from_numpy(f0_np)
1206
-
1207
- if f0t:
1208
- audio_duration = len(wavnp) / sample_rate
1209
- T = len(labels)
1210
- tok_dur_sec = audio_duration / T
1211
- token_starts = np.arange(T) * tok_dur_sec
1212
- token_ends = token_starts + tok_dur_sec
1213
- start_idx = np.searchsorted(t, token_starts, side="left")
1214
- end_idx = np.searchsorted(t, token_ends, side="right")
1215
- pitch_tok = np.zeros(T, dtype=np.float32)
1216
- for i in range(T):
1217
- lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
1218
- segment = f0_np[lo:hi]
1219
- pitch_tok[i] = segment.mean() if mode=="mean" else (np.median(segment) if mode=="median" else segment[-1])
1220
- pitch_tok[pitch_tok < 100.0] = 0.0
1221
-
1222
- bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
1223
- f0t = torch.from_numpy(np.concatenate([[bos_pitch], pitch_tok]))
1224
- f0t = torch.from_numpy(pitch_tok)
1225
- f0 = torch.from_numpy(f0_np)
1226
-
1227
- spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
1228
- apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
1229
- sp = torch.from_numpy(spnp)
1230
- ap = torch.from_numpy(apnp)
1231
- sp = sp[:, :128].contiguous().T
1232
- ap = ap[:, :128].contiguous().T
1233
- f0t = torch.where(f0t == 0.0, torch.zeros_like(f0t), (f0t - 71.0) / (400.0 - 71.0))
1234
- sp = torch.where(sp == 0.0, torch.zeros_like(sp), sp / 1.0)
1235
- ap= torch.where(ap == 0.0, torch.zeros_like(ap), ap / 1.0)
1236
 
1237
- else:
1238
- f0t = None
1239
- sp = None
1240
- ap = None
1241
- t = None
1242
- token_starts = None
1243
- else:
1244
- f0t = None
1245
- f0 = None
1246
- sp = None
1247
- ap = None
1248
- t = None
1249
- token_starts = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1250
 
1251
  if debug:
1252
- print(f"['f0']: {f0t if f0t is not None else None}")
1253
- print(f"['f0']: {f0.shape if f0 is not None else None}")
1254
- print(f"['f0t']: {f0t.shape if f0t is not None else None}")
1255
- print(f"['harmonic']: {sp.shape if sp is not None else None}")
1256
- print(f"['aperiodic']: {ap.shape if ap is not None else None}")
1257
- print(f"['spec']: {spec.shape if spec is not None else None}")
1258
- print(f"['wav']: {wav.shape if wav is not None else None}")
1259
 
1260
  return {
1261
- "f0": f0,
1262
- "f0t": f0t,
1263
- "pitch": f0,
1264
- "harmonic": sp,
1265
- "aperiodic": ap,
 
1266
  "labels": labels,
1267
- "waveform": wav,
1268
- "spectrogram": spec,
1269
-
1270
  }
1271
 
1272
- def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
1273
 
1274
  if sanity_check:
1275
  test = load_dataset(
1276
- "google/fleurs", "en_us", token=token, split="test[:10]", trust_remote_code=True
1277
- ).cast_column("audio", Audio(sample_rate=sample_rate))
1278
-
1279
  dataset = test.map(
1280
  lambda x: extract_features(x, tokenizer, **dataset_config),
1281
  remove_columns=test.column_names)
1282
- dataset = dataset(remove_columns=["audio", "transcription"]).with_format(type="torch")
1283
  train_dataset = dataset
1284
  test_dataset = dataset
 
1285
  else:
1286
 
1287
  cache_dir = "./processed_datasets"
@@ -1300,10 +1412,8 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **
1300
  len(x["audio"]["array"]) > 0 and
1301
  len(x["audio"]["array"]) < 2048 * 160)
1302
 
1303
- raw_train = load_dataset(
1304
- "google/fleurs", "en_us", token=token, split="train[:1000]", trust_remote_code=True)
1305
- raw_test = load_dataset(
1306
- "google/fleurs", "en_us", token=token, split="test[:100]", trust_remote_code=True)
1307
 
1308
  raw_train = raw_train.filter(filter_func)
1309
  raw_test = raw_test.filter(filter_func)
@@ -1318,8 +1428,8 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **
1318
  lambda x: extract_features(x, tokenizer, **dataset_config),
1319
  remove_columns=raw_test.column_names)
1320
 
1321
- train_dataset.save_to_disk(cache_file_train)
1322
- test_dataset.save_to_disk(cache_file_test)
1323
  return train_dataset, test_dataset
1324
 
1325
  @dataclass
@@ -1401,22 +1511,21 @@ def wer_batch(references, hypotheses):
1401
  total_words += len(ref_words)
1402
  return (total_errors / total_words) * 100 if total_words > 0 else 0.0
1403
 
1404
- def clean_ids(ids, pad_token_id=0):
1405
  if isinstance(ids, torch.Tensor):
1406
  ids = ids.tolist()
1407
- return [int(id) for id in ids if id != -100 and id != pad_token_id]
1408
 
1409
- def clean_batch(batch_ids, pad_token_id=0):
1410
- return [clean_ids(seq, pad_token_id) for seq in batch_ids]
1411
 
1412
  def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, optimizer=None, scheduler=None):
1413
-
1414
  label_ids = pred.label_ids
1415
  pred_ids = pred.predictions[0]
1416
- label_ids = clean_batch(label_ids, pad_token_id=tokenizer.pad_token_id)
1417
- pred_ids = clean_batch(pred_ids, pad_token_id=tokenizer.pad_token_id)
1418
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1419
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1420
 
1421
  if print_pred:
1422
  for i in range(min(num_samples, len(pred_ids))):
@@ -1433,17 +1542,25 @@ def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samp
1433
  else:
1434
  trainable_params = 0.0
1435
  efficiency_score = 0.0
1436
- return {
1437
- "wer": float(wer),
1438
- "efficiency_score": float(efficiency_score),
1439
- }
 
 
 
1440
 
1441
  def main():
1442
  token = ""
1443
  log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
1444
  os.makedirs(log_dir, exist_ok=True)
1445
  tokenizer = setup_tokenizer(token)
1446
- train_dataset, test_dataset = prepare_datasets(tokenizer, token)
 
 
 
 
 
1447
 
1448
  param = Dimensions(
1449
  vocab=40000,
@@ -1453,7 +1570,7 @@ def main():
1453
  head=4,
1454
  layer=4,
1455
  act="swish",
1456
- debug={"decoder", "radius"},
1457
  features = ["spectrogram"],
1458
  )
1459
 
@@ -1472,20 +1589,20 @@ def main():
1472
  logging_steps=10,
1473
  logging_dir=log_dir,
1474
  eval_strategy="steps",
1475
- save_strategy="steps",
1476
  report_to=["tensorboard"],
1477
  push_to_hub=False,
1478
  disable_tqdm=False,
1479
  save_total_limit=1,
1480
  label_names=["labels"],
1481
  save_safetensors=False,
1482
- eval_on_start=False,
1483
  batch_eval_metrics=False,
1484
  )
1485
  from functools import partial
1486
  metrics_fn = partial(compute_metrics,
1487
  print_pred=True,
1488
- num_samples=1,
1489
  tokenizer=tokenizer, model=model)
1490
 
1491
  optimizer = torch.optim.AdamW(model.parameters(), lr=0.00025, eps=1e-8, weight_decay=0.025, betas=(0.9, 0.999),
@@ -1499,9 +1616,11 @@ def main():
1499
  train_dataset=train_dataset,
1500
  eval_dataset=test_dataset,
1501
  data_collator=DataCollator(tokenizer=tokenizer),
 
1502
  compute_metrics=metrics_fn,
1503
  optimizers=(optimizer, scheduler)
1504
  )
 
1505
  model.init_weights()
1506
  trainer.train()
1507
 
 
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
+ PATH = 'E:/hf'
36
+ os.environ['HF_HOME'] = PATH
37
+ os.environ['HF_DATASETS_CACHE'] = PATH
38
+ os.environ['TORCH_HOME'] = PATH
39
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
40
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
+
42
  def get_activation(act: str) -> nn.Module:
43
  """Get activation function by name."""
44
  act_map = {
 
273
  def tox():
274
  return {"device": get_device(), "dtype": get_dtype()}
275
 
276
+ class Sinusoids(nn.Module):
277
+ def __init__(self, length, channels, max_tscale=10000):
278
  super().__init__()
279
+ assert channels % 2 == 0
280
+ log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
281
+ inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
282
+ scaled_t = torch.arange(length)[:, None] * inv_tscales[None, :]
283
+ pos1 = torch.sin(scaled_t)
284
+ pos2 = torch.cos(scaled_t)
285
+ positions = torch.cat([pos1, pos2], dim=1)
286
+ self.embedding = nn.Embedding.from_pretrained(positions, freeze=False)
287
  def forward(self, positions):
288
+ return self.embedding(positions)
 
289
 
290
  def sinusoids(length, channels, max_tscale=10000):
291
  assert channels % 2 == 0
292
  log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
293
  inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
294
+ scaled_t = torch.arange(length)[:, None] * inv_tscales[None, :]
295
+ pos1 = torch.sin(scaled_t)
296
+ pos2 = torch.cos(scaled_t)
297
+ positions = torch.cat([pos1, pos2], dim=1)
298
+ return nn.Parameter(positions.clone())
299
 
300
  class rotary(nn.Module):
301
+ def __init__(self, dims, head, max_ctx=1500, radii=True, debug: List[str] = [], use_pbias=False, use_2d_axial=False, spec_shape=None, use_true_2d_relative=False, freq_bins=None):
302
  super(rotary, self).__init__()
 
303
  self.use_pbias = use_pbias
304
  self.dims = dims
305
  self.head = head
 
310
  self.counter = 0
311
  self.last_theta = None
312
 
313
+ self.use_2d_axial = use_2d_axial
314
+ if use_2d_axial and spec_shape is not None:
315
+ time_frames, freq_bins = spec_shape
316
+ self.time_frames = time_frames
317
+ self.freq_bins = freq_bins
318
+ time_theta = 50.0
319
+ time_freqs = 1.0 / (time_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
320
+ self.register_buffer('time_freqs', time_freqs)
321
+ freq_theta = 100.0
322
+ freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
323
+ self.register_buffer('freq_freqs', freq_freqs)
324
+
325
  self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
326
  theta = (torch.tensor(10000, device=device, dtype=dtype))
327
  self.theta = nn.Parameter(theta, requires_grad=True)
328
  self.theta_values = []
329
 
330
+ self.use_true_2d_relative = use_true_2d_relative
331
+ self.freq_bins = freq_bins
332
+ self.true2d_dim = (dims // head) // 2
333
+ self.omega_t = nn.Parameter(torch.randn(self.true2d_dim))
334
+ self.omega_f = nn.Parameter(torch.randn(self.true2d_dim))
335
+
336
+ def axial_freqs(self, seq_len):
337
+ if not self.use_2d_axial:
338
+ return None
339
+ time_frames = self.time_frames
340
+ freq_bins = self.freq_bins
341
+ t = torch.arange(seq_len, device=device, dtype=dtype)
342
+ t_x = (t % time_frames).float()
343
+ t_y = torch.div(t, time_frames, rounding_mode='floor').float()
344
+ freqs_x = torch.outer(t_x, self.time_freqs)
345
+ freqs_y = torch.outer(t_y, self.freq_freqs)
346
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
347
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
348
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
349
+
350
  def mel_scale_scalar(self, freq: float) -> float:
351
  return 1127.0 * math.log(1.0 + freq / 700.0)
352
 
 
368
  freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
369
  torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
370
  self.dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
 
371
  return freq
372
 
373
  def _apply_radii(self, freqs, f0, ctx):
 
386
  return torch.polar(torch.ones_like(freqs), freqs), None
387
 
388
  def check_f0(self, f0, f0t, ctx):
389
+ if f0 is not None and f0.dim() == 2:
390
+ f0 = f0.squeeze(0)
391
+ if f0t is not None and f0t.dim() == 2:
392
+ f0t = f0t.squeeze(0)
393
+
394
+ if f0 is not None and f0.shape[0] == ctx:
395
  return f0
396
+ elif f0t is not None and f0t.shape[0] == ctx:
397
  return f0t
398
  else:
399
  return None
400
 
401
  def forward(self, x=None, enc=None, layer=None, feature=None) -> Tensor:
402
+ ctx = x
403
+ if self.use_2d_axial and feature == "spectrogram":
404
+ freqs_2d = self.axial_freqs(ctx)
405
+ if freqs_2d is not None:
406
+ return freqs_2d.unsqueeze(0)
407
+
408
  f0 = enc.get("f0") if enc is not None else None
409
  f0t = enc.get("f0t") if enc is not None else None
 
410
  f0 = self.check_f0(f0, f0t, ctx)
411
+ theta = f0 + self.theta if f0 is not None else self.theta
 
 
 
 
 
412
  freqs = self.theta_freqs(theta)
413
  t = torch.arange(ctx, device=device, dtype=dtype)
414
  freqs = t[:, None] * freqs
415
  freqs, radius = self._apply_radii(freqs, f0, ctx)
416
+
417
  if "radius" in self.debug and self.counter == 10:
418
+ print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [ctx] {ctx}")
419
+
420
  self.counter += 1
421
  return freqs.unsqueeze(0)
422
 
 
433
  x1 = x1.view(orig_shape)
434
  return torch.cat([x1.type_as(x), x2], dim=-1)
435
 
436
+ # @staticmethod
437
+ # def apply_rotary(x, freqs):
438
+ # # x: [batch, head, seq, head_dim]
439
+ # # freqs: [1, seq, head_dim] or [1, seq, 2*head_dim] for 2D
440
+ # if freqs.shape[-1] == x.shape[-1]:
441
+ # # 1D rotary
442
+ # x1 = x
443
+ # orig_shape = x1.shape
444
+ # if x1.ndim == 2:
445
+ # x1 = x1.unsqueeze(0)
446
+ # x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
447
+ # x1 = torch.view_as_complex(x1) * freqs
448
+ # x1 = torch.view_as_real(x1).flatten(-2)
449
+ # x1 = x1.view(orig_shape)
450
+ # return x1.type_as(x)
451
+ # else:
452
+ # # 2D rotary: split x and apply to each axis
453
+ # head_dim = x.shape[-1] // 2
454
+ # x_time = x[..., :head_dim]
455
+ # x_freq = x[..., head_dim:]
456
+ # f_time = freqs[..., :head_dim]
457
+ # f_freq = freqs[..., head_dim:]
458
+ # # Apply rotary to each axis
459
+ # def apply_axis(xa, freqs):
460
+ # orig_shape = xa.shape
461
+ # xa = xa.float().reshape(*xa.shape[:-1], -1, 2).contiguous()
462
+ # xa = torch.view_as_complex(xa) * freqs
463
+ # xa = torch.view_as_real(xa).flatten(-2)
464
+ # xa = xa.view(orig_shape)
465
+ # return xa.type_as(x)
466
+ # x_time = apply_axis(x_time, f_time)
467
+ # x_freq = apply_axis(x_freq, f_freq)
468
+ # return torch.cat([x_time, x_freq], dim=-1)
469
+
470
+ # def true2d_relative_angle(self, t_q, f_q, t_k, f_k):
471
+ # # t_q, f_q, t_k, f_k: [seq]
472
+ # delta_t = t_q[:, None] - t_k[None, :] # [seq, seq]
473
+ # delta_f = f_q[:, None] - f_k[None, :] # [seq, seq]
474
+ # angle = delta_t[..., None] * self.omega_t + delta_f[..., None] * self.omega_f # [seq, seq, true2d_dim]
475
+ # angle = torch.cat([angle, angle], dim=-1) # [seq, seq, head_dim]
476
+ # return angle
477
+
478
+ # def true2d_apply_rotary(self, q, k, freqs):
479
+ # # q, k: [batch, head, seq, head_dim]
480
+ # # freqs: [seq, seq, head_dim//2] complex, or [seq, seq, head_dim] if you want
481
+ # b, h, seq, d = q.shape
482
+ # d2 = d // 2
483
+ # q_exp = q.unsqueeze(3).expand(b, h, seq, seq, d)
484
+ # k_exp = k.unsqueeze(2).expand(b, h, seq, seq, d)
485
+ # # Convert to complex
486
+ # def to_complex(x):
487
+ # return torch.complex(x[..., 0::2], x[..., 1::2]) # [b, h, seq, seq, d2]
488
+ # q_c = to_complex(q_exp)
489
+ # k_c = to_complex(k_exp)
490
+ # # Multiply by freqs (which should be [seq, seq, d2] complex)
491
+ # q_rot = q_c * freqs
492
+ # k_rot = k_c * freqs
493
+ # # Back to real
494
+ # def to_real(x):
495
+ # return torch.stack([x.real, x.imag], dim=-1).flatten(-2)
496
+ # q_rot = to_real(q_rot)
497
+ # k_rot = to_real(k_rot)
498
+ # return q_rot, k_rot
499
+
500
+ def parallel_slice(self, q, k, v, mask=None):
501
+ batch, head, ctx, dims = q.shape
502
+ head_dim = self.head_dim
503
+ batch, ctx, dims = q.shape
504
+ ctx_len = k.shape[1]
505
+ head = dims // head_dim
506
+
507
+ scores = torch.zeros(batch, head, ctx, ctx_len, device=q.device)
508
+
509
+ for h in range(head):
510
+ start_idx = h * head_dim
511
+ end_idx = start_idx + head_dim
512
+ q_h = q[:, :, start_idx:end_idx]
513
+ k_h = k[:, :, start_idx:end_idx]
514
+
515
+ scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim)
516
+
517
+ if mask is not None:
518
+ scores = scores + mask.unsqueeze(0).unsqueeze(0)
519
+
520
+ attn_weights = F.softmax(scores, dim=-1)
521
+
522
+ output = torch.zeros_like(q)
523
+ for h in range(head):
524
+ start_idx = h * head_dim
525
+ end_idx = start_idx + head_dim
526
+ v_h = v[:, :, start_idx:end_idx]
527
+ output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h)
528
+ return output
529
 
530
+ class MultiheadA(nn.Module):
531
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
532
+ zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False, use_true_2d_relative=False, freq_bins=None, radii=False, use_2d_axial=False, spec_shape=None, rbf=False):
533
+
534
  super(MultiheadA, self).__init__()
 
535
  self.dims = dims
536
  self.head = head
537
  self.head_dim = dims // head
538
  self.debug = debug
539
  self.counter = 0
540
  self.use_pbias = use_pbias
541
+ self.use_true_2d_relative = use_true_2d_relative
542
+ self.freq_bins = freq_bins
543
+ self.rbf = rbf
544
 
545
  self.q = nn.Linear(dims, dims).to(device, dtype)
546
  self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
 
560
  dims=dims,
561
  head=head,
562
  debug=debug,
563
+ radii=radii,
564
+ use_true_2d_relative=use_true_2d_relative,
565
+ freq_bins=freq_bins,
566
+ )
567
  else:
568
  self.rope = None
569
 
 
607
  q2 = q.shape[2]
608
  k2 = k.shape[2]
609
 
610
+ if self.use_true_2d_relative and feature == "spectrogram":
611
+ seq_len = q2
612
+ freq_bins = self.freq_bins
613
+ idxs = torch.arange(seq_len, device=q.device)
614
+ t_idx = idxs // freq_bins
615
+ f_idx = idxs % freq_bins
616
+ angle = self.rope.true2d_relative_angle(t_idx, f_idx, t_idx, f_idx)
617
+ q_rot, k_rot = self.rope.true2d_apply_rotary(q, k, angle)
618
+ scale = (self.dims // self.head) ** -0.25
619
+ qk = (q_rot * scale * k_rot * scale).sum(-1)
620
+ w = F.softmax(qk, dim=-1).to(q.dtype)
621
+ wv = torch.einsum('bhij,bhjd->bhid', w, v.unsqueeze(2).expand(-1, -1, seq_len, -1, -1))
622
+ wv = wv.permute(0, 2, 1, 3).flatten(start_dim=2)
623
+ return self.o(wv), qk
624
+ else:
625
+ q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
626
+ k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
627
  else:
628
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
629
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
638
  if pbias is not None:
639
  qk = qk + pbias[:,:,:q2,:q2]
640
 
641
+ if mask is not None:
642
+ mask = mask[:q2, :q2]
643
+
644
  token_ids = k[:, :, :, 0]
645
  zscale = torch.ones_like(token_ids)
646
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
 
802
  if use_rope:
803
  if spec_shape is not None:
804
  self.rope = rotary(
805
+ dims=dims,
806
+ head=head,
807
  use_2d_axial=True,
808
  spec_shape=spec_shape, debug=[])
809
  else:
810
  self.rope = rotary(
811
+ dims=dims,
812
+ head=head,
813
  use_2d_axial=False, debug=[])
814
  else:
815
  self.rope = None
816
+ self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
817
 
818
  self.norm = RMSNorm(dims)
 
819
 
820
  def apply_rope_to_features(self, x, layer=None, feature=None):
 
 
821
  batch, ctx, dims = x.shape
822
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
823
+ if feature == "spectrogram" and self.rope is not None:
824
+ rope_freqs = self.rope(ctx, layer=layer, feature="spectrogram")
825
  else:
826
+ rope_freqs = self.rope(ctx, layer=layer, feature="audio")
827
  x = self.rope.apply_rotary(x, rope_freqs)
828
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
829
  return x
 
833
  if self.use_rope:
834
  x = self.apply_rope_to_features(x, layer=layer, feature=feature)
835
  else:
836
+ x = x + self.sinusoid_pos(x.shape[1], x.shape[-1]).to(x.device, x.dtype)
837
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
838
+ return self.norm(x)
 
839
 
840
  class WEncoder(nn.Module):
841
  def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
 
848
  self.dims = dims
849
 
850
  act_fn = get_activation(act)
 
851
  self.downsample = nn.Sequential(
852
  Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
853
  Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn,
 
863
  debug=[])
864
  else:
865
  self.rope = None
866
+ self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
867
  self.norm = RMSNorm(dims)
868
 
869
  def apply_rope_to_features(self, x, layer=None, feature=None):
 
883
  if self.use_rope:
884
  x = self.apply_rope_to_features(x, layer=layer)
885
  else:
886
+ x = x + self.sinusoid_pos(x.shape[1], x.shape[-1]).to(x.device, x.dtype)
887
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
888
  return self.norm(x)
889
 
 
911
  debug=[])
912
  else:
913
  self.rope = None
914
+ self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
915
  self.norm = RMSNorm(dims)
916
 
917
  def apply_rope_to_features(self, x, layer=None, feature=None):
 
929
  if self.use_rope:
930
  x = self.apply_rope_to_features(x, layer=layer)
931
  else:
932
+ x = x + self.sinusoid_pos(x.shape[1], x.shape[-1]).to(x.device, x.dtype)
933
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
934
+ return self.norm(x)
 
935
 
936
  class theBridge(nn.Module):
937
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
938
  debug: List[str], features: List[str], act: str = "gelu"):
939
  super(theBridge, self).__init__()
940
 
 
 
 
 
941
  self.debug = debug
942
  self.counter = 0
943
  self.dropout = 0.01
 
947
 
948
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
949
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
 
950
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
951
+ self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
952
 
953
  with torch.no_grad():
954
  self.token.weight[0].zero_()
955
 
956
  self.block = nn.ModuleList([
957
+ Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features)
958
  for _ in range(layer)])
959
 
960
  self.cross_attn = nn.ModuleList([
961
+ Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features)
962
  for _ in range(layer)])
963
 
964
  self.cross_modal = nn.ModuleList([
965
+ Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features)
966
  for _ in range(layer)])
967
 
968
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0).unsqueeze(0).unsqueeze(0)
969
  self.register_buffer("mask", mask, persistent=False)
 
 
 
970
 
971
  act_fn = get_activation(act)
972
  if features == ["spectrogram", "waveform", "pitch"]:
 
991
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
992
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None)})
993
 
994
+ self.norm = RMSNorm(dims)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995
 
996
  def forward(self, x, enc, layer='decoder', feature=None) -> Tensor:
997
+ out = {}
998
+ out.update(enc)
999
  enc = dict_to(enc, device, dtype)
1000
  _text_len = x.shape[1]
1001
 
1002
  x = self.token(x) + self.positional[:x.shape[1]]
1003
+
 
1004
  for f in enc:
1005
  if f in self.features:
1006
  xa = enc[f]
1007
  for block in self.blocks[f]:
1008
  xa = block(xa, enc=enc, layer=layer, feature=feature)
1009
+ xa = xa + self.sinusoid_pos(xa.shape[1], xa.shape[-1]).to(xa.device, xa.dtype)
1010
+ out[f] = xa
1011
 
1012
+ for block in self.block:
1013
+ x = block(x, xa=None, mask=self.mask, enc=enc, layer=layer)
1014
+ if f in self.features:
1015
+
1016
+ out = block(x, xa=xa, mask=self.mask, enc=enc, layer=layer)
 
 
1017
  if self.sequential:
1018
  x = out
1019
  else:
1020
  a = torch.sigmoid(self.blend)
1021
  x = a * out + (1 - a) * x
1022
+ x = self.token(x) + self.positional[:x.shape[1]]
1023
+ out[f] = x
1024
 
1025
  for block in self.cross_attn:
1026
+ if f in self.features:
1027
+ x = block(x, xa=xa, mask=self.mask, enc=enc, layer=layer)
1028
+ xa = block(xa, xa=x, mask=self.mask, enc=enc, layer=layer)
1029
+ out = block(x, xa=xa, mask=self.mask, enc=enc, layer=layer)
 
 
 
1030
  if self.sequential:
1031
  x = out
1032
  else:
1033
  a = torch.sigmoid(self.blend)
1034
  x = a * out + (1 - a) * x
1035
+ x = self.token(x) + self.positional[:x.shape[1]]
1036
+ out[f] = x
1037
 
1038
  for block in self.cross_modal:
1039
+ if f in self.features:
 
1040
  xcat = torch.cat([x, xa], dim=1)
1041
+ x = block(xcat, xa=None, mask=self.mask, enc=enc, layer=layer)
 
1042
  x = x[:, :_text_len]
1043
+ out[f] = x
1044
  if self.counter < 1 and "encoder" in self.debug:
 
 
 
 
1045
  shapes = {k: v.shape for k, v in enc.items()}
1046
  print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
1047
  self.counter += 1
1048
+
1049
+ x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1050
+ x = self.norm(x)
1051
+ return x, out
1052
 
1053
  class Echo(nn.Module):
1054
  def __init__(self, param: Dimensions):
 
1103
  else:
1104
  feature = "spectrogram"
1105
 
1106
+ out, logits = self.processor(input_ids, enc, feature)
1107
+ self.out=out
1108
 
1109
  loss = None
1110
  if labels is not None:
 
1216
  "eos_token_id": self.eos_token_id,
1217
  })
1218
  return Config()
1219
+
1220
  def setup_tokenizer(token: str):
1221
  from tokenizers import Tokenizer
1222
  tokenizer = Tokenizer.from_file("./tokenizer.json")
 
1244
  results.append(tokenizer.decode(ids))
1245
  return results
1246
 
 
1247
  def save_pretrained(save_dir):
1248
  os.makedirs(save_dir, exist_ok=True)
1249
  tokenizer.save(f"{save_dir}/tokenizer.json")
 
1285
 
1286
  return sp_mel, ap_mel
1287
 
1288
+ def extract_features(batch, tokenizer, waveform=False, spec=True, f0=True, f0t=True, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, **dataset_config):
 
1289
  dataset_config = {
1290
  "hop_length": 256,
1291
  "f_min": 150,
 
1300
  "mel_scale": "htk",
1301
  "norm": None,
1302
  "normalized": False,
1303
+ }
1304
 
1305
  audio = batch["audio"]
1306
  sr = audio["sampling_rate"]
 
1307
  labels = tokenizer.encode(batch["transcription"])
1308
 
1309
+ wav = wavnp = f0_np = t = None
1310
+ spectrogram = f0_tensor = f0t_tensor = harmonic = aperiodic = None
1311
+
1312
+ if waveform or spec or f0 or f0t or harmonics:
1313
  wav = load_wave(wave_data=audio, sample_rate=sr)
1314
+ wavnp = wav.numpy().astype(np.float64)
 
1315
 
1316
  if spec:
1317
+ transform = torchaudio.transforms.MelSpectrogram(**dataset_config)
1318
+ mel_spectrogram = transform(wav)
1319
  log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1320
  log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1321
+ spectrogram = (log_mel + 4.0) / 4.0
1322
+ spectrogram = torch.tensor(spectrogram)
 
 
1323
 
1324
+ if f0 or f0t or harmonics:
 
1325
  f0_np, t = pw.dio(wavnp, sample_rate,
1326
+ frame_period=hop_length / sample_rate * 1000, f0_ceil=500, f0_floor=71.1)
1327
  f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1328
 
1329
+ if f0:
1330
+ f0_tensor = torch.from_numpy(f0_np)
1331
+ f0_tensor = torch.where(f0_tensor == 0.0, torch.zeros_like(f0_tensor), (f0_tensor - 71.0) / (500.0 - 71.0))
1332
+
1333
+ if f0t:
1334
+ audio_duration = len(wavnp) / sample_rate
1335
+ T = len(labels)
1336
+ tok_dur_sec = audio_duration / T
1337
+ token_starts = np.arange(T) * tok_dur_sec
1338
+ token_ends = token_starts + tok_dur_sec
1339
+ start_idx = np.searchsorted(t, token_starts, side="left")
1340
+ end_idx = np.searchsorted(t, token_ends, side="right")
1341
+ pitch_tok = np.zeros(T, dtype=np.float32)
1342
+ for i in range(T):
1343
+ lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
1344
+ segment = f0_np[lo:hi]
1345
+ if mode == "mean":
1346
+ pitch_tok[i] = segment.mean()
1347
+ elif mode == "median":
1348
+ pitch_tok[i] = np.median(segment)
1349
+ else:
1350
+ pitch_tok[i] = segment[-1]
1351
+ pitch_tok[pitch_tok < 100.0] = 0.0
1352
+ bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
1353
+ f0t_tensor = torch.from_numpy(np.concatenate([[bos_pitch], pitch_tok]))
1354
+ f0t_tensor = torch.where(f0t_tensor == 0.0, torch.zeros_like(f0t_tensor), (f0t_tensor - 71.0) / (500.0 - 71.0))
1355
+
1356
+ if harmonics:
1357
+ spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
1358
+ apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
1359
+ harmonic = torch.from_numpy(spnp)
1360
+ aperiodic = torch.from_numpy(apnp)
1361
+ harmonic = harmonic[:, :128].contiguous().T
1362
+ aperiodic = aperiodic[:, :128].contiguous().T
1363
+ harmonic = torch.where(harmonic == 0.0, torch.zeros_like(harmonic), harmonic / 1.0)
1364
+ aperiodic = torch.where(aperiodic == 0.0, torch.zeros_like(aperiodic), aperiodic / 1.0)
1365
 
1366
  if debug:
1367
+ print(f"['f0']: {f0_tensor.shape if f0_tensor is not None else None}")
1368
+ print(f"['f0t']: {f0t_tensor.shape if f0t_tensor is not None else None}")
1369
+ print(f"['harmonic']: {harmonic.shape if harmonic is not None else None}")
1370
+ print(f"['aperiodic']: {aperiodic.shape if aperiodic is not None else None}")
1371
+ print(f"['spectrogram']: {spectrogram.shape if spectrogram is not None else None}")
1372
+ print(f"['waveform']: {wav.shape if wav is not None else None}")
1373
+ print(f"['labels']: {len(labels) if labels is not None else None}")
1374
 
1375
  return {
1376
+ "waveform": wav if waveform else None,
1377
+ "spectrogram": spectrogram if spec else None,
1378
+ "f0": f0_tensor if f0 else None,
1379
+ "f0t": f0t_tensor if f0t else None,
1380
+ "harmonic": harmonic if harmonics else None,
1381
+ "aperiodic": aperiodic if harmonics else None,
1382
  "labels": labels,
 
 
 
1383
  }
1384
 
1385
+ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False, **dataset_config):
1386
 
1387
  if sanity_check:
1388
  test = load_dataset(
1389
+ "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True
1390
+ ).cast_column("audio", Audio(sampling_rate=sample_rate)).take(1)
 
1391
  dataset = test.map(
1392
  lambda x: extract_features(x, tokenizer, **dataset_config),
1393
  remove_columns=test.column_names)
 
1394
  train_dataset = dataset
1395
  test_dataset = dataset
1396
+ return train_dataset, test_dataset
1397
  else:
1398
 
1399
  cache_dir = "./processed_datasets"
 
1412
  len(x["audio"]["array"]) > 0 and
1413
  len(x["audio"]["array"]) < 2048 * 160)
1414
 
1415
+ raw_train = load_dataset("google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming)
1416
+ raw_test = load_dataset("google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming)
 
 
1417
 
1418
  raw_train = raw_train.filter(filter_func)
1419
  raw_test = raw_test.filter(filter_func)
 
1428
  lambda x: extract_features(x, tokenizer, **dataset_config),
1429
  remove_columns=raw_test.column_names)
1430
 
1431
+ train_dataset.save_to_disk(cache_file_train) if sanity_check or streaming is False else None
1432
+ test_dataset.save_to_disk(cache_file_test) if sanity_check or streaming is False else None
1433
  return train_dataset, test_dataset
1434
 
1435
  @dataclass
 
1511
  total_words += len(ref_words)
1512
  return (total_errors / total_words) * 100 if total_words > 0 else 0.0
1513
 
1514
+ def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
1515
  if isinstance(ids, torch.Tensor):
1516
  ids = ids.tolist()
1517
+ return [int(id) for id in ids if id != -100 and id != pad_token_id and id != bos_token_id and id != eos_token_id]
1518
 
1519
+ def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
1520
+ return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids]
1521
 
1522
  def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, optimizer=None, scheduler=None):
 
1523
  label_ids = pred.label_ids
1524
  pred_ids = pred.predictions[0]
1525
+ label_ids = clean_batch(label_ids)
1526
+ pred_ids = clean_batch(pred_ids)
1527
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1528
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1529
 
1530
  if print_pred:
1531
  for i in range(min(num_samples, len(pred_ids))):
 
1542
  else:
1543
  trainable_params = 0.0
1544
  efficiency_score = 0.0
1545
+ return { "wer": float(wer), "efficiency_score": float(efficiency_score)}
1546
+
1547
+ def preprocess_logits_for_metrics(logits, labels):
1548
+ pred_ids = torch.argmax(logits, dim=-1)
1549
+ labels = torch.where(labels == -100, 0, labels)
1550
+ pred_ids = torch.where(pred_ids == -100, 0, pred_ids)
1551
+ return pred_ids, labels
1552
 
1553
  def main():
1554
  token = ""
1555
  log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
1556
  os.makedirs(log_dir, exist_ok=True)
1557
  tokenizer = setup_tokenizer(token)
1558
+ train_dataset, test_dataset = prepare_datasets(
1559
+ tokenizer,
1560
+ token,
1561
+ sanity_check=False,
1562
+
1563
+ )
1564
 
1565
  param = Dimensions(
1566
  vocab=40000,
 
1570
  head=4,
1571
  layer=4,
1572
  act="swish",
1573
+ debug={"radius", "encoder"},
1574
  features = ["spectrogram"],
1575
  )
1576
 
 
1589
  logging_steps=10,
1590
  logging_dir=log_dir,
1591
  eval_strategy="steps",
1592
+ save_strategy="no",
1593
  report_to=["tensorboard"],
1594
  push_to_hub=False,
1595
  disable_tqdm=False,
1596
  save_total_limit=1,
1597
  label_names=["labels"],
1598
  save_safetensors=False,
1599
+ eval_on_start=True,
1600
  batch_eval_metrics=False,
1601
  )
1602
  from functools import partial
1603
  metrics_fn = partial(compute_metrics,
1604
  print_pred=True,
1605
+ num_samples=2,
1606
  tokenizer=tokenizer, model=model)
1607
 
1608
  optimizer = torch.optim.AdamW(model.parameters(), lr=0.00025, eps=1e-8, weight_decay=0.025, betas=(0.9, 0.999),
 
1616
  train_dataset=train_dataset,
1617
  eval_dataset=test_dataset,
1618
  data_collator=DataCollator(tokenizer=tokenizer),
1619
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
1620
  compute_metrics=metrics_fn,
1621
  optimizers=(optimizer, scheduler)
1622
  )
1623
+ print(tokenizer.pad_token_id, tokenizer.bos_token_id, tokenizer.eos_token_id)
1624
  model.init_weights()
1625
  trainer.train()
1626