Sin2pi commited on
Commit
2b8f805
·
verified ·
1 Parent(s): 3cbeffb

Delete modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +0 -1047
modelA.py DELETED
@@ -1,1047 +0,0 @@
1
- import os
2
- import math
3
- import warnings
4
- import logging
5
- from itertools import chain
6
- import torch
7
- import torch.nn.functional as F
8
- from torch import nn, Tensor
9
- from tensordict import TensorDict
10
- from typing import Optional, Dict, Union, List, Tuple
11
- import numpy as np
12
- from functools import partial
13
- from datetime import datetime
14
- from tensordict import TensorDict
15
- from transformers.trainer_seq2seq import Seq2SeqTrainer
16
- from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
17
- from echoutils import *
18
-
19
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
- dtype = torch.float32
21
- warnings.filterwarnings("ignore")
22
- logging.basicConfig(level=logging.ERROR)
23
-
24
- class rotary(nn.Module):
25
- def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
26
-
27
- super(rotary, self).__init__()
28
- self.use_pbias = use_pbias
29
- self.dims = dims
30
- self.head = head
31
- self.head_dim = dims // head
32
- self.radii = radii
33
- self.debug = debug
34
- self.counter = 0
35
- self.last_theta = None
36
- self.axial = axial
37
-
38
- self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
39
- theta = (torch.tensor(10000, device=device, dtype=dtype))
40
- self.theta = nn.Parameter(theta, requires_grad=True)
41
- self.theta_values = []
42
-
43
- if axial and spec_shape is not None:
44
- time_frames, freq_bins = spec_shape
45
- self.time_frames = time_frames
46
- self.freq_bins = freq_bins
47
-
48
- time_theta = 50.0
49
- time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
50
- self.register_buffer('time_freqs', time_freqs)
51
-
52
- freq_theta = 100.0
53
- freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
54
- self.register_buffer('freq_freqs', freq_freqs)
55
-
56
- def pitch_bias(self, f0):
57
- if f0 is None:
58
- return None
59
- f0_flat = f0.squeeze().float()
60
- f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
61
- f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
62
- f0_norm.unsqueeze(1)))
63
- return f0_sim.unsqueeze(0).unsqueeze(0)
64
-
65
- def theta_freqs(self, theta):
66
- if theta.dim() == 0:
67
- theta = theta.unsqueeze(0)
68
- freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
69
- torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
70
- self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
71
- return freq
72
-
73
- def _apply_radii(self, freqs, f0, ctx):
74
- if self.radii and f0 is not None:
75
- radius = f0.to(device, dtype)
76
- L = radius.shape[0]
77
- if L != ctx:
78
- F = L / ctx
79
- idx = torch.arange(ctx, device=f0.device)
80
- idx = (idx * F).long().clamp(0, L - 1)
81
- radius = radius[idx]
82
- return torch.polar(radius.unsqueeze(-1), freqs), radius
83
- else:
84
- return torch.polar(radius.unsqueeze(-1), freqs), radius
85
- else:
86
- return torch.polar(torch.ones_like(freqs), freqs), None
87
-
88
- def check_f0(self, f0, f0t, ctx):
89
- if f0 is not None and f0.shape[1] == ctx:
90
- return f0
91
- elif f0t is not None and f0t.shape[1] == ctx:
92
- return f0t
93
- else:
94
- return None
95
-
96
- def axial_freqs(self, ctx):
97
- if not self.axial:
98
- return None
99
- time_frames = self.time_frames
100
- freq_bins = self.freq_bins
101
-
102
- t = torch.arange(ctx, device=device, dtype=dtype)
103
- t_x = (t % time_frames).float()
104
- t_y = torch.div(t, time_frames, rounding_mode='floor').float()
105
- freqs_x = torch.outer(t_x, self.time_freqs)
106
- freqs_y = torch.outer(t_y, self.freq_freqs)
107
- freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
108
- freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
109
- return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
110
-
111
- def forward(self, x=None, en=None, f=None, layer=None) -> Tensor:
112
- ctx=x
113
- f0 = en.get("f0") if en is not None else None
114
- f0t = en.get("f0t") if en is not None else None
115
-
116
- f0 = self.check_f0(f0, f0t, ctx)
117
- if f0 is not None:
118
- if f0.dim() == 2:
119
- f0 = f0.squeeze(0)
120
- theta = f0 + self.theta
121
- else:
122
- theta = self.theta
123
- freqs = self.theta_freqs(theta)
124
- t = torch.arange(ctx, device=device, dtype=dtype)
125
- freqs = t[:, None] * freqs
126
- freqs, radius = self._apply_radii(freqs, f0, ctx)
127
-
128
- if self.axial and f == "spectrogram":
129
- freqs_2d = self.axial_freqs(ctx)
130
- if freqs_2d is not None:
131
- return freqs_2d.unsqueeze(0)
132
-
133
- if "radius" in self.debug and self.counter == 10:
134
- print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
135
- self.counter += 1
136
- return freqs.unsqueeze(0)
137
-
138
- @staticmethod
139
- def apply_rotary(x, freqs):
140
- x1 = x[..., :freqs.shape[-1]*2]
141
- x2 = x[..., freqs.shape[-1]*2:]
142
- orig_shape = x1.shape
143
- if x1.ndim == 2:
144
- x1 = x1.unsqueeze(0)
145
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
146
- x1 = torch.view_as_complex(x1) * freqs
147
- x1 = torch.view_as_real(x1).flatten(-2)
148
- x1 = x1.view(orig_shape)
149
- return torch.cat([x1.type_as(x), x2], dim=-1)
150
-
151
- class MultiheadA(nn.Module):
152
-
153
- rbf = False
154
- def __init__(self, dims: int, head: int, rotary_emb: bool = True,
155
- zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
156
- super(MultiheadA, self).__init__()
157
-
158
- self.dims = dims
159
- self.head = head
160
- self.head_dim = dims // head
161
- self.debug = debug
162
- self.counter = 0
163
- self.use_pbias = use_pbias
164
-
165
- self.q = nn.Linear(dims, dims).to(device, dtype)
166
- self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
167
- self.v = nn.Linear(dims, dims).to(device, dtype)
168
- self.o = nn.Linear(dims, dims).to(device, dtype)
169
-
170
- self.pad_token = 0
171
- self.rotary_emb = rotary_emb
172
- self.minz = minz
173
- self.maxz = maxz
174
- self.zero_val = zero_val
175
- self.optim_attn = optim_attn
176
- self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
177
-
178
- if rotary_emb:
179
- self.rope = rotary(
180
- dims=dims,
181
- head=head,
182
- debug=debug,
183
- radii=False,
184
- )
185
- else:
186
- self.rope = None
187
-
188
- def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
189
- q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
190
- k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
191
- qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
192
- qk_cosine = qk_cosine + mask
193
- weights = F.softmax(qk_cosine, dim=-1)
194
- out = torch.matmul(weights, v)
195
- return out
196
-
197
- def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
198
- scale = (self.dims // self.head) ** -0.25
199
- dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
200
- if rbf_ratio <= 0.0:
201
- return dot_scores
202
- q_norm = q.pow(2).sum(dim=-1, keepdim=True)
203
- k_norm = k.pow(2).sum(dim=-1, keepdim=True)
204
- qk = torch.matmul(q, k.transpose(-1, -2))
205
- dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
206
- rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
207
- return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
208
-
209
- def forward(self, x: Tensor, xa = None, mask = None, en= None, layer = None, f=None) -> tuple:
210
-
211
- x = x.to(device, dtype)
212
- if xa is not None:
213
- xa = xa.to(device, dtype)
214
- scale = (self.dims // self.head) ** -0.25
215
-
216
- z = default(xa, x).to(device, dtype)
217
- q = self.q(x)
218
- k = self.k(z)
219
- v = self.v(z)
220
-
221
- if self.rotary_emb:
222
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
223
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
224
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
225
- q2 = q.shape[2]
226
- k2 = k.shape[2]
227
-
228
- q = self.rope.apply_rotary(q, (self.rope(x=q2, en=en, f=f, layer=layer)))
229
- k = self.rope.apply_rotary(k, (self.rope(x=k2, en=en, f=f, layer=layer)))
230
- else:
231
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
232
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
233
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
234
-
235
- qk = (q * scale) @ (k * scale).transpose(-1, -2)
236
-
237
- if self.rbf:
238
- qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
239
- if self.use_pbias:
240
- pbias = self.rope.pitch_bias(f0 = en.get("f0", None) if en is not None else None)
241
- if pbias is not None:
242
- qk = qk + pbias[:,:,:q2,:q2]
243
-
244
- token_ids = k[:, :, :, 0]
245
- zscale = torch.ones_like(token_ids)
246
- fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
247
- zscale[token_ids.float() == self.pad_token] = fzero
248
-
249
- if mask is not None:
250
- if mask.dim() == 4:
251
- mask = mask[0, 0]
252
- mask = mask[:q2, :k2] if xa is not None else mask[:q2, :q2]
253
- qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
254
-
255
- qk = qk * zscale.unsqueeze(-2)
256
- w = F.softmax(qk, dim=-1).to(q.dtype)
257
- wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
258
-
259
- if "multihead" in self.debug and self.counter % 100 == 0:
260
- print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
261
- self.counter += 1
262
- return self.o(wv), qk
263
-
264
- @staticmethod
265
- def split(X: Tensor) -> (Tensor, Tensor):
266
- half_dim = X.shape[-1] // 2
267
- return X[..., :half_dim], X[..., half_dim:]
268
-
269
- class t_gate(nn.Module):
270
- def __init__(self, dims, num_types=4, enabled=True):
271
- super().__init__()
272
- self.enabled = enabled
273
- self.gate_projections = nn.ModuleList([
274
- nn.Sequential(Linear(dims, 1), nn.Sigmoid())
275
- for _ in range(num_types)])
276
- self.type_classifier = nn.Sequential(
277
- Linear(dims, num_types),
278
- nn.Softmax(dim=-1))
279
- def forward(self, x):
280
- if not self.enabled:
281
- return None
282
- type_probs = self.type_classifier(x)
283
- gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
284
- comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
285
- return comb_gate
286
-
287
- class m_gate(nn.Module):
288
- def __init__(self, dims, mem_size=64, enabled=True):
289
- super().__init__()
290
- self.enabled = enabled
291
- if enabled:
292
- self.m_key = nn.Parameter(torch.randn(mem_size, dims))
293
- self.m_val = nn.Parameter(torch.randn(mem_size, 1))
294
- self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
295
-
296
- def forward(self, x):
297
- if not self.enabled:
298
- return None
299
- d_gate = torch.sigmoid(self.gate_proj(x))
300
- attention = torch.matmul(x, self.m_key.transpose(0, 1))
301
- attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
302
- m_gate = torch.matmul(attention, self.m_val)
303
- m_gate = torch.sigmoid(m_gate)
304
- return 0.5 * (d_gate + m_gate)
305
-
306
- class c_gate(nn.Module):
307
- def __init__(self, dims, enabled=True):
308
- super().__init__()
309
- self.enabled = enabled
310
- if enabled:
311
- self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
312
- self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
313
- self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
314
- self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
315
- self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
316
- self.integ = Linear(dims*5, dims)
317
-
318
- def forward(self, x, features):
319
- if not self.enabled:
320
- return None
321
- s_feat = features.get("spectrogram", x)
322
- w_feat = features.get("waveform", x)
323
- p_feat = features.get("pitch", x)
324
- e_feat = features.get("envelope", x)
325
- ph_feat = features.get("phase", x)
326
- s = self.s_gate(x) * s_feat
327
- w = self.w_gate(x) * w_feat
328
- p = self.p_gate(x) * p_feat
329
- e = self.e_gate(x) * e_feat
330
- ph = self.ph_gate(x) * ph_feat
331
- comb = torch.cat([s, w, p, e, ph], dim=-1)
332
- return self.integ(comb)
333
-
334
- class mlp_gate(nn.Module):
335
- def __init__(self, dims, head, enabled=True, one_shot=True):
336
- super().__init__()
337
- self.enabled = enabled
338
- if enabled:
339
- self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
340
-
341
- def forward(self, x, xa=None, f=None):
342
- if not self.enabled:
343
- return None
344
- return self.gate(x)
345
-
346
- class Residual(nn.Module):
347
- _seen = set()
348
- def __init__(self, ctx, dims, head, act, debug: List[str] = [],
349
- tgate=True, mgate=False, cgate=False, mem_size=512, features=None, one_shot=False):
350
- super().__init__()
351
-
352
- self.dims = dims
353
- self.head = head
354
- self.ctx = ctx
355
- self.head_dim = dims // head
356
- self.features = features
357
- self.debug = debug
358
- self.counter = 0
359
- self.dropout = 0.01
360
- self.one_shot = one_shot
361
-
362
- self.blend = nn.Parameter(torch.tensor(0.5))
363
- act_fn = get_activation(act)
364
- self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
365
- self.curiosity = curiosity(dims, head)
366
-
367
- if not any([tgate, mgate, cgate]):
368
- self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
369
- else:
370
- self.mlp_gate = None
371
-
372
- mlp = dims * 4
373
- self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
374
-
375
- self.t_gate = t_gate(dims=dims, num_types=4*2, enabled=tgate)
376
- self.m_gate = m_gate(dims=dims, mem_size=mem_size, enabled=mgate)
377
- self.c_gate = c_gate(dims=dims, enabled=cgate)
378
- self.mlp_gate = mlp_gate(dims=dims, head=head, enabled=not any([tgate, mgate, cgate]), one_shot=True)
379
-
380
- self.lna = RMSNorm(dims)
381
- self.lnb = RMSNorm(dims)
382
- self.lnc = RMSNorm(dims)
383
-
384
- def forward(self, x, xa=None, mask=None, en=None, layer=None, f=None) -> Tensor:
385
-
386
- b = torch.sigmoid(self.blend)
387
- ax = x + self.attn(self.lna(x), xa=xa, mask=mask, en=en, layer=layer, f=f)[0]
388
- bx = b * ax + (1 - b) * x
389
- cx = self.lnb(bx)
390
- dx = self.mlp(cx)
391
- ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
392
- fx = x + ex + dx
393
- gx = self.lnc(fx)
394
- return gx
395
-
396
- class OneShot(nn.Module):
397
- def __init__(self, dims: int, head: int, scale: float = 0.3):
398
- super().__init__()
399
- self.head = head
400
- self.hdim = dims // head
401
- self.scale = scale
402
- self.q_proj = Linear(dims, dims)
403
- self.k_proj = Linear(dims, dims)
404
-
405
- def forward(self, x: Tensor, guide: Tensor, f=None) -> Tensor | None:
406
- B, Q, _ = x.shape
407
- K = guide.size(1)
408
- q = self.q_proj(x ).view(B, Q, self.head, self.hdim).transpose(1,2)
409
- k = self.k_proj(guide).view(B, K, self.head, self.hdim).transpose(1,2)
410
- bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.hdim)
411
- return bias
412
-
413
- class curiosity(nn.Module):
414
- def __init__(self, d, h, bias=True):
415
- super().__init__()
416
- self.h = h
417
- self.dh = d // h
418
- self.qkv = nn.Linear(d, d * 3, bias=bias)
419
- self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
420
- self.o = nn.Linear(d, d, bias=bias)
421
- self.g = nn.Parameter(torch.zeros(h))
422
-
423
- def split(self, x):
424
- b, t, _ = x.shape
425
- return x.view(b, t, self.h, self.dh).transpose(1, 2)
426
-
427
- def merge(self, x):
428
- b, h, t, dh = x.shape
429
- return x.transpose(1, 2).contiguous().view(b, t, h * dh)
430
-
431
- def forward(self, x, xa, mask=None):
432
- q, k, v = self.qkv(x).chunk(3, -1)
433
- qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
434
- q, k, v = map(self.split, (q, k, v))
435
- qa, ka, va = map(self.split, (qa, ka, va))
436
- dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
437
- dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
438
- if mask is not None: dots = dots.masked_fill(mask, -9e15)
439
- p = dots.softmax(-1)
440
- pa = dots_aux.softmax(-1)
441
- h_main = p @ v
442
- h_aux = pa @ va
443
- g = torch.sigmoid(self.g).view(1, -1, 1, 1)
444
- out = self.merge(h_main * (1 - g) + h_aux * g)
445
- return self.o(out)
446
-
447
- class PositionalEncoding(nn.Module):
448
- def __init__(self, dims, ctx):
449
- super(PositionalEncoding, self).__init__()
450
- self.dims = dims
451
- self.ctx = ctx
452
- self.pe = self.get_positional_encoding(max_ctx=ctx)
453
-
454
- def get_positional_encoding(self, max_ctx):
455
- pe = torch.zeros(max_ctx, self.dims)
456
- position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
457
- div_term = torch.exp(
458
- torch.arange(0, self.dims, 2, dtype=torch.float32)
459
- * (-math.log(10000.0) / self.dims)
460
- )
461
- pe[:, 0::2] = torch.sin(position * div_term)
462
- pe[:, 1::2] = torch.cos(position * div_term)
463
- pe = pe.unsqueeze(0)
464
- return pe.to(device)
465
-
466
- def forward(self, x):
467
- ctx = x.size(1)
468
- pe = self.pe[:, :ctx, :]
469
- x = x * math.sqrt(self.dims)
470
- x = x + pe
471
- return x
472
-
473
- class FEncoder(nn.Module):
474
- def __init__(self, mels, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None, debug=[]):
475
- super().__init__()
476
-
477
- self.head = head
478
- self.head_dim = dims // head
479
- self.dropout = 0.01
480
- self.use_rope = use_rope
481
- self.dims = dims
482
- self.debug = debug
483
- act_fn = get_activation(act)
484
- self.attend_pitch = False
485
-
486
- if self.attend_pitch:
487
- self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
488
- self.mlp = nn.Sequential(
489
- nn.Linear(dims, dims),
490
- nn.ReLU(),
491
- nn.Linear(dims, dims),
492
- )
493
- else:
494
- self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
495
- self.mlp = None
496
-
497
- self.encoder = nn.Sequential(
498
- Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
499
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
500
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
501
-
502
- if use_rope:
503
- if spec_shape is not None:
504
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
505
- else:
506
- self.rope = None
507
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
508
- self.norm = RMSNorm(dims)
509
-
510
- def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
511
- batch, ctx, dims = x.shape
512
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
513
- freqs = self.rope(ctx, en=en, f=f, layer=layer)
514
- x = self.rope.apply_rotary(x, freqs)
515
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
516
-
517
- return x
518
-
519
- def forward(self, x: Tensor, en=None, f=None, layer = None):
520
- x = self.encoder(x).permute(0, 2, 1)
521
- if self.use_rope:
522
- x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
523
- else:
524
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
525
-
526
- if self.mlp is not None:
527
- x = self.mlp(x)
528
-
529
- if self.attend_pitch:
530
- xa = en["input_ids"]
531
- if xa is not None:
532
- q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
533
- out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
534
- out = self.o(out)
535
- x = x + out
536
-
537
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
538
- x = self.norm(x)
539
- return x
540
-
541
- class WEncoder(nn.Module):
542
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
543
- super().__init__()
544
-
545
- self.head = head
546
- self.head_dim = dims // head
547
- self.dropout = 0.01
548
- self.use_rope = use_rope
549
- self.dims = dims
550
- self.debug = debug
551
- act_fn = get_activation(act)
552
- self.target_length = None
553
- self.encoder = nn.Sequential(
554
- Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
555
- Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
556
- Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
557
-
558
- if use_rope:
559
- if spec_shape is not None:
560
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
561
- else:
562
- self.rope = None
563
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
564
- self.norm = RMSNorm(dims)
565
-
566
- def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
567
- batch, ctx, dims = x.shape
568
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
569
- freqs = self.rope(ctx, en=en, f=f, layer=layer)
570
- x = self.rope.apply_rotary(x, freqs)
571
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
572
- return x
573
-
574
- def forward(self, x: Tensor, en= None, f=None, layer = None):
575
- x = self.encoder(x).permute(0, 2, 1)
576
- if self.target_length and x.shape[1] != self.target_length:
577
- x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
578
- if self.use_rope:
579
- x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
580
- else:
581
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
582
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
583
-
584
- x = self.ln(x)
585
- print(f"X: {x.shape} {f}") if "encoder" in self.debug else None
586
- return self.norm(x)
587
-
588
- class PEncoder(nn.Module):
589
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=True, debug=[], one_shot=False, spec_shape=None):
590
- super().__init__()
591
-
592
- self.head = head
593
- self.head_dim = dims // head
594
- self.dims = dims
595
- self.dropout = 0.01
596
- self.use_rope = use_rope
597
- self.debug = debug
598
- act_fn = get_activation(act)
599
-
600
- self.encoder = nn.Sequential(
601
- Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
602
- Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
603
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
604
-
605
- if use_rope:
606
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
607
- else:
608
- self.rope = None
609
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
610
-
611
- self.norm = RMSNorm(dims)
612
-
613
- def rope_to_feature(self, x, en=None, f="pitch", layer="PEncoder"):
614
- batch, ctx, dims = x.shape
615
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
616
- freqs = self.rope(ctx, en=en, f=f, layer=layer)
617
- x = self.rope.apply_rotary(x, freqs)
618
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
619
- return x
620
-
621
- def forward(self, x: Tensor, en= None, f="pitch", layer="PEncoder"):
622
-
623
- if x.dim() == 2:
624
- x = x.unsqueeze(0)
625
-
626
- x = self.encoder(x).permute(0, 2, 1)
627
- if self.use_rope:
628
- x = self.rope_to_feature(x, en=en, f=f, layer=layer)
629
- else:
630
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
631
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
632
- x = self.norm(x)
633
- print(f"X: {x.shape} {f}") if "PEncoder" in self.debug else None
634
- return x
635
-
636
- class theBridge(nn.Module):
637
- def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
638
- debug: List[str], features: List[str], act: str = "gelu"):
639
- super(theBridge, self).__init__()
640
-
641
- tgate = True
642
- mgate = False
643
- cgate = False
644
-
645
- self.debug = debug
646
- self.counter = 0
647
- self.dropout = 0.01
648
- self.features = features
649
- self.do_blend = "no_blend" not in self.debug
650
- self.sequential = "sequential" in self.debug
651
- self.layer = layer
652
-
653
- self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
654
- self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
655
- self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
656
- self.norm = RMSNorm(dims)
657
- self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, 10000)
658
- self.rotary = rotary(dims=dims, head=head, debug=debug, radii=False)
659
-
660
- with torch.no_grad():
661
- self.token.weight[0].zero_()
662
-
663
- act_fn = get_activation(act)
664
- if features == ["spectrogram", "waveform", "pitch"]:
665
- cgate=True
666
- else:
667
- cgate = False
668
-
669
- self.blockA = nn.ModuleDict()
670
- self.blockA["waveform"] = nn.ModuleList(
671
- [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
672
- [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
673
- for _ in range(layer)] if "waveform" in features else None)
674
-
675
- for feature_type in ["spectrogram", "aperiodic", "harmonic"]:
676
- if feature_type in features:
677
- self.blockA[feature_type] = nn.ModuleList(
678
- [FEncoder(mels=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
679
- [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
680
- else:
681
- self.blockA[feature_type] = None
682
-
683
- for feature_type in ["pitch", "phase"]:
684
- if feature_type in features:
685
- self.blockA[feature_type] = nn.ModuleList(
686
- [PEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act_fn)] +
687
- [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
688
- else:
689
- self.blockA[feature_type] = None
690
-
691
- self.blockB = nn.ModuleList([
692
- Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
693
- for _ in range(layer)])
694
-
695
- self.modal = nn.ModuleList([
696
- Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
697
- for _ in range(layer)])
698
-
699
- mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
700
- self.register_buffer("mask", mask, persistent=False)
701
-
702
- self.norm = RMSNorm(dims)
703
-
704
- def forward(self, x, xa, en, f, sequential=False) -> Tensor:
705
- mask = self.mask[:x.shape[1], :x.shape[1]]
706
- x = self.token(x.long()) + self.positional[:x.shape[1]]
707
-
708
- out = {}
709
- out["input_ids"] = x
710
- out.update(en)
711
-
712
- for b in chain(self.blockA[f] or []):
713
- xa = b(x=xa, en=out, f=f, layer="en")
714
-
715
- for b in chain(self.blockB or []):
716
- x = b(x=x, xa=None, mask=mask, en=out, f=f, layer="dec")
717
- y = b(x, xa=xa, mask=None, en=out, f=f, layer="cross")
718
- if sequential:
719
- x = y
720
- else:
721
- a = torch.sigmoid(self.blend)
722
- x = a * y + (1 - a) * x
723
- for b in self.modal:
724
- xc = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None, en=out, f=f, layer="modal")
725
- xm = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None, en=out, f=f, layer="modal")
726
- if sequential:
727
- x = xm
728
- else:
729
- a = torch.sigmoid(self.blend)
730
- x = a * x + (1 - a) * xm
731
-
732
- if self.counter < 1 and "encoder" in self.debug:
733
- shapes = {k: v.shape for k, v in en.items()}
734
- print(f"Step {self.counter}: mode: {list(en.keys()) }: shapes: {shapes}")
735
- self.counter += 1
736
-
737
- x = self.norm(x)
738
- x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
739
-
740
- return x
741
-
742
- class Echo(nn.Module):
743
- def __init__(self, param: Dimensions):
744
- super().__init__()
745
- self.param = param
746
-
747
- self.processor = theBridge(
748
- vocab=param.vocab,
749
- mels=param.mels,
750
- ctx=param.ctx,
751
- dims=param.dims,
752
- head=param.head,
753
- layer=param.layer,
754
- features=param.features,
755
- act=param.act,
756
- debug=param.debug,
757
- )
758
-
759
- def forward(self,
760
- labels=None,
761
- input_ids=None,
762
- waveform: Optional[torch.Tensor]=None,
763
- spectrogram: Optional[torch.Tensor]=None,
764
- pitch: Optional[torch.Tensor]=None,
765
- f0: Optional[torch.Tensor]=None,
766
- f0t: Optional[torch.Tensor]=None,
767
- harmonic: Optional[torch.Tensor]=None,
768
- aperiodic: Optional[torch.Tensor]=None,
769
- phase: Optional[torch.Tensor]=None,
770
- ) -> Dict[str, Optional[torch.Tensor]]:
771
-
772
- en= TensorDict(batch_size=[1], device=self.device, dtype=self.dtype)
773
-
774
- en= {}
775
- if f0 is not None:
776
- en["f0"] = f0
777
- if f0t is not None:
778
- en["f0t"] = f0t
779
- if harmonic is not None:
780
- en["harmonic"] = harmonic
781
- if aperiodic is not None:
782
- en["aperiodic"] = aperiodic
783
- if phase is not None:
784
- en["phase"] = phase
785
- if pitch is not None:
786
- en["pitch"] = pitch
787
- if waveform is not None:
788
- en["waveform"] = waveform
789
- if spectrogram is not None:
790
- en["spectrogram"] = spectrogram
791
-
792
- x = input_ids
793
- for f, xa in en.items():
794
-
795
- logits = self.processor(x, xa, en, f)
796
-
797
- loss = None
798
- if labels is not None:
799
- loss = F.cross_entropy(
800
- logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
801
-
802
- return {"logits": logits, "loss": loss}
803
-
804
- @property
805
- def device(self):
806
- return next(self.parameters()).device
807
- @property
808
- def dtype(self):
809
- return next(self.parameters()).dtype
810
-
811
- def _init_weights(self, module):
812
- std = 0.02
813
- self.init_counts = {
814
- "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
815
- "Conv2d": 0, "theBridge": 0, "Echo": 0,
816
- "Residual": 0, "MultiheadA": 0,
817
- "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
818
- "WEncoder": 0, "PEncoder": 0}
819
-
820
- for name, module in self.named_modules():
821
- if isinstance(module, RMSNorm):
822
- nn.init.ones_(module.weight)
823
- self.init_counts["RMSNorm"] += 1
824
- elif isinstance(module, nn.Linear):
825
- if module.weight is not None:
826
- nn.init.xavier_uniform_(module.weight)
827
- if module.bias is not None:
828
- nn.init.zeros_(module.bias)
829
- self.init_counts["Linear"] += 1
830
- elif isinstance(module, Conv1d):
831
- nn.init.normal_(module.weight, mean=0.0, std=std)
832
- if module.bias is not None:
833
- nn.init.zeros_(module.bias)
834
- self.init_counts["Conv1d"] += 1
835
- elif isinstance(module, Conv2d):
836
- nn.init.normal_(module.weight, mean=0.0, std=std)
837
- if module.bias is not None:
838
- nn.init.zeros_(module.bias)
839
- self.init_counts["Conv2d"] += 1
840
- elif isinstance(module, MultiheadA):
841
- self.init_counts["MultiheadA"] += 1
842
- elif isinstance(module, Residual):
843
- self.init_counts["Residual"] += 1
844
- elif isinstance(module, PEncoder):
845
- self.init_counts["PEncoder"] += 1
846
- elif isinstance(module, FEncoder):
847
- self.init_counts["FEncoder"] += 1
848
- elif isinstance(module, WEncoder):
849
- self.init_counts["WEncoder"] += 1
850
- elif isinstance(module, theBridge):
851
- self.init_counts["theBridge"] += 1
852
- elif isinstance(module, Echo):
853
- self.init_counts["Echo"] += 1
854
-
855
- def init_weights(self):
856
- print("Initializing model weights...")
857
- self.apply(self._init_weights)
858
- print("Initialization summary:")
859
- for module_type, count in self.init_counts.items():
860
- if count > 0:
861
- print(f"{module_type}: {count}")
862
-
863
- def generate(self, input_ids=None, spectrogram=None, waveform=None, pitch=None, f0=None,
864
- envelope=None, phase=None, tokenizer=None, max_length=128, min_length=1, device=None, **kwargs):
865
- if device is None:
866
- device = self.device
867
- pad_token_id = getattr(tokenizer, "pad_token_id", 0)
868
- bos_token_id = getattr(tokenizer, "bos_token_id", 1)
869
- eos_token_id = getattr(tokenizer, "eos_token_id", 2)
870
- batch_size = 1
871
- for x in [spectrogram, waveform, pitch, f0, envelope, phase]:
872
- if x is not None:
873
- batch_size = x.shape[0]
874
- break
875
- ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
876
- feature = {}
877
- if spectrogram is not None:
878
- feature["spectrogram"] = spectrogram
879
- if waveform is not None:
880
- feature["waveform"] = waveform
881
- if pitch is not None:
882
- feature["pitch"] = pitch
883
- if envelope is not None:
884
- feature["envelope"] = envelope
885
- if phase is not None:
886
- feature["phase"] = phase
887
- if f0 is not None:
888
- feature["f0"] = f0
889
-
890
- for i in range(max_length - 1):
891
- with torch.no_grad():
892
- feature["input_ids"] = ids
893
- logits = self.SpeechTransformer(feature)
894
- next_token_logits = logits[:, -1, :]
895
- if i < min_length:
896
- next_token_logits[:, eos_token_id] = 0
897
- next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
898
- ids = torch.cat([ids, next_tokens], dim=1)
899
- if (next_tokens == eos_token_id).all() and i >= min_length:
900
- break
901
- return ids
902
-
903
- @property
904
- def config(self):
905
- class Config:
906
- pad_token_id = getattr(self.param, "pad_token_id", 0)
907
- bos_token_id = getattr(self.param, "bos_token_id", 1)
908
- eos_token_id = getattr(self.param, "eos_token_id", 2)
909
- def to_json_string(self):
910
- import json
911
- return json.dumps({
912
- "pad_token_id": self.pad_token_id,
913
- "bos_token_id": self.bos_token_id,
914
- "eos_token_id": self.eos_token_id,
915
- })
916
- return Config()
917
-
918
- def main():
919
- token = ""
920
- log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
921
- os.makedirs(log_dir, exist_ok=True)
922
- tokenizer = setup_tokenizer("./")
923
-
924
- sanity_check = False
925
- streaming = False
926
- load_saved = False
927
- save_dataset = False
928
- cache_dir = None
929
- extract_args = None
930
-
931
- extract_args = {
932
- "waveform": False,
933
- "spec": True,
934
- "f0": False,
935
- "f0t": False,
936
- "pitch": True,
937
- "harmonics": False,
938
- "aperiodics": False,
939
- "phase_mod": False,
940
- "crepe": False,
941
- "sample_rate": 16000,
942
- "hop_length": 256,
943
- "mode": "mean",
944
- "debug": False,
945
- }
946
-
947
- param = Dimensions(
948
- vocab=40000,
949
- mels=128,
950
- ctx=2048,
951
- dims=512,
952
- head=4,
953
- layer=4,
954
- act="swish",
955
- debug={"encoder"},
956
- features = ["spectrogram", "pitch"],
957
- )
958
-
959
- train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=sanity_check, sample_rate=16000, streaming=streaming,
960
- load_saved=load_saved, save_dataset=save_dataset, cache_dir=cache_dir, extract_args=extract_args, max_ctx=param.ctx)
961
-
962
- model = Echo(param).to('cuda')
963
- print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
964
- print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
965
-
966
- from functools import partial
967
- metrics_fn = partial(compute_metrics,
968
- print_pred=True,
969
- num_samples=1,
970
- tokenizer=tokenizer, model=model)
971
-
972
- if sanity_check:
973
- training_args = Seq2SeqTrainingArguments(
974
- output_dir=log_dir,
975
- per_device_train_batch_size=1,
976
- per_device_eval_batch_size=1,
977
- max_steps=10,
978
- eval_steps=5,
979
- save_steps=0,
980
- warmup_steps=0,
981
- logging_steps=1,
982
- logging_dir=log_dir,
983
- eval_strategy="steps",
984
- save_strategy="no",
985
- logging_strategy="no",
986
- report_to=["tensorboard"],
987
- push_to_hub=False,
988
- save_total_limit=1,
989
- label_names=["labels"],
990
- save_safetensors=False,
991
- eval_on_start=True,
992
- batch_eval_metrics=False,
993
- disable_tqdm=False,
994
- include_tokens_per_second=True,
995
- include_num_input_tokens_seen=True,
996
- learning_rate=1e-7,
997
- weight_decay=0.01,
998
- )
999
- else:
1000
- training_args = Seq2SeqTrainingArguments(
1001
- output_dir=log_dir,
1002
- per_device_train_batch_size=1,
1003
- per_device_eval_batch_size=1,
1004
- max_steps=1000,
1005
- eval_steps=100,
1006
- save_steps=1000,
1007
- warmup_steps=100,
1008
- logging_steps=10,
1009
- logging_dir=log_dir,
1010
- logging_strategy="steps",
1011
- eval_strategy="steps",
1012
- save_strategy="no",
1013
- report_to=["tensorboard"],
1014
- push_to_hub=False,
1015
- save_total_limit=1,
1016
- label_names=["labels"],
1017
- save_safetensors=False,
1018
- eval_on_start=True,
1019
- batch_eval_metrics=False,
1020
- disable_tqdm=False,
1021
- include_tokens_per_second=True,
1022
- include_num_input_tokens_seen=True,
1023
- learning_rate=0.00025,
1024
- weight_decay=0.025,
1025
- )
1026
-
1027
- optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
1028
- amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
1029
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
1030
-
1031
- trainer = Seq2SeqTrainer(
1032
- args=training_args,
1033
- model=model,
1034
- train_dataset=train_dataset,
1035
- eval_dataset=test_dataset,
1036
- data_collator=DataCollator(tokenizer=tokenizer),
1037
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
1038
- compute_metrics=metrics_fn,
1039
- optimizers=(optimizer, scheduler)
1040
- )
1041
-
1042
- model.init_weights()
1043
- trainer.train()
1044
- if __name__ == "__main__":
1045
-
1046
- main()
1047
-