Sin2pi commited on
Commit
80c70a4
·
verified ·
1 Parent(s): 00f642e

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +563 -406
modelA.py CHANGED
@@ -8,10 +8,15 @@ import torchaudio
8
  import torch.nn.functional as F
9
  import torch.nn.init as init
10
  from torch import nn, Tensor
11
- from datasets import load_dataset, Audio
12
- from torch.utils.data import Dataset, DataLoader, random_split
 
13
  import numpy as np
14
- from typing import Optional, Dict, Union, List, Tuple
 
 
 
 
15
  import transformers
16
  from dataclasses import dataclass
17
  from opimizer import MaxFactor
@@ -20,22 +25,44 @@ torch.backends.cudnn.allow_tf32 = True
20
  torch.backends.cuda.matmul.allow_tf32 = True
21
  torch.set_float32_matmul_precision('high')
22
  transformers.utils.logging.set_verbosity_error()
 
23
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
  dtype = torch.float32
 
 
25
  logging.basicConfig(level=logging.ERROR)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @dataclass
28
  class Dimensions:
29
  vocab: int
30
- text_ctx: int
31
- text_dims: int
32
- text_head: int
33
- text_idx: int
34
  mels: int
35
- aud_ctx: int
36
- aud_dims: int
37
- aud_head: int
38
- aud_idx: int
39
  act: str
40
  debug: List[str]
41
  cross_attn: bool
@@ -59,6 +86,132 @@ def get_generation_config(param):
59
  use_cache=False,
60
  return_timestamps=False)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def dict_to(d, device, dtype=dtype):
63
  return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
64
  for k, v in d.items()}
@@ -100,17 +253,17 @@ class RMSNorm(nn.Module):
100
  self.eps = eps
101
  self.elementwise_affine = elementwise_affine
102
  if self.elementwise_affine:
103
- self.weight = nn.Parameter(torch.empty(self.normalized_shape))
104
  init.ones_(self.weight)
105
  else:
106
  self.register_parameter("weight", None)
107
  def forward(self, x):
108
- return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
109
 
110
  def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
111
  weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
112
  eps: float = 1e-5) -> Tensor:
113
- return F.layer_norm(x, normalized_shape, weight, bias, eps)
114
 
115
  def get_device():
116
  return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -128,9 +281,8 @@ def sinusoids(length, channels, max_tscale=10000):
128
  scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
129
  return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
130
 
131
-
132
  class rotary(nn.Module):
133
- def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=True, debug: List[str] = [], use_pbias=False):
134
  super(rotary, self).__init__()
135
 
136
  self.use_pbias = use_pbias
@@ -143,96 +295,17 @@ class rotary(nn.Module):
143
  self.counter = 0
144
  self.last_theta = None
145
 
146
- self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
147
- self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
148
-
149
- # def theta_freqs(self, theta):
150
- # freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
151
- # freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
152
- # return freqs
153
-
154
- # def mel_geodesic_rotary(f0, theta):
155
- # mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
156
- # fisher_info = torch.var(mel_f0) + 1e-8
157
- # adaptive_theta = theta * torch.sqrt(fisher_info)
158
- # freqs = self.theta_freqs(adaptive_theta)
159
- # return freqs
160
-
161
- # def compute_pitch_fisher_info(f0, window_size=10):
162
- # if f0.dim() == 1:
163
- # f0 = f0.unsqueeze(0)
164
- # mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
165
- # fisher_info = torch.nn.functional.avg_pool1d(
166
- # mel_f0.unsqueeze(0),
167
- # kernel_size=window_size,
168
- # stride=1,
169
- # padding=window_size//2
170
- # ).squeeze(0)
171
- # fisher_info = (fisher_info - fisher_info.min()) / (fisher_info.max() - fisher_info.min() + 1e-8)
172
- # return fisher_info
173
-
174
- # def compute_advanced_fisher_info(f0, window_size=10):
175
- # mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
176
- # local_mean = torch.nn.functional.avg_pool1d(
177
- # mel_f0.unsqueeze(0), window_size, 1, window_size//2
178
- # ).squeeze(0)
179
-
180
- # local_var = torch.nn.functional.avg_pool1d(
181
- # (mel_f0 - local_mean).pow(2).unsqueeze(0),
182
- # window_size, 1, window_size//2
183
- # ).squeeze(0)
184
-
185
- # fisher_info = 1.0 / (local_var + 1e-8)
186
- # return fisher_info
187
-
188
- # def test_fisher_info(self, f0):
189
- # """Test Fisher information computation.""" # fisher_info = self.compute_pitch_fisher_info(f0)
190
-
191
- # print(f"f0 range: {f0.min():.1f} - {f0.max():.1f}")
192
- # print(f"Fisher info range: {fisher_info.min():.3f} - {fisher_info.max():.3f}")
193
- # print(f"Fisher info mean: {fisher_info.mean():.3f}")
194
-
195
- # # Visualize: high Fisher info = meaningful pitch changes
196
- # return fisher_info
197
-
198
- # def forward(self, x=None, enc=None, layer=None, feature_type="audio"):
199
-
200
- # if f0 is not None:
201
- # # Compute Fisher information
202
- # fisher_info = self.compute_pitch_fisher_info(f0)
203
-
204
- # # Use Fisher info to weight pitch influence
205
- # f0_weighted = f0 * fisher_info
206
-
207
- # # Apply to both theta and radius
208
- # f0_mean = f0_weighted.mean()
209
- # theta = f0_mean + self.theta
210
-
211
- # if self.radii:
212
- # radius = f0_weighted.to(device, dtype)
213
-
214
-
215
-
216
- def theta_freqs(self, theta):
217
- freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
218
- freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
219
- return freqs
220
 
221
- def mel_scale_scalar(freq: float) -> float:
222
  return 1127.0 * math.log(1.0 + freq / 700.0)
223
 
224
- def mel_scale(freq: Tensor) -> Tensor:
225
  return 1127.0 * (1.0 + freq / 700.0).log()
226
 
227
- def return_f0(self, f0=None):
228
- if f0 is not None:
229
- self.f0 = f0
230
- self.update_base(f0)
231
- return f0.squeeze(0).to(device, dtype)
232
- elif hasattr(self, 'f0') and self.f0 is not None:
233
- return self.f0.squeeze(0).to(device, dtype)
234
- return None
235
-
236
  def pitch_bias(self, f0):
237
  if f0 is None:
238
  return None
@@ -242,9 +315,31 @@ class rotary(nn.Module):
242
  f0_norm.unsqueeze(1)))
243
  return f0_sim.unsqueeze(0).unsqueeze(0)
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def forward(self, x=None, enc=None, layer=None, feature_type="audio") -> Tensor:
247
  f0 = enc.get("f0") if enc is not None else None
 
248
  if isinstance(x, int):
249
  ctx = x
250
  elif isinstance(x, torch.Tensor) and x.ndim == 2:
@@ -252,46 +347,33 @@ class rotary(nn.Module):
252
  elif isinstance(x, torch.Tensor) and x.ndim == 3:
253
  batch, ctx, dims = x.shape
254
  else:
255
- batch, head, ctx, head_dim = x.shape
256
- t = torch.arange(ctx, device=device, dtype=dtype)
257
-
258
- if f0 is not None and f0.dim() == 2:
259
- if f0.shape[0] == 1:
260
- f0 = f0.squeeze(0)
261
- else:
262
- f0 = f0.view(-1)
263
 
264
  if f0 is not None:
265
- f0_mean = f0.mean()
266
- theta = f0_mean + self.theta
 
267
  else:
268
  theta = self.theta
269
 
270
  freqs = self.theta_freqs(theta)
271
-
272
- freqs = t[:, None] * freqs[None, :]
273
-
274
  if self.radii and f0 is not None:
275
  radius = f0.to(device, dtype)
276
- L = radius.shape[0]
277
- if L != ctx:
278
- F = L / ctx
279
- idx = torch.arange(ctx, device=f0.device)
280
- idx = (idx * F).long().clamp(0, L - 1)
281
- radius = radius[idx]
282
- freqs = torch.polar(radius.unsqueeze(-1).expand_as(freqs), freqs)
283
  else:
284
- freqs = torch.polar(torch.ones_like(freqs), freqs)
285
-
286
- if "radius" in self.debug and self.counter % 100 == 0:
287
- theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
288
- print(f" [{layer}] [Radius] {radius.shape} {radius.mean():.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
289
 
290
- if "theta" in self.debug and self.counter % 100 == 0:
291
- if self.last_theta is None or abs(self.last_theta - theta.item()) > 1.0:
292
- self.last_theta = theta.item()
293
- print(f"[Theta] {self.last_theta:.2f}")
294
-
 
 
295
  self.counter += 1
296
  return freqs.unsqueeze(0)
297
 
@@ -308,12 +390,11 @@ class rotary(nn.Module):
308
  x1 = x1.view(orig_shape)
309
  return torch.cat([x1.type_as(x), x2], dim=-1)
310
 
311
-
312
  class MultiheadA(nn.Module):
313
- _seen = set()
314
  rbf = False
315
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
316
- zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False, use_pbias=False):
317
  super(MultiheadA, self).__init__()
318
 
319
  self.dims = dims
@@ -345,8 +426,29 @@ class MultiheadA(nn.Module):
345
  )
346
  else:
347
  self.rope = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
350
 
351
  x = x.to(device, dtype)
352
  if xa is not None:
@@ -365,8 +467,8 @@ class MultiheadA(nn.Module):
365
  q2 = q.shape[2]
366
  k2 = k.shape[2]
367
 
368
- q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
369
- k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
370
  else:
371
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
372
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -374,14 +476,25 @@ class MultiheadA(nn.Module):
374
 
375
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
376
 
 
 
 
 
 
 
 
377
  token_ids = k[:, :, :, 0]
378
  zscale = torch.ones_like(token_ids)
379
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
380
  zscale[token_ids.float() == self.pad_token] = fzero
381
 
382
  if mask is not None:
383
- mask = mask[:q2, :q2]
384
- qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
 
 
 
 
385
  qk = qk * zscale.unsqueeze(-2)
386
  w = F.softmax(qk, dim=-1).to(q.dtype)
387
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
@@ -392,8 +505,9 @@ class MultiheadA(nn.Module):
392
  return self.o(wv), qk
393
 
394
  class t_gate(nn.Module):
395
- def __init__(self, dims, num_types=4):
396
  super().__init__()
 
397
  self.gate_projections = nn.ModuleList([
398
  nn.Sequential(Linear(dims, 1), nn.Sigmoid())
399
  for _ in range(num_types)])
@@ -401,19 +515,25 @@ class t_gate(nn.Module):
401
  Linear(dims, num_types),
402
  nn.Softmax(dim=-1))
403
  def forward(self, x):
 
 
404
  type_probs = self.type_classifier(x)
405
  gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
406
  comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
407
  return comb_gate
408
 
409
  class m_gate(nn.Module):
410
- def __init__(self, dims, mem_size=64):
411
  super().__init__()
412
- self.m_key = nn.Parameter(torch.randn(mem_size, dims))
413
- self.m_val = nn.Parameter(torch.randn(mem_size, 1))
414
- self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
415
-
 
 
416
  def forward(self, x):
 
 
417
  d_gate = torch.sigmoid(self.gate_proj(x))
418
  attention = torch.matmul(x, self.m_key.transpose(0, 1))
419
  attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
@@ -422,16 +542,20 @@ class m_gate(nn.Module):
422
  return 0.5 * (d_gate + m_gate)
423
 
424
  class c_gate(nn.Module):
425
- def __init__(self, dims):
426
  super().__init__()
427
- self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
428
- self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
429
- self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
430
- self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
431
- self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
432
- self.integ = Linear(dims*5, dims)
 
 
433
 
434
  def forward(self, x, features):
 
 
435
  s_feat = features.get("spectrogram", x)
436
  w_feat = features.get("waveform", x)
437
  p_feat = features.get("pitch", x)
@@ -445,9 +569,21 @@ class c_gate(nn.Module):
445
  comb = torch.cat([s, w, p, e, ph], dim=-1)
446
  return self.integ(comb)
447
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  class Residual(nn.Module):
449
  _seen = set()
450
- def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
451
  tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
452
  super().__init__()
453
 
@@ -455,80 +591,44 @@ class Residual(nn.Module):
455
  self.head = head
456
  self.ctx = ctx
457
  self.head_dim = dims // head
458
- self.cross_attn = cross_attn
459
  self.features = features
460
  self.debug = debug
461
  self.counter = 0
462
  self.dropout = 0.01
463
-
464
- self.t_gate = tgate
465
- self.m_gate = mgate
466
- self.c_gate = cgate
467
- self.do_blend = "no_blend" not in self.debug
468
  self.blend = nn.Parameter(torch.tensor(0.5))
469
- self.skip_gates = True if "skip_gates" in self.debug else False
470
-
471
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
472
- "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
473
- "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
474
- "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
475
- act_fn = act_map.get(act, nn.GELU())
476
-
477
- self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug)
478
- self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None)
479
 
480
  mlp = dims * 4
481
  self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
482
 
483
- self.t_gate = t_gate(dims=dims, num_types=4) if tgate else None
484
- self.m_gate = m_gate(dims=dims, mem_size=mem_size) if mgate else None
485
- self.c_gate = c_gate(dims=dims) if cgate else None
 
486
 
487
  self.lna = RMSNorm(dims)
488
- self.lnb = RMSNorm(dims) if cross_attn else None
489
  self.lnc = RMSNorm(dims)
490
 
491
- if not any([t_gate, m_gate, c_gate]):
492
- self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
493
-
494
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
495
-
496
- x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
497
- xb = x
498
- if self.attnb and xa is not None:
499
- x = x + self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
500
-
501
- if self.do_blend:
502
- b = torch.sigmoid(self.blend)
503
- x = b * xb + (1 - b) * x
504
 
505
- if self.skip_gates:
506
- x = x + self.mlp(self.lnc(x))
507
- else:
508
- normx = self.lnc(x)
509
- mlp_out = self.mlp(normx)
510
-
511
- if self.t_gate:
512
- gate = self.t_gate(normx)
513
- x = x + gate * mlp_out
514
-
515
- elif self.m_gate:
516
- gate = self.m_gate(normx)
517
- x = x + gate * mlp_out
518
 
519
- elif self.c_gate:
520
- gate_output = self.c_gate(normx, self.features)
521
- x = x + gate_output
522
-
523
- else:
524
- if hasattr(self, 'mlp_gate'):
525
- mlp_gate = self.mlp_gate(normx)
526
- x = x + mlp_gate * mlp_out
527
- else:
528
- x = x + mlp_out
529
-
530
- return x
531
-
532
  class FEncoder(nn.Module):
533
  def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
534
  super().__init__()
@@ -539,8 +639,7 @@ class FEncoder(nn.Module):
539
  self.use_rope = use_rope
540
  self.dims = dims
541
 
542
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
543
- act_fn = act_map.get(act, nn.GELU())
544
 
545
  self.encoder = nn.Sequential(
546
  Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
@@ -551,11 +650,13 @@ class FEncoder(nn.Module):
551
  if spec_shape is not None:
552
  self.rope = rotary(
553
  dims=self.head_dim,
 
554
  use_2d_axial=True,
555
  spec_shape=spec_shape, debug=[])
556
  else:
557
  self.rope = rotary(
558
  dims=self.head_dim,
 
559
  use_2d_axial=False, debug=[])
560
  else:
561
  self.rope = None
@@ -569,7 +670,7 @@ class FEncoder(nn.Module):
569
  feature_type = "spectrogram"
570
  batch, ctx, dims = x.shape
571
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
572
- if feature_type == "spectrogram" and hasattr(self.rope, 'use_2d_axial') and self.rope.use_2d_axial:
573
  rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
574
  else:
575
  rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
@@ -597,8 +698,7 @@ class WEncoder(nn.Module):
597
  self.use_rope = use_rope
598
  self.dims = dims
599
 
600
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
601
- act_fn = act_map.get(act, nn.GELU())
602
 
603
  self.downsample = nn.Sequential(
604
  Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
@@ -611,8 +711,8 @@ class WEncoder(nn.Module):
611
  if use_rope:
612
  self.rope = rotary(
613
  dims=self.head_dim,
614
- use_2d_axial=False,
615
- theta=50.0, debug=[])
616
  else:
617
  self.rope = None
618
  self.positional = lambda length: sinusoids(length, dims)
@@ -649,8 +749,7 @@ class PEncoder(nn.Module):
649
  self.use_rope = use_rope
650
  self.dims = dims
651
 
652
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
653
- act_fn = act_map.get(act, nn.GELU())
654
 
655
  self.encoder = nn.Sequential(
656
  Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
@@ -660,8 +759,8 @@ class PEncoder(nn.Module):
660
  if use_rope:
661
  self.rope = rotary(
662
  dims=self.head_dim,
663
- use_2d_axial=False,
664
- theta=100.0, debug=[])
665
  else:
666
  self.rope = None
667
  self.positional = lambda length: sinusoids(length, dims)
@@ -687,10 +786,10 @@ class PEncoder(nn.Module):
687
  x = self.norm(x)
688
  return x
689
 
690
- class AudioEncoder(nn.Module):
691
  _seen = set()
692
- def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
693
- super(AudioEncoder, self).__init__()
694
 
695
  self.dims = dims
696
  self.head = head
@@ -700,9 +799,12 @@ class AudioEncoder(nn.Module):
700
  self.counter = 0
701
  self.features = features
702
  self.dropout = 0.01
 
 
703
 
704
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),"tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
705
- act_fn = act_map.get(act, nn.GELU())
 
706
 
707
  if features == ["spectrogram", "waveform", "pitch"]:
708
  cgate=True
@@ -737,80 +839,55 @@ class AudioEncoder(nn.Module):
737
  if "phase" in features else None),
738
  })
739
 
740
- def forward(self, enc, layer="encoder"):
741
- enc = dict_to(enc, device, dtype)
742
- out = {}
743
- out.update(enc)
744
-
745
- for f in self.features:
746
- if f in enc and f in self.blocks:
747
- x = enc[f]
748
- for block in self.blocks[f]:
749
- x = block(x, enc=enc, layer=layer)
750
- out[f] = x
751
-
752
- return out
753
-
754
- class TextDecoder(nn.Module):
755
- def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
756
- debug: List[str], features: List[str]):
757
- super(TextDecoder, self).__init__()
758
-
759
- self.ctx = ctx
760
- self.dims = dims
761
- self.head = head
762
- self.head_dim = dims // head
763
- self.debug = debug
764
- self.counter = 0
765
- self.dropout = 0.01
766
- self.features = features
767
- self.do_blend = "no_blend" not in self.debug
768
- self.sequential = "sequential" in self.debug
769
-
770
- self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
771
- with torch.no_grad():
772
- self.token.weight[0].zero_()
773
- self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
774
-
775
  self.block = nn.ModuleList([
776
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
777
- for _ in range(layer)])
778
-
779
- self.blocks = nn.ModuleDict({
780
- f: nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
781
- for _ in range(layer)]) for f in features})
782
 
783
- self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
784
  self.ln_dec = RMSNorm(dims)
785
 
 
 
 
 
 
 
 
786
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
787
  self.register_buffer("mask", mask, persistent=False)
788
 
789
- def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
 
790
 
791
- if order is None:
792
- order = self.features
793
-
794
- mask = self.mask[:x.shape[1], :x.shape[1]]
795
  x = self.token(x) + self.positional[:x.shape[1]]
796
  x = F.dropout(x, p=self.dropout, training=self.training)
797
-
 
 
 
 
 
 
 
 
 
 
798
 
799
  for block in self.block:
 
800
  x = block(x, xa=None, mask=mask, enc=None, layer=layer)
801
 
802
- for f in order:
803
  if f in enc:
804
- xa = enc[f]
805
- for block in self.blocks[f]:
806
- out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
807
-
808
  if self.sequential:
809
  x = out
810
  else:
811
- a = torch.sigmoid(self.blend[f])
812
  x = a * out + (1 - a) * x
813
-
814
 
815
  x = self.ln_dec(x)
816
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
@@ -820,38 +897,28 @@ class Echo(nn.Module):
820
  super().__init__()
821
  self.param = param
822
 
823
- self.encoder = AudioEncoder(
 
824
  mels=param.mels,
825
- ctx=param.aud_ctx,
826
- dims=param.aud_dims,
827
- head=param.aud_head,
828
- layer=param.aud_idx,
829
- act=param.act,
830
  debug=param.debug,
831
  features=param.features,
 
832
  )
833
 
834
- self.decoder = TextDecoder(
835
- vocab=param.vocab,
836
- ctx=param.text_ctx,
837
- dims=param.text_dims,
838
- head=param.text_head,
839
- layer=param.text_idx,
840
- cross_attn=param.cross_attn,
841
- debug=param.debug,
842
- features=param.features,
843
- )
844
-
845
  def forward(self,
846
  labels=None,
847
- waveform: Optional[torch.Tensor]=None,
848
  input_ids=None,
849
- spectrogram: torch.Tensor=None,
 
850
  pitch: Optional[torch.Tensor]=None,
851
  f0: Optional[torch.Tensor]=None,
852
  envelope: Optional[torch.Tensor]=None,
853
  phase: Optional[torch.Tensor]=None,
854
- ) -> Dict[str, torch.Tensor]:
855
 
856
  encoder_inputs = {}
857
  if spectrogram is not None:
@@ -866,9 +933,10 @@ class Echo(nn.Module):
866
  encoder_inputs["phase"] = phase
867
  if f0 is not None:
868
  encoder_inputs["f0"] = f0
 
 
869
 
870
- encoder_outputs = self.encoder(encoder_inputs)
871
- logits = self.decoder(input_ids, encoder_outputs)
872
 
873
  loss = None
874
  if labels is not None:
@@ -888,7 +956,7 @@ class Echo(nn.Module):
888
  std = 0.02
889
  self.init_counts = {
890
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
891
- "Conv2d": 0, "SEBlock": 0, "TextDecoder": 0, "AudioEncoder": 0,
892
  "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
893
  "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
894
  "WEncoder": 0, "PEncoder": 0}
@@ -914,12 +982,9 @@ class Echo(nn.Module):
914
  nn.init.zeros_(module.bias)
915
  self.init_counts["Conv2d"] += 1
916
  elif isinstance(module, MultiheadA):
917
-
918
  self.init_counts["MultiheadA"] += 1
919
- elif isinstance(module, TextDecoder):
920
- self.init_counts["TextDecoder"] += 1
921
- elif isinstance(module, AudioEncoder):
922
- self.init_counts["AudioEncoder"] += 1
923
  elif isinstance(module, Residual):
924
  self.init_counts["Residual"] += 1
925
 
@@ -957,10 +1022,11 @@ class Echo(nn.Module):
957
  encoder_inputs["phase"] = phase
958
  if f0 is not None:
959
  encoder_inputs["f0"] = f0
960
- encoder_outputs = self.encoder(encoder_inputs)
961
  for i in range(max_length - 1):
962
  with torch.no_grad():
963
- logits = self.decoder(ids, encoder_outputs)
 
964
  next_token_logits = logits[:, -1, :]
965
  if i < min_length:
966
  next_token_logits[:, eos_token_id] = 0
@@ -985,10 +1051,9 @@ class Echo(nn.Module):
985
  })
986
  return Config()
987
 
988
-
989
- def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
990
  from tokenizers import Tokenizer
991
- tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
992
  orig_encode = tokenizer.encode
993
  def enc(text, add_special_tokens=True):
994
  ids = orig_encode(text).ids
@@ -1005,6 +1070,11 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
1005
  ids = ids[1:]
1006
  while ids and ids[-1] in [0, 2]:
1007
  ids = ids[:-1]
 
 
 
 
 
1008
  results.append(tokenizer.decode(ids))
1009
  return results
1010
 
@@ -1019,95 +1089,165 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
1019
  tokenizer.eos_token_id = 2
1020
  return tokenizer
1021
 
1022
- def extract_features(batch, tokenizer, sample_rate=16000, n_mels=128, n_fft=1024, hop_length=256):
 
 
 
 
 
 
 
 
 
 
 
 
1023
  audio = batch["audio"]
1024
- waveform = torch.tensor(audio["array"]).float()
1025
- if waveform.dim() == 2:
1026
- waveform = waveform.mean(dim=0)
1027
-
1028
- # mel_spectrogram = transform(wav)
1029
- # log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1030
- # log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1031
- # spec = (log_mel + 4.0) / 4.0
1032
- # spec = torch.tensor(spec)
1033
-
1034
- mel = torchaudio.transforms.MelSpectrogram(
1035
- sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
1036
- )
1037
- spec = mel(waveform)
1038
- spec = torch.clamp(spec, min=1e-10).log10()
1039
- spec = torch.tensor(spec) if not isinstance(spec, torch.Tensor) else spec
1040
- wav_np = waveform.numpy().astype(np.float64)
 
 
 
 
 
 
 
 
 
 
 
 
 
1041
  f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
1042
  f0 = pw.stonemask(wav_np, f0, t, sample_rate)
1043
- f0 = torch.from_numpy(f0).float()
1044
- transcription = batch.get("sentence", batch.get("transcription", ""))
1045
- input_ids = tokenizer.encode(transcription)
 
1046
  return {
1047
  "spectrogram": spec,
1048
  "f0": f0,
1049
- "input_ids": input_ids,
1050
- "labels": input_ids,
 
1051
  }
1052
 
1053
- def prepare_datasets(tokenizer, token: str, sample_rate=16000, n_mels=128, n_fft=1024, hop_length=256):
1054
- raw_train = load_dataset(
1055
- "google/fleurs", "en_us", token=token, split="train[:1000]", trust_remote_code=True
1056
- )
1057
- raw_test = load_dataset(
1058
- "google/fleurs", "en_us", token=token, split="test[:100]", trust_remote_code=True
1059
- )
1060
- raw_train = raw_train.cast_column("audio", Audio(sampling_rate=sample_rate))
1061
- raw_test = raw_test.cast_column("audio", Audio(sampling_rate=sample_rate))
1062
- train_dataset = raw_train.map(
1063
- lambda x: extract_features(x, tokenizer, sample_rate, n_mels, n_fft, hop_length),
1064
- remove_columns=raw_train.column_names
1065
- )
1066
- test_dataset = raw_test.map(
1067
- lambda x: extract_features(x, tokenizer, sample_rate, n_mels, n_fft, hop_length),
1068
- remove_columns=raw_test.column_names
1069
- )
1070
- return train_dataset, test_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
 
1072
  @dataclass
1073
  class DataCollator:
1074
  tokenizer: Any
1075
 
1076
  def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
 
 
 
 
1077
  pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1078
  bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1079
  eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
1080
 
1081
- # Gather and pad spectrograms and f0
1082
- specs = [f["spectrogram"] for f in features]
1083
- f0s = [f["f0"] for f in features]
1084
- specs = [torch.tensor(s) if not isinstance(s, torch.Tensor) else s for s in specs]
1085
- f0s = [torch.tensor(f0) if not isinstance(f0, torch.Tensor) else f0 for f0 in f0s]
1086
- max_spec_len = max(s.shape[-1] for s in specs)
1087
- max_f0_len = max(f0.shape[-1] for f0 in f0s)
1088
- padded_specs = torch.stack([
1089
- torch.nn.functional.pad(s, (0, max_spec_len - s.shape[-1])) for s in specs
1090
- ])
1091
- padded_f0s = torch.stack([
1092
- torch.nn.functional.pad(f0, (0, max_f0_len - f0.shape[-1])) for f0 in f0s
1093
- ])
1094
-
1095
- input_ids_list = [f["input_ids"] for f in features]
1096
- # Ensure all are lists, not tensors
1097
- input_ids_list = [ids.tolist() if isinstance(ids, torch.Tensor) else ids for ids in input_ids_list]
1098
- max_len = max(len(ids) for ids in input_ids_list)
1099
- # Add BOS to input_ids, EOS to labels, pad both to max_len+1
1100
- input_ids = [[bos_token_id] + ids + [pad_token_id] * (max_len - len(ids)) for ids in input_ids_list]
1101
- labels = [ids + [eos_token_id] + [pad_token_id] * (max_len - len(ids)) for ids in input_ids_list]
1102
- input_ids = torch.tensor(input_ids, dtype=torch.long)
1103
- labels = torch.tensor(labels, dtype=torch.long)
1104
-
1105
- return {
1106
- "spectrogram": padded_specs,
1107
- "f0": padded_f0s,
1108
- "input_ids": input_ids,
1109
- "labels": labels,
1110
- }
 
 
 
 
 
1111
 
1112
  def levenshtein(reference_words, hypothesis_words):
1113
  m, n = len(reference_words), len(hypothesis_words)
@@ -1137,7 +1277,7 @@ def wer_batch(references, hypotheses):
1137
  total_words += len(ref_words)
1138
  return (total_errors / total_words) * 100 if total_words > 0 else 0.0
1139
 
1140
- def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0):
1141
  pred_ids = pred.predictions
1142
  label_ids = pred.label_ids
1143
  if isinstance(pred_ids, tuple):
@@ -1146,21 +1286,25 @@ def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samp
1146
  if not isinstance(pred_ids, torch.Tensor):
1147
  pred_ids = torch.tensor(pred_ids)
1148
  pred_ids = pred_ids.argmax(dim=-1)
 
1149
  pred_ids = pred_ids.tolist()
1150
  label_ids = label_ids.tolist()
1151
  pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1152
  label_ids = [[pad_token_id if token == -100 else token for token in seq] for seq in label_ids]
1153
- def strip_trailing(seq, pad_token_id):
1154
- while seq and seq[-1] == pad_token_id:
1155
- seq = seq[:-1]
1156
- return seq
1157
- pred_ids = [strip_trailing(seq, pad_token_id) for seq in pred_ids]
1158
- label_ids = [strip_trailing(seq, pad_token_id) for seq in label_ids]
1159
  if print_pred:
1160
  for i in range(min(num_samples, len(pred_ids))):
1161
- print(f"Pred: '{tokenizer.batch_decode([pred_ids[i]])[0]}'")
1162
- print(f"Label: '{tokenizer.batch_decode([label_ids[i]])[0]}'")
 
 
 
 
 
 
 
1163
  print("-" * 40)
 
1164
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1165
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1166
  wer = wer_batch(label_str, pred_str)
@@ -1170,9 +1314,9 @@ def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samp
1170
  else:
1171
  trainable_params = 0.0
1172
  efficiency_score = 0.0
 
1173
  return {
1174
  "wer": float(wer),
1175
- "trainable_params_M": float(trainable_params),
1176
  "efficiency_score": float(efficiency_score),
1177
  }
1178
 
@@ -1183,10 +1327,13 @@ def main():
1183
  tokenizer = setup_tokenizer(token)
1184
  train_dataset, test_dataset = prepare_datasets(tokenizer, token)
1185
  param = Dimensions(
1186
- mels=128, aud_ctx=1500, aud_head=4, aud_dims=512, aud_idx=4,
1187
- vocab=40000, text_ctx=512, text_head=4, text_dims=512, text_idx=4,
1188
- act="swish", debug={"radius"}, cross_attn=True, features=["spectrogram"]
1189
- )
 
 
 
1190
  model = Echo(param).to('cuda')
1191
  print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
1192
  print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
@@ -1202,7 +1349,6 @@ def main():
1202
  logging_steps=10,
1203
  logging_dir=log_dir,
1204
  eval_strategy="steps",
1205
-
1206
  save_strategy="steps",
1207
  report_to=["tensorboard"],
1208
  push_to_hub=False,
@@ -1214,17 +1360,28 @@ def main():
1214
  batch_eval_metrics=False,
1215
  )
1216
  from functools import partial
1217
- metrics_fn = partial(compute_metrics, print_pred=True, num_samples=2, tokenizer=tokenizer, model=model)
 
 
 
 
 
 
 
 
 
1218
  trainer = Seq2SeqTrainer(
1219
  args=training_args,
1220
  model=model,
1221
- train_dataset=train_dataset,
1222
- eval_dataset=test_dataset,
1223
- data_collator=DataCollator(tokenizer=tokenizer),
1224
  compute_metrics=metrics_fn,
 
1225
  )
1226
  model.init_weights()
1227
  trainer.train()
1228
 
1229
  if __name__ == "__main__":
1230
- main()
 
 
8
  import torch.nn.functional as F
9
  import torch.nn.init as init
10
  from torch import nn, Tensor
11
+
12
+ import matplotlib.pyplot as plt
13
+ from typing import Optional, Dict, Union, List, Tuple, Any
14
  import numpy as np
15
+ from functools import partial
16
+ from datetime import datetime
17
+ from datasets import load_dataset, Audio
18
+ from transformers.trainer_seq2seq import Seq2SeqTrainer
19
+ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
20
  import transformers
21
  from dataclasses import dataclass
22
  from opimizer import MaxFactor
 
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
  torch.set_float32_matmul_precision('high')
27
  transformers.utils.logging.set_verbosity_error()
28
+
29
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
  dtype = torch.float32
31
+
32
+ warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
+ PATH = 'E:/hf'
36
+ os.environ['HF_HOME'] = PATH
37
+ os.environ['HF_DATASETS_CACHE'] = PATH
38
+ os.environ['TORCH_HOME'] = PATH
39
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
40
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
+
42
+ def get_activation(act: str) -> nn.Module:
43
+ """Get activation function by name."""
44
+ act_map = {
45
+ "gelu": nn.GELU(),
46
+ "relu": nn.ReLU(),
47
+ "sigmoid": nn.Sigmoid(),
48
+ "tanh": nn.Tanh(),
49
+ "swish": nn.SiLU(),
50
+ "tanhshrink": nn.Tanhshrink(),
51
+ "softplus": nn.Softplus(),
52
+ "softshrink": nn.Softshrink(),
53
+ "leaky_relu": nn.LeakyReLU(),
54
+ "elu": nn.ELU()
55
+ }
56
+ return act_map.get(act, nn.GELU())
57
+
58
  @dataclass
59
  class Dimensions:
60
  vocab: int
61
+ ctx: int
62
+ dims: int
63
+ head: int
64
+ layer: int
65
  mels: int
 
 
 
 
66
  act: str
67
  debug: List[str]
68
  cross_attn: bool
 
86
  use_cache=False,
87
  return_timestamps=False)
88
 
89
+ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
90
+ title="", markers=None, marker_labels=None,
91
+ show_voiced_regions=True, show_energy=False):
92
+ num_plots = sum([x is not None, w is not None, p is not None, per is not None])
93
+ if num_plots == 0:
94
+ raise ValueError("No data to plot. Please provide at least one input tensor.")
95
+ t_spans = []
96
+
97
+ if w is not None:
98
+ w_np = w[sample_idx].detach().cpu().numpy()
99
+ if w_np.ndim > 1:
100
+ w_np = w_np.squeeze()
101
+ t_spans.append(len(w_np) / sr)
102
+ if x is not None:
103
+ x_np = x[sample_idx].detach().cpu().numpy()
104
+ if x_np.shape[0] < x_np.shape[1]:
105
+ x_np = x_np.T
106
+ t_spans.append(x_np.shape[0] * hop_length / sr)
107
+ if p is not None:
108
+ p_np = p[sample_idx].detach().cpu().numpy()
109
+ if p_np.ndim > 1:
110
+ p_np = p_np.squeeze()
111
+ t_spans.append(len(p_np) * hop_length / sr)
112
+ if per is not None:
113
+ per_np = per[sample_idx].detach().cpu().numpy()
114
+ if per_np.ndim > 1:
115
+ per_np = per_np.squeeze()
116
+ t_spans.append(len(per_np) * hop_length / sr)
117
+ max_t = max(t_spans) if t_spans else 0
118
+ fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
119
+ if num_plots == 1:
120
+ axs = [axs]
121
+ if show_voiced_regions and per is not None:
122
+ per_np = per[sample_idx].detach().cpu().numpy()
123
+ if per_np.ndim > 1:
124
+ per_np = per_np.squeeze()
125
+ t_per = np.arange(len(per_np)) * hop_length / sr
126
+ threshold = 0.5
127
+ for ax in axs:
128
+ for i in range(len(per_np)-1):
129
+ if per_np[i] > threshold:
130
+ ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
131
+ cu_ax = 0
132
+ if w is not None:
133
+ w_np = w[sample_idx].detach().cpu().numpy()
134
+ if w_np.ndim > 1:
135
+ w_np = w_np.squeeze()
136
+ t = np.arange(len(w_np)) / sr
137
+ axs[cu_ax].plot(t, w_np, color="tab:blue")
138
+
139
+ if show_energy:
140
+ frame_length = hop_length
141
+ hop_length_energy = hop_length // 2
142
+ energy = []
143
+ for i in range(0, len(w_np)-frame_length, hop_length_energy):
144
+ frame = w_np[i:i+frame_length]
145
+ energy.append(np.sqrt(np.mean(frame**2)))
146
+ energy = np.array(energy)
147
+ energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
148
+ t_energy = np.arange(len(energy)) * hop_length_energy / sr
149
+ axs[cu_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
150
+ axs[cu_ax].legend(loc='upper right')
151
+ axs[cu_ax].set_title("Waveform")
152
+ axs[cu_ax].set_ylabel("Amplitude")
153
+ axs[cu_ax].set_xlim([0, max_t])
154
+ axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
155
+ cu_ax += 1
156
+
157
+ if x is not None:
158
+ x_np = x[sample_idx].detach().cpu().numpy()
159
+ if x_np.shape[0] < x_np.shape[1]:
160
+ x_np = x_np.T
161
+ axs[cu_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
162
+ extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
163
+ axs[cu_ax].set_title("Spectrogram")
164
+ axs[cu_ax].set_ylabel("Mel Bin")
165
+ axs[cu_ax].set_xlim([0, max_t])
166
+ axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
167
+ cu_ax += 1
168
+
169
+ if p is not None:
170
+ p_np = p[sample_idx].detach().cpu().numpy()
171
+ if p_np.ndim > 1:
172
+ p_np = p_np.squeeze()
173
+ t_p = np.arange(len(p_np)) * hop_length / sr
174
+ axs[cu_ax].plot(t_p, p_np, color="tab:green")
175
+ axs[cu_ax].set_title("Pitch")
176
+ axs[cu_ax].set_ylabel("Frequency (Hz)")
177
+ axs[cu_ax].set_xlim([0, max_t])
178
+ axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
179
+ axs[cu_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
180
+ cu_ax += 1
181
+
182
+ if per is not None:
183
+ per_np = per[sample_idx].detach().cpu().numpy()
184
+ if per_np.ndim > 1:
185
+ per_np = per_np.squeeze()
186
+ t_per = np.arange(len(per_np)) * hop_length / sr
187
+ axs[cu_ax].plot(t_per, per_np, color="tab:red")
188
+ axs[cu_ax].set_title("Period (Voice Activity)")
189
+ axs[cu_ax].set_ylabel("periodocity")
190
+ axs[cu_ax].set_xlim([0, max_t])
191
+ axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
192
+ axs[cu_ax].set_ylim([-0.05, 1.05])
193
+ axs[cu_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
194
+
195
+ if markers is not None:
196
+ for i, t in enumerate(markers):
197
+ label = marker_labels[i] if marker_labels and i < len(marker_labels) else None
198
+ for ax in axs:
199
+ ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
200
+ if marker_labels:
201
+ axs[0].legend(loc='upper right', fontsize='small')
202
+ axs[-1].set_xlabel("t (s)")
203
+ fig.suptitle(title, fontsize=16)
204
+ plt.tight_layout(rect=[0, 0, 1, 0.97]) # type: ignore
205
+ plt.show()
206
+ return fig
207
+
208
+ def valid(default_value, *items):
209
+ """Get first non-None item"""
210
+ for item in items:
211
+ if item is not None:
212
+ return item
213
+ return default_value
214
+
215
  def dict_to(d, device, dtype=dtype):
216
  return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
217
  for k, v in d.items()}
 
253
  self.eps = eps
254
  self.elementwise_affine = elementwise_affine
255
  if self.elementwise_affine:
256
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
257
  init.ones_(self.weight)
258
  else:
259
  self.register_parameter("weight", None)
260
  def forward(self, x):
261
+ return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
262
 
263
  def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
264
  weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
265
  eps: float = 1e-5) -> Tensor:
266
+ return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
267
 
268
  def get_device():
269
  return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
281
  scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
282
  return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
283
 
 
284
  class rotary(nn.Module):
285
+ def __init__(self, dims, head, max_ctx=1500, radii=True, debug: List[str] = [], use_pbias=False):
286
  super(rotary, self).__init__()
287
 
288
  self.use_pbias = use_pbias
 
295
  self.counter = 0
296
  self.last_theta = None
297
 
298
+ self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
299
+ theta = (torch.tensor(10000, device=device, dtype=dtype))
300
+ self.theta = nn.Parameter(theta, requires_grad=True)
301
+ self.theta_values = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
+ def mel_scale_scalar(self, freq: float) -> float:
304
  return 1127.0 * math.log(1.0 + freq / 700.0)
305
 
306
+ def mel_scale(self, freq: Tensor) -> Tensor:
307
  return 1127.0 * (1.0 + freq / 700.0).log()
308
 
 
 
 
 
 
 
 
 
 
309
  def pitch_bias(self, f0):
310
  if f0 is None:
311
  return None
 
315
  f0_norm.unsqueeze(1)))
316
  return f0_sim.unsqueeze(0).unsqueeze(0)
317
 
318
+ def theta_freqs(self, theta):
319
+ if theta.dim() == 0:
320
+ theta = theta.unsqueeze(0)
321
+ freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
322
+ torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
323
+ self.dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
324
+
325
+ return freq
326
+
327
+ def _apply_radii(self, freqs, f0, ctx):
328
+ if self.radii and f0 is not None:
329
+ radius = f0.to(device, dtype)
330
+ L = radius.shape[0]
331
+ if L != ctx:
332
+ F = L / ctx
333
+ idx = torch.arange(ctx, device=f0.device)
334
+ idx = (idx * F).long().clamp(0, L - 1)
335
+ radius = radius[idx]
336
+ return torch.polar(radius.unsqueeze(-1), freqs)
337
+ else:
338
+ return torch.polar(torch.ones_like(freqs), freqs)
339
 
340
  def forward(self, x=None, enc=None, layer=None, feature_type="audio") -> Tensor:
341
  f0 = enc.get("f0") if enc is not None else None
342
+
343
  if isinstance(x, int):
344
  ctx = x
345
  elif isinstance(x, torch.Tensor) and x.ndim == 2:
 
347
  elif isinstance(x, torch.Tensor) and x.ndim == 3:
348
  batch, ctx, dims = x.shape
349
  else:
350
+ batch, head, ctx, head_dim = x.shape # type: ignore
 
 
 
 
 
 
 
351
 
352
  if f0 is not None:
353
+ if f0.dim() == 2:
354
+ f0 = f0.squeeze(0)
355
+ theta = f0 + self.theta
356
  else:
357
  theta = self.theta
358
 
359
  freqs = self.theta_freqs(theta)
360
+ t = torch.arange(ctx, device=device, dtype=dtype)
361
+ freqs = t[:, None] * freqs
362
+
363
  if self.radii and f0 is not None:
364
  radius = f0.to(device, dtype)
365
+ freqs = torch.polar(radius.unsqueeze(-1), freqs)
 
 
 
 
 
 
366
  else:
367
+ radius = torch.ones_like(freqs)
368
+ freqs = torch.polar(radius, freqs)
 
 
 
369
 
370
+ if "radius" in self.debug and self.counter == 10:
371
+ theta_value = theta.mean()
372
+ radius_shape = radius.shape if 'radius' in locals() else "N/A"
373
+ radius_mean = radius.mean() if 'radius' in locals() else 0.0
374
+ print(f" [{layer}] [Radius] {radius_shape} {radius_mean:.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
375
+ print(f" [{layer}] [Radius] {radius}")
376
+ # self.theta_values.append(theta.item())
377
  self.counter += 1
378
  return freqs.unsqueeze(0)
379
 
 
390
  x1 = x1.view(orig_shape)
391
  return torch.cat([x1.type_as(x), x2], dim=-1)
392
 
 
393
  class MultiheadA(nn.Module):
394
+
395
  rbf = False
396
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
397
+ zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
398
  super(MultiheadA, self).__init__()
399
 
400
  self.dims = dims
 
426
  )
427
  else:
428
  self.rope = None
429
+
430
+ def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
431
+ q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
432
+ k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
433
+ qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
434
+ qk_cosine = qk_cosine + mask
435
+ weights = F.softmax(qk_cosine, dim=-1)
436
+ out = torch.matmul(weights, v)
437
+ return out
438
+
439
+ def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
440
+ scale = (self.dims // self.head) ** -0.25
441
+ dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
442
+ if rbf_ratio <= 0.0:
443
+ return dot_scores
444
+ q_norm = q.pow(2).sum(dim=-1, keepdim=True)
445
+ k_norm = k.pow(2).sum(dim=-1, keepdim=True)
446
+ qk = torch.matmul(q, k.transpose(-1, -2))
447
+ dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
448
+ rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
449
+ return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
450
 
451
+ def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
452
 
453
  x = x.to(device, dtype)
454
  if xa is not None:
 
467
  q2 = q.shape[2]
468
  k2 = k.shape[2]
469
 
470
+ q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer))) # type: ignore
471
+ k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer))) # type: ignore
472
  else:
473
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
474
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
476
 
477
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
478
 
479
+ if self.rbf:
480
+ qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
481
+ if self.use_pbias:
482
+ pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None) # type: ignore
483
+ if pbias is not None:
484
+ qk = qk + pbias[:,:,:q2,:q2]
485
+
486
  token_ids = k[:, :, :, 0]
487
  zscale = torch.ones_like(token_ids)
488
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
489
  zscale[token_ids.float() == self.pad_token] = fzero
490
 
491
  if mask is not None:
492
+ # mask = mask[:q2, :q2]#torch.tril(torch.ones(q2, q2, device=q.device))
493
+ # audio_mask = torch.ones(q2, k2 - q2, device=q.device)
494
+ # mask = torch.cat([mask, audio_mask], dim=-1)
495
+ mask = mask.unsqueeze(0).unsqueeze(0)
496
+ qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
497
+
498
  qk = qk * zscale.unsqueeze(-2)
499
  w = F.softmax(qk, dim=-1).to(q.dtype)
500
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
 
505
  return self.o(wv), qk
506
 
507
  class t_gate(nn.Module):
508
+ def __init__(self, dims, num_types=4, enabled=True):
509
  super().__init__()
510
+ self.enabled = enabled
511
  self.gate_projections = nn.ModuleList([
512
  nn.Sequential(Linear(dims, 1), nn.Sigmoid())
513
  for _ in range(num_types)])
 
515
  Linear(dims, num_types),
516
  nn.Softmax(dim=-1))
517
  def forward(self, x):
518
+ if not self.enabled:
519
+ return None
520
  type_probs = self.type_classifier(x)
521
  gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
522
  comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
523
  return comb_gate
524
 
525
  class m_gate(nn.Module):
526
+ def __init__(self, dims, mem_size=64, enabled=True):
527
  super().__init__()
528
+ self.enabled = enabled
529
+ if enabled:
530
+ self.m_key = nn.Parameter(torch.randn(mem_size, dims))
531
+ self.m_val = nn.Parameter(torch.randn(mem_size, 1))
532
+ self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
533
+
534
  def forward(self, x):
535
+ if not self.enabled:
536
+ return None
537
  d_gate = torch.sigmoid(self.gate_proj(x))
538
  attention = torch.matmul(x, self.m_key.transpose(0, 1))
539
  attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
 
542
  return 0.5 * (d_gate + m_gate)
543
 
544
  class c_gate(nn.Module):
545
+ def __init__(self, dims, enabled=True):
546
  super().__init__()
547
+ self.enabled = enabled
548
+ if enabled:
549
+ self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
550
+ self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
551
+ self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
552
+ self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
553
+ self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
554
+ self.integ = Linear(dims*5, dims)
555
 
556
  def forward(self, x, features):
557
+ if not self.enabled:
558
+ return None
559
  s_feat = features.get("spectrogram", x)
560
  w_feat = features.get("waveform", x)
561
  p_feat = features.get("pitch", x)
 
569
  comb = torch.cat([s, w, p, e, ph], dim=-1)
570
  return self.integ(comb)
571
 
572
+ class mlp_gate(nn.Module):
573
+ def __init__(self, dims, enabled=True):
574
+ super().__init__()
575
+ self.enabled = enabled
576
+ if enabled:
577
+ self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
578
+
579
+ def forward(self, x):
580
+ if not self.enabled:
581
+ return None
582
+ return self.gate(x)
583
+
584
  class Residual(nn.Module):
585
  _seen = set()
586
+ def __init__(self, ctx, dims, head, act, debug: List[str] = [],
587
  tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
588
  super().__init__()
589
 
 
591
  self.head = head
592
  self.ctx = ctx
593
  self.head_dim = dims // head
 
594
  self.features = features
595
  self.debug = debug
596
  self.counter = 0
597
  self.dropout = 0.01
598
+
 
 
 
 
599
  self.blend = nn.Parameter(torch.tensor(0.5))
600
+ act_fn = get_activation(act)
601
+ self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
602
+
603
+ if not any([tgate, mgate, cgate]):
604
+ self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
605
+ else:
606
+ self.mlp_gate = None
 
 
 
607
 
608
  mlp = dims * 4
609
  self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
610
 
611
+ self.t_gate = t_gate(dims=dims, num_types=4*2, enabled=tgate)
612
+ self.m_gate = m_gate(dims=dims, mem_size=mem_size, enabled=mgate)
613
+ self.c_gate = c_gate(dims=dims, enabled=cgate)
614
+ self.mlp_gate = mlp_gate(dims=dims, enabled=not any([tgate, mgate, cgate]))
615
 
616
  self.lna = RMSNorm(dims)
617
+ self.lnb = RMSNorm(dims)
618
  self.lnc = RMSNorm(dims)
619
 
 
 
 
620
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
 
 
 
 
 
 
 
 
 
621
 
622
+ b = torch.sigmoid(self.blend)
623
+ ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer)[0]
624
+ bx = b * ax + (1 - b) * x
625
+ cx = self.lnb(bx)
626
+ dx = self.mlp(cx)
627
+ ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
628
+ fx = x + ex + dx
629
+ gx = self.lnc(fx)
630
+ return gx
 
 
 
 
631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  class FEncoder(nn.Module):
633
  def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
634
  super().__init__()
 
639
  self.use_rope = use_rope
640
  self.dims = dims
641
 
642
+ act_fn = get_activation(act)
 
643
 
644
  self.encoder = nn.Sequential(
645
  Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
 
650
  if spec_shape is not None:
651
  self.rope = rotary(
652
  dims=self.head_dim,
653
+ head=self.head,
654
  use_2d_axial=True,
655
  spec_shape=spec_shape, debug=[])
656
  else:
657
  self.rope = rotary(
658
  dims=self.head_dim,
659
+ head=self.head,
660
  use_2d_axial=False, debug=[])
661
  else:
662
  self.rope = None
 
670
  feature_type = "spectrogram"
671
  batch, ctx, dims = x.shape
672
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
673
+ if feature_type == "spectrogram" and self.rope is not None:
674
  rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
675
  else:
676
  rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
 
698
  self.use_rope = use_rope
699
  self.dims = dims
700
 
701
+ act_fn = get_activation(act)
 
702
 
703
  self.downsample = nn.Sequential(
704
  Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
 
711
  if use_rope:
712
  self.rope = rotary(
713
  dims=self.head_dim,
714
+ head=self.head,
715
+ debug=[])
716
  else:
717
  self.rope = None
718
  self.positional = lambda length: sinusoids(length, dims)
 
749
  self.use_rope = use_rope
750
  self.dims = dims
751
 
752
+ act_fn = get_activation(act)
 
753
 
754
  self.encoder = nn.Sequential(
755
  Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
 
759
  if use_rope:
760
  self.rope = rotary(
761
  dims=self.head_dim,
762
+ head=self.head,
763
+ debug=[])
764
  else:
765
  self.rope = None
766
  self.positional = lambda length: sinusoids(length, dims)
 
786
  x = self.norm(x)
787
  return x
788
 
789
+ class SpeechTransformer(nn.Module):
790
  _seen = set()
791
+ def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
792
+ super(SpeechTransformer, self).__init__()
793
 
794
  self.dims = dims
795
  self.head = head
 
799
  self.counter = 0
800
  self.features = features
801
  self.dropout = 0.01
802
+ self.sequential = "sequential" in debug
803
+ act_fn = get_activation(act)
804
 
805
+ self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
806
+ self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
807
+ self.register_buffer("audio_embedding", sinusoids(ctx, dims))
808
 
809
  if features == ["spectrogram", "waveform", "pitch"]:
810
  cgate=True
 
839
  if "phase" in features else None),
840
  })
841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  self.block = nn.ModuleList([
843
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
844
+ for _ in range(layer)])
 
 
 
 
845
 
846
+ self.blend = nn.Parameter(torch.tensor(0.5))
847
  self.ln_dec = RMSNorm(dims)
848
 
849
+ def get_mask(text_ctx, aud_ctx):
850
+ mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device), diagonal=0)
851
+ audio_mask = torch.ones(text_ctx, aud_ctx - text_ctx, device=device)
852
+ full_mask = torch.cat([mask, audio_mask], dim=-1)
853
+ return full_mask
854
+ self.register_buffer("mask_ax", get_mask(ctx, ctx), persistent=False)
855
+
856
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
857
  self.register_buffer("mask", mask, persistent=False)
858
 
859
+ def forward(self, enc, layer="encoder"):
860
+ enc = dict_to(enc, device, dtype)
861
 
862
+ x = enc.get("input_ids").long()
 
 
 
863
  x = self.token(x) + self.positional[:x.shape[1]]
864
  x = F.dropout(x, p=self.dropout, training=self.training)
865
+
866
+ out = {}
867
+ out.update(enc)
868
+
869
+ for f in self.features:
870
+ if f in enc and f in self.blocks:
871
+ xa = enc[f]
872
+ for block in self.blocks[f]: # type: ignore
873
+ xa = block(xa, enc=enc, layer=layer)
874
+ out[f] = xa
875
+ xa = xa + self.audio_embedding[:xa.shape[1]]
876
 
877
  for block in self.block:
878
+ mask = self.mask[:x.shape[1], :x.shape[1]]
879
  x = block(x, xa=None, mask=mask, enc=None, layer=layer)
880
 
881
+ for f in self.features:
882
  if f in enc:
883
+ mask = self.mask_ax[:x.shape[1], :xa.shape[1]]
884
+ for block in self.block:
885
+ out = block(x, xa=xa, mask=mask, enc=None, layer=layer)
 
886
  if self.sequential:
887
  x = out
888
  else:
889
+ a = torch.sigmoid(self.blend)
890
  x = a * out + (1 - a) * x
 
891
 
892
  x = self.ln_dec(x)
893
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
 
897
  super().__init__()
898
  self.param = param
899
 
900
+ self.SpeechTransformer = SpeechTransformer(
901
+ vocab=param.vocab,
902
  mels=param.mels,
903
+ ctx=param.ctx,
904
+ dims=param.dims,
905
+ head=param.head,
906
+ layer=param.layer,
 
907
  debug=param.debug,
908
  features=param.features,
909
+ act=param.act,
910
  )
911
 
 
 
 
 
 
 
 
 
 
 
 
912
  def forward(self,
913
  labels=None,
 
914
  input_ids=None,
915
+ waveform: Optional[torch.Tensor]=None,
916
+ spectrogram: Optional[torch.Tensor]=None,
917
  pitch: Optional[torch.Tensor]=None,
918
  f0: Optional[torch.Tensor]=None,
919
  envelope: Optional[torch.Tensor]=None,
920
  phase: Optional[torch.Tensor]=None,
921
+ ) -> Dict[str, Optional[torch.Tensor]]:
922
 
923
  encoder_inputs = {}
924
  if spectrogram is not None:
 
933
  encoder_inputs["phase"] = phase
934
  if f0 is not None:
935
  encoder_inputs["f0"] = f0
936
+ if input_ids is not None:
937
+ encoder_inputs["input_ids"] = input_ids
938
 
939
+ logits = self.SpeechTransformer(encoder_inputs)
 
940
 
941
  loss = None
942
  if labels is not None:
 
956
  std = 0.02
957
  self.init_counts = {
958
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
959
+ "Conv2d": 0, "SEBlock": 0, "SpeechTransformer": 0,
960
  "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
961
  "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
962
  "WEncoder": 0, "PEncoder": 0}
 
982
  nn.init.zeros_(module.bias)
983
  self.init_counts["Conv2d"] += 1
984
  elif isinstance(module, MultiheadA):
 
985
  self.init_counts["MultiheadA"] += 1
986
+ elif isinstance(module, SpeechTransformer):
987
+ self.init_counts["SpeechTransformer"] += 1
 
 
988
  elif isinstance(module, Residual):
989
  self.init_counts["Residual"] += 1
990
 
 
1022
  encoder_inputs["phase"] = phase
1023
  if f0 is not None:
1024
  encoder_inputs["f0"] = f0
1025
+
1026
  for i in range(max_length - 1):
1027
  with torch.no_grad():
1028
+ encoder_inputs["input_ids"] = ids
1029
+ logits = self.SpeechTransformer(encoder_inputs)
1030
  next_token_logits = logits[:, -1, :]
1031
  if i < min_length:
1032
  next_token_logits[:, eos_token_id] = 0
 
1051
  })
1052
  return Config()
1053
 
1054
+ def setup_tokenizer(token: str):
 
1055
  from tokenizers import Tokenizer
1056
+ tokenizer = Tokenizer.from_file("./tokenizer.json")
1057
  orig_encode = tokenizer.encode
1058
  def enc(text, add_special_tokens=True):
1059
  ids = orig_encode(text).ids
 
1070
  ids = ids[1:]
1071
  while ids and ids[-1] in [0, 2]:
1072
  ids = ids[:-1]
1073
+
1074
+ if isinstance(ids, torch.Tensor):
1075
+ ids = ids.tolist()
1076
+ elif isinstance(ids, np.ndarray):
1077
+ ids = ids.tolist()
1078
  results.append(tokenizer.decode(ids))
1079
  return results
1080
 
 
1089
  tokenizer.eos_token_id = 2
1090
  return tokenizer
1091
 
1092
+ def load_wave(wave_data, sample_rate):
1093
+ if isinstance(wave_data, str):
1094
+ waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1095
+ elif isinstance(wave_data, dict):
1096
+ waveform = torch.tensor(data=wave_data["array"]).float()
1097
+ sr = wave_data["sampling_rate"]
1098
+ else:
1099
+ raise TypeError("Invalid wave_data format.")
1100
+
1101
+ return waveform
1102
+
1103
+ def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **dataset_config):
1104
+
1105
  audio = batch["audio"]
1106
+ sr = audio["sampling_rate"]
1107
+ wav = load_wave(wave_data=audio, sample_rate=sr)
1108
+
1109
+ dataset_config = {
1110
+ "hop_length": 256,
1111
+ "f_min": 150,
1112
+ "f_max": 2000,
1113
+ "n_mels": 128,
1114
+ "n_fft": 1024,
1115
+ "sample_rate": 16000,
1116
+ "pad_mode": "constant",
1117
+ "center": True,
1118
+ "power": 1.0,
1119
+ "window_fn": torch.hann_window,
1120
+ "mel_scale": "htk",
1121
+ "norm": None,
1122
+ "normalized": False}
1123
+
1124
+ transform = torchaudio.transforms.MelSpectrogram(
1125
+ **dataset_config
1126
+ )
1127
+
1128
+ mel_spectrogram = transform(wav)
1129
+ log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1130
+ log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1131
+ spec = (log_mel + 4.0) / 4.0
1132
+ spec = torch.tensor(spec)
1133
+ # batch["spectrogram"] = spec
1134
+
1135
+ wav_np = wav.numpy().astype(np.float64)
1136
  f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
1137
  f0 = pw.stonemask(wav_np, f0, t, sample_rate)
1138
+ f0 = torch.from_numpy(f0)
1139
+
1140
+ labels = tokenizer.encode(batch["transcription"])
1141
+
1142
  return {
1143
  "spectrogram": spec,
1144
  "f0": f0,
1145
+ "labels": labels,
1146
+ # "waveform": wav,
1147
+ # "pitch": f0,
1148
  }
1149
 
1150
+ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
1151
+
1152
+ if sanity_check:
1153
+ test = load_dataset(
1154
+ "google/fleurs", "en_us", token=token, split="test[:10]", trust_remote_code=True
1155
+ ).cast_column("audio", Audio(sample_rate=sample_rate))
1156
+
1157
+ dataset = test.map(
1158
+ lambda x: extract_features(x, tokenizer, **dataset_config),
1159
+ remove_columns=test.column_names)
1160
+ dataset = dataset(remove_columns=["audio", "transcription"]).with_format(type="torch")
1161
+ train_dataset = dataset
1162
+ test_dataset = dataset
1163
+ else:
1164
+
1165
+ cache_dir = "./processed_datasets"
1166
+ os.makedirs(cache_dir, exist_ok=True)
1167
+ cache_file_train = os.path.join(cache_dir, "train.arrow")
1168
+ cache_file_test = os.path.join(cache_dir, "test.arrow")
1169
+
1170
+ if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
1171
+ from datasets import Dataset
1172
+ train_dataset = Dataset.load_from_disk(cache_file_train)
1173
+ test_dataset = Dataset.load_from_disk(cache_file_test)
1174
+ return train_dataset, test_dataset
1175
+
1176
+ def filter_func(x):
1177
+ return (0 < len(x["transcription"]) < 512 and
1178
+ len(x["audio"]["array"]) > 0 and
1179
+ len(x["audio"]["array"]) < 1500 * 160)
1180
+
1181
+ raw_train = load_dataset(
1182
+ "google/fleurs", "en_us", token=token, split="train[:1000]", trust_remote_code=True)
1183
+ raw_test = load_dataset(
1184
+ "google/fleurs", "en_us", token=token, split="test[:100]", trust_remote_code=True)
1185
+
1186
+ raw_train = raw_train.filter(filter_func)
1187
+ raw_test = raw_test.filter(filter_func)
1188
+
1189
+ raw_train = raw_train.cast_column("audio", Audio(sampling_rate=sample_rate))
1190
+ raw_test = raw_test.cast_column("audio", Audio(sampling_rate=sample_rate))
1191
+
1192
+ train_dataset = raw_train.map(
1193
+ lambda x: extract_features(x, tokenizer, **dataset_config),
1194
+ remove_columns=raw_train.column_names)
1195
+ test_dataset = raw_test.map(
1196
+ lambda x: extract_features(x, tokenizer, **dataset_config),
1197
+ remove_columns=raw_test.column_names)
1198
+
1199
+ train_dataset.save_to_disk(cache_file_train)
1200
+ test_dataset.save_to_disk(cache_file_test)
1201
+ return train_dataset, test_dataset
1202
 
1203
  @dataclass
1204
  class DataCollator:
1205
  tokenizer: Any
1206
 
1207
  def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1208
+ all_keys = set()
1209
+ for f in features:
1210
+ all_keys.update(f.keys())
1211
+ batch = {}
1212
  pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1213
  bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1214
  eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
1215
 
1216
+ for key in all_keys:
1217
+ if key == "labels":
1218
+ labels_list = [f["labels"] for f in features]
1219
+ max_len = max(len(l) for l in labels_list)
1220
+ all_ids, all_labels = [], []
1221
+ for label in labels_list:
1222
+ label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1223
+ decoder_input = [bos_token_id] + label_list
1224
+ label_eos = label_list + [eos_token_id]
1225
+ input_len = max_len + 1 - len(decoder_input)
1226
+ label_len = max_len + 1 - len(label_eos)
1227
+ padded_input = decoder_input + [pad_token_id] * input_len
1228
+ padded_labels = label_eos + [pad_token_id] * label_len
1229
+ all_ids.append(padded_input)
1230
+ all_labels.append(padded_labels)
1231
+ batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1232
+ batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1233
+
1234
+ elif key in ["spectrogram", "waveform", "pitch", "f0", "envelope", "phase"]:
1235
+
1236
+ items = [f[key] for f in features if key in f]
1237
+ items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
1238
+ max_len = max(item.shape[-1] for item in items)
1239
+ padded = []
1240
+ for item in items:
1241
+ pad_width = max_len - item.shape[-1]
1242
+ if pad_width > 0:
1243
+ pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1244
+ else:
1245
+ pad_item = item
1246
+ padded.append(pad_item)
1247
+ batch[key] = torch.stack(padded)
1248
+ if key == "spectrogram":
1249
+ batch["spectrogram"] = batch[key]
1250
+ return batch
1251
 
1252
  def levenshtein(reference_words, hypothesis_words):
1253
  m, n = len(reference_words), len(hypothesis_words)
 
1277
  total_words += len(ref_words)
1278
  return (total_errors / total_words) * 100 if total_words > 0 else 0.0
1279
 
1280
+ def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, optimizer=None, scheduler=None):
1281
  pred_ids = pred.predictions
1282
  label_ids = pred.label_ids
1283
  if isinstance(pred_ids, tuple):
 
1286
  if not isinstance(pred_ids, torch.Tensor):
1287
  pred_ids = torch.tensor(pred_ids)
1288
  pred_ids = pred_ids.argmax(dim=-1)
1289
+
1290
  pred_ids = pred_ids.tolist()
1291
  label_ids = label_ids.tolist()
1292
  pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1293
  label_ids = [[pad_token_id if token == -100 else token for token in seq] for seq in label_ids]
1294
+
 
 
 
 
 
1295
  if print_pred:
1296
  for i in range(min(num_samples, len(pred_ids))):
1297
+
1298
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1299
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1300
+
1301
+ print(f"Pred tokens: {pred_ids[i]}")
1302
+ print(f"Label tokens: {label_ids[i]}")
1303
+ print(f"Pred: '{pred_str[i]}'")
1304
+ print(f"Label: '{label_str[i]}'")
1305
+
1306
  print("-" * 40)
1307
+
1308
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1309
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1310
  wer = wer_batch(label_str, pred_str)
 
1314
  else:
1315
  trainable_params = 0.0
1316
  efficiency_score = 0.0
1317
+
1318
  return {
1319
  "wer": float(wer),
 
1320
  "efficiency_score": float(efficiency_score),
1321
  }
1322
 
 
1327
  tokenizer = setup_tokenizer(token)
1328
  train_dataset, test_dataset = prepare_datasets(tokenizer, token)
1329
  param = Dimensions(
1330
+ vocab=40000, ctx=2048, dims=512, head=4, layer=4,
1331
+ mels=128, act="swish",
1332
+ debug={},
1333
+ cross_attn=True,
1334
+ features=["spectrogram"]
1335
+ )
1336
+
1337
  model = Echo(param).to('cuda')
1338
  print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
1339
  print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
 
1349
  logging_steps=10,
1350
  logging_dir=log_dir,
1351
  eval_strategy="steps",
 
1352
  save_strategy="steps",
1353
  report_to=["tensorboard"],
1354
  push_to_hub=False,
 
1360
  batch_eval_metrics=False,
1361
  )
1362
  from functools import partial
1363
+ metrics_fn = partial(compute_metrics,
1364
+ print_pred=True,
1365
+ num_samples=1,
1366
+ tokenizer=tokenizer, model=model)
1367
+
1368
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.00025, eps=1e-8, weight_decay=0.025, betas=(0.9, 0.999),
1369
+ amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
1370
+
1371
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
1372
+
1373
  trainer = Seq2SeqTrainer(
1374
  args=training_args,
1375
  model=model,
1376
+ train_dataset=train_dataset, # type: ignore
1377
+ eval_dataset=test_dataset, # type: ignore
1378
+ data_collator=DataCollator(tokenizer=tokenizer), # type: ignore
1379
  compute_metrics=metrics_fn,
1380
+ optimizers=(optimizer, scheduler) # type: ignore
1381
  )
1382
  model.init_weights()
1383
  trainer.train()
1384
 
1385
  if __name__ == "__main__":
1386
+ main()
1387
+