Sin2pi commited on
Commit
5c4028e
·
verified ·
1 Parent(s): 2b25a15

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +0 -812
README.md CHANGED
@@ -296,815 +296,3 @@ MaxFactor is a custom PyTorch optimizer with adaptive learning rates and special
296
 
297
  ** this model deviates in a lot of ways from standard transformer models.
298
 
299
-
300
- ```python
301
- import os
302
- import math
303
- import warnings
304
- import logging
305
- from itertools import chain
306
- import torch
307
- import torch.nn.functional as F
308
- from torch import nn, Tensor
309
- from tensordict import TensorDict
310
- from typing import Optional, Dict, Union, List, Tuple
311
- import numpy as np
312
- from functools import partial
313
- from datetime import datetime
314
- from tensordict import TensorDict
315
- from transformers.trainer_seq2seq import Seq2SeqTrainer
316
- from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
317
- from echoutils import *
318
-
319
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
320
- dtype = torch.float32
321
- warnings.filterwarnings("ignore")
322
- logging.basicConfig(level=logging.ERROR)
323
-
324
- class rotary(nn.Module):
325
- def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
326
-
327
- super(rotary, self).__init__()
328
- self.use_pbias = use_pbias
329
- self.dims = dims
330
- self.head = head
331
- self.head_dim = dims // head
332
- self.radii = radii
333
- self.debug = debug
334
- self.counter = 0
335
- self.last_theta = None
336
- self.axial = axial
337
-
338
- self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
339
- theta = (torch.tensor(10000, device=device, dtype=dtype))
340
- self.theta = nn.Parameter(theta, requires_grad=True)
341
- self.theta_values = []
342
-
343
- if axial and spec_shape is not None:
344
- time_frames, freq_bins = spec_shape
345
- self.time_frames = time_frames
346
- self.freq_bins = freq_bins
347
-
348
- time_theta = 50.0
349
- time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
350
- self.register_buffer('time_freqs', time_freqs)
351
-
352
- freq_theta = 100.0
353
- freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
354
- self.register_buffer('freq_freqs', freq_freqs)
355
-
356
- def pitch_bias(self, f0):
357
- if f0 is None:
358
- return None
359
- f0_flat = f0.squeeze().float()
360
- f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
361
- f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
362
- f0_norm.unsqueeze(1)))
363
- return f0_sim.unsqueeze(0).unsqueeze(0)
364
-
365
- def theta_freqs(self, theta):
366
- if theta.dim() == 0:
367
- theta = theta.unsqueeze(0)
368
- freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
369
- torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
370
- self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
371
- return freq
372
-
373
- def _apply_radii(self, freqs, f0, ctx):
374
- if self.radii and f0 is not None:
375
- radius = f0.to(device, dtype)
376
- L = radius.shape[0]
377
- if L != ctx:
378
- F = L / ctx
379
- idx = torch.arange(ctx, device=f0.device)
380
- idx = (idx * F).long().clamp(0, L - 1)
381
- radius = radius[idx]
382
- return torch.polar(radius.unsqueeze(-1), freqs), radius
383
- else:
384
- return torch.polar(radius.unsqueeze(-1), freqs), radius
385
- else:
386
- return torch.polar(torch.ones_like(freqs), freqs), None
387
-
388
- def check_f0(self, f0, f0t, ctx):
389
- if f0 is not None and f0.shape[1] == ctx:
390
- return f0
391
- elif f0t is not None and f0t.shape[1] == ctx:
392
- return f0t
393
- else:
394
- return None
395
-
396
- def axial_freqs(self, ctx):
397
- if not self.axial:
398
- return None
399
- time_frames = self.time_frames
400
- freq_bins = self.freq_bins
401
-
402
- t = torch.arange(ctx, device=device, dtype=dtype)
403
- t_x = (t % time_frames).float()
404
- t_y = torch.div(t, time_frames, rounding_mode='floor').float()
405
- freqs_x = torch.outer(t_x, self.time_freqs)
406
- freqs_y = torch.outer(t_y, self.freq_freqs)
407
- freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
408
- freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
409
- return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
410
-
411
- def forward(self, x=None, en=None, f=None, layer=None) -> Tensor:
412
- ctx=x
413
- f0 = en.get("f0") if en is not None else None
414
- f0t = en.get("f0t") if en is not None else None
415
-
416
- f0 = self.check_f0(f0, f0t, ctx)
417
- if f0 is not None:
418
- if f0.dim() == 2:
419
- f0 = f0.squeeze(0)
420
- theta = f0 + self.theta
421
- else:
422
- theta = self.theta
423
- freqs = self.theta_freqs(theta)
424
- t = torch.arange(ctx, device=device, dtype=dtype)
425
- freqs = t[:, None] * freqs
426
- freqs, radius = self._apply_radii(freqs, f0, ctx)
427
-
428
- if self.axial and f == "spectrogram":
429
- freqs_2d = self.axial_freqs(ctx)
430
- if freqs_2d is not None:
431
- return freqs_2d.unsqueeze(0)
432
-
433
- if "radius" in self.debug and self.counter == 10:
434
- print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
435
- self.counter += 1
436
- return freqs.unsqueeze(0)
437
-
438
- @staticmethod
439
- def apply_rotary(x, freqs):
440
- x1 = x[..., :freqs.shape[-1]*2]
441
- x2 = x[..., freqs.shape[-1]*2:]
442
- orig_shape = x1.shape
443
- if x1.ndim == 2:
444
- x1 = x1.unsqueeze(0)
445
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
446
- x1 = torch.view_as_complex(x1) * freqs
447
- x1 = torch.view_as_real(x1).flatten(-2)
448
- x1 = x1.view(orig_shape)
449
- return torch.cat([x1.type_as(x), x2], dim=-1)
450
-
451
- class MultiheadA(nn.Module):
452
-
453
- rbf = False
454
- def __init__(self, dims: int, head: int, rotary_emb: bool = True,
455
- zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
456
- super(MultiheadA, self).__init__()
457
-
458
- self.dims = dims
459
- self.head = head
460
- self.head_dim = dims // head
461
- self.debug = debug
462
- self.counter = 0
463
- self.use_pbias = use_pbias
464
-
465
- self.q = nn.Linear(dims, dims).to(device, dtype)
466
- self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
467
- self.v = nn.Linear(dims, dims).to(device, dtype)
468
- self.o = nn.Linear(dims, dims).to(device, dtype)
469
-
470
- self.pad_token = 0
471
- self.rotary_emb = rotary_emb
472
- self.minz = minz
473
- self.maxz = maxz
474
- self.zero_val = zero_val
475
- self.optim_attn = optim_attn
476
- self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
477
-
478
- if rotary_emb:
479
- self.rope = rotary(
480
- dims=dims,
481
- head=head,
482
- debug=debug,
483
- radii=False,
484
- )
485
- else:
486
- self.rope = None
487
-
488
- def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
489
- q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
490
- k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
491
- qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
492
- qk_cosine = qk_cosine + mask
493
- weights = F.softmax(qk_cosine, dim=-1)
494
- out = torch.matmul(weights, v)
495
- return out
496
-
497
- def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
498
- scale = (self.dims // self.head) ** -0.25
499
- dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
500
- if rbf_ratio <= 0.0:
501
- return dot_scores
502
- q_norm = q.pow(2).sum(dim=-1, keepdim=True)
503
- k_norm = k.pow(2).sum(dim=-1, keepdim=True)
504
- qk = torch.matmul(q, k.transpose(-1, -2))
505
- dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
506
- rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
507
- return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
508
-
509
- def forward(self, x: Tensor, xa = None, mask = None, en= None, layer = None, f=None) -> tuple:
510
-
511
- x = x.to(device, dtype)
512
- if xa is not None:
513
- xa = xa.to(device, dtype)
514
- scale = (self.dims // self.head) ** -0.25
515
-
516
- z = default(xa, x).to(device, dtype)
517
- q = self.q(x)
518
- k = self.k(z)
519
- v = self.v(z)
520
-
521
- if self.rotary_emb:
522
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
523
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
524
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
525
- q2 = q.shape[2]
526
- k2 = k.shape[2]
527
-
528
- q = self.rope.apply_rotary(q, (self.rope(x=q2, en=en, f=f, layer=layer)))
529
- k = self.rope.apply_rotary(k, (self.rope(x=k2, en=en, f=f, layer=layer)))
530
- else:
531
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
532
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
533
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
534
-
535
- qk = (q * scale) @ (k * scale).transpose(-1, -2)
536
-
537
- if self.rbf:
538
- qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
539
- if self.use_pbias:
540
- pbias = self.rope.pitch_bias(f0 = en.get("f0", None) if en is not None else None)
541
- if pbias is not None:
542
- qk = qk + pbias[:,:,:q2,:q2]
543
-
544
- token_ids = k[:, :, :, 0]
545
- zscale = torch.ones_like(token_ids)
546
- fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
547
- zscale[token_ids.float() == self.pad_token] = fzero
548
-
549
- if mask is not None:
550
- if mask.dim() == 4:
551
- mask = mask[0, 0]
552
- mask = mask[:q2, :k2] if xa is not None else mask[:q2, :q2]
553
- qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
554
-
555
- qk = qk * zscale.unsqueeze(-2)
556
- w = F.softmax(qk, dim=-1).to(q.dtype)
557
- wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
558
-
559
- if "multihead" in self.debug and self.counter % 100 == 0:
560
- print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
561
- self.counter += 1
562
- return self.o(wv), qk
563
-
564
- @staticmethod
565
- def split(X: Tensor) -> (Tensor, Tensor):
566
- half_dim = X.shape[-1] // 2
567
- return X[..., :half_dim], X[..., half_dim:]
568
-
569
- class t_gate(nn.Module):
570
- def __init__(self, dims, num_types=4, enabled=True):
571
- super().__init__()
572
- self.enabled = enabled
573
- self.gate_projections = nn.ModuleList([
574
- nn.Sequential(Linear(dims, 1), nn.Sigmoid())
575
- for _ in range(num_types)])
576
- self.type_classifier = nn.Sequential(
577
- Linear(dims, num_types),
578
- nn.Softmax(dim=-1))
579
- def forward(self, x):
580
- if not self.enabled:
581
- return None
582
- type_probs = self.type_classifier(x)
583
- gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
584
- comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
585
- return comb_gate
586
-
587
- class m_gate(nn.Module):
588
- def __init__(self, dims, mem_size=64, enabled=True):
589
- super().__init__()
590
- self.enabled = enabled
591
- if enabled:
592
- self.m_key = nn.Parameter(torch.randn(mem_size, dims))
593
- self.m_val = nn.Parameter(torch.randn(mem_size, 1))
594
- self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
595
-
596
- def forward(self, x):
597
- if not self.enabled:
598
- return None
599
- d_gate = torch.sigmoid(self.gate_proj(x))
600
- attention = torch.matmul(x, self.m_key.transpose(0, 1))
601
- attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
602
- m_gate = torch.matmul(attention, self.m_val)
603
- m_gate = torch.sigmoid(m_gate)
604
- return 0.5 * (d_gate + m_gate)
605
-
606
- class c_gate(nn.Module):
607
- def __init__(self, dims, enabled=True):
608
- super().__init__()
609
- self.enabled = enabled
610
- if enabled:
611
- self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
612
- self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
613
- self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
614
- self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
615
- self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
616
- self.integ = Linear(dims*5, dims)
617
-
618
- def forward(self, x, features):
619
- if not self.enabled:
620
- return None
621
- s_feat = features.get("spectrogram", x)
622
- w_feat = features.get("waveform", x)
623
- p_feat = features.get("pitch", x)
624
- e_feat = features.get("envelope", x)
625
- ph_feat = features.get("phase", x)
626
- s = self.s_gate(x) * s_feat
627
- w = self.w_gate(x) * w_feat
628
- p = self.p_gate(x) * p_feat
629
- e = self.e_gate(x) * e_feat
630
- ph = self.ph_gate(x) * ph_feat
631
- comb = torch.cat([s, w, p, e, ph], dim=-1)
632
- return self.integ(comb)
633
-
634
- class mlp_gate(nn.Module):
635
- def __init__(self, dims, head, enabled=True, one_shot=True):
636
- super().__init__()
637
- self.enabled = enabled
638
- if enabled:
639
- self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
640
-
641
- def forward(self, x, xa=None, f=None):
642
- if not self.enabled:
643
- return None
644
- return self.gate(x)
645
-
646
- class Residual(nn.Module):
647
- _seen = set()
648
- def __init__(self, ctx, dims, head, act, debug: List[str] = [],
649
- tgate=True, mgate=False, cgate=False, mem_size=512, features=None, one_shot=False):
650
- super().__init__()
651
-
652
- self.dims = dims
653
- self.head = head
654
- self.ctx = ctx
655
- self.head_dim = dims // head
656
- self.features = features
657
- self.debug = debug
658
- self.counter = 0
659
- self.dropout = 0.01
660
- self.one_shot = one_shot
661
-
662
- self.blend = nn.Parameter(torch.tensor(0.5))
663
- act_fn = get_activation(act)
664
- self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
665
- self.curiosity = curiosity(dims, head)
666
-
667
- if not any([tgate, mgate, cgate]):
668
- self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
669
- else:
670
- self.mlp_gate = None
671
-
672
- mlp = dims * 4
673
- self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
674
-
675
- self.t_gate = t_gate(dims=dims, num_types=4*2, enabled=tgate)
676
- self.m_gate = m_gate(dims=dims, mem_size=mem_size, enabled=mgate)
677
- self.c_gate = c_gate(dims=dims, enabled=cgate)
678
- self.mlp_gate = mlp_gate(dims=dims, head=head, enabled=not any([tgate, mgate, cgate]), one_shot=True)
679
-
680
- self.lna = RMSNorm(dims)
681
- self.lnb = RMSNorm(dims)
682
- self.lnc = RMSNorm(dims)
683
-
684
- def forward(self, x, xa=None, mask=None, en=None, layer=None, f=None) -> Tensor:
685
-
686
- b = torch.sigmoid(self.blend)
687
- ax = x + self.attn(self.lna(x), xa=xa, mask=mask, en=en, layer=layer, f=f)[0]
688
- bx = b * ax + (1 - b) * x
689
- cx = self.lnb(bx)
690
- dx = self.mlp(cx)
691
- ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
692
- fx = x + ex + dx
693
- gx = self.lnc(fx)
694
- return gx
695
-
696
- class OneShot(nn.Module):
697
- def __init__(self, dims: int, head: int, scale: float = 0.3):
698
- super().__init__()
699
- self.head = head
700
- self.hdim = dims // head
701
- self.scale = scale
702
- self.q_proj = Linear(dims, dims)
703
- self.k_proj = Linear(dims, dims)
704
-
705
- def forward(self, x: Tensor, guide: Tensor, f=None) -> Tensor | None:
706
- B, Q, _ = x.shape
707
- K = guide.size(1)
708
- q = self.q_proj(x ).view(B, Q, self.head, self.hdim).transpose(1,2)
709
- k = self.k_proj(guide).view(B, K, self.head, self.hdim).transpose(1,2)
710
- bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.hdim)
711
- return bias
712
-
713
- class curiosity(nn.Module):
714
- def __init__(self, d, h, bias=True):
715
- super().__init__()
716
- self.h = h
717
- self.dh = d // h
718
- self.qkv = nn.Linear(d, d * 3, bias=bias)
719
- self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
720
- self.o = nn.Linear(d, d, bias=bias)
721
- self.g = nn.Parameter(torch.zeros(h))
722
-
723
- def split(self, x):
724
- b, t, _ = x.shape
725
- return x.view(b, t, self.h, self.dh).transpose(1, 2)
726
-
727
- def merge(self, x):
728
- b, h, t, dh = x.shape
729
- return x.transpose(1, 2).contiguous().view(b, t, h * dh)
730
-
731
- def forward(self, x, xa, mask=None):
732
- q, k, v = self.qkv(x).chunk(3, -1)
733
- qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
734
- q, k, v = map(self.split, (q, k, v))
735
- qa, ka, va = map(self.split, (qa, ka, va))
736
- dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
737
- dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
738
- if mask is not None: dots = dots.masked_fill(mask, -9e15)
739
- p = dots.softmax(-1)
740
- pa = dots_aux.softmax(-1)
741
- h_main = p @ v
742
- h_aux = pa @ va
743
- g = torch.sigmoid(self.g).view(1, -1, 1, 1)
744
- out = self.merge(h_main * (1 - g) + h_aux * g)
745
- return self.o(out)
746
-
747
- class PositionalEncoding(nn.Module):
748
- def __init__(self, dims, ctx):
749
- super(PositionalEncoding, self).__init__()
750
- self.dims = dims
751
- self.ctx = ctx
752
- self.pe = self.get_positional_encoding(max_ctx=ctx)
753
-
754
- def get_positional_encoding(self, max_ctx):
755
- pe = torch.zeros(max_ctx, self.dims)
756
- position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
757
- div_term = torch.exp(
758
- torch.arange(0, self.dims, 2, dtype=torch.float32)
759
- * (-math.log(10000.0) / self.dims)
760
- )
761
- pe[:, 0::2] = torch.sin(position * div_term)
762
- pe[:, 1::2] = torch.cos(position * div_term)
763
- pe = pe.unsqueeze(0)
764
- return pe.to(device)
765
-
766
- def forward(self, x):
767
- ctx = x.size(1)
768
- pe = self.pe[:, :ctx, :]
769
- x = x * math.sqrt(self.dims)
770
- x = x + pe
771
- return x
772
-
773
- class FEncoder(nn.Module):
774
- def __init__(self, mels, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None, debug=[]):
775
- super().__init__()
776
-
777
- self.head = head
778
- self.head_dim = dims // head
779
- self.dropout = 0.01
780
- self.use_rope = use_rope
781
- self.dims = dims
782
- self.debug = debug
783
- act_fn = get_activation(act)
784
- self.attend_pitch = False
785
-
786
- if self.attend_pitch:
787
- self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
788
- self.mlp = nn.Sequential(
789
- nn.Linear(dims, dims),
790
- nn.ReLU(),
791
- nn.Linear(dims, dims),
792
- )
793
- else:
794
- self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
795
- self.mlp = None
796
-
797
- self.encoder = nn.Sequential(
798
- Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
799
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
800
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
801
-
802
- if use_rope:
803
- if spec_shape is not None:
804
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
805
- else:
806
- self.rope = None
807
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
808
- self.norm = RMSNorm(dims)
809
-
810
- def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
811
- batch, ctx, dims = x.shape
812
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
813
- freqs = self.rope(ctx, en=en, f=f, layer=layer)
814
- x = self.rope.apply_rotary(x, freqs)
815
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
816
-
817
- return x
818
-
819
- def forward(self, x: Tensor, en=None, f=None, layer = None):
820
- x = self.encoder(x).permute(0, 2, 1)
821
- if self.use_rope:
822
- x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
823
- else:
824
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
825
-
826
- if self.mlp is not None:
827
- x = self.mlp(x)
828
-
829
- if self.attend_pitch:
830
- xa = en["input_ids"]
831
- if xa is not None:
832
- q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
833
- out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
834
- out = self.o(out)
835
- x = x + out
836
-
837
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
838
- x = self.norm(x)
839
- return x
840
-
841
- class WEncoder(nn.Module):
842
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
843
- super().__init__()
844
-
845
- self.head = head
846
- self.head_dim = dims // head
847
- self.dropout = 0.01
848
- self.use_rope = use_rope
849
- self.dims = dims
850
- self.debug = debug
851
- act_fn = get_activation(act)
852
- self.target_length = None
853
- self.encoder = nn.Sequential(
854
- Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
855
- Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
856
- Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
857
-
858
- if use_rope:
859
- if spec_shape is not None:
860
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
861
- else:
862
- self.rope = None
863
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
864
- self.norm = RMSNorm(dims)
865
-
866
- def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
867
- batch, ctx, dims = x.shape
868
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
869
- freqs = self.rope(ctx, en=en, f=f, layer=layer)
870
- x = self.rope.apply_rotary(x, freqs)
871
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
872
- return x
873
-
874
- def forward(self, x: Tensor, en= None, f=None, layer = None):
875
- x = self.encoder(x).permute(0, 2, 1)
876
- if self.target_length and x.shape[1] != self.target_length:
877
- x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
878
- if self.use_rope:
879
- x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
880
- else:
881
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
882
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
883
-
884
- x = self.ln(x)
885
- print(f"X: {x.shape} {f}") if "encoder" in self.debug else None
886
- return self.norm(x)
887
-
888
- class PEncoder(nn.Module):
889
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=True, debug=[], one_shot=False, spec_shape=None):
890
- super().__init__()
891
-
892
- self.head = head
893
- self.head_dim = dims // head
894
- self.dims = dims
895
- self.dropout = 0.01
896
- self.use_rope = use_rope
897
- self.debug = debug
898
- act_fn = get_activation(act)
899
-
900
- self.encoder = nn.Sequential(
901
- Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
902
- Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
903
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
904
-
905
- if use_rope:
906
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
907
- else:
908
- self.rope = None
909
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
910
-
911
- self.norm = RMSNorm(dims)
912
-
913
- def rope_to_feature(self, x, en=None, f="pitch", layer="PEncoder"):
914
- batch, ctx, dims = x.shape
915
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
916
- freqs = self.rope(ctx, en=en, f=f, layer=layer)
917
- x = self.rope.apply_rotary(x, freqs)
918
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
919
- return x
920
-
921
- def forward(self, x: Tensor, en= None, f="pitch", layer="PEncoder"):
922
-
923
- if x.dim() == 2:
924
- x = x.unsqueeze(0)
925
-
926
- x = self.encoder(x).permute(0, 2, 1)
927
- if self.use_rope:
928
- x = self.rope_to_feature(x, en=en, f=f, layer=layer)
929
- else:
930
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
931
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
932
- x = self.norm(x)
933
- print(f"X: {x.shape} {f}") if "PEncoder" in self.debug else None
934
- return x
935
-
936
- class theBridge(nn.Module):
937
- def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
938
- debug: List[str], features: List[str], act: str = "gelu"):
939
- super(theBridge, self).__init__()
940
-
941
- tgate = True
942
- mgate = False
943
- cgate = False
944
-
945
- self.debug = debug
946
- self.counter = 0
947
- self.dropout = 0.01
948
- self.features = features
949
- self.do_blend = "no_blend" not in self.debug
950
- self.sequential = "sequential" in self.debug
951
- self.layer = layer
952
-
953
- self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
954
- self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
955
- self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
956
- self.norm = RMSNorm(dims)
957
- self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, 10000)
958
- self.rotary = rotary(dims=dims, head=head, debug=debug, radii=False)
959
-
960
- with torch.no_grad():
961
- self.token.weight[0].zero_()
962
-
963
- act_fn = get_activation(act)
964
- if features == ["spectrogram", "waveform", "pitch"]:
965
- cgate=True
966
- else:
967
- cgate = False
968
-
969
- self.blockA = nn.ModuleDict()
970
- self.blockA["waveform"] = nn.ModuleList(
971
- [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
972
- [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
973
- for _ in range(layer)] if "waveform" in features else None)
974
-
975
- for feature_type in ["spectrogram", "aperiodic", "harmonic"]:
976
- if feature_type in features:
977
- self.blockA[feature_type] = nn.ModuleList(
978
- [FEncoder(mels=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
979
- [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
980
- else:
981
- self.blockA[feature_type] = None
982
-
983
- for feature_type in ["pitch", "phase"]:
984
- if feature_type in features:
985
- self.blockA[feature_type] = nn.ModuleList(
986
- [PEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act_fn)] +
987
- [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
988
- else:
989
- self.blockA[feature_type] = None
990
-
991
- self.blockB = nn.ModuleList([
992
- Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
993
- for _ in range(layer)])
994
-
995
- self.modal = nn.ModuleList([
996
- Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
997
- for _ in range(layer)])
998
-
999
- mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
1000
- self.register_buffer("mask", mask, persistent=False)
1001
-
1002
- self.norm = RMSNorm(dims)
1003
-
1004
- def forward(self, x, xa, en, f, sequential=False) -> Tensor:
1005
- mask = self.mask[:x.shape[1], :x.shape[1]]
1006
- x = self.token(x.long()) + self.positional[:x.shape[1]]
1007
-
1008
- out = {}
1009
- out["input_ids"] = x
1010
- out.update(en)
1011
-
1012
- for b in chain(self.blockA[f] or []):
1013
- xa = b(x=xa, en=out, f=f, layer="en")
1014
-
1015
- for b in chain(self.blockB or []):
1016
- x = b(x=x, xa=None, mask=mask, en=out, f=f, layer="dec")
1017
- y = b(x, xa=xa, mask=None, en=out, f=f, layer="cross")
1018
- if sequential:
1019
- x = y
1020
- else:
1021
- a = torch.sigmoid(self.blend)
1022
- x = a * y + (1 - a) * x
1023
- for b in self.modal:
1024
- xc = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None, en=out, f=f, layer="modal")
1025
- xm = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None, en=out, f=f, layer="modal")
1026
- if sequential:
1027
- x = xm
1028
- else:
1029
- a = torch.sigmoid(self.blend)
1030
- x = a * x + (1 - a) * xm
1031
-
1032
- if self.counter < 1 and "encoder" in self.debug:
1033
- shapes = {k: v.shape for k, v in en.items()}
1034
- print(f"Step {self.counter}: mode: {list(en.keys()) }: shapes: {shapes}")
1035
- self.counter += 1
1036
-
1037
- x = self.norm(x)
1038
- x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1039
-
1040
- return x
1041
-
1042
- class Echo(nn.Module):
1043
- def __init__(self, param: Dimensions):
1044
- super().__init__()
1045
- self.param = param
1046
-
1047
- self.processor = theBridge(
1048
- vocab=param.vocab,
1049
- mels=param.mels,
1050
- ctx=param.ctx,
1051
- dims=param.dims,
1052
- head=param.head,
1053
- layer=param.layer,
1054
- features=param.features,
1055
- act=param.act,
1056
- debug=param.debug,
1057
- )
1058
-
1059
- def forward(self,
1060
- labels=None,
1061
- input_ids=None,
1062
- waveform: Optional[torch.Tensor]=None,
1063
- spectrogram: Optional[torch.Tensor]=None,
1064
- pitch: Optional[torch.Tensor]=None,
1065
- f0: Optional[torch.Tensor]=None,
1066
- f0t: Optional[torch.Tensor]=None,
1067
- harmonic: Optional[torch.Tensor]=None,
1068
- aperiodic: Optional[torch.Tensor]=None,
1069
- phase: Optional[torch.Tensor]=None,
1070
- ) -> Dict[str, Optional[torch.Tensor]]:
1071
-
1072
- en= TensorDict(batch_size=[1], device=self.device, dtype=self.dtype)
1073
-
1074
- en= {}
1075
- if f0 is not None:
1076
- en["f0"] = f0
1077
- if f0t is not None:
1078
- en["f0t"] = f0t
1079
- if harmonic is not None:
1080
- en["harmonic"] = harmonic
1081
- if aperiodic is not None:
1082
- en["aperiodic"] = aperiodic
1083
- if phase is not None:
1084
- en["phase"] = phase
1085
- if pitch is not None:
1086
- en["pitch"] = pitch
1087
- if waveform is not None:
1088
- en["waveform"] = waveform
1089
- if spectrogram is not None:
1090
- en["spectrogram"] = spectrogram
1091
-
1092
- x = input_ids
1093
- for f, xa in en.items():
1094
-
1095
- logits = self.processor(x, xa, en, f)
1096
-
1097
- loss = None
1098
- if labels is not None:
1099
- loss = F.cross_entropy(
1100
- logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
1101
-
1102
- return {"logits": logits, "loss": loss}
1103
-
1104
- @property
1105
- def device(self):
1106
- return next(self.parameters()).device
1107
- @property
1108
- def dtype(self):
1109
- return next(self.parameters()).dtype
1110
- ```
 
296
 
297
  ** this model deviates in a lot of ways from standard transformer models.
298