Sin2pi commited on
Commit
1db8188
·
verified ·
1 Parent(s): 7544e99

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +183 -192
model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pyworld as pw
2
  import os
3
  import math
@@ -34,20 +35,6 @@ dtype = torch.float32
34
  warnings.filterwarnings("ignore")
35
  logging.basicConfig(level=logging.ERROR)
36
 
37
- from rich.traceback import install
38
- install(show_locals=True)
39
-
40
- import pretty_errors
41
- pretty_errors.configure(
42
- separator_character = '*',
43
- filename_display = pretty_errors.FILENAME_EXTENDED,
44
- line_number_first = True,
45
- display_link = True,
46
- lines_before = 5,
47
- lines_after = 2,
48
- line_color = pretty_errors.RED + '> ' + pretty_errors.default_config.line_color,
49
- code_color = ' ' + pretty_errors.default_config.line_color,
50
- )
51
 
52
  extractor = None
53
  tokenizer = None
@@ -256,93 +243,12 @@ def get_dtype():
256
  def tox():
257
  return {"device": get_device(), "dtype": get_dtype()}
258
 
259
- def sinusoids(length, channels, max_timescale=10000):
260
- assert channels % 2 == 0
261
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
262
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
263
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
264
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
265
-
266
- def align_f0(f0, ctx):
267
- b, l = f0.shape
268
- if l == ctx:
269
- return f0.squeeze(0).float()
270
- frames_per_token = l / ctx
271
- idx = torch.arange(ctx, device=device, dtype=dtype)
272
- src_idx = (idx * frames_per_token).long().clamp(0, l-1)
273
- batch_idx = torch.arange(b, device=device, dtype=dtype).unsqueeze(1)
274
- f0 = f0[batch_idx, src_idx]
275
- return f0.squeeze(0).float()
276
-
277
- def align_f0(f0, target_length, method='nearest', device=device, dtype=dtype):
278
- if device is None:
279
- device = f0.device
280
- if dtype is None:
281
- dtype = f0.dtype
282
- original_shape = f0.shape
283
- squeeze_batch = False
284
- reshape_back = None
285
-
286
- if f0.dim() == 1:
287
- f0 = f0.unsqueeze(0)
288
- squeeze_batch = True
289
- elif f0.dim() == 2:
290
- pass
291
- elif f0.dim() == 3:
292
- batch_size, ctx, length = f0.shape
293
- f0 = f0.view(-1, length)
294
- reshape_back = (batch_size, ctx)
295
- else:
296
- raise ValueError(f"F0 tensor must be 1D, 2D, or 3D, got {f0.dim()}D")
297
- batch_size, current_length = f0.shape
298
- if current_length == target_length:
299
- result = f0
300
- elif method == 'nearest':
301
- frames_per_token = current_length / target_length
302
- target_indices = torch.arange(target_length, device=device, dtype=torch.float32)
303
- source_indices = (target_indices * frames_per_token).long().clamp(0, current_length - 1)
304
- batch_indices = torch.arange(batch_size, device=device, dtype=torch.long).unsqueeze(1)
305
- result = f0[batch_indices, source_indices]
306
- else:
307
- import torch.nn.functional as F
308
- f0_for_interp = f0.unsqueeze(1)
309
- mode_map = {'linear': 'linear', 'cubic': 'bicubic'}
310
- if method not in mode_map:
311
- raise ValueError(f"Method '{method}' not supported. Use 'nearest', 'linear', or 'cubic'")
312
-
313
- result = F.interpolate(
314
- f0_for_interp.float(),
315
- size=target_length,
316
- mode=mode_map[method],
317
- align_corners=False
318
- ).squeeze(1)
319
-
320
- if reshape_back is not None:
321
- result = result.view(reshape_back[0], reshape_back[1], target_length)
322
- elif squeeze_batch:
323
- result = result.squeeze(0)
324
- return result.to(dtype)
325
-
326
- # def update_base(self, f0):
327
- # f0 = f0.to(device, dtype)
328
- # f0_mean = f0.mean() + 1e-8
329
-
330
- # # Standard RoPE calculation (keep this)
331
- # theta_freqs = 1.0 / (f0_mean ** (torch.arange(0, self.dim, 2, device=device, dtype=dtype)[:(self.dim // 2)].float() / self.dim))
332
-
333
- # # Direct f0-adapted mel scale (new part)
334
- # center_freq = f0_mean
335
- # min_freq = center_freq * 0.25 # Lower bound
336
- # max_freq = center_freq * 4.0 # Upper bound
337
-
338
- # # Direct mel calculation centered on f0
339
- # mel_min = 2595 * torch.log10(1 + min_freq/700)
340
- # mel_max = 2595 * torch.log10(1 + max_freq/700)
341
- # mel_freqs = 700 * (torch.pow(10, torch.linspace(mel_min, mel_max, self.dim//2, device=device, dtype=dtype) / 2595) - 1) / 1000
342
-
343
- # # Use a weighted combination
344
- # self.inv_freq.data.copy_(0.5 * theta_freqs + 0.5 * mel_freqs)
345
- # self.theta.data.copy_(f0_mean)
346
 
347
  class rotary(nn.Module):
348
  def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [],
@@ -366,15 +272,23 @@ class rotary(nn.Module):
366
  theta = torch.tensor(theta, device=device, dtype=dtype)
367
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
368
  self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
369
- inv_freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
370
  self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
371
 
372
  def update_base(self, f0):
373
  f0 = f0.squeeze(0).to(device, dtype)
374
  theta = f0.mean() + 1e-8
375
- inv_freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
376
  self.inv_freq.data.copy_(inv_freq)
377
- self.theta.data.copy_(theta)
 
 
 
 
 
 
 
 
378
 
379
  def get_pitch_bias(self, f0):
380
  if f0 is None:
@@ -386,14 +300,26 @@ class rotary(nn.Module):
386
  return f0_sim.unsqueeze(0).unsqueeze(0)
387
 
388
  def f0proj(self, f0):
389
- self.f0_proj = nn.Linear(1, self.head_dim // 2, device=device, dtype=dtype)
390
- f0 = f0.to(device, dtype)
391
- f0 = self.f0_proj(f0.unsqueeze(-1))
392
- return f0.to(device=device, dtype=dtype)
 
 
 
 
393
 
394
- def align_f0(self, f0, ctx):
 
395
  f0 = self.f0proj(f0)
396
- print(f"Aligning f0 with context: {ctx}, f0 shape: {f0}")
 
 
 
 
 
 
 
397
  if f0.dim() == 1:
398
  length = f0.shape[0]
399
  if length == ctx:
@@ -410,64 +336,48 @@ class rotary(nn.Module):
410
  idx = torch.arange(ctx, device=f0.device)
411
  idx = (idx * frames).long().clamp(0, length - 1)
412
  return f0[idx, :]
413
-
414
- # def orthogonal(self, dims, i, j, theta):
415
- # R = torch.eye(dims).to(theta.device)
416
- # R[i, i] = torch.cos(theta)
417
- # R[i, j] = -torch.sin(theta)
418
- # R[j, i] = torch.sin(theta)
419
- # R[j, j] = torch.cos(theta)
420
- # R = torch.eye(dims).to(theta.device) - 2 * torch.outer(R, R) / torch.dot(R, R)
421
- # return R
422
-
423
- # def orthogonal_regularization_term(self):
424
- # loss = torch.tensor(0.0, device=self.r_matrix.device)
425
- # if self.r_matrix.requires_grad:
426
- # product = torch.matmul(self.r_matrix, self.r_matrix.t())
427
- # identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)
428
- # loss = ((product - identity) ** 2).sum()
429
- # return self.orthogonal_reg_weight * loss
430
-
431
- def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
432
- f0 = enc.get("f0", None) if enc is not None else None
433
-
434
  if isinstance(x, int):
435
  ctx = x
436
  elif isinstance(x, torch.Tensor) and x.ndim == 3:
437
  batch, ctx, dims = x.shape
438
  else:
439
  batch, head, ctx, head_dim = x.shape
440
-
441
  t = torch.arange(ctx, device=device, dtype=dtype)
442
-
443
- if f0 is not None:
444
- freqs = self.inv_freq
445
- f0_mean = f0.mean()
446
- theta = f0_mean + 1e-8
447
- freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
448
-
449
- if "rotary1" in self.debug:
450
- print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
451
- else:
452
- freqs = self.inv_freq
453
  freqs = t[:, None] * freqs[None, :]
454
 
455
- # sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
456
-
457
  if self.radii:
458
- if f0 is not None:
459
- radius = self.align_f0(f0, ctx)
460
- else:
461
- radius = freqs
462
- if "rotary2" in self.debug:
463
- print(f"{layer} radius: {radius} ctx: {ctx}")
464
  else:
465
  radius = freqs
466
- freqs = torch.polar(torch.ones_like(radius), freqs.unsqueeze(0))
467
 
468
  if "rotary3" in self.debug:
469
- print(f"{layer} radius: {f0.shape if f0 is not None else None} ctx: {ctx}")
470
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  self._counter += 1
472
  return freqs.unsqueeze(0)
473
 
@@ -484,8 +394,89 @@ class rotary(nn.Module):
484
  x1 = x1.view(orig_shape)
485
  return torch.cat([x1.type_as(x), x2], dim=-1)
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  class MultiheadA(nn.Module):
488
- _seen = set()
489
  rbf = False
490
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
491
  zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
@@ -531,7 +522,7 @@ class MultiheadA(nn.Module):
531
  dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
532
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
533
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
534
-
535
  def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
536
  x = x.to(device, dtype)
537
  if xa is not None:
@@ -733,7 +724,6 @@ class Residual(nn.Module):
733
  else:
734
  print(f"Step {self._counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
735
  self._counter += 1
736
-
737
  return x
738
 
739
  class FEncoder(nn.Module):
@@ -930,42 +920,40 @@ class AudioEncoder(nn.Module):
930
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
931
  ),
932
  "envelope": nn.ModuleList(
933
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
934
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
935
- for _ in range(layer)] if "envelope" in features else None
936
- ),
937
- "phase": nn.ModuleList(
938
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
939
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
940
- for _ in range(layer)] if "phase" in features else None
941
- )
942
- })
943
 
944
  def forward(self, enc, layer="encoder"):
945
  enc = dict_to(enc, device, dtype)
946
-
947
  if self._counter < 1:
948
  s = enc.get("spectrogram")
949
  w = enc.get("waveform")
950
  p = default(enc.get("pitch"), enc.get("f0"))
951
  plot_waveform(x=s, w=w, p=p, hop_length=128)
952
 
953
- out = {}
954
- out.update(enc)
955
 
956
  for f in self.features:
957
  if f in enc and f in self.blocks:
958
  x = enc[f]
959
  for block in self.blocks[f]:
960
  x = block(x, enc=enc, layer=layer)
961
- out[f] = x
962
 
963
  if "encoder" in self.debug and self._counter % 100 == 0:
964
  names = list(x.keys())
965
  shapes = {k: v.shape for k, v in x.items()}
966
  print(f"Step {self._counter}: mode: {names}: shapes: {shapes}")
967
  self._counter += 1
968
- return out
969
 
970
  class TextDecoder(nn.Module):
971
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
@@ -1001,8 +989,8 @@ class TextDecoder(nn.Module):
1001
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
1002
  self.register_buffer("mask", mask, persistent=False)
1003
 
1004
- def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
1005
- enc = dict_to(enc, device, dtype)
1006
  x = x.to(device)
1007
  bln = self.blend
1008
 
@@ -1017,10 +1005,10 @@ class TextDecoder(nn.Module):
1017
  x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
1018
 
1019
  for f in order:
1020
- if f in enc:
1021
- xa = enc[f]
1022
  for block in self.blocks[f]:
1023
- out = block(x=x, xa=xa, mask=None, enc=enc, layer=layer)
1024
 
1025
  a = torch.sigmoid(bln[f])
1026
  x = a * out + (1 - a) * x
@@ -1068,7 +1056,13 @@ class Echo(nn.Module):
1068
  for name, module in self.encoder.named_modules():
1069
  if isinstance(module, (rotary)):
1070
  module.update_base(f0)
 
1071
 
 
 
 
 
 
1072
  def set_alignment_head(self, dump: bytes):
1073
  array = np.frombuffer(
1074
  gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
@@ -1109,12 +1103,10 @@ class Echo(nn.Module):
1109
  encoder_inputs["phase"] = phase
1110
  if f0 is not None:
1111
  encoder_inputs["f0"] = f0
1112
- if f0d is not None:
1113
- encoder_inputs["f0d"] = f0d
1114
-
1115
  if f0 is not None:
1116
  f0 = f0.squeeze(0)
1117
- self.update_base(f0)
1118
 
1119
  encoder_outputs = self.encoder(encoder_inputs)
1120
  logits = self.decoder(input_ids, encoder_outputs)
@@ -1695,7 +1687,9 @@ def get_training_args(
1695
  gradient_accumulation_steps=1,
1696
  eval_accumulation_steps=1,
1697
  eval_strategy="steps",
1698
- save_strategy="steps",
 
 
1699
  max_steps=max_steps,
1700
  save_steps=save_steps,
1701
  eval_steps=eval_steps,
@@ -1769,19 +1763,20 @@ def main():
1769
  text_dims=512,
1770
  text_idx=4,
1771
  act="swish",
1772
- debug={"rotary"},
1773
  cross_attn=True,
1774
  features = ["spectrogram"]
1775
  )
1776
 
1777
- sanity_check = False
 
1778
  training_args = sanity(sanity_check)
1779
  dataset_config = {
1780
  "spectrogram": True,
1781
  "waveforms": False,
1782
  "pitch": False,
1783
  "downsamples": False,
1784
- "frequency": False,
1785
  "hilbert": False,
1786
  "hop_length": 128,
1787
  "fmin": 150,
@@ -1798,7 +1793,7 @@ def main():
1798
  "normalized": False}
1799
 
1800
  model = create_model(param)
1801
-
1802
  global global_model
1803
  global_model = model
1804
 
@@ -1827,9 +1822,5 @@ def main():
1827
  if __name__ == "__main__":
1828
  main()
1829
 
1830
- from tensorboard import program
1831
- log_dir = "./output/logs"
1832
- tb = program.TensorBoard()
1833
- tb.configure(argv=[None, '--logdir', log_dir])
1834
- url = tb.launch()
1835
- print(f"TensorBoard started at {url}")
 
1
+
2
  import pyworld as pw
3
  import os
4
  import math
 
35
  warnings.filterwarnings("ignore")
36
  logging.basicConfig(level=logging.ERROR)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  extractor = None
40
  tokenizer = None
 
243
  def tox():
244
  return {"device": get_device(), "dtype": get_dtype()}
245
 
246
+ def sinusoids(length, num_chan, max=10000):
247
+ assert num_chan % 2 == 0
248
+ time_x = np.log(max) / (num_chan // 2 - 1)
249
+ inv_time = torch.exp(-time_x * torch.arange(num_chan // 2))
250
+ s_time = torch.arange(length)[:, np.newaxis] * inv_time[np.newaxis, :]
251
+ return torch.cat([torch.sin(s_time), torch.cos(s_time)], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  class rotary(nn.Module):
254
  def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [],
 
272
  theta = torch.tensor(theta, device=device, dtype=dtype)
273
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
274
  self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
275
+ inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
276
  self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
277
 
278
  def update_base(self, f0):
279
  f0 = f0.squeeze(0).to(device, dtype)
280
  theta = f0.mean() + 1e-8
281
+ inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
282
  self.inv_freq.data.copy_(inv_freq)
283
+ self.theta.data.copy_(theta)
284
+
285
+ def return_f0(self, f0=None):
286
+ if f0 is not None:
287
+ self.f0 = f0
288
+ return f0.squeeze(0).to(device, dtype)
289
+ elif hasattr(self, 'f0') and self.f0 is not None:
290
+ return self.f0.squeeze(0).to(device, dtype)
291
+ return None
292
 
293
  def get_pitch_bias(self, f0):
294
  if f0 is None:
 
300
  return f0_sim.unsqueeze(0).unsqueeze(0)
301
 
302
  def f0proj(self, f0):
303
+ if f0.ndim == 3:
304
+ f0 = f0.squeeze(0)
305
+ self.f0_proj = nn.Linear(1, self.head_dim // 2, device=device, dtype=dtype)
306
+ f0 = f0.to(device, dtype)
307
+ f0 = self.f0_proj(f0.unsqueeze(-1))
308
+ if f0.ndim == 3:
309
+ f0 = f0.squeeze(0)
310
+ return f0.to(device=device, dtype=dtype)
311
 
312
+ def align_f0(self, ctx):
313
+ f0 = self.return_f0()
314
  f0 = self.f0proj(f0)
315
+ if f0.dim() == 3:
316
+ batch, length, dims = f0.shape
317
+ if length == ctx:
318
+ return f0
319
+ frames = length / ctx
320
+ idx = torch.arange(ctx, device=f0.device)
321
+ idx = (idx * frames).long().clamp(0, length - 1)
322
+ return f0[:, idx, :]
323
  if f0.dim() == 1:
324
  length = f0.shape[0]
325
  if length == ctx:
 
336
  idx = torch.arange(ctx, device=f0.device)
337
  idx = (idx * frames).long().clamp(0, length - 1)
338
  return f0[idx, :]
339
+
340
+ def forward(self, x=None, f0=None, enc=None, layer=None, input_type="audio") -> Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  if isinstance(x, int):
342
  ctx = x
343
  elif isinstance(x, torch.Tensor) and x.ndim == 3:
344
  batch, ctx, dims = x.shape
345
  else:
346
  batch, head, ctx, head_dim = x.shape
 
347
  t = torch.arange(ctx, device=device, dtype=dtype)
348
+ freqs = self.inv_freq
 
 
 
 
 
 
 
 
 
 
349
  freqs = t[:, None] * freqs[None, :]
350
 
 
 
351
  if self.radii:
352
+ radius = self.align_f0(ctx)
353
+ if "rotary2" in self.debug:
354
+ print(f"{layer} radius: {radius} ctx: {ctx}")
 
 
 
355
  else:
356
  radius = freqs
357
+ freqs = torch.polar(torch.ones_like(radius), freqs)
358
 
359
  if "rotary3" in self.debug:
360
+ print(f"{layer} count {self._counter} f0: {f0.shape if f0 is not None else None} freqs: {freqs.shape} radius: {radius.shape} ctx: {ctx}")
361
+ print(f"freqs mean: {freqs.mean():.2f} inv_freq mean: {self.inv_freq.mean():.2f} theta: {self.theta.item():.2f} radius mean: {radius.mean():.2f} radius shape: {radius.shape} ctx: {ctx}")
362
+
363
+ if "rotary_detail" in self.debug:
364
+ print(f"\n==== Detailed RoPE Analysis ====")
365
+ print(f"Layer: {layer}, Context Length: {ctx}")
366
+ print(f"F0 stats: mean={self.theta.item():.2f}")
367
+ print(f"inv_freq range: [{self.inv_freq.min().item():.4f}, {self.inv_freq.max().item():.4f}]")
368
+
369
+ if self.radii:
370
+ print(f"Radius Shape: {radius.shape}, Mean: {radius.mean().item():.4f}")
371
+ print(f"Radius[0]: {radius[0][:5].cpu().numpy()}")
372
+ print(f"Radius[mid]: {radius[ctx//2][:5].cpu().numpy()}")
373
+ print(f"Radius[end]: {radius[-1][:5].cpu().numpy()}")
374
+
375
+ print(f"Final freqs shape: {freqs.shape}")
376
+ print(f"Freqs[0]: {freqs[0][:5].cpu().numpy()}")
377
+ print(f"Freqs[mid]: {freqs[ctx//2][:5].cpu().numpy()}")
378
+ print(f"Freqs[end]: {freqs[-1][:5].cpu().numpy()}")
379
+ print("================================\n")
380
+
381
  self._counter += 1
382
  return freqs.unsqueeze(0)
383
 
 
394
  x1 = x1.view(orig_shape)
395
  return torch.cat([x1.type_as(x), x2], dim=-1)
396
 
397
+
398
+ # class FocusA(nn.Module):
399
+ # def __init__(self, dims, head, max_dist=None, win_size=32, max_span=32, temp_scale=0.01, iterations=2):
400
+ # super().__init__()
401
+ # self.dims = dims
402
+ # self.head = head
403
+ # self.max_dist = max_dist
404
+ # self.win_size = win_size
405
+ # self.max_span = max_span
406
+ # self.temp_scale = temp_scale
407
+ # self.iterations = iterations
408
+
409
+ # self.span_predictor = nn.Linear(dims, 1)
410
+
411
+ # self.attn_l = nn.MultiheadAttention(embed_dim=dims, num_heads=head)
412
+ # self.attn_g = nn.MultiheadAttention(embed_dim=dims, num_heads=head)
413
+
414
+ # self.ln_l = nn.LayerNorm(dims)
415
+ # self.ln_g = nn.LayerNorm(dims)
416
+ # self.projection = nn.Linear(2 * dims, dims)
417
+
418
+ # def _focus(self, que, key, val, span_scale):
419
+ # attn_out = que
420
+ # span_len = max(1, int(self.max_span * span_scale.mean().item()))
421
+ # span_len = min(span_len, que.size(1), key.size(1), val.size(1))
422
+
423
+ # for _ in range(self.iterations):
424
+ # temp = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
425
+ # q = que / temp
426
+ # k = key / temp
427
+ # v = val / temp
428
+ # output, _ = self.attn_l(q, k, v)
429
+ # que = que + output
430
+ # return que
431
+
432
+ # def _window(self, x, win_size, span_len, span_scale):
433
+ # batch_size, ctx, dims = x.size()
434
+ # output = torch.zeros_like(x)
435
+
436
+ # for i in range(0, ctx, win_size // 2):
437
+ # end = min(i + win_size, ctx)
438
+ # que = x[:, i:end]
439
+ # start = max(0, i - span_len)
440
+ # end_con = min(i + win_size + span_len, ctx)
441
+ # con = x[:, start:end_con]
442
+ # win_out = self._focus(que, con, con, span_scale)
443
+
444
+ # if i > 0:
445
+ # start_over = i
446
+ # end_over = min(i + win_size // 2, ctx)
447
+ # blend = torch.linspace(0, 1, end_over - start_over).view(1, -1, 1)
448
+ # blend = blend.to(x.device)
449
+ # output[:, start_over:end_over] = (
450
+ # (1 - blend) * output[:, start_over:end_over] +
451
+ # blend * win_out[:, :end_over-start_over])
452
+ # if end_over < end:
453
+ # output[:, end_over:end] = win_out[:, end_over-i:end-i]
454
+ # else:
455
+ # output[:, i:end] = win_out
456
+ # return output
457
+
458
+ # def forward(self, x, mask=None):
459
+ # l_x = self.ln_l(x)
460
+ # g_x = self.ln_g(x)
461
+ # g_out, g_attn = self.attn_g(g_x, g_x, g_x, need_weights=True)
462
+ # g_focus = g_attn.sum(dim=1)
463
+ # f_score = g_focus.max(dim=-1)[0]
464
+ # b_scale = torch.sigmoid(self.span_predictor(x.mean(dim=1)))
465
+ # var = (f_score - f_score.mean(dim=1, keepdim=True)).abs()
466
+ # a_span = b_scale * (1.0 + 0.5 * var.mean(dim=1, keepdim=True))
467
+
468
+ # l_out = self._window(
469
+ # l_x,
470
+ # win_size=self.win_size,
471
+ # span_len=max(1, int(self.max_span * a_span.mean().item())),
472
+ # span_scale=a_span
473
+ # )
474
+
475
+ # combined = torch.cat([l_out, g_out], dim=-1)
476
+ # return self.projection(combined)
477
+
478
  class MultiheadA(nn.Module):
479
+
480
  rbf = False
481
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
482
  zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
 
522
  dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
523
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
524
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
525
+
526
  def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
527
  x = x.to(device, dtype)
528
  if xa is not None:
 
724
  else:
725
  print(f"Step {self._counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
726
  self._counter += 1
 
727
  return x
728
 
729
  class FEncoder(nn.Module):
 
920
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
921
  ),
922
  "envelope": nn.ModuleList(
923
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
924
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
925
+ for _ in range(layer)] if "envelope" in features else None
926
+ ),
927
+ "phase": nn.ModuleList(
928
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
929
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
930
+ for _ in range(layer)] if "phase" in features else None
931
+ )})
 
932
 
933
  def forward(self, enc, layer="encoder"):
934
  enc = dict_to(enc, device, dtype)
935
+
936
  if self._counter < 1:
937
  s = enc.get("spectrogram")
938
  w = enc.get("waveform")
939
  p = default(enc.get("pitch"), enc.get("f0"))
940
  plot_waveform(x=s, w=w, p=p, hop_length=128)
941
 
942
+ xa = {}
 
943
 
944
  for f in self.features:
945
  if f in enc and f in self.blocks:
946
  x = enc[f]
947
  for block in self.blocks[f]:
948
  x = block(x, enc=enc, layer=layer)
949
+ xa[f] = x
950
 
951
  if "encoder" in self.debug and self._counter % 100 == 0:
952
  names = list(x.keys())
953
  shapes = {k: v.shape for k, v in x.items()}
954
  print(f"Step {self._counter}: mode: {names}: shapes: {shapes}")
955
  self._counter += 1
956
+ return xa
957
 
958
  class TextDecoder(nn.Module):
959
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
 
989
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
990
  self.register_buffer("mask", mask, persistent=False)
991
 
992
+ def forward(self, x, xa, enc=None, order=None, layer='decoder') -> Tensor:
993
+ xa = dict_to(xa, device, dtype)
994
  x = x.to(device)
995
  bln = self.blend
996
 
 
1005
  x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
1006
 
1007
  for f in order:
1008
+ if f in xa:
1009
+ ax = xa[f]
1010
  for block in self.blocks[f]:
1011
+ out = block(x=x, xa=ax, mask=None, enc=enc, layer=layer)
1012
 
1013
  a = torch.sigmoid(bln[f])
1014
  x = a * out + (1 - a) * x
 
1056
  for name, module in self.encoder.named_modules():
1057
  if isinstance(module, (rotary)):
1058
  module.update_base(f0)
1059
+ module.return_f0(f0)
1060
 
1061
+ for name, module in self.decoder.named_modules():
1062
+ if isinstance(module, (rotary)):
1063
+ module.update_base(f0)
1064
+ module.return_f0(f0)
1065
+
1066
  def set_alignment_head(self, dump: bytes):
1067
  array = np.frombuffer(
1068
  gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
 
1103
  encoder_inputs["phase"] = phase
1104
  if f0 is not None:
1105
  encoder_inputs["f0"] = f0
1106
+
 
 
1107
  if f0 is not None:
1108
  f0 = f0.squeeze(0)
1109
+ self.update_base(f0)
1110
 
1111
  encoder_outputs = self.encoder(encoder_inputs)
1112
  logits = self.decoder(input_ids, encoder_outputs)
 
1687
  gradient_accumulation_steps=1,
1688
  eval_accumulation_steps=1,
1689
  eval_strategy="steps",
1690
+ save_strategy="no",
1691
+ include_tokens_per_second=True,
1692
+ include_num_input_tokens_seen=True,
1693
  max_steps=max_steps,
1694
  save_steps=save_steps,
1695
  eval_steps=eval_steps,
 
1763
  text_dims=512,
1764
  text_idx=4,
1765
  act="swish",
1766
+ debug={"rotary_detail"},
1767
  cross_attn=True,
1768
  features = ["spectrogram"]
1769
  )
1770
 
1771
+ sanity_check = True
1772
+
1773
  training_args = sanity(sanity_check)
1774
  dataset_config = {
1775
  "spectrogram": True,
1776
  "waveforms": False,
1777
  "pitch": False,
1778
  "downsamples": False,
1779
+ "frequency": True,
1780
  "hilbert": False,
1781
  "hop_length": 128,
1782
  "fmin": 150,
 
1793
  "normalized": False}
1794
 
1795
  model = create_model(param)
1796
+
1797
  global global_model
1798
  global_model = model
1799
 
 
1822
  if __name__ == "__main__":
1823
  main()
1824
 
1825
+
1826
+