Sin2pi commited on
Commit
3d6f600
·
verified ·
1 Parent(s): eb7dee5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +121 -223
model.py CHANGED
@@ -1,7 +1,6 @@
1
-
2
  import pyworld as pw
3
  import os
4
- import math
5
  import warnings
6
  import logging
7
  import gzip
@@ -39,7 +38,6 @@ warnings.filterwarnings("ignore")
39
  logging.basicConfig(level=logging.ERROR)
40
  tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32}
41
 
42
-
43
  extractor = None
44
  tokenizer = None
45
  optimizer = None
@@ -149,7 +147,7 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
149
  axs[current_ax].set_ylabel("Mel Bin")
150
  axs[current_ax].set_xlim([0, max_time])
151
  axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
152
- fig.colorbar(im, ax=axs[current_ax])
153
  current_ax += 1
154
 
155
  if p is not None:
@@ -257,6 +255,17 @@ def sinusoids(length, channels, max_timescale=10000):
257
  scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
258
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
259
 
 
 
 
 
 
 
 
 
 
 
 
260
  class rotary(nn.Module):
261
  _seen = set()
262
  def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
@@ -264,33 +273,31 @@ class rotary(nn.Module):
264
  super().__init__()
265
 
266
  self.use_pbias = use_pbias
267
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
268
- self.device = device
269
- dtype = torch.float32
270
- self.dtype = dtype
271
  self.debug = debug
272
  self._counter = 0
273
  self.dims = dims
274
  self.max_ctx = max_ctx
275
  self.radii = radii
276
- max_freq = 10.0
277
- f0_scale_factor = 0.5
278
  self.learned_adaptation: bool = False
279
  pitch_scale = 1.0
280
-
 
281
  if self.learned_adaptation:
282
- self.f0_scale = nn.Parameter(torch.tensor(f0_scale_factor))
283
  else:
284
- self.register_buffer('f0_scale', torch.tensor(f0_scale_factor))
285
 
286
- self.theta = nn.Parameter(torch.tensor(float(theta)), requires_grad=learned_theta)
287
- self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale), requires_grad=learned_pitch)
288
  freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
289
- self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
 
290
 
291
- if radii:
292
- radius = 1.0
293
- self.radius = nn.Parameter(torch.ones(radius), requires_grad=learned_radius)
294
 
295
  def get_pitch_bias(self, f0):
296
  if f0 is None:
@@ -318,14 +325,26 @@ class rotary(nn.Module):
318
  rotary.fwd_sim = fwd_sim
319
 
320
  def align_f0(self, f0, ctx):
321
- bat, length = f0.shape
322
- if length == ctx:
323
- return f0
324
- frames = length / ctx
325
- idx = torch.arange(ctx, device=f0.device)
326
- idx = (idx * frames).long()
327
- batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
328
- return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  def scale_f0(self, f0):
331
  f0_min = f0.min(dim=1, keepdim=True)[0]
@@ -370,13 +389,14 @@ class rotary(nn.Module):
370
  freqs = torch.outer(positions, freqs)
371
  return torch.polar(torch.ones_like(freqs), freqs)
372
 
373
- def forward(self, x=None, f0=None, block=None) -> Tensor:
 
374
  if isinstance(x, int):
375
  ctx = x
376
  else:
377
  batch, ctx, dims = x.shape
378
  t = torch.arange(ctx, device=self.device).float()
379
-
380
  if self.learned_adaptation:
381
  freqs = self.get_f0_adapted_freqs(ctx, f0)
382
  x_complex = torch.view_as_complex(
@@ -394,16 +414,17 @@ class rotary(nn.Module):
394
 
395
  freqs = torch.einsum('i,j->ij', t, freqs)
396
  freqs = freqs.float()
397
-
398
  if self.radii and f0 is not None:
399
- if block == "decoder":
400
- dec = f0 + 1e-8
401
- else:
402
- enc = f0 + 1e-8
403
-
404
- radius = default(dec, enc)
405
-
406
- freqs = torch.polar(radius.unsqueeze(-1), freqs.unsqueeze(0))
 
407
  else:
408
  freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
409
  if "rotary" in self.debug:
@@ -415,9 +436,9 @@ class rotary(nn.Module):
415
  print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
416
  elif abs(self._prev_f0_theta - theta) > 100.0:
417
  print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
418
- print(f"f0_mean: {f0_mean} Hz, freqs: {freqs.shape}, ctx: {ctx}, dims: {self.dims}, block: {block}")
419
  if self.radii:
420
- print(f"radius: {radius} Hz, enc: {enc if enc else dec} Hz, ctx: {ctx}")
421
  self._prev_f0_theta = theta
422
  rotary._seen.add(key)
423
  self._counter += 1
@@ -558,7 +579,7 @@ class MultiheadA(nn.Module):
558
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
559
 
560
  def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None,
561
- return_attn: bool = False, f0: Tensor = None, block = None) -> tuple:
562
 
563
  batch, ctx, dims = x.shape
564
  scale = (self.dims // self.head) ** -0.25
@@ -569,8 +590,8 @@ class MultiheadA(nn.Module):
569
  v = self.v(z).to(x.dtype)
570
 
571
  if self.rotary_emb:
572
- qf = self.rope(q.size(1), f0=f0, block=block)
573
- kf = self.rope(k.size(1), f0=f0, block=block)
574
 
575
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
576
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -668,9 +689,6 @@ class c_gate(nn.Module):
668
  comb = torch.cat([s, w, p], dim=-1)
669
  return self.integ(comb)
670
 
671
-
672
-
673
-
674
  class Residual(nn.Module):
675
  _seen = set()
676
  def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
@@ -718,12 +736,12 @@ class Residual(nn.Module):
718
  if not any([t_gate, m_gate, c_gate]):
719
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
720
 
721
- def forward(self, x, xa=None, mask=None, f0=None, mode=None, block=None):
722
  bln = self.blend
723
- x = x + self.attna(self.lna(x), mask=mask, f0=f0, block=block)[0]
724
 
725
  if self.attnb and xa is not None:
726
- c = self.attnb(self.lnb(x), xa, f0=f0, mask=None, block=block)[0]
727
  b = torch.sigmoid(bln)
728
  x = b * x + (1 - b) * c
729
 
@@ -763,8 +781,6 @@ class Residual(nn.Module):
763
 
764
  return x
765
 
766
-
767
-
768
  class PEncoder(nn.Module):
769
  def __init__(self, input_dims, dims, head, layer, kernel_size, act):
770
  super().__init__()
@@ -780,7 +796,7 @@ class PEncoder(nn.Module):
780
  Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
781
  Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
782
 
783
- def forward(self, x, f0=None, block=None):
784
  x = self.encoder(x).permute(0, 2, 1)
785
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
786
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
@@ -809,7 +825,7 @@ class WEncoder(nn.Module):
809
  self.positional = lambda length: sinusoids(length, dims)
810
  self.norm = RMSNorm(dims)
811
 
812
- def forward(self, x, f0=None, block=None):
813
  x = self.downsample(x)
814
  x = self.encoder(x)
815
  x = x.permute(0, 2, 1)
@@ -836,132 +852,13 @@ class FEncoder(nn.Module):
836
  self.norm = RMSNorm(dims)
837
  self._norm = RMSNorm(dims)
838
 
839
- def forward(self, x, f0=None, block=None):
840
  x = self.encoder(x).permute(0, 2, 1)
841
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
842
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
843
  x = self._norm(x)
844
  return x
845
 
846
-
847
-
848
-
849
- class F0Encoder(nn.Module):
850
- def __init__(self, ctx, dims, head, layers, dropout, use_norm=True, wav = None):
851
- super().__init__()
852
- self.dims = dims
853
- self.layers = layers
854
- sampling_rate = 16000
855
- hop_length = 128
856
-
857
- self.inputection = nn.Linear(dims, dims)
858
-
859
-
860
- self.blocks = nn.ModuleList()
861
- for _ in range(layers):
862
- layer = rotary(
863
- dims=dims,
864
- debug=[],
865
- radii=False,
866
- learned_pitch=False,
867
- learned_freq=False,
868
- learned_theta=False,
869
- learned_radius=False,
870
- )
871
- self.blocks.append(layer)
872
-
873
- self.norm = nn.LayerNorm(dims) if use_norm else nn.Identity()
874
- self.outputection = nn.Linear(dims, dims)
875
-
876
- def extract_f0(self, wav, sampling_rate, hop_length):
877
- wav = wav.shape[1] if isinstance(wav, torch.Tensor) else wav
878
- wav_np = wav.numpy().astype(np.float64)
879
- f0, t = pw.dio(wav_np, sampling_rate,
880
- frame_period=hop_length/sampling_rate*1000)
881
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
882
- f0 = torch.from_numpy(f0).float()
883
- f0 = f0.unsqueeze(-1)
884
-
885
- def align_f0(self, f0, ctx):
886
- bat, f0_length = f0.shape
887
- if f0_length == ctx:
888
- return f0
889
- frames_per_token = f0_length / ctx
890
- idx = torch.arange(ctx, device=f0.device)
891
- idx = (idx * frames_per_token).long()
892
- batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
893
- return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
894
-
895
- def scale_f0(self, f0):
896
- f0_min = f0.min(dim=1, keepdim=True)[0]
897
- f0_max = f0.max(dim=1, keepdim=True)[0]
898
- denom = f0_max - f0_min + 1e-8
899
- normalized_f0 = (f0 - f0_min) / denom
900
- normalized_f0 = torch.clamp(normalized_f0, 0.0, 1.0)
901
- return normalized_f0
902
-
903
- def process_f0(f0, threshold=0.05):
904
- thresholded_f0 = torch.where(f0 < threshold, torch.zeros_like(f0), f0)
905
- return thresholded_f0
906
-
907
- def map_perceptual(self, f0_mean, theta=10000.0):
908
- if f0_mean >= theta:
909
- return torch.log(f0_mean / theta)
910
- else:
911
- return -torch.log(theta / f0_mean)
912
-
913
- def linear_map(self, freq, min_freq=40.0, max_freq=400.0, target_max=10000.0):
914
- mapped_freq = ((freq - min_freq) / (max_freq - min_freq)) * target_max
915
- return mapped_freq
916
-
917
- def log_map(self, freq, min_freq=40.0, max_freq=400.0, target_max=10000.0):
918
- log_freq = torch.log(freq)
919
- log_min_freq = torch.log(min_freq)
920
- log_max_freq = torch.log(max_freq)
921
- mapped_log_freq = ((log_freq - log_min_freq) / (log_max_freq - log_min_freq)) * torch.log(torch.tensor(target_max, device=self.device))
922
- return mapped_log_freq
923
-
924
- def get_f0_adapted_freqs(self, ctx, f0=None):
925
- f0_min: float = 40.0,
926
- f0_max: float = 400.0,
927
- base_freq: float = 1.0,
928
- positions = torch.arange(ctx, device=device, dtype=torch.float)
929
- freqs = base_freq.clone()
930
- if f0 is not None:
931
- f0_norm = torch.clamp((f0 - f0_min) / (f0_max - f0_min), 0.0, 1.0)
932
- freq_mod = torch.pow(torch.linspace(0.5, 1.5, self.dims//2, device=device),
933
- f0_norm.unsqueeze(-1) * self.f0_scale)
934
- freqs = freqs * freq_mod
935
- freqs = torch.outer(positions, freqs)
936
- return torch.polar(torch.ones_like(freqs), freqs)
937
-
938
- def get_freqs(self, f0=None, pitch_scale=1.0, learned=False, theta=10000.0):
939
- f0_mean=f0.mean()+1e-8
940
- pitch_scale = nn.Parameter(torch.tensor(pitch_scale), requires_grad=learned)
941
- theta=f0_mean*pitch_scale
942
- freqs = 1. / (theta ** (torch.arange(0, self.dims, 2)[:(self.dims // 2)].float() / self.dims))
943
- freqs = nn.Parameter(freqs, requires_grad = learned)
944
- return freqs
945
-
946
- def get_radii(self, f0=None, ctx=None, learned=False):
947
- f0 = f0 + 1e-8
948
- radius = self.align_f0(f0, ctx)
949
- radius = F.softplus(self.radius) * radius
950
- return radius
951
-
952
- def forward(self, x, f0, radius=None, freqs=None):
953
-
954
- radius = self.get_radii(f0=f0, ctx=x.shape[1], learned=False) if radius is None else radius
955
-
956
- f0 = self.inputection(f0)
957
-
958
- for block in self.blocks:
959
- x = x + block(x)
960
-
961
- x = self.norm(x)
962
- f0 = self.outputection(x)
963
- return f0
964
-
965
  class AudioEncoder(nn.Module):
966
  _seen = set()
967
  def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str],
@@ -984,6 +881,10 @@ class AudioEncoder(nn.Module):
984
  self.dropout = 0.01
985
  self.f0_rotary = f0_rotary
986
 
 
 
 
 
987
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
988
  "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
989
  act_fn = act_map.get(act, nn.GELU())
@@ -1005,7 +906,7 @@ class AudioEncoder(nn.Module):
1005
  "pitch": nn.ModuleList(
1006
  [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
1007
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
1008
- ),
1009
  "spec_envelope": nn.ModuleList(
1010
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1011
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug) for _ in range(layer)] if "spec_envelope" in features else None
@@ -1015,24 +916,31 @@ class AudioEncoder(nn.Module):
1015
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None),
1016
  })
1017
 
1018
- def forward(self, x, f0=None, block="encoder"):
 
 
 
 
1019
  if self._counter < 1:
1020
  s = x.get("spectrogram")
1021
  w = x.get("waveform")
1022
  p = f0 if f0 is not None else x.get("pitch")
1023
  plot_waveform(x=s, w=w, p=p, hop_length=128)
 
1024
  enc = {}
1025
- if self.f0_rotary:
1026
- f0 = f0 if f0 is not None else x.get("pitch")
1027
- else:
1028
- f0 = None
 
 
1029
  for y in self.features:
1030
  if y in x and y in self.blocks:
1031
  f = x[y]
1032
  for block in self.blocks[y]:
1033
- f = block(f, f0=f0, block=block)
1034
  enc[y] = f
1035
-
1036
  if "encoder" in self.debug and self._counter % 100 == 0:
1037
  names = list(x.keys())
1038
  shapes = {k: v.shape for k, v in x.items()}
@@ -1044,8 +952,6 @@ class AudioEncoder(nn.Module):
1044
  self._counter += 1
1045
  return enc
1046
 
1047
-
1048
-
1049
  class TextDecoder(nn.Module):
1050
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
1051
  debug: List[str], features: List[str], f0_rotary: bool = False, sequential=False):
@@ -1073,7 +979,7 @@ class TextDecoder(nn.Module):
1073
  self.token.weight[0].zero_()
1074
  self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
1075
 
1076
- self._blocks = nn.ModuleList([
1077
  Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
1078
  for _ in range(layer)])
1079
 
@@ -1087,25 +993,21 @@ class TextDecoder(nn.Module):
1087
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
1088
  self.register_buffer("mask", mask, persistent=False)
1089
 
1090
- def forward(self, x, enc, order=None, f0d=None, block='decoder') -> Tensor:
1091
  bln = self.blend
1092
  x = x.to(device)
1093
- if self.f0_rotary:
1094
- f0d = f0d
1095
- else:
1096
- f0d = None
1097
  if order is None:
1098
  order = self.features
1099
  mask = self.mask[:x.shape[1], :x.shape[1]]
1100
  x = self.token(x) + self.positional[:x.shape[1]]
1101
  x = F.dropout(x, p=self.dropout, training=self.training)
1102
- for block in self._blocks:
1103
- x = block(x, f0=f0d, mask=mask, block=block)
1104
  for f in order:
1105
  if f in enc:
1106
  xa = enc[f]
1107
  for block in self.blocks[f]:
1108
- out = block(x=x, xa=xa, f0=f0d, mask=None, block=block)
1109
  a = torch.sigmoid(bln[f])
1110
  x = a * out + (1 - a) * x
1111
  x = self.ln_dec(x)
@@ -1124,8 +1026,6 @@ class Echo(nn.Module):
1124
  self.param = param
1125
  self.count = 0
1126
 
1127
-
1128
-
1129
  self.encoder = AudioEncoder(
1130
  mels=param.mels,
1131
  ctx=param.aud_ctx,
@@ -1175,7 +1075,7 @@ class Echo(nn.Module):
1175
  spectrogram: torch.Tensor=None,
1176
  pitch: Optional[torch.Tensor]=None,
1177
  f0: Optional[torch.Tensor]=None,
1178
- f0d: Optional[torch.Tensor]=None,
1179
  envelope: Optional[torch.Tensor]=None,
1180
  phase: Optional[torch.Tensor]=None,
1181
  ) -> Dict[str, torch.Tensor]:
@@ -1192,11 +1092,11 @@ class Echo(nn.Module):
1192
  encoder_inputs["envelope"] = envelope
1193
  if phase is not None:
1194
  encoder_inputs["phase"] = phase
 
 
1195
 
1196
-
1197
- encoder_outputs = self.encoder(encoder_inputs, f0=f0)
1198
- logits = self.decoder(input_ids, encoder_outputs, f0d=f0d)
1199
-
1200
 
1201
  loss = None
1202
  if labels is not None:
@@ -1385,19 +1285,19 @@ class DataCollator:
1385
  all_f0 = torch.cat([f["f0"] for f in features])
1386
  batch["f0"] = all_f0.unsqueeze(0)
1387
 
1388
- if "f0" in features[0] and features[0]["f0"] is not None:
1389
- f0_labels = batch.get("labels", None)
1390
- aligned_features = []
1391
- for feature in features:
1392
- f0 = feature["f0"]
1393
- length = f0.shape
1394
- if length != f0_labels.shape[-1]:
1395
- ctx = f0_labels.shape[-1]
1396
- aligned_features.append(align_f0(f0.unsqueeze(0), ctx))
1397
- else:
1398
- aligned_features.append(f0)
1399
- all_aligned_f0 = torch.cat(aligned_features)
1400
- batch["f0d"] = all_aligned_f0
1401
 
1402
  if "envelope" in features[0] and features[0]["envelope"] is not None:
1403
  env_list = [f["envelope"] for f in features]
@@ -1498,7 +1398,6 @@ def load_wave(wave_data, sample_rate):
1498
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1499
  waveform = resampler(waveform)
1500
 
1501
-
1502
  return waveform.flatten()
1503
 
1504
  def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
@@ -1548,7 +1447,6 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=
1548
  batch["envelope"] = torch.stack(envelope_list)
1549
  batch["phase"] = torch.stack(phase_list)
1550
 
1551
-
1552
  wav_1d = wav.unsqueeze(0)
1553
 
1554
  if waveforms:
@@ -1569,7 +1467,7 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=
1569
  f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1570
  f0 = f0
1571
  batch["f0"] = torch.from_numpy(f0).float()
1572
-
1573
  if spectrogram and waveforms and pitch:
1574
  spec_mean = batch["spectrogram"].mean()
1575
  spec_std = batch["spectrogram"].std() + 1e-6
@@ -1805,8 +1703,7 @@ def get_training_args(
1805
  save_safetensors=False,
1806
  eval_on_start=eval_on_start,
1807
  batch_eval_metrics=batch_eval_metrics,
1808
- max_grad_norm=max_grad_norm,
1809
-
1810
  )
1811
 
1812
  def main():
@@ -1859,20 +1756,20 @@ def main():
1859
  text_dims=512,
1860
  text_idx=4,
1861
  act="swish",
1862
- debug={"rotary"},#{"encoder", "decoder", "residual", "rotary"},
1863
  cross_attn=True,
1864
- f0_rotary=True,
1865
- features = ["spectrogram", "waveform", "pitch"],
1866
  )
1867
 
1868
  sanity_check = False
1869
  training_args = sanity(sanity_check)
1870
  dataset_config = {
1871
  "spectrogram": True,
1872
- "waveforms": True,
1873
- "pitch": True,
1874
  "downsamples": False,
1875
- "frequency": True,
1876
  "hilbert": False,
1877
  "hop_length": 128,
1878
  "fmin": 150,
@@ -1917,3 +1814,4 @@ def main():
1917
 
1918
  if __name__ == "__main__":
1919
  main()
 
 
 
1
  import pyworld as pw
2
  import os
3
+ import math, random
4
  import warnings
5
  import logging
6
  import gzip
 
38
  logging.basicConfig(level=logging.ERROR)
39
  tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32}
40
 
 
41
  extractor = None
42
  tokenizer = None
43
  optimizer = None
 
147
  axs[current_ax].set_ylabel("Mel Bin")
148
  axs[current_ax].set_xlim([0, max_time])
149
  axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
150
+ # fig.colorbar(im, ax=axs[current_ax])
151
  current_ax += 1
152
 
153
  if p is not None:
 
255
  scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
256
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
257
 
258
+ class ParameterCycler:
259
+ def __init__(self, parameters):
260
+ self.parameters = parameters
261
+ self.current_idx = 0
262
+ def toggle_requires_grad(self):
263
+ x = random.randint(0, len(self.parameters) - 1)
264
+ for x, param in enumerate(self.parameters):
265
+ param.requires_grad = (x == self.current_idx)
266
+ print(f"Parameter {x}: requires_grad={param.requires_grad}")
267
+ self.current_idx = (self.current_idx + 1) % len(self.parameters)
268
+
269
  class rotary(nn.Module):
270
  _seen = set()
271
  def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
 
273
  super().__init__()
274
 
275
  self.use_pbias = use_pbias
276
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
277
+ self.dtype = torch.float32
 
 
278
  self.debug = debug
279
  self._counter = 0
280
  self.dims = dims
281
  self.max_ctx = max_ctx
282
  self.radii = radii
283
+ f0_factor = 0.5
 
284
  self.learned_adaptation: bool = False
285
  pitch_scale = 1.0
286
+ radius = 1
287
+
288
  if self.learned_adaptation:
289
+ self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=self.device, dtype=self.dtype), requires_grad=True)
290
  else:
291
+ self.register_buffer('f0_scale', torch.tensor(f0_factor))
292
 
293
+ self.theta = nn.Parameter(torch.tensor(theta, device=self.device, dtype=self.dtype), requires_grad=True)
294
+ self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale, device=self.device, dtype=self.dtype), requires_grad=True)
295
  freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
296
+ self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
297
+ self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
298
 
299
+ # self.cycler = ParameterCycler(parameters=[self.theta, self.pitch_scale, self.freqs])
300
+ # self.reset_parameters()
 
301
 
302
  def get_pitch_bias(self, f0):
303
  if f0 is None:
 
325
  rotary.fwd_sim = fwd_sim
326
 
327
  def align_f0(self, f0, ctx):
328
+ b, l = f0.shape
329
+ if l == ctx:
330
+ return f0.squeeze(0).float()
331
+ frames_per_token = l / ctx
332
+ idx = torch.arange(ctx, device=self.device, dtype=torch.float32)
333
+ src_idx = (idx * frames_per_token).long().clamp(0, l-1)
334
+ batch_idx = torch.arange(b, device=self.device, dtype=torch.float32).unsqueeze(1)
335
+ f0 = f0[batch_idx, src_idx]
336
+ return f0.squeeze(0).float()
337
+
338
+ # def align_f0(self, f0, ctx):
339
+ # b, l = f0.shape
340
+ # if l == ctx:
341
+ # return f0.squeeze(0).float()
342
+ # frames = l / ctx
343
+ # idx = torch.arange(ctx, device=f0.device)
344
+ # f0 = (idx * frames).long()
345
+ # # b_idx = torch.arange(b, device=f0.device).unsqueeze(1)
346
+ # # f0 = f0[b_idx, idx.unsqueeze(0).expand(b, -1)]
347
+ # return f0.squeeze(0).float()
348
 
349
  def scale_f0(self, f0):
350
  f0_min = f0.min(dim=1, keepdim=True)[0]
 
389
  freqs = torch.outer(positions, freqs)
390
  return torch.polar(torch.ones_like(freqs), freqs)
391
 
392
+ def forward(self, x=None, f0=None, layer=None) -> Tensor:
393
+ # self.cycler.toggle_requires_grad()
394
  if isinstance(x, int):
395
  ctx = x
396
  else:
397
  batch, ctx, dims = x.shape
398
  t = torch.arange(ctx, device=self.device).float()
399
+
400
  if self.learned_adaptation:
401
  freqs = self.get_f0_adapted_freqs(ctx, f0)
402
  x_complex = torch.view_as_complex(
 
414
 
415
  freqs = torch.einsum('i,j->ij', t, freqs)
416
  freqs = freqs.float()
417
+
418
  if self.radii and f0 is not None:
419
+
420
+ radius = self.align_f0(f0, ctx)
421
+
422
+ # radius = torch.clamp(radius, min=50.0, max=500.0) # Clamp to voice range
423
+ # radius = radius / 500.0 # Normalize to [0.1, 1.0] range
424
+ # radius = radius.float()
425
+
426
+ radius = radius.float()
427
+ freqs = torch.polar(radius.unsqueeze(-1), freqs)
428
  else:
429
  freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
430
  if "rotary" in self.debug:
 
436
  print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
437
  elif abs(self._prev_f0_theta - theta) > 100.0:
438
  print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
439
+ print(f"f0_mean: {f0_mean} Hz, freqs: {freqs.shape}, ctx: {ctx}, dims: {self.dims}, block: {layer}")
440
  if self.radii:
441
+ print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
442
  self._prev_f0_theta = theta
443
  rotary._seen.add(key)
444
  self._counter += 1
 
579
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
580
 
581
  def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None,
582
+ return_attn: bool = False, f0: Tensor = None, layer = None) -> tuple:
583
 
584
  batch, ctx, dims = x.shape
585
  scale = (self.dims // self.head) ** -0.25
 
590
  v = self.v(z).to(x.dtype)
591
 
592
  if self.rotary_emb:
593
+ qf = self.rope(q.size(1), f0=f0, layer=layer)
594
+ kf = self.rope(k.size(1), f0=f0, layer=layer)
595
 
596
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
597
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
689
  comb = torch.cat([s, w, p], dim=-1)
690
  return self.integ(comb)
691
 
 
 
 
692
  class Residual(nn.Module):
693
  _seen = set()
694
  def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
 
736
  if not any([t_gate, m_gate, c_gate]):
737
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
738
 
739
+ def forward(self, x, xa=None, mask=None, f0=None, mode=None, layer=None):
740
  bln = self.blend
741
+ x = x + self.attna(self.lna(x), mask=mask, f0=f0, layer=layer)[0]
742
 
743
  if self.attnb and xa is not None:
744
+ c = self.attnb(self.lnb(x), xa, f0=f0, mask=None, layer=layer)[0]
745
  b = torch.sigmoid(bln)
746
  x = b * x + (1 - b) * c
747
 
 
781
 
782
  return x
783
 
 
 
784
  class PEncoder(nn.Module):
785
  def __init__(self, input_dims, dims, head, layer, kernel_size, act):
786
  super().__init__()
 
796
  Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
797
  Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
798
 
799
+ def forward(self, x, f0=None, layer=None):
800
  x = self.encoder(x).permute(0, 2, 1)
801
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
802
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
 
825
  self.positional = lambda length: sinusoids(length, dims)
826
  self.norm = RMSNorm(dims)
827
 
828
+ def forward(self, x, f0=None, layer=None):
829
  x = self.downsample(x)
830
  x = self.encoder(x)
831
  x = x.permute(0, 2, 1)
 
852
  self.norm = RMSNorm(dims)
853
  self._norm = RMSNorm(dims)
854
 
855
+ def forward(self, x, f0=None, layer=None):
856
  x = self.encoder(x).permute(0, 2, 1)
857
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
858
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
859
  x = self._norm(x)
860
  return x
861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  class AudioEncoder(nn.Module):
863
  _seen = set()
864
  def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str],
 
881
  self.dropout = 0.01
882
  self.f0_rotary = f0_rotary
883
 
884
+ self.rope = rotary(
885
+ dims=self.head_dim,
886
+ )
887
+
888
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
889
  "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
890
  act_fn = act_map.get(act, nn.GELU())
 
906
  "pitch": nn.ModuleList(
907
  [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
908
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
909
+ ),
910
  "spec_envelope": nn.ModuleList(
911
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
912
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug) for _ in range(layer)] if "spec_envelope" in features else None
 
916
  [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None),
917
  })
918
 
919
+ self.f0 = nn.ModuleList([
920
+ FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
921
+ for _ in range(layer)])
922
+
923
+ def forward(self, x, f0=None, layer="encoder"):
924
  if self._counter < 1:
925
  s = x.get("spectrogram")
926
  w = x.get("waveform")
927
  p = f0 if f0 is not None else x.get("pitch")
928
  plot_waveform(x=s, w=w, p=p, hop_length=128)
929
+
930
  enc = {}
931
+
932
+ # if f0 is not None:
933
+ # f0 = self.f0(f0)
934
+
935
+ #self.rope(x=f, f0=f0, layer=layer)
936
+
937
  for y in self.features:
938
  if y in x and y in self.blocks:
939
  f = x[y]
940
  for block in self.blocks[y]:
941
+ f = block(f, f0=f0, layer=layer)
942
  enc[y] = f
943
+
944
  if "encoder" in self.debug and self._counter % 100 == 0:
945
  names = list(x.keys())
946
  shapes = {k: v.shape for k, v in x.items()}
 
952
  self._counter += 1
953
  return enc
954
 
 
 
955
  class TextDecoder(nn.Module):
956
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
957
  debug: List[str], features: List[str], f0_rotary: bool = False, sequential=False):
 
979
  self.token.weight[0].zero_()
980
  self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
981
 
982
+ self.block = nn.ModuleList([
983
  Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
984
  for _ in range(layer)])
985
 
 
993
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
994
  self.register_buffer("mask", mask, persistent=False)
995
 
996
+ def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
997
  bln = self.blend
998
  x = x.to(device)
 
 
 
 
999
  if order is None:
1000
  order = self.features
1001
  mask = self.mask[:x.shape[1], :x.shape[1]]
1002
  x = self.token(x) + self.positional[:x.shape[1]]
1003
  x = F.dropout(x, p=self.dropout, training=self.training)
1004
+ for block in self.block:
1005
+ x = block(x, xa=None, mask=mask, layer=layer)
1006
  for f in order:
1007
  if f in enc:
1008
  xa = enc[f]
1009
  for block in self.blocks[f]:
1010
+ out = block(x=x, xa=xa, mask=None, layer=layer)
1011
  a = torch.sigmoid(bln[f])
1012
  x = a * out + (1 - a) * x
1013
  x = self.ln_dec(x)
 
1026
  self.param = param
1027
  self.count = 0
1028
 
 
 
1029
  self.encoder = AudioEncoder(
1030
  mels=param.mels,
1031
  ctx=param.aud_ctx,
 
1075
  spectrogram: torch.Tensor=None,
1076
  pitch: Optional[torch.Tensor]=None,
1077
  f0: Optional[torch.Tensor]=None,
1078
+ # f0d: Optional[torch.Tensor]=None,
1079
  envelope: Optional[torch.Tensor]=None,
1080
  phase: Optional[torch.Tensor]=None,
1081
  ) -> Dict[str, torch.Tensor]:
 
1092
  encoder_inputs["envelope"] = envelope
1093
  if phase is not None:
1094
  encoder_inputs["phase"] = phase
1095
+ # if f0 is not None:
1096
+ # encoder_inputs["f0"] = f0
1097
 
1098
+ encoder_outputs = self.encoder(encoder_inputs, f0=f0, layer="encoder")
1099
+ logits = self.decoder(input_ids, encoder_outputs)
 
 
1100
 
1101
  loss = None
1102
  if labels is not None:
 
1285
  all_f0 = torch.cat([f["f0"] for f in features])
1286
  batch["f0"] = all_f0.unsqueeze(0)
1287
 
1288
+ # if "f0" in features[0] and features[0]["f0"] is not None:
1289
+ # f0_labels = batch.get("labels", None)
1290
+ # aligned_features = []
1291
+ # for feature in features:
1292
+ # f0 = feature["f0"]
1293
+ # length = f0.shape
1294
+ # if length != f0_labels.shape[-1]:
1295
+ # ctx = f0_labels.shape[-1]
1296
+ # aligned_features.append(align_f0(f0.unsqueeze(0), ctx))
1297
+ # else:
1298
+ # aligned_features.append(f0)
1299
+ # all_aligned_f0 = torch.cat(aligned_features)
1300
+ # batch["f0d"] = all_aligned_f0
1301
 
1302
  if "envelope" in features[0] and features[0]["envelope"] is not None:
1303
  env_list = [f["envelope"] for f in features]
 
1398
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1399
  waveform = resampler(waveform)
1400
 
 
1401
  return waveform.flatten()
1402
 
1403
  def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
 
1447
  batch["envelope"] = torch.stack(envelope_list)
1448
  batch["phase"] = torch.stack(phase_list)
1449
 
 
1450
  wav_1d = wav.unsqueeze(0)
1451
 
1452
  if waveforms:
 
1467
  f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1468
  f0 = f0
1469
  batch["f0"] = torch.from_numpy(f0).float()
1470
+
1471
  if spectrogram and waveforms and pitch:
1472
  spec_mean = batch["spectrogram"].mean()
1473
  spec_std = batch["spectrogram"].std() + 1e-6
 
1703
  save_safetensors=False,
1704
  eval_on_start=eval_on_start,
1705
  batch_eval_metrics=batch_eval_metrics,
1706
+ max_grad_norm=max_grad_norm,
 
1707
  )
1708
 
1709
  def main():
 
1756
  text_dims=512,
1757
  text_idx=4,
1758
  act="swish",
1759
+ debug={},#{"encoder", "decoder", "residual", "rotary"},
1760
  cross_attn=True,
1761
+ f0_rotary=False,
1762
+ features = ["spectrogram"]#, "waveform", "pitch", "f0", "envelope", "phase"],
1763
  )
1764
 
1765
  sanity_check = False
1766
  training_args = sanity(sanity_check)
1767
  dataset_config = {
1768
  "spectrogram": True,
1769
+ "waveforms": False,
1770
+ "pitch": False,
1771
  "downsamples": False,
1772
+ "frequency": False,
1773
  "hilbert": False,
1774
  "hop_length": 128,
1775
  "fmin": 150,
 
1814
 
1815
  if __name__ == "__main__":
1816
  main()
1817
+