Sin2pi commited on
Commit
f5a4351
·
verified ·
1 Parent(s): 6d7136d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +113 -201
model.py CHANGED
@@ -12,7 +12,6 @@ import torch.nn.functional as F
12
  import torch.nn.init as init
13
  from torch import nn, Tensor
14
  import numpy as np
15
- from einops import rearrange
16
  import matplotlib.pyplot as plt
17
  from typing import Optional, Dict, Union, List, Tuple, Any
18
  from functools import partial
@@ -242,37 +241,34 @@ def get_dtype():
242
  def tox():
243
  return {"device": get_device(), "dtype": get_dtype()}
244
 
245
- def sinusoids(length, num_chan, max=10000):
246
- assert num_chan % 2 == 0
247
- time_x = np.log(max) / (num_chan // 2 - 1)
248
- inv_time = torch.exp(-time_x * torch.arange(num_chan // 2))
249
- s_time = torch.arange(length)[:, np.newaxis] * inv_time[np.newaxis, :]
250
- return torch.cat([torch.sin(s_time), torch.cos(s_time)], dim=1)
251
-
252
  class rotary(nn.Module):
253
  def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False, spec_shape=None):
254
  super(rotary, self).__init__()
255
 
256
- self.pitch_scale = 0.1
257
- self.use_pbias = use_pbias
258
- self.spec_shape = spec_shape
259
  self.dims = dims
260
  self.head = head
261
  self.head_dim = dims // head
262
- self.radii = radii
263
- self.theta = theta
264
  self.max_ctx = max_ctx
 
 
 
 
 
265
  self.debug = debug
266
  self.counter = 0
267
  self.last_theta = None
268
- dim = self.head_dim
269
- self.dim = dim
270
 
271
- self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
272
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
273
- self.base_radius = nn.Parameter(torch.ones(1))
274
-
275
- freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
276
  self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
277
 
278
  def return_f0(self, f0=None):
@@ -291,12 +287,11 @@ class rotary(nn.Module):
291
  self.theta.data.copy_(theta)
292
 
293
  def get_pitch_bias(self, f0):
294
- if f0 is None:
295
- return None
296
  f0_flat = f0.squeeze().float()
297
  f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
298
  f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
299
- f0_norm.unsqueeze(1)) * self.pitch_scale)
300
  return f0_sim.unsqueeze(0).unsqueeze(0)
301
 
302
  def f0proj(self, f0):
@@ -311,7 +306,7 @@ class rotary(nn.Module):
311
 
312
  def synth_f0(self, f0, ctx):
313
  f0 = self.f0proj(f0)
314
- print(f"Aligning f0 with context: {ctx}, f0 shape: {f0}")
315
  if f0.dim() == 1:
316
  length = f0.shape[0]
317
  if length == ctx:
@@ -319,7 +314,7 @@ class rotary(nn.Module):
319
  frames = length / ctx
320
  idx = torch.arange(ctx, device=f0.device)
321
  # return torch.arange(1, ctx+1, device=f0.device, dtype=torch.float)
322
- return f0[id]
323
 
324
  def align_f0(self, ctx, f0):
325
  # f0 = self.return_f0()
@@ -351,6 +346,14 @@ class rotary(nn.Module):
351
 
352
  def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
353
  f0 = default(enc.get("pitch"), enc.get("f0")) if enc is not None else None
 
 
 
 
 
 
 
 
354
 
355
  if isinstance(x, int):
356
  ctx = x
@@ -359,48 +362,41 @@ class rotary(nn.Module):
359
  else:
360
  batch, head, ctx, head_dim = x.shape
361
  t = torch.arange(ctx, device=device, dtype=dtype)
 
362
  if f0 is not None:
363
  f0_mean = f0.mean()
364
- theta = f0_mean + 1e-8
365
  else:
366
  theta = self.theta
367
-
368
  freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
369
  self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
 
 
 
370
 
371
  freqs = t[:, None] * freqs[None, :]
372
-
373
- if self.radii:
374
- radius = self.align_f0(ctx)
375
- if "rotary2" in self.debug and self._counter == 5:
376
- print(f"{layer} radius: {radius} ctx: {ctx}")
 
 
 
 
377
  else:
378
- radius = freqs
379
- freqs = torch.polar(torch.ones_like(radius), freqs)
380
-
381
- if "rotary3" in self.debug and self._counter == 5:
382
- print(f"{layer} count {self._counter} f0: {f0.shape if f0 is not None else None} freqs: {freqs.shape} radius: {radius.shape} ctx: {ctx}")
383
- print(f"freqs mean: {freqs.mean():.2f} inv_freq mean: {self.inv_freq.mean():.2f} theta: {self.theta.item():.2f} radius mean: {radius.mean():.2f} radius shape: {radius.shape} ctx: {ctx}")
384
-
385
- if "rotary_detail" in self.debug and self._counter == 5:
386
- print(f"\n==== Detailed RoPE Analysis ====")
387
- print(f"Layer: {layer}, Context Length: {ctx}")
388
- print(f"F0 stats: mean={self.theta.item():.2f}")
389
- print(f"inv_freq range: [{self.inv_freq.min().item():.4f}, {self.inv_freq.max().item():.4f}]")
390
-
391
- if self.radii:
392
- print(f"Radius Shape: {radius.shape}, Mean: {radius.mean().item():.4f}")
393
- print(f"Radius[0]: {radius[0][:5].cpu().numpy()}")
394
- print(f"Radius[mid]: {radius[ctx//2][:5].cpu().numpy()}")
395
- print(f"Radius[end]: {radius[-1][:5].cpu().numpy()}")
396
-
397
- print(f"Final freqs shape: {freqs.shape}")
398
- print(f"Freqs[0]: {freqs[0][:5].cpu().detach().numpy()}")
399
- print(f"Freqs[mid]: {freqs[ctx//2][:5].cpu().detach().numpy()}")
400
- print(f"Freqs[end]: {freqs[-1][:5].cpu().detach().numpy()}")
401
- print("================================\n")
402
 
403
- self._counter += 1
 
 
 
 
 
404
  return freqs.unsqueeze(0)
405
 
406
  @staticmethod
@@ -416,88 +412,8 @@ class rotary(nn.Module):
416
  x1 = x1.view(orig_shape)
417
  return torch.cat([x1.type_as(x), x2], dim=-1)
418
 
419
- # class FocusA(nn.Module):
420
- # def __init__(self, dims, head, max_dist=None, win_size=32, max_span=32, temp_scale=0.01, iterations=2):
421
- # super().__init__()
422
- # self.dims = dims
423
- # self.head = head
424
- # self.max_dist = max_dist
425
- # self.win_size = win_size
426
- # self.max_span = max_span
427
- # self.temp_scale = temp_scale
428
- # self.iterations = iterations
429
-
430
- # self.span_predictor = nn.Linear(dims, 1)
431
-
432
- # self.attn_l = nn.MultiheadAttention(embed_dim=dims, num_heads=head)
433
- # self.attn_g = nn.MultiheadAttention(embed_dim=dims, num_heads=head)
434
-
435
- # self.ln_l = nn.LayerNorm(dims)
436
- # self.ln_g = nn.LayerNorm(dims)
437
- # self.projection = nn.Linear(2 * dims, dims)
438
-
439
- # def _focus(self, que, key, val, span_scale):
440
- # attn_out = que
441
- # span_len = max(1, int(self.max_span * span_scale.mean().item()))
442
- # span_len = min(span_len, que.size(1), key.size(1), val.size(1))
443
-
444
- # for _ in range(self.iterations):
445
- # temp = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
446
- # q = que / temp
447
- # k = key / temp
448
- # v = val / temp
449
- # output, _ = self.attn_l(q, k, v)
450
- # que = que + output
451
- # return que
452
-
453
- # def _window(self, x, win_size, span_len, span_scale):
454
- # batch_size, ctx, dims = x.size()
455
- # output = torch.zeros_like(x)
456
-
457
- # for i in range(0, ctx, win_size // 2):
458
- # end = min(i + win_size, ctx)
459
- # que = x[:, i:end]
460
- # start = max(0, i - span_len)
461
- # end_con = min(i + win_size + span_len, ctx)
462
- # con = x[:, start:end_con]
463
- # win_out = self._focus(que, con, con, span_scale)
464
-
465
- # if i > 0:
466
- # start_over = i
467
- # end_over = min(i + win_size // 2, ctx)
468
- # blend = torch.linspace(0, 1, end_over - start_over).view(1, -1, 1)
469
- # blend = blend.to(x.device)
470
- # output[:, start_over:end_over] = (
471
- # (1 - blend) * output[:, start_over:end_over] +
472
- # blend * win_out[:, :end_over-start_over])
473
- # if end_over < end:
474
- # output[:, end_over:end] = win_out[:, end_over-i:end-i]
475
- # else:
476
- # output[:, i:end] = win_out
477
- # return output
478
-
479
- # def forward(self, x, mask=None):
480
- # l_x = self.ln_l(x)
481
- # g_x = self.ln_g(x)
482
- # g_out, g_attn = self.attn_g(g_x, g_x, g_x, need_weights=True)
483
- # g_focus = g_attn.sum(dim=1)
484
- # f_score = g_focus.max(dim=-1)[0]
485
- # b_scale = torch.sigmoid(self.span_predictor(x.mean(dim=1)))
486
- # var = (f_score - f_score.mean(dim=1, keepdim=True)).abs()
487
- # a_span = b_scale * (1.0 + 0.5 * var.mean(dim=1, keepdim=True))
488
-
489
- # l_out = self._window(
490
- # l_x,
491
- # win_size=self.win_size,
492
- # span_len=max(1, int(self.max_span * a_span.mean().item())),
493
- # span_scale=a_span
494
- # )
495
-
496
- # combined = torch.cat([l_out, g_out], dim=-1)
497
- # return self.projection(combined)
498
-
499
  class MultiheadA(nn.Module):
500
-
501
  rbf = False
502
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
503
  zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
@@ -507,7 +423,7 @@ class MultiheadA(nn.Module):
507
  self.head = head
508
  self.head_dim = dims // head
509
  self.debug = debug
510
- self._counter = 0
511
 
512
  self.q = Linear(dims, dims).to(device, dtype)
513
  self.k = Linear(dims, dims, bias=False).to(device, dtype)
@@ -543,7 +459,7 @@ class MultiheadA(nn.Module):
543
  dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
544
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
545
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
546
-
547
  def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
548
  x = x.to(device, dtype)
549
  if xa is not None:
@@ -595,9 +511,9 @@ class MultiheadA(nn.Module):
595
  w = F.softmax(qk, dim=-1).to(q.dtype)
596
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
597
 
598
- if "multihead" in self.debug and self._counter % 100 == 0:
599
  print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
600
- self._counter += 1
601
  return self.o(wv), qk.detach()
602
 
603
  class t_gate(nn.Module):
@@ -667,7 +583,7 @@ class Residual(nn.Module):
667
  self.cross_attn = cross_attn
668
  self.features = features
669
  self.debug = debug
670
- self._counter = 0
671
  self.dropout = 0.01
672
 
673
  self.t_gate = tgate
@@ -734,17 +650,18 @@ class Residual(nn.Module):
734
  else:
735
  x = x + mlp_out
736
 
737
- if "residual" in self.debug and self._counter % 100 == 0:
738
- print(f"Step {self._counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
739
  if self.t_gate:
740
- print(f"Step {self._counter}: Using t_gate: {self.t_gate}")
741
  elif self.m_gate:
742
- print(f"Step {self._counter}: Using m_gate: {self.m_gate}")
743
  elif self.c_gate:
744
- print(f"Step {self._counter}: Using c_gate: {self.c_gate}")
745
  else:
746
- print(f"Step {self._counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
747
- self._counter += 1
 
748
  return x
749
 
750
  class FEncoder(nn.Module):
@@ -915,7 +832,7 @@ class AudioEncoder(nn.Module):
915
  self.ctx = ctx
916
  self.head_dim = dims // head
917
  self.debug = debug
918
- self._counter = 0
919
  self.features = features
920
  self.dropout = 0.01
921
 
@@ -943,38 +860,42 @@ class AudioEncoder(nn.Module):
943
  "envelope": nn.ModuleList(
944
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
945
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
946
- for _ in range(layer)] if "envelope" in features else None
947
  ),
948
  "phase": nn.ModuleList(
949
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
950
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
951
  for _ in range(layer)] if "phase" in features else None
952
- )})
 
953
 
954
- def forward(self, enc, layer="encoder"):
955
  enc = dict_to(enc, device, dtype)
956
-
957
- if self._counter < 1:
958
  s = enc.get("spectrogram")
959
  w = enc.get("waveform")
960
  p = default(enc.get("pitch"), enc.get("f0"))
961
  plot_waveform(x=s, w=w, p=p, hop_length=128)
962
 
963
- xa = {}
 
 
 
 
964
 
965
- for f in self.features:
966
  if f in enc and f in self.blocks:
967
  x = enc[f]
968
  for block in self.blocks[f]:
969
  x = block(x, enc=enc, layer=layer)
970
- xa[f] = x
971
 
972
- if "encoder" in self.debug and self._counter % 100 == 0:
973
- names = list(x.keys())
974
- shapes = {k: v.shape for k, v in x.items()}
975
- print(f"Step {self._counter}: mode: {names}: shapes: {shapes}")
976
- self._counter += 1
977
- return xa
978
 
979
  class TextDecoder(nn.Module):
980
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
@@ -986,7 +907,7 @@ class TextDecoder(nn.Module):
986
  self.head = head
987
  self.head_dim = dims // head
988
  self.debug = debug
989
- self._counter = 0
990
  self.dropout = 0.01
991
  self.sequential = sequential
992
  self.features = features
@@ -1010,8 +931,8 @@ class TextDecoder(nn.Module):
1010
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
1011
  self.register_buffer("mask", mask, persistent=False)
1012
 
1013
- def forward(self, x, xa, enc=None, order=None, layer='decoder') -> Tensor:
1014
- xa = dict_to(xa, device, dtype)
1015
  x = x.to(device)
1016
  bln = self.blend
1017
 
@@ -1026,17 +947,17 @@ class TextDecoder(nn.Module):
1026
  x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
1027
 
1028
  for f in order:
1029
- if f in xa:
1030
- ax = xa[f]
1031
  for block in self.blocks[f]:
1032
- out = block(x=x, xa=ax, mask=None, enc=enc, layer=layer)
1033
 
1034
  a = torch.sigmoid(bln[f])
1035
  x = a * out + (1 - a) * x
1036
 
1037
- if "decoder" in self.debug and self._counter % 100 == 0:
1038
- print(f"Step {self._counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
1039
- self._counter += 1
1040
 
1041
  x = self.ln_dec(x)
1042
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
@@ -1076,14 +997,14 @@ class Echo(nn.Module):
1076
  def update_base(self, f0):
1077
  for name, module in self.encoder.named_modules():
1078
  if isinstance(module, (rotary)):
1079
- module.return_f0(f0)
1080
  module.update_base(f0)
 
1081
 
1082
  for name, module in self.decoder.named_modules():
1083
  if isinstance(module, (rotary)):
1084
- module.return_f0(f0)
1085
  module.update_base(f0)
1086
-
 
1087
  def set_alignment_head(self, dump: bytes):
1088
  array = np.frombuffer(
1089
  gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
@@ -1125,9 +1046,10 @@ class Echo(nn.Module):
1125
  if f0 is not None:
1126
  encoder_inputs["f0"] = f0
1127
 
1128
- if f0 is not None:
1129
- f0 = f0.squeeze(0)
1130
- self.update_base(f0)
 
1131
 
1132
  encoder_outputs = self.encoder(encoder_inputs)
1133
  logits = self.decoder(input_ids, encoder_outputs)
@@ -1225,22 +1147,11 @@ class Echo(nn.Module):
1225
  print(f"DECODER GRAD: {name} = {norm:.6f}")
1226
  return None
1227
 
1228
- def reset_counter(self):
1229
- self._counter = 0
1230
  print("Counter reset to 0.")
1231
 
1232
  metric = evaluate.load(path="wer")
1233
-
1234
- def align_f0(f0, ctx):
1235
- ctx = torch.tensor(ctx)
1236
- bat, length = f0.shape
1237
- if length == ctx:
1238
- return f0
1239
- frames = length / ctx
1240
- idx = torch.arange(ctx, device=f0.device)
1241
- idx = (idx * frames).long()
1242
- batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
1243
- return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
1244
 
1245
  @dataclass
1246
  class DataCollator:
@@ -1486,14 +1397,14 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=
1486
  f0, t = pw.dio(wav_np, sampling_rate,
1487
  frame_period=hop_length/sampling_rate*1000)
1488
  f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1489
- f0 = torch.from_numpy(f0).float()
1490
  batch["pitch"] = f0
1491
 
1492
  if frequency:
1493
  wav_np = wav.numpy().astype(np.float64)
1494
  f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000)
1495
  f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1496
- f0 = torch.from_numpy(f0).float()
1497
  batch["f0"] = f0
1498
 
1499
  if spectrogram and waveforms and pitch:
@@ -1708,9 +1619,7 @@ def get_training_args(
1708
  gradient_accumulation_steps=1,
1709
  eval_accumulation_steps=1,
1710
  eval_strategy="steps",
1711
- save_strategy="no",
1712
- include_tokens_per_second=True,
1713
- include_num_input_tokens_seen=True,
1714
  max_steps=max_steps,
1715
  save_steps=save_steps,
1716
  eval_steps=eval_steps,
@@ -1784,13 +1693,13 @@ def main():
1784
  text_dims=512,
1785
  text_idx=4,
1786
  act="swish",
1787
- debug={"rotary_detail"},
1788
  cross_attn=True,
1789
  features = ["spectrogram"]
1790
  )
1791
 
1792
- sanity_check = True
1793
-
1794
  training_args = sanity(sanity_check)
1795
  dataset_config = {
1796
  "spectrogram": True,
@@ -1814,7 +1723,7 @@ def main():
1814
  "normalized": False}
1815
 
1816
  model = create_model(param)
1817
-
1818
  global global_model
1819
  global_model = model
1820
 
@@ -1843,3 +1752,6 @@ def main():
1843
  if __name__ == "__main__":
1844
  main()
1845
 
 
 
 
 
12
  import torch.nn.init as init
13
  from torch import nn, Tensor
14
  import numpy as np
 
15
  import matplotlib.pyplot as plt
16
  from typing import Optional, Dict, Union, List, Tuple, Any
17
  from functools import partial
 
241
  def tox():
242
  return {"device": get_device(), "dtype": get_dtype()}
243
 
244
+ def sinusoids(length, channels, max_timescale=10000):
245
+ assert channels % 2 == 0
246
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
247
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
248
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
249
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
250
+
251
  class rotary(nn.Module):
252
  def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False, spec_shape=None):
253
  super(rotary, self).__init__()
254
 
 
 
 
255
  self.dims = dims
256
  self.head = head
257
  self.head_dim = dims // head
258
+ self.dim = self.head_dim
 
259
  self.max_ctx = max_ctx
260
+ self.theta = theta
261
+ self.radii = radii
262
+ self.pitch_scale = 0.1
263
+ self.use_pbias = use_pbias
264
+ self.spec_shape = spec_shape
265
  self.debug = debug
266
  self.counter = 0
267
  self.last_theta = None
 
 
268
 
269
+ # self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
270
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
271
+ freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
 
 
272
  self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
273
 
274
  def return_f0(self, f0=None):
 
287
  self.theta.data.copy_(theta)
288
 
289
  def get_pitch_bias(self, f0):
290
+ f0 = self.return_f0()
 
291
  f0_flat = f0.squeeze().float()
292
  f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
293
  f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
294
+ f0_norm.unsqueeze(1)))
295
  return f0_sim.unsqueeze(0).unsqueeze(0)
296
 
297
  def f0proj(self, f0):
 
306
 
307
  def synth_f0(self, f0, ctx):
308
  f0 = self.f0proj(f0)
309
+
310
  if f0.dim() == 1:
311
  length = f0.shape[0]
312
  if length == ctx:
 
314
  frames = length / ctx
315
  idx = torch.arange(ctx, device=f0.device)
316
  # return torch.arange(1, ctx+1, device=f0.device, dtype=torch.float)
317
+ return f0[idx]
318
 
319
  def align_f0(self, ctx, f0):
320
  # f0 = self.return_f0()
 
346
 
347
  def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
348
  f0 = default(enc.get("pitch"), enc.get("f0")) if enc is not None else None
349
+ if f0 is not None and f0.dim() == 2:
350
+ if f0.shape[0] == 1:
351
+ f0 = f0.squeeze(0)
352
+ else:
353
+ f0 = f0.view(-1)
354
+
355
+ if "rot1" in self.debug and self.counter % 100 == 0:
356
+ print(f"Rotary forward: {x if x is not None else None}, f0: {f0.shape if f0 is not None else None}")
357
 
358
  if isinstance(x, int):
359
  ctx = x
 
362
  else:
363
  batch, head, ctx, head_dim = x.shape
364
  t = torch.arange(ctx, device=device, dtype=dtype)
365
+
366
  if f0 is not None:
367
  f0_mean = f0.mean()
368
+ theta = f0_mean + self.theta
369
  else:
370
  theta = self.theta
 
371
  freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
372
  self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
373
+
374
+ if "rot2" in self.debug and self.counter % 100 == 0:
375
+ print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
376
 
377
  freqs = t[:, None] * freqs[None, :]
378
+ if self.radii and f0 is not None:
379
+ radius = f0.to(device, dtype)
380
+ L = radius.shape[0]
381
+ if L != ctx:
382
+ F = L / ctx
383
+ idx = torch.arange(ctx, device=f0.device)
384
+ idx = (idx * F).long().clamp(0, L - 1)
385
+ radius = radius[idx]
386
+ radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
387
  else:
388
+ radius = torch.ones_like(freqs)
389
+ freqs = torch.polar(radius, freqs)
390
+
391
+ if "rot3" in self.debug and self.counter % 100 == 0:
392
+ print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
+ if "theta" in self.debug and self.counter % 100 == 0:
395
+ if self.last_theta is None or abs(self.last_theta - theta.item()) > 1.0:
396
+ self.last_theta = theta.item()
397
+ print(f"[Theta] {self.last_theta:.2f}")
398
+
399
+ self.counter += 1
400
  return freqs.unsqueeze(0)
401
 
402
  @staticmethod
 
412
  x1 = x1.view(orig_shape)
413
  return torch.cat([x1.type_as(x), x2], dim=-1)
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  class MultiheadA(nn.Module):
416
+ _seen = set()
417
  rbf = False
418
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
419
  zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
 
423
  self.head = head
424
  self.head_dim = dims // head
425
  self.debug = debug
426
+ self.counter = 0
427
 
428
  self.q = Linear(dims, dims).to(device, dtype)
429
  self.k = Linear(dims, dims, bias=False).to(device, dtype)
 
459
  dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
460
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
461
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
462
+
463
  def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
464
  x = x.to(device, dtype)
465
  if xa is not None:
 
511
  w = F.softmax(qk, dim=-1).to(q.dtype)
512
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
513
 
514
+ if "multihead" in self.debug and self.counter % 100 == 0:
515
  print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
516
+ self.counter += 1
517
  return self.o(wv), qk.detach()
518
 
519
  class t_gate(nn.Module):
 
583
  self.cross_attn = cross_attn
584
  self.features = features
585
  self.debug = debug
586
+ self.counter = 0
587
  self.dropout = 0.01
588
 
589
  self.t_gate = tgate
 
650
  else:
651
  x = x + mlp_out
652
 
653
+ if "residual" in self.debug and self.counter % 100 == 0:
654
+ print(f"Step {self.counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
655
  if self.t_gate:
656
+ print(f"Step {self.counter}: Using t_gate: {self.t_gate}")
657
  elif self.m_gate:
658
+ print(f"Step {self.counter}: Using m_gate: {self.m_gate}")
659
  elif self.c_gate:
660
+ print(f"Step {self.counter}: Using c_gate: {self.c_gate}")
661
  else:
662
+ print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
663
+ self.counter += 1
664
+
665
  return x
666
 
667
  class FEncoder(nn.Module):
 
832
  self.ctx = ctx
833
  self.head_dim = dims // head
834
  self.debug = debug
835
+ self.counter = 0
836
  self.features = features
837
  self.dropout = 0.01
838
 
 
860
  "envelope": nn.ModuleList(
861
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
862
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
863
+ for _ in range(layer)] if "envelope" in features else None
864
  ),
865
  "phase": nn.ModuleList(
866
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
867
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
868
  for _ in range(layer)] if "phase" in features else None
869
+ )
870
+ })
871
 
872
+ def forward(self, enc, order=None, layer="encoder"):
873
  enc = dict_to(enc, device, dtype)
874
+
875
+ if self.counter < 1:
876
  s = enc.get("spectrogram")
877
  w = enc.get("waveform")
878
  p = default(enc.get("pitch"), enc.get("f0"))
879
  plot_waveform(x=s, w=w, p=p, hop_length=128)
880
 
881
+ if order is None:
882
+ order = self.features
883
+
884
+ out = {}
885
+ out.update(enc)
886
 
887
+ for f in order:
888
  if f in enc and f in self.blocks:
889
  x = enc[f]
890
  for block in self.blocks[f]:
891
  x = block(x, enc=enc, layer=layer)
892
+ out[f] = x
893
 
894
+ if "encoder" in self.debug and self.counter % 100 == 0:
895
+ shapes = {k: v.shape for k, v in enc.items()}
896
+ print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}, order: {order}")
897
+ self.counter += 1
898
+ return out
 
899
 
900
  class TextDecoder(nn.Module):
901
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
 
907
  self.head = head
908
  self.head_dim = dims // head
909
  self.debug = debug
910
+ self.counter = 0
911
  self.dropout = 0.01
912
  self.sequential = sequential
913
  self.features = features
 
931
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
932
  self.register_buffer("mask", mask, persistent=False)
933
 
934
+ def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
935
+ enc = dict_to(enc, device, dtype)
936
  x = x.to(device)
937
  bln = self.blend
938
 
 
947
  x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
948
 
949
  for f in order:
950
+ if f in enc:
951
+ xa = enc[f]
952
  for block in self.blocks[f]:
953
+ out = block(x=x, xa=xa, mask=None, enc=enc, layer=layer)
954
 
955
  a = torch.sigmoid(bln[f])
956
  x = a * out + (1 - a) * x
957
 
958
+ if "decoder" in self.debug and self.counter % 100 == 0:
959
+ print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
960
+ self.counter += 1
961
 
962
  x = self.ln_dec(x)
963
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
 
997
  def update_base(self, f0):
998
  for name, module in self.encoder.named_modules():
999
  if isinstance(module, (rotary)):
 
1000
  module.update_base(f0)
1001
+ module.return_f0(f0)
1002
 
1003
  for name, module in self.decoder.named_modules():
1004
  if isinstance(module, (rotary)):
 
1005
  module.update_base(f0)
1006
+ module.return_f0(f0)
1007
+
1008
  def set_alignment_head(self, dump: bytes):
1009
  array = np.frombuffer(
1010
  gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
 
1046
  if f0 is not None:
1047
  encoder_inputs["f0"] = f0
1048
 
1049
+
1050
+ # if f0 is not None:
1051
+ # f0 = f0.squeeze(0)
1052
+ # self.update_base(f0)
1053
 
1054
  encoder_outputs = self.encoder(encoder_inputs)
1055
  logits = self.decoder(input_ids, encoder_outputs)
 
1147
  print(f"DECODER GRAD: {name} = {norm:.6f}")
1148
  return None
1149
 
1150
+ def resetcounter(self):
1151
+ self.counter = 0
1152
  print("Counter reset to 0.")
1153
 
1154
  metric = evaluate.load(path="wer")
 
 
 
 
 
 
 
 
 
 
 
1155
 
1156
  @dataclass
1157
  class DataCollator:
 
1397
  f0, t = pw.dio(wav_np, sampling_rate,
1398
  frame_period=hop_length/sampling_rate*1000)
1399
  f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1400
+ f0 = torch.from_numpy(f0)
1401
  batch["pitch"] = f0
1402
 
1403
  if frequency:
1404
  wav_np = wav.numpy().astype(np.float64)
1405
  f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000)
1406
  f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1407
+ f0 = torch.from_numpy(f0)
1408
  batch["f0"] = f0
1409
 
1410
  if spectrogram and waveforms and pitch:
 
1619
  gradient_accumulation_steps=1,
1620
  eval_accumulation_steps=1,
1621
  eval_strategy="steps",
1622
+ save_strategy="steps",
 
 
1623
  max_steps=max_steps,
1624
  save_steps=save_steps,
1625
  eval_steps=eval_steps,
 
1693
  text_dims=512,
1694
  text_idx=4,
1695
  act="swish",
1696
+ debug={},
1697
  cross_attn=True,
1698
  features = ["spectrogram"]
1699
  )
1700
 
1701
+ sanity_check = False
1702
+
1703
  training_args = sanity(sanity_check)
1704
  dataset_config = {
1705
  "spectrogram": True,
 
1723
  "normalized": False}
1724
 
1725
  model = create_model(param)
1726
+
1727
  global global_model
1728
  global_model = model
1729
 
 
1752
  if __name__ == "__main__":
1753
  main()
1754
 
1755
+
1756
+
1757
+