Sin2pi commited on
Commit
2b083de
·
verified ·
1 Parent(s): e1d2095

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +153 -876
model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pyworld as pw
2
  import os
3
  import math, random
@@ -5,11 +6,8 @@ import warnings
5
  import logging
6
  import gzip
7
  import base64
8
- import re
9
- from einops import rearrange, repeat
10
  import torch
11
  import torchaudio
12
- import torchcrepe
13
  import torch.nn.functional as F
14
  import torch.nn.init as init
15
  from torch import nn, Tensor
@@ -23,21 +21,11 @@ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
23
  import transformers
24
  import evaluate
25
  from dataclasses import dataclass
26
- from math import pi, log
27
-
28
- torch.backends.cudnn.allow_tf32 = True
29
- torch.backends.cuda.matmul.allow_tf32 = True
30
- torch.set_float32_matmul_precision('high')
31
- transformers.utils.logging.set_verbosity_error()
32
 
33
  device = torch.device(device="cuda:0")
34
  dtype = torch.float32
35
 
36
- torch.set_default_dtype(dtype)
37
- warnings.filterwarnings("ignore")
38
- logging.basicConfig(level=logging.ERROR)
39
- tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32}
40
-
41
  extractor = None
42
  tokenizer = None
43
  optimizer = None
@@ -64,11 +52,6 @@ class Dimensions:
64
  features: List[str]
65
  f0_rotary: bool
66
 
67
-
68
- import numpy as np
69
- import matplotlib.pyplot as plt
70
-
71
-
72
  def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
73
  title="", markers=None, marker_labels=None,
74
  show_voiced_regions=True, show_energy=False):
@@ -147,7 +130,6 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
147
  axs[current_ax].set_ylabel("Mel Bin")
148
  axs[current_ax].set_xlim([0, max_time])
149
  axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
150
- # fig.colorbar(im, ax=axs[current_ax])
151
  current_ax += 1
152
 
153
  if p is not None:
@@ -265,6 +247,37 @@ class ParameterCycler:
265
  param.requires_grad = (x == self.current_idx)
266
  print(f"Parameter {x}: requires_grad={param.requires_grad}")
267
  self.current_idx = (self.current_idx + 1) % len(self.parameters)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  class rotary(nn.Module):
270
  _seen = set()
@@ -272,173 +285,64 @@ class rotary(nn.Module):
272
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
273
  super().__init__()
274
 
 
275
  self.use_pbias = use_pbias
276
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
277
  self.dtype = torch.float32
278
  self.debug = debug
279
  self._counter = 0
280
- self.dims = dims
281
  self.max_ctx = max_ctx
282
  self.radii = radii
283
  f0_factor = 0.5
284
- self.learned_adaptation: bool = False
285
  pitch_scale = 1.0
286
  radius = 1
287
 
288
- if self.learned_adaptation:
289
  self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=self.device, dtype=self.dtype), requires_grad=True)
290
  else:
291
  self.register_buffer('f0_scale', torch.tensor(f0_factor))
292
 
293
  self.theta = nn.Parameter(torch.tensor(theta, device=self.device, dtype=self.dtype), requires_grad=True)
294
  self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale, device=self.device, dtype=self.dtype), requires_grad=True)
295
- freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
296
  self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
297
  self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
298
 
299
- # self.cycler = ParameterCycler(parameters=[self.theta, self.pitch_scale, self.freqs])
300
- # self.reset_parameters()
301
-
302
- def get_pitch_bias(self, f0):
303
- if f0 is None:
304
- return None
305
- f0_flat = f0.squeeze().float()
306
- f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
307
- f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
308
- f0_norm.unsqueeze(1)) * self.pitch_scale)
309
- return f0_sim.unsqueeze(0).unsqueeze(0)
310
-
311
- def add_to_rotary(self):
312
- def get_sim(self, freqs):
313
- real = freqs.real.squeeze(0)
314
- imag = freqs.imag.squeeze(0)
315
- vecs = torch.cat([real.unsqueeze(-2), imag.unsqueeze(-2)], dim=-1)
316
- vecs = vecs.squeeze(-2)
317
- return F.cosine_similarity(vecs.unsqueeze(1), vecs.unsqueeze(0), dim=-1)
318
-
319
- def fwd_sim(self, x=None, f0=None):
320
- freqs = self.forward(x, f0)
321
- sim = get_sim(self, freqs)
322
- return freqs, sim
323
-
324
- rotary.get_sim = get_sim
325
- rotary.fwd_sim = fwd_sim
326
-
327
- def align_f0(self, f0, ctx):
328
- b, l = f0.shape
329
- if l == ctx:
330
- return f0.squeeze(0).float()
331
- frames_per_token = l / ctx
332
- idx = torch.arange(ctx, device=self.device, dtype=torch.float32)
333
- src_idx = (idx * frames_per_token).long().clamp(0, l-1)
334
- batch_idx = torch.arange(b, device=self.device, dtype=torch.float32).unsqueeze(1)
335
- f0 = f0[batch_idx, src_idx]
336
- return f0.squeeze(0).float()
337
-
338
- # def align_f0(self, f0, ctx):
339
- # b, l = f0.shape
340
- # if l == ctx:
341
- # return f0.squeeze(0).float()
342
- # frames = l / ctx
343
- # idx = torch.arange(ctx, device=f0.device)
344
- # f0 = (idx * frames).long()
345
- # # b_idx = torch.arange(b, device=f0.device).unsqueeze(1)
346
- # # f0 = f0[b_idx, idx.unsqueeze(0).expand(b, -1)]
347
- # return f0.squeeze(0).float()
348
-
349
- def scale_f0(self, f0):
350
- f0_min = f0.min(dim=1, keepdim=True)[0]
351
- f0_max = f0.max(dim=1, keepdim=True)[0]
352
- denom = f0_max - f0_min + 1e-8
353
- normalized_f0 = (f0 - f0_min) / denom
354
- normalized_f0 = torch.clamp(normalized_f0, 0.0, 1.0)
355
- return normalized_f0
356
-
357
- def process_f0(f0, threshold=0.05):
358
- thresholded_f0 = torch.where(f0 < threshold, torch.zeros_like(f0), f0)
359
- return thresholded_f0
360
-
361
- def map_perceptual(self, f0_mean, theta=10000.0):
362
- if f0_mean >= theta:
363
- return torch.log(f0_mean / theta)
364
- else:
365
- return -torch.log(theta / f0_mean)
366
-
367
- def linear_map(self, freq, min_freq=40.0, max_freq=400.0, target_max=10000.0):
368
- mapped_freq = ((freq - min_freq) / (max_freq - min_freq)) * target_max
369
- return mapped_freq
370
-
371
- def log_map(self, freq, min_freq=40.0, max_freq=400.0, target_max=10000.0):
372
- log_freq = torch.log(freq)
373
- log_min_freq = torch.log(min_freq)
374
- log_max_freq = torch.log(max_freq)
375
- mapped_log_freq = ((log_freq - log_min_freq) / (log_max_freq - log_min_freq)) * torch.log(torch.tensor(target_max, device=self.device))
376
- return mapped_log_freq
377
-
378
- def get_f0_adapted_freqs(self, ctx, f0=None):
379
- f0_min: float = 80.0,
380
- f0_max: float = 500.0,
381
- base_freq: float = 1.0,
382
- positions = torch.arange(ctx, device=device, dtype=torch.float)
383
- freqs = base_freq.clone()
384
- if f0 is not None:
385
- f0_norm = torch.clamp((f0 - f0_min) / (f0_max - f0_min), 0.0, 1.0)
386
- freq_mod = torch.pow(torch.linspace(0.5, 1.5, self.dims//2, device=device),
387
- f0_norm.unsqueeze(-1) * self.f0_scale)
388
- freqs = freqs * freq_mod
389
- freqs = torch.outer(positions, freqs)
390
- return torch.polar(torch.ones_like(freqs), freqs)
391
-
392
- def forward(self, x=None, f0=None, layer=None) -> Tensor:
393
- # self.cycler.toggle_requires_grad()
394
  if isinstance(x, int):
395
  ctx = x
396
  else:
397
  batch, ctx, dims = x.shape
398
  t = torch.arange(ctx, device=self.device).float()
399
-
400
- if self.learned_adaptation:
401
- freqs = self.get_f0_adapted_freqs(ctx, f0)
402
- x_complex = torch.view_as_complex(
403
- x.float().reshape(*x.shape[:-1], -1, 2).contiguous())
404
- x_rotated = x_complex * freqs.unsqueeze(0).unsqueeze(0)
405
- freqs = torch.view_as_real(x_rotated).flatten(3).type_as(x)
406
-
407
  if f0 is not None:
408
  f0_mean=f0.mean()+1e-8
409
- pitch_scale=self.pitch_scale
410
- theta=f0_mean*pitch_scale
411
- freqs = 1.0 / (theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
412
  else:
413
  freqs = self.freqs
414
-
415
  freqs = torch.einsum('i,j->ij', t, freqs)
416
  freqs = freqs.float()
417
-
418
- if self.radii and f0 is not None:
419
-
420
- radius = self.align_f0(f0, ctx)
421
-
422
- # radius = torch.clamp(radius, min=50.0, max=500.0) # Clamp to voice range
423
- # radius = radius / 500.0 # Normalize to [0.1, 1.0] range
424
- # radius = radius.float()
425
-
426
  radius = radius.float()
427
- freqs = torch.polar(radius.unsqueeze(-1), freqs)
428
  else:
429
- freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
 
 
430
  if "rotary" in self.debug:
431
  if f0 is not None:
432
  key = f"{self._counter}_{theta:.2f}"
433
  if key not in rotary._seen:
434
  if not hasattr(self, '_prev_f0_theta'):
435
  self._prev_f0_theta = theta
436
- print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
437
  elif abs(self._prev_f0_theta - theta) > 100.0:
438
- print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
439
- print(f"f0_mean: {f0_mean} Hz, freqs: {freqs.shape}, ctx: {ctx}, dims: {self.dims}, block: {layer}")
440
- if self.radii:
441
- print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
442
  self._prev_f0_theta = theta
443
  rotary._seen.add(key)
444
  self._counter += 1
@@ -474,55 +378,6 @@ class rotary(nn.Module):
474
  x1 = torch.view_as_real(x1).flatten(-2)
475
  return torch.cat([x1.type_as(x), x2], dim=-1)
476
 
477
- def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001):
478
-
479
- batch, heads, ctx, dims = q.shape
480
- token_ids = k[:, :, :, 0]
481
- is_padding = (token_ids.float() == pad_token).unsqueeze(-2)
482
- log_scale_factor = -10.0
483
- attn_mask = torch.zeros((batch, heads, ctx, ctx), device=q.device)
484
-
485
- if mask is not None:
486
- attn_mask = attn_mask + mask.unsqueeze(0).unsqueeze(0)
487
- attn_mask = torch.where(is_padding,
488
- torch.tensor(log_scale_factor, device=q.device),
489
- attn_mask)
490
- attn_output = torch.nn.functional.scaled_dot_product_attention(
491
- q, k, v, attn_mask=attn_mask,
492
- dropout_p=0.0, is_causal=False)
493
- attn_output = attn_output.permute(0, 2, 1, 3).flatten(start_dim=2)
494
- return attn_output
495
-
496
- def parallel_slice(self, q, k, v, mask=None):
497
- batch, head, ctx, dims = q.shape
498
- head_dim = self.head_dim
499
- batch, ctx, dims = q.shape
500
- ctx_len = k.shape[1]
501
- head = dims // head_dim
502
-
503
- scores = torch.zeros(batch, head, ctx, ctx_len, device=q.device)
504
-
505
- for h in range(head):
506
- start_idx = h * head_dim
507
- end_idx = start_idx + head_dim
508
- q_h = q[:, :, start_idx:end_idx]
509
- k_h = k[:, :, start_idx:end_idx]
510
-
511
- scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim)
512
-
513
- if mask is not None:
514
- scores = scores + mask.unsqueeze(0).unsqueeze(0)
515
-
516
- attn_weights = F.softmax(scores, dim=-1)
517
-
518
- output = torch.zeros_like(q)
519
- for h in range(head):
520
- start_idx = h * head_dim
521
- end_idx = start_idx + head_dim
522
- v_h = v[:, :, start_idx:end_idx]
523
- output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h)
524
- return output
525
-
526
  class MultiheadA(nn.Module):
527
  _seen = set()
528
  rbf = False
@@ -578,20 +433,19 @@ class MultiheadA(nn.Module):
578
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
579
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
580
 
581
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None,
582
- return_attn: bool = False, f0: Tensor = None, layer = None) -> tuple:
583
 
584
- batch, ctx, dims = x.shape
585
  scale = (self.dims // self.head) ** -0.25
586
 
587
  z = xa if xa is not None else x
588
  q = self.q(x).to(x.dtype)
589
  k = self.k(z).to(x.dtype)
590
  v = self.v(z).to(x.dtype)
 
591
 
592
  if self.rotary_emb:
593
- qf = self.rope(q.size(1), f0=f0, layer=layer)
594
- kf = self.rope(k.size(1), f0=f0, layer=layer)
595
 
596
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
597
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -605,18 +459,13 @@ class MultiheadA(nn.Module):
605
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
606
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
607
  batch, head, ctx, head_dim = q.shape
608
-
609
- if self.optim_attn and not return_attn:
610
- wv = optim_attn(q * scale, k * scale, v, mask=mask,
611
- pad_token=self.pad_token, fzero_val=torch.clamp(F.softplus(self.fzero), self.minz, self.maxz).item())
612
- return self.o(wv), None
613
-
614
  if self.rbf:
615
  qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
616
 
617
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
618
- if f0 is not None and self.rope.use_pbias:
619
- pbias = self.rope.pbias(f0)
620
  if pbias is not None:
621
  qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
622
  token_ids = k[:, :, :, 0]
@@ -628,8 +477,6 @@ class MultiheadA(nn.Module):
628
  mask = mask[:q.shape[2], :q.shape[2]]
629
  qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
630
  qk = qk * zscale.unsqueeze(-2)
631
- if return_attn:
632
- return qk, v
633
  w = F.softmax(qk, dim=-1).to(q.dtype)
634
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
635
 
@@ -736,12 +583,12 @@ class Residual(nn.Module):
736
  if not any([t_gate, m_gate, c_gate]):
737
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
738
 
739
- def forward(self, x, xa=None, mask=None, f0=None, mode=None, layer=None):
740
  bln = self.blend
741
- x = x + self.attna(self.lna(x), mask=mask, f0=f0, layer=layer)[0]
742
 
743
  if self.attnb and xa is not None:
744
- c = self.attnb(self.lnb(x), xa, f0=f0, mask=None, layer=layer)[0]
745
  b = torch.sigmoid(bln)
746
  x = b * x + (1 - b) * c
747
 
@@ -756,7 +603,7 @@ class Residual(nn.Module):
756
  gate = self.m_gate(normx)
757
  x = x + gate * mlp_out
758
 
759
- elif self.c_gate and mode is not None:
760
  gate_output = self.c_gate(normx, self.features)
761
  x = x + gate_output
762
 
@@ -796,7 +643,7 @@ class PEncoder(nn.Module):
796
  Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
797
  Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
798
 
799
- def forward(self, x, f0=None, layer=None):
800
  x = self.encoder(x).permute(0, 2, 1)
801
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
802
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
@@ -825,7 +672,7 @@ class WEncoder(nn.Module):
825
  self.positional = lambda length: sinusoids(length, dims)
826
  self.norm = RMSNorm(dims)
827
 
828
- def forward(self, x, f0=None, layer=None):
829
  x = self.downsample(x)
830
  x = self.encoder(x)
831
  x = x.permute(0, 2, 1)
@@ -852,13 +699,49 @@ class FEncoder(nn.Module):
852
  self.norm = RMSNorm(dims)
853
  self._norm = RMSNorm(dims)
854
 
855
- def forward(self, x, f0=None, layer=None):
856
  x = self.encoder(x).permute(0, 2, 1)
857
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
858
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
859
  x = self._norm(x)
860
  return x
861
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  class AudioEncoder(nn.Module):
863
  _seen = set()
864
  def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str],
@@ -882,8 +765,7 @@ class AudioEncoder(nn.Module):
882
  self.f0_rotary = f0_rotary
883
 
884
  self.rope = rotary(
885
- dims=self.head_dim,
886
- )
887
 
888
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
889
  "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
@@ -920,35 +802,32 @@ class AudioEncoder(nn.Module):
920
  FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
921
  for _ in range(layer)])
922
 
923
- def forward(self, x, f0=None, layer="encoder"):
 
924
  if self._counter < 1:
925
- s = x.get("spectrogram")
926
- w = x.get("waveform")
927
- p = f0 if f0 is not None else x.get("pitch")
928
  plot_waveform(x=s, w=w, p=p, hop_length=128)
929
 
930
  enc = {}
 
 
 
 
 
 
 
 
931
 
932
- # if f0 is not None:
933
- # f0 = self.f0(f0)
934
-
935
- #self.rope(x=f, f0=f0, layer=layer)
936
-
937
- for y in self.features:
938
- if y in x and y in self.blocks:
939
- f = x[y]
940
- for block in self.blocks[y]:
941
- f = block(f, f0=f0, layer=layer)
942
- enc[y] = f
943
-
944
  if "encoder" in self.debug and self._counter % 100 == 0:
945
- names = list(x.keys())
946
- shapes = {k: v.shape for k, v in x.items()}
947
  print(f"Step {self._counter}: mode: {names}")
948
  print(f"shapes: {shapes}")
949
  for name, param in self.named_parameters():
950
  if param.requires_grad:
951
- print(f"🎛️ ENCODER LAYER {name}: grad_norm={param.median():.4f}")
952
  self._counter += 1
953
  return enc
954
 
@@ -992,8 +871,23 @@ class TextDecoder(nn.Module):
992
 
993
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
994
  self.register_buffer("mask", mask, persistent=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995
 
996
- def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
 
997
  bln = self.blend
998
  x = x.to(device)
999
  if order is None:
@@ -1001,13 +895,15 @@ class TextDecoder(nn.Module):
1001
  mask = self.mask[:x.shape[1], :x.shape[1]]
1002
  x = self.token(x) + self.positional[:x.shape[1]]
1003
  x = F.dropout(x, p=self.dropout, training=self.training)
 
1004
  for block in self.block:
1005
- x = block(x, xa=None, mask=mask, layer=layer)
 
1006
  for f in order:
1007
- if f in enc:
1008
- xa = enc[f]
1009
  for block in self.blocks[f]:
1010
- out = block(x=x, xa=xa, mask=None, layer=layer)
1011
  a = torch.sigmoid(bln[f])
1012
  x = a * out + (1 - a) * x
1013
  x = self.ln_dec(x)
@@ -1015,9 +911,8 @@ class TextDecoder(nn.Module):
1015
  if "decoder" in self.debug and self._counter % 100 == 0:
1016
  for name, param in self.named_parameters():
1017
  if param.requires_grad:
1018
- print(f"🎚️ DECODER LAYER {name}: grad_norm={param.median():.4f}")
1019
  self._counter += 1
1020
-
1021
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1022
 
1023
  class Echo(nn.Module):
@@ -1075,7 +970,7 @@ class Echo(nn.Module):
1075
  spectrogram: torch.Tensor=None,
1076
  pitch: Optional[torch.Tensor]=None,
1077
  f0: Optional[torch.Tensor]=None,
1078
- # f0d: Optional[torch.Tensor]=None,
1079
  envelope: Optional[torch.Tensor]=None,
1080
  phase: Optional[torch.Tensor]=None,
1081
  ) -> Dict[str, torch.Tensor]:
@@ -1092,10 +987,12 @@ class Echo(nn.Module):
1092
  encoder_inputs["envelope"] = envelope
1093
  if phase is not None:
1094
  encoder_inputs["phase"] = phase
1095
- # if f0 is not None:
1096
- # encoder_inputs["f0"] = f0
1097
-
1098
- encoder_outputs = self.encoder(encoder_inputs, f0=f0, layer="encoder")
 
 
1099
  logits = self.decoder(input_ids, encoder_outputs)
1100
 
1101
  loss = None
@@ -1167,7 +1064,7 @@ class Echo(nn.Module):
1167
  print(f"{module_type}: {count}")
1168
 
1169
  def register_gradient_hooks(self):
1170
- """Add this method to your Echo model class"""
1171
  for name, param in self.named_parameters():
1172
  if param.requires_grad:
1173
  if "encoder" in name:
@@ -1175,643 +1072,23 @@ class Echo(nn.Module):
1175
  elif "decoder" in name:
1176
  param.register_hook(lambda grad, n=name: self._print_decoder_grad(n, grad))
1177
 
1178
- print("📊 Gradient debugging hooks registered")
1179
  return self
1180
 
1181
  def _print_encoder_grad(self, name, grad):
1182
  if grad is not None and self.count == 10:
1183
  norm = grad.median().item()
1184
- print(f"🎛️ ENCODER GRAD: {name} = {norm:.6f}")
1185
 
1186
  return None
1187
 
1188
  def _print_decoder_grad(self, name, grad):
1189
  if grad is not None and self.count == 10:
1190
  norm = grad.median().item()
1191
- print(f"🎚️ DECODER GRAD: {name} = {norm:.6f}")
1192
  return None
1193
 
1194
  def reset_counter(self):
1195
- """Reset the internal counter for debugging purposes."""
1196
  self._counter = 0
1197
  print("Counter reset to 0.")
1198
-
1199
- metric = evaluate.load(path="wer")
1200
-
1201
- def align_f0(f0, ctx):
1202
- ctx = torch.tensor(ctx)
1203
- bat, length = f0.shape
1204
- if length == ctx:
1205
- return f0
1206
- frames = length / ctx
1207
- idx = torch.arange(ctx, device=f0.device)
1208
- idx = (idx * frames).long()
1209
- batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
1210
- return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
1211
-
1212
- @dataclass
1213
- class DataCollator:
1214
- tokenizer: Any
1215
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1216
- pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1217
- bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
1218
-
1219
- batch = {}
1220
-
1221
- if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
1222
- spectrogram_list = [f["spectrogram"] for f in features]
1223
- max_len_feat = max(f.shape[-1] for f in spectrogram_list)
1224
- pad_spectrogram = []
1225
- for feat in spectrogram_list:
1226
- current_len = feat.shape[-1]
1227
- padding = max_len_feat - current_len
1228
- if padding > 0:
1229
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1230
- else:
1231
- pad_feat = feat
1232
- pad_spectrogram.append(pad_feat)
1233
- batch["spectrogram"] = torch.stack(pad_spectrogram)
1234
-
1235
- if "waveform" in features[0] and features[0]["waveform"] is not None:
1236
- waveform_list = [f["waveform"] for f in features]
1237
- max_len_wav = max(w.shape[-1] for w in waveform_list)
1238
- pad_waveforms = []
1239
- for wav in waveform_list:
1240
- current_len = wav.shape[-1]
1241
- padding = max_len_wav - current_len
1242
- if padding > 0:
1243
- if wav.ndim == 1:
1244
- wav = wav.unsqueeze(0)
1245
- pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
1246
- else:
1247
- pad_wav = wav
1248
- pad_waveforms.append(pad_wav)
1249
- batch["waveform"] = torch.stack(pad_waveforms)
1250
-
1251
- if "label" in features[0] and features[0]["label"] is not None:
1252
- labels_list = [f["label"] for f in features]
1253
- max_len = max(len(l) for l in labels_list)
1254
- all_ids = []
1255
- all_labels = []
1256
-
1257
- for label in labels_list:
1258
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1259
- decoder_input = [bos_token_id] + label_list
1260
- label_eos = label_list + [pad_token_id]
1261
- input_len = max_len + 1 - len(decoder_input)
1262
- label_len = max_len + 1 - len(label_eos)
1263
- padded_input = decoder_input + [pad_token_id] * input_len
1264
- padded_labels = label_eos + [pad_token_id] * label_len
1265
- all_ids.append(padded_input)
1266
- all_labels.append(padded_labels)
1267
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1268
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1269
-
1270
- if "pitch" in features[0] and features[0]["pitch"] is not None:
1271
- pitch_list = [f["pitch"] for f in features]
1272
- max_len_pitch = max(e.shape[-1] for e in pitch_list)
1273
- pad_pitch = []
1274
- for pitch in pitch_list:
1275
- current_len = pitch.shape[-1]
1276
- padding = max_len_pitch - current_len
1277
- if padding > 0:
1278
- pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
1279
- else:
1280
- pad_pitch_item = pitch
1281
- pad_pitch.append(pad_pitch_item)
1282
- batch["pitch"] = torch.stack(pad_pitch)
1283
-
1284
- if "f0" in features[0] and features[0]["f0"] is not None:
1285
- all_f0 = torch.cat([f["f0"] for f in features])
1286
- batch["f0"] = all_f0.unsqueeze(0)
1287
-
1288
- # if "f0" in features[0] and features[0]["f0"] is not None:
1289
- # f0_labels = batch.get("labels", None)
1290
- # aligned_features = []
1291
- # for feature in features:
1292
- # f0 = feature["f0"]
1293
- # length = f0.shape
1294
- # if length != f0_labels.shape[-1]:
1295
- # ctx = f0_labels.shape[-1]
1296
- # aligned_features.append(align_f0(f0.unsqueeze(0), ctx))
1297
- # else:
1298
- # aligned_features.append(f0)
1299
- # all_aligned_f0 = torch.cat(aligned_features)
1300
- # batch["f0d"] = all_aligned_f0
1301
-
1302
- if "envelope" in features[0] and features[0]["envelope"] is not None:
1303
- env_list = [f["envelope"] for f in features]
1304
- max_len = max(f.shape[-1] for f in env_list)
1305
- pad_env = []
1306
- for feat in env_list:
1307
- current_len = feat.shape[-1]
1308
- padding = max_len_feat - current_len
1309
- if padding > 0:
1310
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1311
- else:
1312
- pad_feat = feat
1313
- pad_env.append(pad_feat)
1314
- batch["envelope"] = torch.stack(pad_env)
1315
-
1316
- if "phase" in features[0] and features[0]["phase"] is not None:
1317
- ph_list = [f["phase"] for f in features]
1318
- max_len = max(f.shape[-1] for f in ph_list)
1319
- pad_ph = []
1320
- for feat in ph_list:
1321
- current_len = feat.shape[-1]
1322
- padding = max_len_feat - current_len
1323
- if padding > 0:
1324
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1325
- else:
1326
- pad_feat = feat
1327
- pad_ph.append(pad_feat)
1328
- batch["phase"] = torch.stack(pad_ph)
1329
- return batch
1330
-
1331
- def hilbert_transform(x):
1332
- N = x.shape[-1]
1333
- xf = torch.fft.rfft(x)
1334
- h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
1335
- if N % 2 == 0:
1336
- h[0] = h[N//2] = 1
1337
- h[1:N//2] = 2
1338
- else:
1339
- h[0] = 1
1340
- h[1:(N+1)//2] = 2
1341
- return torch.fft.irfft(xf * h, n=N)
1342
-
1343
- def analytic_signal(x):
1344
- return x + 1j * hilbert_transform(x)
1345
-
1346
- def hilbert_transform_2d(x, dim=-1):
1347
- N = x.shape[dim]
1348
- if dim == -1 or dim == len(x.shape) - 1:
1349
- xf = torch.fft.rfft(x)
1350
- else:
1351
- xf = torch.fft.rfft(x, dim=dim)
1352
- h_shape = [1] * len(x.shape)
1353
- h_shape[dim] = N // 2 + 1
1354
- h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
1355
- if dim == -1 or dim == len(x.shape) - 1:
1356
- if N % 2 == 0:
1357
- h[..., 0] = h[..., -1] = 1
1358
- h[..., 1:-1] = 2
1359
- else:
1360
- h[..., 0] = 1
1361
- h[..., 1:] = 2
1362
- else:
1363
- pass
1364
- return torch.fft.irfft(xf * h, n=N, dim=dim)
1365
-
1366
- def hilbert_transform_true_2d(x):
1367
- xf = torch.fft.rfft2(x)
1368
- h1, h2 = torch.meshgrid(
1369
- torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
1370
- torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
1371
- indexing='ij')
1372
- h = -1j / (math.pi * (h1 + 1j*h2))
1373
- h[0, 0] = 0
1374
- return torch.fft.irfft2(xf * h.to(x.device))
1375
-
1376
- def process_spectrogram_with_hilbert(spec):
1377
- analytic = spec + 1j * hilbert_transform(spec)
1378
- envelope = torch.abs(analytic)
1379
- phase = torch.angle(analytic)
1380
- return envelope, phase
1381
-
1382
- def load_wave(wave_data, sample_rate):
1383
- if isinstance(wave_data, str):
1384
- waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1385
- elif isinstance(wave_data, dict):
1386
- waveform = torch.tensor(data=wave_data["array"]).float()
1387
- sr = wave_data["sampling_rate"]
1388
- else:
1389
- raise TypeError("Invalid wave_data format.")
1390
-
1391
- if waveform.dim() == 1:
1392
- waveform = waveform.unsqueeze(0)
1393
-
1394
- if sr != sample_rate:
1395
- original_length = waveform.shape[1]
1396
- target_length = int(original_length * (sample_rate / sr))
1397
-
1398
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1399
- waveform = resampler(waveform)
1400
-
1401
- return waveform.flatten()
1402
-
1403
- def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
1404
- hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
1405
- pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
1406
- norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
1407
-
1408
- dtype = torch.float32
1409
- device = torch.device("cuda:0")
1410
- audio = batch["audio"]
1411
- sampling_rate = audio["sampling_rate"]
1412
- sr = audio["sampling_rate"]
1413
- wav = load_wave(wave_data=audio, sample_rate=sr)
1414
-
1415
- if spectrogram:
1416
- transform = torchaudio.transforms.MelSpectrogram(
1417
- f_max=fmax,
1418
- f_min=fmin,
1419
- n_mels=n_mels,
1420
- sample_rate=sr,
1421
- n_fft=n_fft,
1422
- hop_length=hop_length,
1423
- norm=norm,
1424
- normalized=normalized,
1425
- power=power,
1426
- center=center,
1427
- mel_scale=mel_scale,
1428
- window_fn=window_fn,
1429
- pad_mode=pad_mode)
1430
-
1431
- mel_spectrogram = transform(wav)
1432
- log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1433
- log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1434
- spec = (log_mel + 4.0) / 4.0
1435
- spec = torch.tensor(spec)
1436
- batch["spectrogram"] = spec
1437
-
1438
- if hilbert:
1439
- envelope_list = []
1440
- phase_list = []
1441
-
1442
- for ch_idx in range(spec.shape[0]):
1443
- envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
1444
- envelope_list.append(envelope)
1445
- phase_list.append(phase)
1446
-
1447
- batch["envelope"] = torch.stack(envelope_list)
1448
- batch["phase"] = torch.stack(phase_list)
1449
-
1450
- wav_1d = wav.unsqueeze(0)
1451
-
1452
- if waveforms:
1453
- batch["waveform"] = wav_1d
1454
-
1455
- if pitch:
1456
- wav_np = wav.numpy().astype(np.float64)
1457
- f0, t = pw.dio(wav_np, sampling_rate,
1458
- frame_period=hop_length/sampling_rate*1000)
1459
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1460
- f0 = torch.from_numpy(f0).float()
1461
- batch["pitch"] = f0.unsqueeze(0)
1462
-
1463
- if frequency:
1464
- wav_np = wav.numpy().astype(np.float64)
1465
- f0, t = pw.dio(wav_np, sampling_rate,
1466
- frame_period=hop_length/sampling_rate*1000)
1467
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1468
- f0 = f0
1469
- batch["f0"] = torch.from_numpy(f0).float()
1470
-
1471
- if spectrogram and waveforms and pitch:
1472
- spec_mean = batch["spectrogram"].mean()
1473
- spec_std = batch["spectrogram"].std() + 1e-6
1474
- batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
1475
-
1476
- wav_mean = batch["waveform"].mean()
1477
- wav_std = batch["waveform"].std() + 1e-6
1478
- batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
1479
-
1480
- if batch["pitch"].max() > 1.0:
1481
- pitch_min = 50.0
1482
- pitch_max = 600.0
1483
- batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
1484
-
1485
- batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1486
- return batch
1487
-
1488
- def compute_metrics(eval_pred, compute_result: bool = True,
1489
- print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
1490
-
1491
- pred_logits = eval_pred.predictions
1492
- label_ids = eval_pred.label_ids
1493
-
1494
- if hasattr(pred_logits, "cpu"):
1495
- pred_logits = pred_logits.cpu()
1496
- if hasattr(label_ids, "cpu"):
1497
- label_ids = label_ids.cpu()
1498
- if isinstance(pred_logits, tuple):
1499
- pred_ids = pred_logits[0]
1500
- else:
1501
- pred_ids = pred_logits
1502
- if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
1503
- if not isinstance(pred_ids, torch.Tensor):
1504
- pred_ids = torch.tensor(pred_ids)
1505
- pred_ids = pred_ids.argmax(dim=-1)
1506
- pred_ids = pred_ids.tolist()
1507
-
1508
- if hasattr(label_ids, "tolist"):
1509
- label_ids = label_ids.tolist()
1510
-
1511
- label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
1512
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1513
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1514
-
1515
- if print_pred:
1516
- for i in range(min(num_samples, len(pred_str))):
1517
- print(f"Preds: {pred_str[i]}")
1518
- print(f"Label: {label_str[i]}")
1519
- print(f"preds: {pred_ids[i]}")
1520
- print(f"label: {label_ids[i]}")
1521
- print("--------------------------------")
1522
-
1523
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1524
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1525
- wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1526
-
1527
- if model is None:
1528
- global global_model
1529
- if 'global_model' in globals():
1530
- model = global_model
1531
-
1532
- if model is not None:
1533
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1534
- if trainable_params > 0:
1535
- efficiency_score = (100 - wer) / trainable_params
1536
- else:
1537
- print("Warning: Zero trainable parameters detected")
1538
- efficiency_score = 0.0
1539
- else:
1540
- print("Warning: Model not available for parameter counting")
1541
- trainable_params = 0.0
1542
- efficiency_score = 0.0
1543
-
1544
- if hasattr(wer, "item"):
1545
- wer = wer.item()
1546
-
1547
- metrics = {
1548
- "wer": float(wer),
1549
- "trainable_params_M": float(trainable_params),
1550
- "efficiency_score": float(efficiency_score),
1551
- }
1552
-
1553
- return metrics
1554
-
1555
- logger = logging.getLogger(__name__)
1556
-
1557
- def create_model(param: Dimensions) -> Echo:
1558
- model = Echo(param).to('cuda')
1559
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1560
- total_params = sum(p.numel() for p in model.parameters())
1561
- logger.info(f"Trainable parameters: {trainable_params:,}")
1562
- logger.info(f"Total parameters: {total_params:,}")
1563
- print(f"Trainable parameters: {trainable_params:,}")
1564
- print(f"Total parameters: {total_params:,}")
1565
-
1566
- return model
1567
-
1568
- def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
1569
- from tokenizers import Tokenizer
1570
- tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
1571
- orig_encode = tokenizer.encode
1572
- def enc(text, add_special_tokens=True):
1573
- ids = orig_encode(text).ids
1574
- if not add_special_tokens:
1575
- sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1576
- ids = [id for id in ids if id not in sp_ids]
1577
- return ids
1578
- def bdec(ids_list, skip_special_tokens=True):
1579
- results = []
1580
- for ids in ids_list:
1581
- if skip_special_tokens:
1582
- ids = [id for id in ids if id not in [0, 1, 2]]
1583
- results.append(tokenizer.decode(ids))
1584
- return results
1585
- def save_pretrained(save_dir):
1586
- os.makedirs(save_dir, exist_ok=True)
1587
- tokenizer.save(f"{save_dir}/tokenizer.json")
1588
- tokenizer.encode = enc
1589
- tokenizer.batch_decode = bdec
1590
- tokenizer.save_pretrained = save_pretrained
1591
- tokenizer.pad_token_id = 0
1592
- tokenizer.bos_token_id = 1
1593
- tokenizer.eos_token_id = 2
1594
- return tokenizer
1595
-
1596
- def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
1597
- if dataset_config is None:
1598
- dataset_config = {
1599
- "spectrogram": True,
1600
- "waveforms": True,
1601
- "pitch": True,
1602
- "frequency": True,
1603
- "downsamples": True,
1604
- "hop_length": 128,
1605
- "fmin": 50,
1606
- "fmax": 2000,
1607
- "n_mels": 128,
1608
- "n_fft": 1024,
1609
- "sampling_rate": 16000,
1610
- }
1611
-
1612
- dataset = load_dataset(
1613
- "google/fleurs",
1614
- "en_us",
1615
- token=token,
1616
- trust_remote_code=True,
1617
- streaming=False)
1618
-
1619
- dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1620
-
1621
- if sanity_check:
1622
- dataset = dataset["test"].take(10)
1623
- dataset = dataset.select_columns(["audio", "transcription"])
1624
- logger.info(f"Sanity dataset size: {dataset.num_rows}")
1625
- print(f"Sanity dataset size: {dataset.num_rows}")
1626
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1627
-
1628
- dataset = dataset.map(
1629
- function=prepare_fn,
1630
- remove_columns=["audio", "transcription"]
1631
- ).with_format(type="torch")
1632
- train_dataset = dataset
1633
- test_dataset = dataset
1634
- else:
1635
- def filter_func(x):
1636
- return (0 < len(x["transcription"]) < 512 and
1637
- len(x["audio"]["array"]) > 0 and
1638
- len(x["audio"]["array"]) < 1500 * 160)
1639
-
1640
- dataset = dataset.filter(filter_func).shuffle(seed=4)
1641
- logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1642
- print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1643
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1644
- columns_to_remove = list(next(iter(dataset.values())).features)
1645
- train_dataset = dataset["train"]
1646
- test_dataset = dataset["test"].take(50)
1647
- logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
1648
-
1649
- train_dataset = train_dataset.map(
1650
- function=prepare_fn,
1651
- remove_columns=columns_to_remove
1652
- ).with_format(type="torch")
1653
-
1654
- test_dataset = test_dataset.map(
1655
- function=prepare_fn,
1656
- remove_columns=columns_to_remove
1657
- ).with_format(type="torch")
1658
-
1659
- return train_dataset, test_dataset
1660
-
1661
- def get_training_args(
1662
- log_dir: str,
1663
- batch_eval_metrics: bool = False,
1664
- max_steps: int = 10,
1665
- save_steps: int = 1000,
1666
- eval_steps: int = 1,
1667
- warmup_steps: int = 0,
1668
- num_train_epochs: int = 1,
1669
- logging_steps: int = 1,
1670
- eval_on_start: bool = False,
1671
- learning_rate: float = 1e-4,
1672
- weight_decay: float = 0.01,
1673
- max_grad_norm: float = 1.0,
1674
- ) -> Seq2SeqTrainingArguments:
1675
-
1676
- return Seq2SeqTrainingArguments(
1677
- output_dir=log_dir,
1678
- per_device_train_batch_size=1,
1679
- per_device_eval_batch_size=1,
1680
- gradient_accumulation_steps=1,
1681
- eval_accumulation_steps=1,
1682
- tf32=True,
1683
- bf16=True,
1684
- eval_strategy="steps",
1685
- save_strategy="steps",
1686
- max_steps=max_steps,
1687
- save_steps=save_steps,
1688
- eval_steps=eval_steps,
1689
- warmup_steps=warmup_steps,
1690
- num_train_epochs=num_train_epochs,
1691
- logging_steps=logging_steps,
1692
- logging_dir=log_dir,
1693
- logging_strategy="steps",
1694
- report_to=["tensorboard"],
1695
- push_to_hub=False,
1696
- disable_tqdm=False,
1697
- save_total_limit=1,
1698
- label_names=["labels"],
1699
- optim="adamw_torch",
1700
- lr_scheduler_type="cosine",
1701
- learning_rate=learning_rate,
1702
- weight_decay=weight_decay,
1703
- save_safetensors=False,
1704
- eval_on_start=eval_on_start,
1705
- batch_eval_metrics=batch_eval_metrics,
1706
- max_grad_norm=max_grad_norm,
1707
- )
1708
-
1709
- def main():
1710
-
1711
- token = ""
1712
- log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H'))
1713
- os.makedirs(name=log_dir, exist_ok=True)
1714
- tokenizer = setup_tokenizer(token)
1715
-
1716
- def sanity(sanity: bool):
1717
-
1718
- if sanity:
1719
- training_args = get_training_args(
1720
- log_dir,
1721
- batch_eval_metrics = False,
1722
- max_steps = 10,
1723
- save_steps = 0,
1724
- eval_steps = 1,
1725
- warmup_steps = 0,
1726
- logging_steps = 1,
1727
- eval_on_start = False,
1728
- learning_rate = 5e-6,
1729
- weight_decay = 0.01,
1730
- )
1731
- else:
1732
- training_args = get_training_args(
1733
- log_dir,
1734
- batch_eval_metrics = False,
1735
- max_steps = 1000,
1736
- save_steps = 1000,
1737
- eval_steps = 100,
1738
- warmup_steps = 100,
1739
- logging_steps = 10,
1740
- eval_on_start = False,
1741
- learning_rate = 2.5e-4,
1742
- weight_decay = 0.01,
1743
- )
1744
-
1745
- return training_args
1746
-
1747
- param = Dimensions(
1748
- mels=128,
1749
- aud_ctx=1500,
1750
- aud_head=4,
1751
- aud_dims=512,
1752
- aud_idx=4,
1753
- vocab=40000,
1754
- text_ctx=512,
1755
- text_head=4,
1756
- text_dims=512,
1757
- text_idx=4,
1758
- act="swish",
1759
- debug={},#{"encoder", "decoder", "residual", "rotary"},
1760
- cross_attn=True,
1761
- f0_rotary=False,
1762
- features = ["spectrogram"]#, "waveform", "pitch", "f0", "envelope", "phase"],
1763
- )
1764
-
1765
- sanity_check = False
1766
- training_args = sanity(sanity_check)
1767
- dataset_config = {
1768
- "spectrogram": True,
1769
- "waveforms": False,
1770
- "pitch": False,
1771
- "downsamples": False,
1772
- "frequency": False,
1773
- "hilbert": False,
1774
- "hop_length": 128,
1775
- "fmin": 150,
1776
- "fmax": 2000,
1777
- "n_mels": 128,
1778
- "n_fft": 1024,
1779
- "sampling_rate": 16000,
1780
- "pad_mode": "constant",
1781
- "center": True,
1782
- "power": 2.0,
1783
- "window_fn": torch.hann_window,
1784
- "mel_scale": "htk",
1785
- "norm": None,
1786
- "normalized": False}
1787
-
1788
- model = create_model(param)
1789
-
1790
- global global_model
1791
- global_model = model
1792
-
1793
- metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
1794
- tokenizer=tokenizer, model=model)
1795
-
1796
- print(f"{'Sanity check' if sanity_check else 'Training'} mode")
1797
- train_dataset, test_dataset = prepare_datasets(
1798
- tokenizer=tokenizer,
1799
- token=token,
1800
- sanity_check=sanity_check,
1801
- dataset_config=dataset_config)
1802
-
1803
- trainer = Seq2SeqTrainer(
1804
- args=training_args,
1805
- model=model,
1806
- train_dataset=train_dataset,
1807
- eval_dataset=test_dataset,
1808
- data_collator=DataCollator(tokenizer=tokenizer),
1809
- compute_metrics=metrics_fn,
1810
- )
1811
-
1812
- model.init_weights()
1813
- trainer.train()
1814
-
1815
- if __name__ == "__main__":
1816
- main()
1817
 
 
1
+
2
  import pyworld as pw
3
  import os
4
  import math, random
 
6
  import logging
7
  import gzip
8
  import base64
 
 
9
  import torch
10
  import torchaudio
 
11
  import torch.nn.functional as F
12
  import torch.nn.init as init
13
  from torch import nn, Tensor
 
21
  import transformers
22
  import evaluate
23
  from dataclasses import dataclass
24
+ import matplotlib.pyplot as plt
 
 
 
 
 
25
 
26
  device = torch.device(device="cuda:0")
27
  dtype = torch.float32
28
 
 
 
 
 
 
29
  extractor = None
30
  tokenizer = None
31
  optimizer = None
 
52
  features: List[str]
53
  f0_rotary: bool
54
 
 
 
 
 
 
55
  def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
56
  title="", markers=None, marker_labels=None,
57
  show_voiced_regions=True, show_energy=False):
 
130
  axs[current_ax].set_ylabel("Mel Bin")
131
  axs[current_ax].set_xlim([0, max_time])
132
  axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
 
133
  current_ax += 1
134
 
135
  if p is not None:
 
247
  param.requires_grad = (x == self.current_idx)
248
  print(f"Parameter {x}: requires_grad={param.requires_grad}")
249
  self.current_idx = (self.current_idx + 1) % len(self.parameters)
250
+
251
+ def extract_f0(waveform, sampling_rate=16000, hop_length=128, device="cuda:0"):
252
+ """Extract F0 from waveform - handle various input types"""
253
+ if waveform is None:
254
+ return None
255
+
256
+ if isinstance(waveform, list):
257
+ if len(waveform) == 0:
258
+ return None
259
+ waveform = waveform[0]
260
+ print(f"DEBUG: Converted list to tensor, new type: {type(waveform)}")
261
+
262
+ if not isinstance(waveform, torch.Tensor):
263
+ waveform = torch.tensor(waveform)
264
+
265
+ if isinstance(waveform, torch.Tensor):
266
+ if waveform.dim() == 3:
267
+ waveform = waveform.squeeze(1)
268
+ if waveform.dim() == 2:
269
+ waveform = waveform[0]
270
+
271
+ wav_np = waveform.detach().cpu().numpy().astype(np.float64)
272
+ else:
273
+ wav_np = np.array(waveform).astype(np.float64)
274
+
275
+ f0, t = pw.dio(wav_np, sampling_rate,
276
+ frame_period=hop_length/sampling_rate*1000)
277
+ f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
278
+
279
+ f0_tensor = torch.from_numpy(f0).float().to(device)
280
+ return f0_tensor.unsqueeze(0).unsqueeze(0)
281
 
282
  class rotary(nn.Module):
283
  _seen = set()
 
285
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
286
  super().__init__()
287
 
288
+ self.dims = dims
289
  self.use_pbias = use_pbias
290
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
291
  self.dtype = torch.float32
292
  self.debug = debug
293
  self._counter = 0
294
+
295
  self.max_ctx = max_ctx
296
  self.radii = radii
297
  f0_factor = 0.5
298
+ self.adaptation: bool = False
299
  pitch_scale = 1.0
300
  radius = 1
301
 
302
+ if self.adaptation:
303
  self.f0_scale = nn.Parameter(torch.tensor(f0_factor, device=self.device, dtype=self.dtype), requires_grad=True)
304
  else:
305
  self.register_buffer('f0_scale', torch.tensor(f0_factor))
306
 
307
  self.theta = nn.Parameter(torch.tensor(theta, device=self.device, dtype=self.dtype), requires_grad=True)
308
  self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale, device=self.device, dtype=self.dtype), requires_grad=True)
309
+ freqs = 1. / (theta ** (torch.arange(0, dims, 2, device=self.device, dtype=self.dtype)[:(dims // 2)].float() / dims))
310
  self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
311
  self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
312
 
313
+ def forward(self, x=None, feat=None, layer=None) -> Tensor:
314
+ f0 = feat.get("f0") if feat else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  if isinstance(x, int):
316
  ctx = x
317
  else:
318
  batch, ctx, dims = x.shape
319
  t = torch.arange(ctx, device=self.device).float()
 
 
 
 
 
 
 
 
320
  if f0 is not None:
321
  f0_mean=f0.mean()+1e-8
322
+ theta=f0_mean*self.pitch_scale
323
+ freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
 
324
  else:
325
  freqs = self.freqs
 
326
  freqs = torch.einsum('i,j->ij', t, freqs)
327
  freqs = freqs.float()
328
+
329
+ if self.radii:
330
+ radius = feat.get("f0d") if feat else self.radius
 
 
 
 
 
 
331
  radius = radius.float()
 
332
  else:
333
+ radius = self.radius
334
+ freqs = torch.polar(radius.unsqueeze(-1), freqs) # freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
335
+
336
  if "rotary" in self.debug:
337
  if f0 is not None:
338
  key = f"{self._counter}_{theta:.2f}"
339
  if key not in rotary._seen:
340
  if not hasattr(self, '_prev_f0_theta'):
341
  self._prev_f0_theta = theta
 
342
  elif abs(self._prev_f0_theta - theta) > 100.0:
343
+ print(f"{layer} : {f0_mean} : Theta: {theta:.2f} : {theta:.2f} : {ctx} ")
344
+ if self.radii:
345
+ print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
 
346
  self._prev_f0_theta = theta
347
  rotary._seen.add(key)
348
  self._counter += 1
 
378
  x1 = torch.view_as_real(x1).flatten(-2)
379
  return torch.cat([x1.type_as(x), x2], dim=-1)
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  class MultiheadA(nn.Module):
382
  _seen = set()
383
  rbf = False
 
433
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
434
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
435
 
436
+ def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, feat=None, layer = None) -> tuple:
 
437
 
 
438
  scale = (self.dims // self.head) ** -0.25
439
 
440
  z = xa if xa is not None else x
441
  q = self.q(x).to(x.dtype)
442
  k = self.k(z).to(x.dtype)
443
  v = self.v(z).to(x.dtype)
444
+ batch, ctx, dims = q.shape
445
 
446
  if self.rotary_emb:
447
+ qf = self.rope(q.size(1), layer=layer, feat=feat)
448
+ kf = self.rope(k.size(1), layer=layer, feat=feat)
449
 
450
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
451
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
459
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
460
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
461
  batch, head, ctx, head_dim = q.shape
462
+
 
 
 
 
 
463
  if self.rbf:
464
  qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
465
 
466
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
467
+ if self.rope.use_pbias:
468
+ pbias = self.rope.pbias(feat.get("f0"))
469
  if pbias is not None:
470
  qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
471
  token_ids = k[:, :, :, 0]
 
477
  mask = mask[:q.shape[2], :q.shape[2]]
478
  qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
479
  qk = qk * zscale.unsqueeze(-2)
 
 
480
  w = F.softmax(qk, dim=-1).to(q.dtype)
481
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
482
 
 
583
  if not any([t_gate, m_gate, c_gate]):
584
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
585
 
586
+ def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, feat=None, layer = None):
587
  bln = self.blend
588
+ x = x + self.attna(self.lna(x), xa=None, mask=mask, layer=layer, feat=feat)[0]
589
 
590
  if self.attnb and xa is not None:
591
+ c = self.attnb(self.lnb(x), xa, mask=None, layer=layer, feat=feat)[0]
592
  b = torch.sigmoid(bln)
593
  x = b * x + (1 - b) * c
594
 
 
603
  gate = self.m_gate(normx)
604
  x = x + gate * mlp_out
605
 
606
+ elif self.c_gate is not None:
607
  gate_output = self.c_gate(normx, self.features)
608
  x = x + gate_output
609
 
 
643
  Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
644
  Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
645
 
646
+ def forward(self, x, feat=None, layer=None):
647
  x = self.encoder(x).permute(0, 2, 1)
648
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
649
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
 
672
  self.positional = lambda length: sinusoids(length, dims)
673
  self.norm = RMSNorm(dims)
674
 
675
+ def forward(self, x, feat=None, layer=None):
676
  x = self.downsample(x)
677
  x = self.encoder(x)
678
  x = x.permute(0, 2, 1)
 
699
  self.norm = RMSNorm(dims)
700
  self._norm = RMSNorm(dims)
701
 
702
+ def forward(self, x, feat=None, layer=None):
703
  x = self.encoder(x).permute(0, 2, 1)
704
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
705
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
706
  x = self._norm(x)
707
  return x
708
+
709
+ class F0Encoder(nn.Module):
710
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1):
711
+ super().__init__()
712
+
713
+ self.head_dim = dims // head
714
+ self.dropout = 0.01
715
+
716
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
717
+ "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
718
+ "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
719
+ "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
720
+ act_fn = act_map.get(act, nn.GELU())
721
+
722
+ self.encoder = nn.Sequential(
723
+ Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
724
+ Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
725
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
726
+
727
+ self.positional = lambda length: sinusoids(length, dims)
728
+ self.norm = RMSNorm(dims)
729
+ self._norm = RMSNorm(dims)
730
+
731
+ def forward(self, x, feat=None, layer=None):
732
+ if x.dim() == 3 and x.shape[0] == 1 and x.shape[1] == 1:
733
+ pass
734
+ elif x.dim() == 2:
735
+ x = x.unsqueeze(1)
736
+ elif x.dim() == 1:
737
+ x = x.unsqueeze(0).unsqueeze(0)
738
+ x = self.encoder(x)
739
+ x = x.permute(0, 2, 1)
740
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
741
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
742
+ x = self._norm(x)
743
+ return x
744
+
745
  class AudioEncoder(nn.Module):
746
  _seen = set()
747
  def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str],
 
765
  self.f0_rotary = f0_rotary
766
 
767
  self.rope = rotary(
768
+ dims=self.head_dim)
 
769
 
770
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
771
  "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
 
802
  FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
803
  for _ in range(layer)])
804
 
805
+ def forward(self, feat, layer="encoder"):
806
+
807
  if self._counter < 1:
808
+ s = feat.get("spectrogram")
809
+ w = feat.get("waveform")
810
+ p = default(feat.get("f0"), feat.get("pitch"))
811
  plot_waveform(x=s, w=w, p=p, hop_length=128)
812
 
813
  enc = {}
814
+ enc.update(feat)
815
+
816
+ for f in self.features:
817
+ if f in feat and f in self.blocks:
818
+ x = feat[f]
819
+ for block in self.blocks[f]:
820
+ x = block(x, feat=feat, layer=layer)
821
+ enc[f] = x
822
 
 
 
 
 
 
 
 
 
 
 
 
 
823
  if "encoder" in self.debug and self._counter % 100 == 0:
824
+ names = list(feat.keys())
825
+ shapes = {k: v.shape for k, v in feat.items()}
826
  print(f"Step {self._counter}: mode: {names}")
827
  print(f"shapes: {shapes}")
828
  for name, param in self.named_parameters():
829
  if param.requires_grad:
830
+ print(f"ENCODER LAYER {name}: grad_norm={param.median():.4f}")
831
  self._counter += 1
832
  return enc
833
 
 
871
 
872
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
873
  self.register_buffer("mask", mask, persistent=False)
874
+
875
+ rotary_emb = False
876
+ if rotary_emb:
877
+ self.rope = rotary(
878
+ dims=self.head_dim,
879
+ debug = debug,
880
+ radii=False,
881
+ learned_pitch=False,
882
+ learned_freq=False,
883
+ learned_theta=False,
884
+ learned_radius=False,
885
+ )
886
+ else:
887
+ self.rope = None
888
 
889
+ def forward(self, x, feat, order=None, layer='decoder') -> Tensor:
890
+
891
  bln = self.blend
892
  x = x.to(device)
893
  if order is None:
 
895
  mask = self.mask[:x.shape[1], :x.shape[1]]
896
  x = self.token(x) + self.positional[:x.shape[1]]
897
  x = F.dropout(x, p=self.dropout, training=self.training)
898
+
899
  for block in self.block:
900
+ x = block(x, xa=None, mask=mask, feat=feat, layer=layer)
901
+
902
  for f in order:
903
+ if f in feat:
904
+ xa = feat[f]
905
  for block in self.blocks[f]:
906
+ out = block(x=x, xa=xa, mask=None, feat=feat, layer=layer)
907
  a = torch.sigmoid(bln[f])
908
  x = a * out + (1 - a) * x
909
  x = self.ln_dec(x)
 
911
  if "decoder" in self.debug and self._counter % 100 == 0:
912
  for name, param in self.named_parameters():
913
  if param.requires_grad:
914
+ print(f"DECODER LAYER {name}: grad_norm={param.median():.4f}")
915
  self._counter += 1
 
916
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
917
 
918
  class Echo(nn.Module):
 
970
  spectrogram: torch.Tensor=None,
971
  pitch: Optional[torch.Tensor]=None,
972
  f0: Optional[torch.Tensor]=None,
973
+ f0d: Optional[torch.Tensor]=None,
974
  envelope: Optional[torch.Tensor]=None,
975
  phase: Optional[torch.Tensor]=None,
976
  ) -> Dict[str, torch.Tensor]:
 
987
  encoder_inputs["envelope"] = envelope
988
  if phase is not None:
989
  encoder_inputs["phase"] = phase
990
+ if f0 is not None:
991
+ encoder_inputs["f0"] = f0
992
+ if f0d is not None:
993
+ encoder_inputs["f0d"] = f0d
994
+
995
+ encoder_outputs = self.encoder(encoder_inputs)
996
  logits = self.decoder(input_ids, encoder_outputs)
997
 
998
  loss = None
 
1064
  print(f"{module_type}: {count}")
1065
 
1066
  def register_gradient_hooks(self):
1067
+
1068
  for name, param in self.named_parameters():
1069
  if param.requires_grad:
1070
  if "encoder" in name:
 
1072
  elif "decoder" in name:
1073
  param.register_hook(lambda grad, n=name: self._print_decoder_grad(n, grad))
1074
 
1075
+ print("Gradient debugging hooks registered")
1076
  return self
1077
 
1078
  def _print_encoder_grad(self, name, grad):
1079
  if grad is not None and self.count == 10:
1080
  norm = grad.median().item()
1081
+ print(f"ENCODER GRAD: {name} = {norm:.6f}")
1082
 
1083
  return None
1084
 
1085
  def _print_decoder_grad(self, name, grad):
1086
  if grad is not None and self.count == 10:
1087
  norm = grad.median().item()
1088
+ print(f"DECODER GRAD: {name} = {norm:.6f}")
1089
  return None
1090
 
1091
  def reset_counter(self):
 
1092
  self._counter = 0
1093
  print("Counter reset to 0.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1094