Sin2pi commited on
Commit
e284fca
·
verified ·
1 Parent(s): 3e934d2

Create modelA.py

Browse files

simplified loop removed excess unused functions

Files changed (1) hide show
  1. modelA.py +1102 -0
modelA.py ADDED
@@ -0,0 +1,1102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pyworld as pw
3
+ import math
4
+ import warnings
5
+ import logging
6
+ import torch
7
+ import torchaudio
8
+ import torch.nn.functional as F
9
+ import torch.nn.init as init
10
+ from torch import nn, Tensor
11
+ from datasets import load_dataset, Audio
12
+ from torch.utils.data import Dataset, DataLoader, random_split
13
+ import numpy as np
14
+ from typing import Optional, Dict, Union, List, Tuple
15
+ import transformers
16
+ from dataclasses import dataclass
17
+ from opimizer import MaxFactor
18
+ from transformers.generation.configuration_utils import GenerationConfig
19
+ torch.backends.cudnn.allow_tf32 = True
20
+ torch.backends.cuda.matmul.allow_tf32 = True
21
+ torch.set_float32_matmul_precision('high')
22
+ transformers.utils.logging.set_verbosity_error()
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+ dtype = torch.float32
25
+ logging.basicConfig(level=logging.ERROR)
26
+
27
+ @dataclass
28
+ class Dimensions:
29
+ vocab: int
30
+ text_ctx: int
31
+ text_dims: int
32
+ text_head: int
33
+ text_idx: int
34
+ mels: int
35
+ aud_ctx: int
36
+ aud_dims: int
37
+ aud_head: int
38
+ aud_idx: int
39
+ act: str
40
+ debug: List[str]
41
+ cross_attn: bool
42
+ features: List[str]
43
+
44
+ def get_generation_config(param):
45
+ return GenerationConfig(
46
+ max_length=param.text_ctx,
47
+ pad_token_id=getattr(param, "pad_token_id", 0),
48
+ bos_token_id=getattr(param, "bos_token_id", 1),
49
+ eos_token_id=getattr(param, "eos_token_id", 2),
50
+ do_sample=False,
51
+ num_beams=1,
52
+ early_stopping=False,
53
+ length_penalty=1.0,
54
+ no_repeat_ngram_size=0,
55
+ repetition_penalty=1.0,
56
+ temperature=1.0,
57
+ decoder_start_token_id=1,
58
+ is_multilingual=False,
59
+ use_cache=False,
60
+ return_timestamps=False)
61
+
62
+ def dict_to(d, device, dtype=dtype):
63
+ return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
64
+ for k, v in d.items()}
65
+
66
+ def exists(v):
67
+ return v is not None
68
+
69
+ def default(v, b):
70
+ return v if exists(v) else b
71
+
72
+ class Conv1d(nn.Conv1d):
73
+ def _conv_forward(
74
+ self, x: Tensor, weight: Tensor, bias) -> Tensor:
75
+ return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
76
+
77
+ class Conv2d(nn.Conv2d):
78
+ def _conv_forward(
79
+ self, x: Tensor, weight: Tensor, bias) -> Tensor:
80
+ return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
81
+
82
+ class Linear(nn.Module):
83
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
84
+ super(Linear, self).__init__()
85
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
86
+ init.xavier_uniform_(self.linear.weight)
87
+ if bias:
88
+ init.zeros_(self.linear.bias)
89
+ def forward(self, x: Tensor) -> Tensor:
90
+ return self.linear(x)
91
+
92
+ class RMSNorm(nn.Module):
93
+ def __init__(self, dims: Union[int, Tensor, List, Tuple],
94
+ eps = 1e-8, elementwise_affine = True):
95
+ super(RMSNorm, self).__init__()
96
+ if isinstance(dims, int):
97
+ self.normalized_shape = (dims,)
98
+ else:
99
+ self.normalized_shape = tuple(dims)
100
+ self.eps = eps
101
+ self.elementwise_affine = elementwise_affine
102
+ if self.elementwise_affine:
103
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape))
104
+ init.ones_(self.weight)
105
+ else:
106
+ self.register_parameter("weight", None)
107
+ def forward(self, x):
108
+ return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
109
+
110
+ def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
111
+ weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
112
+ eps: float = 1e-5) -> Tensor:
113
+ return F.layer_norm(x, normalized_shape, weight, bias, eps)
114
+
115
+ def get_device():
116
+ return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
117
+
118
+ def get_dtype():
119
+ return torch.float32 if torch.cuda.is_available() else torch.float64
120
+
121
+ def tox():
122
+ return {"device": get_device(), "dtype": get_dtype()}
123
+
124
+ def sinusoids(length, channels, max_tscale=10000):
125
+ assert channels % 2 == 0
126
+ log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
127
+ inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
128
+ scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
129
+ return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
130
+
131
+
132
+ class rotary(nn.Module):
133
+ def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=True, debug: List[str] = [], use_pbias=False):
134
+ super(rotary, self).__init__()
135
+
136
+ self.use_pbias = use_pbias
137
+ self.dims = dims
138
+ self.head = head
139
+ self.head_dim = dims // head
140
+ self.radii = radii
141
+ self.dim = self.head_dim
142
+ self.debug = debug
143
+ self.counter = 0
144
+ self.last_theta = None
145
+
146
+ self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
147
+ self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
148
+
149
+ def theta_freqs(self, theta):
150
+ freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
151
+ freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
152
+ return freqs
153
+
154
+ def mel_scale_scalar(freq: float) -> float:
155
+ return 1127.0 * math.log(1.0 + freq / 700.0)
156
+
157
+ def mel_scale(freq: Tensor) -> Tensor:
158
+ return 1127.0 * (1.0 + freq / 700.0).log()
159
+
160
+ def return_f0(self, f0=None):
161
+ if f0 is not None:
162
+ self.f0 = f0
163
+ self.update_base(f0)
164
+ return f0.squeeze(0).to(device, dtype)
165
+ elif hasattr(self, 'f0') and self.f0 is not None:
166
+ return self.f0.squeeze(0).to(device, dtype)
167
+ return None
168
+
169
+ def pitch_bias(self, f0):
170
+ if f0 is None:
171
+ return None
172
+ f0_flat = f0.squeeze().float()
173
+ f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
174
+ f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
175
+ f0_norm.unsqueeze(1)))
176
+ return f0_sim.unsqueeze(0).unsqueeze(0)
177
+
178
+
179
+ def forward(self, x=None, enc=None, layer=None, feature_type="audio") -> Tensor:
180
+ f0 = enc.get("f0") if enc is not None else None
181
+ if isinstance(x, int):
182
+ ctx = x
183
+ elif isinstance(x, torch.Tensor) and x.ndim == 2:
184
+ batch, ctx = x.shape
185
+ elif isinstance(x, torch.Tensor) and x.ndim == 3:
186
+ batch, ctx, dims = x.shape
187
+ else:
188
+ batch, head, ctx, head_dim = x.shape
189
+ t = torch.arange(ctx, device=device, dtype=dtype)
190
+
191
+ if f0 is not None and f0.dim() == 2:
192
+ if f0.shape[0] == 1:
193
+ f0 = f0.squeeze(0)
194
+ else:
195
+ f0 = f0.view(-1)
196
+
197
+ if f0 is not None:
198
+ f0_mean = f0.mean()
199
+ theta = f0_mean + self.theta
200
+ else:
201
+ theta = self.theta
202
+
203
+ freqs = self.theta_freqs(theta)
204
+
205
+ freqs = t[:, None] * freqs[None, :]
206
+
207
+ if self.radii and f0 is not None:
208
+ radius = f0.to(device, dtype)
209
+ L = radius.shape[0]
210
+ if L != ctx:
211
+ F = L / ctx
212
+ idx = torch.arange(ctx, device=f0.device)
213
+ idx = (idx * F).long().clamp(0, L - 1)
214
+ radius = radius[idx]
215
+ freqs = torch.polar(radius.unsqueeze(-1).expand_as(freqs), freqs)
216
+ else:
217
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
218
+
219
+ if "radius" in self.debug and self.counter % 100 == 0:
220
+ theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
221
+ print(f" [{layer}] [Radius] {radius.shape} {radius.mean():.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
222
+
223
+ if "theta" in self.debug and self.counter % 100 == 0:
224
+ if self.last_theta is None or abs(self.last_theta - theta.item()) > 1.0:
225
+ self.last_theta = theta.item()
226
+ print(f"[Theta] {self.last_theta:.2f}")
227
+
228
+ self.counter += 1
229
+ return freqs.unsqueeze(0)
230
+
231
+ @staticmethod
232
+ def apply_rotary(x, freqs):
233
+ x1 = x[..., :freqs.shape[-1]*2]
234
+ x2 = x[..., freqs.shape[-1]*2:]
235
+ orig_shape = x1.shape
236
+ if x1.ndim == 2:
237
+ x1 = x1.unsqueeze(0)
238
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
239
+ x1 = torch.view_as_complex(x1) * freqs
240
+ x1 = torch.view_as_real(x1).flatten(-2)
241
+ x1 = x1.view(orig_shape)
242
+ return torch.cat([x1.type_as(x), x2], dim=-1)
243
+
244
+
245
+ class MultiheadA(nn.Module):
246
+ _seen = set()
247
+ rbf = False
248
+ def __init__(self, dims: int, head: int, rotary_emb: bool = True,
249
+ zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False, use_pbias=False):
250
+ super(MultiheadA, self).__init__()
251
+
252
+ self.dims = dims
253
+ self.head = head
254
+ self.head_dim = dims // head
255
+ self.debug = debug
256
+ self.counter = 0
257
+ self.use_pbias = use_pbias
258
+
259
+ self.q = nn.Linear(dims, dims).to(device, dtype)
260
+ self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
261
+ self.v = nn.Linear(dims, dims).to(device, dtype)
262
+ self.o = nn.Linear(dims, dims).to(device, dtype)
263
+
264
+ self.pad_token = 0
265
+ self.rotary_emb = rotary_emb
266
+ self.minz = minz
267
+ self.maxz = maxz
268
+ self.zero_val = zero_val
269
+ self.optim_attn = optim_attn
270
+ self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
271
+
272
+ if rotary_emb:
273
+ self.rope = rotary(
274
+ dims=dims,
275
+ head=head,
276
+ debug=debug,
277
+ radii=True,
278
+ )
279
+ else:
280
+ self.rope = None
281
+
282
+ def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
283
+
284
+ x = x.to(device, dtype)
285
+ if xa is not None:
286
+ xa = xa.to(device, dtype)
287
+ scale = (self.dims // self.head) ** -0.25
288
+
289
+ z = default(xa, x).to(device, dtype)
290
+ q = self.q(x)
291
+ k = self.k(z)
292
+ v = self.v(z)
293
+
294
+ if self.rotary_emb:
295
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
296
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
297
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
298
+ q2 = q.shape[2]
299
+ k2 = k.shape[2]
300
+
301
+ q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
302
+ k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
303
+ else:
304
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
305
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
306
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
307
+
308
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
309
+
310
+ token_ids = k[:, :, :, 0]
311
+ zscale = torch.ones_like(token_ids)
312
+ fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
313
+ zscale[token_ids.float() == self.pad_token] = fzero
314
+
315
+ if mask is not None:
316
+ mask = mask[:q2, :q2]
317
+ qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
318
+ qk = qk * zscale.unsqueeze(-2)
319
+ w = F.softmax(qk, dim=-1).to(q.dtype)
320
+ wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
321
+
322
+ if "multihead" in self.debug and self.counter % 100 == 0:
323
+ print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
324
+ self.counter += 1
325
+ return self.o(wv), qk
326
+
327
+ class t_gate(nn.Module):
328
+ def __init__(self, dims, num_types=4):
329
+ super().__init__()
330
+ self.gate_projections = nn.ModuleList([
331
+ nn.Sequential(Linear(dims, 1), nn.Sigmoid())
332
+ for _ in range(num_types)])
333
+ self.type_classifier = nn.Sequential(
334
+ Linear(dims, num_types),
335
+ nn.Softmax(dim=-1))
336
+ def forward(self, x):
337
+ type_probs = self.type_classifier(x)
338
+ gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
339
+ comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
340
+ return comb_gate
341
+
342
+ class m_gate(nn.Module):
343
+ def __init__(self, dims, mem_size=64):
344
+ super().__init__()
345
+ self.m_key = nn.Parameter(torch.randn(mem_size, dims))
346
+ self.m_val = nn.Parameter(torch.randn(mem_size, 1))
347
+ self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
348
+
349
+ def forward(self, x):
350
+ d_gate = torch.sigmoid(self.gate_proj(x))
351
+ attention = torch.matmul(x, self.m_key.transpose(0, 1))
352
+ attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
353
+ m_gate = torch.matmul(attention, self.m_val)
354
+ m_gate = torch.sigmoid(m_gate)
355
+ return 0.5 * (d_gate + m_gate)
356
+
357
+ class c_gate(nn.Module):
358
+ def __init__(self, dims):
359
+ super().__init__()
360
+ self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
361
+ self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
362
+ self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
363
+ self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
364
+ self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
365
+ self.integ = Linear(dims*5, dims)
366
+
367
+ def forward(self, x, features):
368
+ s_feat = features.get("spectrogram", x)
369
+ w_feat = features.get("waveform", x)
370
+ p_feat = features.get("pitch", x)
371
+ e_feat = features.get("envelope", x)
372
+ ph_feat = features.get("phase", x)
373
+ s = self.s_gate(x) * s_feat
374
+ w = self.w_gate(x) * w_feat
375
+ p = self.p_gate(x) * p_feat
376
+ e = self.e_gate(x) * e_feat
377
+ ph = self.ph_gate(x) * ph_feat
378
+ comb = torch.cat([s, w, p, e, ph], dim=-1)
379
+ return self.integ(comb)
380
+
381
+ class Residual(nn.Module):
382
+ _seen = set()
383
+ def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
384
+ tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
385
+ super().__init__()
386
+
387
+ self.dims = dims
388
+ self.head = head
389
+ self.ctx = ctx
390
+ self.head_dim = dims // head
391
+ self.cross_attn = cross_attn
392
+ self.features = features
393
+ self.debug = debug
394
+ self.counter = 0
395
+ self.dropout = 0.01
396
+
397
+ self.t_gate = tgate
398
+ self.m_gate = mgate
399
+ self.c_gate = cgate
400
+ self.do_blend = "no_blend" not in self.debug
401
+ self.blend = nn.Parameter(torch.tensor(0.5))
402
+ self.skip_gates = True if "skip_gates" in self.debug else False
403
+
404
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
405
+ "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
406
+ "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
407
+ "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
408
+ act_fn = act_map.get(act, nn.GELU())
409
+
410
+ self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug)
411
+ self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None)
412
+
413
+ mlp = dims * 4
414
+ self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
415
+
416
+ self.t_gate = t_gate(dims=dims, num_types=4) if tgate else None
417
+ self.m_gate = m_gate(dims=dims, mem_size=mem_size) if mgate else None
418
+ self.c_gate = c_gate(dims=dims) if cgate else None
419
+
420
+ self.lna = RMSNorm(dims)
421
+ self.lnb = RMSNorm(dims) if cross_attn else None
422
+ self.lnc = RMSNorm(dims)
423
+
424
+ if not any([t_gate, m_gate, c_gate]):
425
+ self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
426
+
427
+ def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
428
+
429
+ x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
430
+ xb = x
431
+ if self.attnb and xa is not None:
432
+ x = x + self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
433
+
434
+ if self.do_blend:
435
+ b = torch.sigmoid(self.blend)
436
+ x = b * xb + (1 - b) * x
437
+
438
+ if self.skip_gates:
439
+ x = x + self.mlp(self.lnc(x))
440
+ else:
441
+ normx = self.lnc(x)
442
+ mlp_out = self.mlp(normx)
443
+
444
+ if self.t_gate:
445
+ gate = self.t_gate(normx)
446
+ x = x + gate * mlp_out
447
+
448
+ elif self.m_gate:
449
+ gate = self.m_gate(normx)
450
+ x = x + gate * mlp_out
451
+
452
+ elif self.c_gate:
453
+ gate_output = self.c_gate(normx, self.features)
454
+ x = x + gate_output
455
+
456
+ else:
457
+ if hasattr(self, 'mlp_gate'):
458
+ mlp_gate = self.mlp_gate(normx)
459
+ x = x + mlp_gate * mlp_out
460
+ else:
461
+ x = x + mlp_out
462
+
463
+ return x
464
+
465
+ class FEncoder(nn.Module):
466
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
467
+ super().__init__()
468
+
469
+ self.head = head
470
+ self.head_dim = dims // head
471
+ self.dropout = 0.01
472
+ self.use_rope = use_rope
473
+ self.dims = dims
474
+
475
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
476
+ act_fn = act_map.get(act, nn.GELU())
477
+
478
+ self.encoder = nn.Sequential(
479
+ Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
480
+ Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
481
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
482
+
483
+ if use_rope:
484
+ if spec_shape is not None:
485
+ self.rope = rotary(
486
+ dims=self.head_dim,
487
+ use_2d_axial=True,
488
+ spec_shape=spec_shape, debug=[])
489
+ else:
490
+ self.rope = rotary(
491
+ dims=self.head_dim,
492
+ use_2d_axial=False, debug=[])
493
+ else:
494
+ self.rope = None
495
+ self.positional = lambda length: sinusoids(length, dims)
496
+
497
+ self.norm = RMSNorm(dims)
498
+ self._norm = RMSNorm(dims)
499
+
500
+ def apply_rope_to_features(self, x, layer=None, feature_type="audio"):
501
+ if feature_type in ["envelope", "phase"]:
502
+ feature_type = "spectrogram"
503
+ batch, ctx, dims = x.shape
504
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
505
+ if feature_type == "spectrogram" and hasattr(self.rope, 'use_2d_axial') and self.rope.use_2d_axial:
506
+ rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
507
+ else:
508
+ rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
509
+ x = self.rope.apply_rotary(x, rope_freqs)
510
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
511
+ return x
512
+
513
+ def forward(self, x, enc=None, layer=None, feature_type="audio"):
514
+ x = self.encoder(x).permute(0, 2, 1)
515
+ if self.use_rope:
516
+ x = self.apply_rope_to_features(x, layer=layer, feature_type=feature_type)
517
+ else:
518
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
519
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
520
+ x = self._norm(x)
521
+ return x
522
+
523
+ class WEncoder(nn.Module):
524
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
525
+ super().__init__()
526
+
527
+ self.head = head
528
+ self.head_dim = dims // head
529
+ self.dropout = 0.01
530
+ self.use_rope = use_rope
531
+ self.dims = dims
532
+
533
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
534
+ act_fn = act_map.get(act, nn.GELU())
535
+
536
+ self.downsample = nn.Sequential(
537
+ Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
538
+ Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn,
539
+ Conv1d(dims//4, dims, kernel_size=9, stride=5, padding=4), act_fn)
540
+
541
+ self.encoder = nn.Sequential(
542
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn,
543
+ Conv1d(dims, dims, kernel_size=1), act_fn)
544
+ if use_rope:
545
+ self.rope = rotary(
546
+ dims=self.head_dim,
547
+ use_2d_axial=False,
548
+ theta=50.0, debug=[])
549
+ else:
550
+ self.rope = None
551
+ self.positional = lambda length: sinusoids(length, dims)
552
+ self.norm = RMSNorm(dims)
553
+
554
+ def apply_rope_to_features(self, x, layer=None):
555
+ if not self.use_rope or self.rope is None:
556
+ return x
557
+ batch, ctx, dims = x.shape
558
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
559
+ rope_freqs = self.rope(ctx, layer=layer, input_type="waveform")
560
+ x = self.rope.apply_rotary(x, rope_freqs)
561
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
562
+ return x
563
+
564
+ def forward(self, x, enc=None, layer=None, feature_type="waveform"):
565
+ x = self.downsample(x)
566
+ x = self.encoder(x)
567
+ x = x.permute(0, 2, 1)
568
+ if self.use_rope:
569
+ x = self.apply_rope_to_features(x, layer=layer)
570
+ else:
571
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
572
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
573
+ return self.norm(x)
574
+
575
+ class PEncoder(nn.Module):
576
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
577
+ super().__init__()
578
+
579
+ self.head = head
580
+ self.head_dim = dims // head
581
+ self.dropout = 0.01
582
+ self.use_rope = use_rope
583
+ self.dims = dims
584
+
585
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
586
+ act_fn = act_map.get(act, nn.GELU())
587
+
588
+ self.encoder = nn.Sequential(
589
+ Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
590
+ Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
591
+ Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2), act_fn)
592
+
593
+ if use_rope:
594
+ self.rope = rotary(
595
+ dims=self.head_dim,
596
+ use_2d_axial=False,
597
+ theta=100.0, debug=[])
598
+ else:
599
+ self.rope = None
600
+ self.positional = lambda length: sinusoids(length, dims)
601
+ self.norm = RMSNorm(dims)
602
+
603
+ def apply_rope_to_features(self, x, layer=None):
604
+ if not self.use_rope or self.rope is None:
605
+ return x
606
+ batch, ctx, dims = x.shape
607
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
608
+ rope_freqs = self.rope(ctx, layer=layer, input_type="pitch")
609
+ x = self.rope.apply_rotary(x, rope_freqs)
610
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
611
+ return x
612
+
613
+ def forward(self, x, enc=None, layer=None, feature_type="pitch"):
614
+ x = self.encoder(x).permute(0, 2, 1)
615
+ if self.use_rope:
616
+ x = self.apply_rope_to_features(x, layer=layer)
617
+ else:
618
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
619
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
620
+ x = self.norm(x)
621
+ return x
622
+
623
+ class AudioEncoder(nn.Module):
624
+ _seen = set()
625
+ def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
626
+ super(AudioEncoder, self).__init__()
627
+
628
+ self.dims = dims
629
+ self.head = head
630
+ self.ctx = ctx
631
+ self.head_dim = dims // head
632
+ self.debug = debug
633
+ self.counter = 0
634
+ self.features = features
635
+ self.dropout = 0.01
636
+
637
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),"tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
638
+ act_fn = act_map.get(act, nn.GELU())
639
+
640
+ if features == ["spectrogram", "waveform", "pitch"]:
641
+ cgate=True
642
+ else:
643
+ cgate = False
644
+
645
+ self.blocks = nn.ModuleDict({
646
+
647
+ "spectrogram": nn.ModuleList(
648
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
649
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
650
+ if "spectrogram" in features else None),
651
+
652
+ "waveform": nn.ModuleList(
653
+ [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
654
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
655
+ if "waveform" in features else None),
656
+
657
+ "pitch": nn.ModuleList(
658
+ [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
659
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
660
+ if "pitch" in features else None),
661
+
662
+ "envelope": nn.ModuleList(
663
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
664
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
665
+ if "envelope" in features else None),
666
+
667
+ "phase": nn.ModuleList(
668
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
669
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
670
+ if "phase" in features else None),
671
+ })
672
+
673
+ def forward(self, enc, layer="encoder"):
674
+ enc = dict_to(enc, device, dtype)
675
+ out = {}
676
+ out.update(enc)
677
+
678
+ for f in self.features:
679
+ if f in enc and f in self.blocks:
680
+ x = enc[f]
681
+ for block in self.blocks[f]:
682
+ x = block(x, enc=enc, layer=layer)
683
+ out[f] = x
684
+
685
+ return out
686
+
687
+ class TextDecoder(nn.Module):
688
+ def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
689
+ debug: List[str], features: List[str]):
690
+ super(TextDecoder, self).__init__()
691
+
692
+ self.ctx = ctx
693
+ self.dims = dims
694
+ self.head = head
695
+ self.head_dim = dims // head
696
+ self.debug = debug
697
+ self.counter = 0
698
+ self.dropout = 0.01
699
+ self.features = features
700
+ self.do_blend = "no_blend" not in self.debug
701
+ self.sequential = "sequential" in self.debug
702
+
703
+ self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
704
+ with torch.no_grad():
705
+ self.token.weight[0].zero_()
706
+ self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
707
+
708
+ self.block = nn.ModuleList([
709
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
710
+ for _ in range(layer)])
711
+
712
+ self.blocks = nn.ModuleDict({
713
+ f: nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
714
+ for _ in range(layer)]) for f in features})
715
+
716
+ self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
717
+ self.ln_dec = RMSNorm(dims)
718
+
719
+ mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
720
+ self.register_buffer("mask", mask, persistent=False)
721
+
722
+ def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
723
+
724
+ if order is None:
725
+ order = self.features
726
+
727
+ mask = self.mask[:x.shape[1], :x.shape[1]]
728
+ x = self.token(x) + self.positional[:x.shape[1]]
729
+ x = F.dropout(x, p=self.dropout, training=self.training)
730
+
731
+
732
+ for block in self.block:
733
+ x = block(x, xa=None, mask=mask, enc=None, layer=layer)
734
+
735
+ for f in order:
736
+ if f in enc:
737
+ xa = enc[f]
738
+ for block in self.blocks[f]:
739
+ out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
740
+
741
+ if self.sequential:
742
+ x = out
743
+ else:
744
+ a = torch.sigmoid(self.blend[f])
745
+ x = a * out + (1 - a) * x
746
+
747
+
748
+ x = self.ln_dec(x)
749
+ return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
750
+
751
+ class Echo(nn.Module):
752
+ def __init__(self, param: Dimensions):
753
+ super().__init__()
754
+ self.param = param
755
+
756
+ self.encoder = AudioEncoder(
757
+ mels=param.mels,
758
+ ctx=param.aud_ctx,
759
+ dims=param.aud_dims,
760
+ head=param.aud_head,
761
+ layer=param.aud_idx,
762
+ act=param.act,
763
+ debug=param.debug,
764
+ features=param.features,
765
+ )
766
+
767
+ self.decoder = TextDecoder(
768
+ vocab=param.vocab,
769
+ ctx=param.text_ctx,
770
+ dims=param.text_dims,
771
+ head=param.text_head,
772
+ layer=param.text_idx,
773
+ cross_attn=param.cross_attn,
774
+ debug=param.debug,
775
+ features=param.features,
776
+ )
777
+
778
+ def forward(self,
779
+ labels=None,
780
+ waveform: Optional[torch.Tensor]=None,
781
+ input_ids=None,
782
+ spectrogram: torch.Tensor=None,
783
+ pitch: Optional[torch.Tensor]=None,
784
+ f0: Optional[torch.Tensor]=None,
785
+ envelope: Optional[torch.Tensor]=None,
786
+ phase: Optional[torch.Tensor]=None,
787
+ ) -> Dict[str, torch.Tensor]:
788
+
789
+ encoder_inputs = {}
790
+ if spectrogram is not None:
791
+ encoder_inputs["spectrogram"] = spectrogram
792
+ if waveform is not None:
793
+ encoder_inputs["waveform"] = waveform
794
+ if pitch is not None:
795
+ encoder_inputs["pitch"] = pitch
796
+ if envelope is not None:
797
+ encoder_inputs["envelope"] = envelope
798
+ if phase is not None:
799
+ encoder_inputs["phase"] = phase
800
+ if f0 is not None:
801
+ encoder_inputs["f0"] = f0
802
+
803
+ encoder_outputs = self.encoder(encoder_inputs)
804
+ logits = self.decoder(input_ids, encoder_outputs)
805
+
806
+ loss = None
807
+ if labels is not None:
808
+ loss = F.cross_entropy(
809
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
810
+
811
+ return {"logits": logits, "loss": loss}
812
+
813
+ @property
814
+ def device(self):
815
+ return next(self.parameters()).device
816
+ @property
817
+ def dtype(self):
818
+ return next(self.parameters()).dtype
819
+
820
+ def _init_weights(self, module):
821
+ std = 0.02
822
+ self.init_counts = {
823
+ "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
824
+ "Conv2d": 0, "SEBlock": 0, "TextDecoder": 0, "AudioEncoder": 0,
825
+ "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
826
+ "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
827
+ "WEncoder": 0, "PEncoder": 0}
828
+
829
+ for name, module in self.named_modules():
830
+ if isinstance(module, RMSNorm):
831
+ nn.init.ones_(module.weight)
832
+ self.init_counts["RMSNorm"] += 1
833
+ elif isinstance(module, nn.Linear):
834
+ if module.weight is not None:
835
+ nn.init.xavier_uniform_(module.weight)
836
+ if module.bias is not None:
837
+ nn.init.zeros_(module.bias)
838
+ self.init_counts["Linear"] += 1
839
+ elif isinstance(module, Conv1d):
840
+ nn.init.normal_(module.weight, mean=0.0, std=std)
841
+ if module.bias is not None:
842
+ nn.init.zeros_(module.bias)
843
+ self.init_counts["Conv1d"] += 1
844
+ elif isinstance(module, Conv2d):
845
+ nn.init.normal_(module.weight, mean=0.0, std=std)
846
+ if module.bias is not None:
847
+ nn.init.zeros_(module.bias)
848
+ self.init_counts["Conv2d"] += 1
849
+ elif isinstance(module, MultiheadA):
850
+
851
+ self.init_counts["MultiheadA"] += 1
852
+ elif isinstance(module, TextDecoder):
853
+ self.init_counts["TextDecoder"] += 1
854
+ elif isinstance(module, AudioEncoder):
855
+ self.init_counts["AudioEncoder"] += 1
856
+ elif isinstance(module, Residual):
857
+ self.init_counts["Residual"] += 1
858
+
859
+ def init_weights(self):
860
+ print("Initializing model weights...")
861
+ self.apply(self._init_weights)
862
+ print("Initialization summary:")
863
+ for module_type, count in self.init_counts.items():
864
+ if count > 0:
865
+ print(f"{module_type}: {count}")
866
+
867
+ def generate(self, input_ids=None, spectrogram=None, waveform=None, pitch=None, f0=None,
868
+ envelope=None, phase=None, tokenizer=None, max_length=128, min_length=1, device=None, **kwargs):
869
+ if device is None:
870
+ device = self.device
871
+ pad_token_id = getattr(tokenizer, "pad_token_id", 0)
872
+ bos_token_id = getattr(tokenizer, "bos_token_id", 1)
873
+ eos_token_id = getattr(tokenizer, "eos_token_id", 2)
874
+ batch_size = 1
875
+ for x in [spectrogram, waveform, pitch, f0, envelope, phase]:
876
+ if x is not None:
877
+ batch_size = x.shape[0]
878
+ break
879
+ ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
880
+ encoder_inputs = {}
881
+ if spectrogram is not None:
882
+ encoder_inputs["spectrogram"] = spectrogram
883
+ if waveform is not None:
884
+ encoder_inputs["waveform"] = waveform
885
+ if pitch is not None:
886
+ encoder_inputs["pitch"] = pitch
887
+ if envelope is not None:
888
+ encoder_inputs["envelope"] = envelope
889
+ if phase is not None:
890
+ encoder_inputs["phase"] = phase
891
+ if f0 is not None:
892
+ encoder_inputs["f0"] = f0
893
+ encoder_outputs = self.encoder(encoder_inputs)
894
+ for i in range(max_length - 1):
895
+ with torch.no_grad():
896
+ logits = self.decoder(ids, encoder_outputs)
897
+ next_token_logits = logits[:, -1, :]
898
+ if i < min_length:
899
+ next_token_logits[:, eos_token_id] = 0
900
+ next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
901
+ ids = torch.cat([ids, next_tokens], dim=1)
902
+ if (next_tokens == eos_token_id).all() and i >= min_length:
903
+ break
904
+ return ids
905
+
906
+ @property
907
+ def config(self):
908
+ class Config:
909
+ pad_token_id = getattr(self.param, "pad_token_id", 0)
910
+ bos_token_id = getattr(self.param, "bos_token_id", 1)
911
+ eos_token_id = getattr(self.param, "eos_token_id", 2)
912
+ def to_json_string(self):
913
+ import json
914
+ return json.dumps({
915
+ "pad_token_id": self.pad_token_id,
916
+ "bos_token_id": self.bos_token_id,
917
+ "eos_token_id": self.eos_token_id,
918
+ })
919
+ return Config()
920
+
921
+ token = ""
922
+
923
+ param = Dimensions(
924
+ mels=128,
925
+ aud_ctx=1500,
926
+ aud_head=4,
927
+ aud_dims=512,
928
+ aud_idx=4,
929
+ vocab=40000,
930
+ text_ctx=512,
931
+ text_head=4,
932
+ text_dims=512,
933
+ text_idx=4,
934
+ act="swish",
935
+ debug={},
936
+ cross_attn=True,
937
+ features = ["spectrogram"],
938
+ )
939
+
940
+ def setup_tokenizer(token, local_tokenizer_path: str = "./"):
941
+ from tokenizers import Tokenizer
942
+ tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
943
+ orig_encode = tokenizer.encode
944
+ def enc(text, add_special_tokens=True):
945
+ ids = orig_encode(text).ids
946
+ if not add_special_tokens:
947
+ sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
948
+ ids = [id for id in ids if id not in sp_ids]
949
+ return ids
950
+
951
+ def bdec(ids_list, skip_special_tokens=True):
952
+ results = []
953
+ for ids in ids_list:
954
+ if skip_special_tokens:
955
+ if ids and ids[0] == 1:
956
+ ids = ids[1:]
957
+ while ids and ids[-1] in [0, 2]:
958
+ ids = ids[:-1]
959
+ results.append(tokenizer.decode(ids))
960
+ return results
961
+
962
+
963
+ def save_pretrained(save_dir):
964
+ os.makedirs(save_dir, exist_ok=True)
965
+ tokenizer.save(f"{save_dir}/tokenizer.json")
966
+ tokenizer.encode = enc
967
+ tokenizer.batch_decode = bdec
968
+ tokenizer.save_pretrained = save_pretrained
969
+ tokenizer.pad_token_id = 0
970
+ tokenizer.bos_token_id = 1
971
+ tokenizer.eos_token_id = 2
972
+ return tokenizer
973
+
974
+ raw_dataset = load_dataset(
975
+ "google/fleurs",
976
+ "en_us",
977
+ token=token,
978
+ split="train[:1000]",
979
+ trust_remote_code=True,
980
+ )
981
+ raw_dataset = raw_dataset.cast_column("audio", Audio(sampling_rate=16000))
982
+
983
+ class SimpleSpeechDataset(Dataset):
984
+ def __init__(self, hf_dataset):
985
+ self.samples = []
986
+ self.mel = torchaudio.transforms.MelSpectrogram(
987
+ sample_rate=16000, n_fft=1024, hop_length=256, n_mels=128
988
+ )
989
+ for item in hf_dataset:
990
+ waveform = torch.tensor(item["audio"]["array"]).float()
991
+ if waveform.dim() == 2:
992
+ waveform = waveform.mean(dim=0)
993
+ spec = self.mel(waveform)
994
+ wav_np = waveform.numpy().astype(np.float64)
995
+ f0, t = pw.dio(wav_np, 16000, frame_period=256/16000*1000)
996
+ f0 = pw.stonemask(wav_np, f0, t, 16000)
997
+ f0 = torch.from_numpy(f0).float()
998
+ self.samples.append({
999
+ "spectrogram": spec,
1000
+ "f0": f0,
1001
+ "transcription": item["sentence"] if "sentence" in item else item["transcription"]
1002
+ })
1003
+ def __len__(self):
1004
+ return len(self.samples)
1005
+ def __getitem__(self, idx):
1006
+ return self.samples[idx]
1007
+
1008
+ def simple_collate(batch):
1009
+ specs = [item["spectrogram"] for item in batch]
1010
+ f0s = [item["f0"] for item in batch]
1011
+ labels = [item["transcription"] for item in batch]
1012
+ max_spec_len = max(s.shape[-1] for s in specs)
1013
+ max_f0_len = max(f0.shape[-1] for f0 in f0s)
1014
+ padded_specs = torch.stack([
1015
+ torch.nn.functional.pad(s, (0, max_spec_len - s.shape[-1])) for s in specs
1016
+ ])
1017
+ padded_f0s = torch.stack([
1018
+ torch.nn.functional.pad(f0, (0, max_f0_len - f0.shape[-1])) for f0 in f0s
1019
+ ])
1020
+ return {"spectrogram": padded_specs, "f0": padded_f0s, "transcription": labels}
1021
+
1022
+ dataset = SimpleSpeechDataset(raw_dataset)
1023
+ train_size = int(0.8 * len(dataset))
1024
+ test_size = len(dataset) - train_size
1025
+ train_set, test_set = random_split(dataset, [train_size, test_size])
1026
+
1027
+ train_loader = DataLoader(train_set, batch_size=1, shuffle=True, collate_fn=simple_collate)
1028
+ test_loader = DataLoader(test_set, batch_size=1, shuffle=False, collate_fn=simple_collate)
1029
+
1030
+ tokenizer = setup_tokenizer(token)
1031
+
1032
+ model = Echo(param).to('cuda')
1033
+ max_steps = 10000
1034
+ optimizer = torch.optim.AdamW(
1035
+ model.parameters(), lr=0.00025, eps=1e-8, weight_decay=0.025, betas=(0.9, 0.999),
1036
+ amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False
1037
+ )
1038
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
1039
+ optimizer, T_max=max_steps, eta_min=0.0, last_epoch=-1
1040
+ )
1041
+
1042
+ def wer(ref, hyp):
1043
+ r = ref.split()
1044
+ h = hyp.split()
1045
+ d = np.zeros((len(r)+1, len(h)+1), dtype=np.uint8)
1046
+ for i in range(len(r)+1):
1047
+ d[i][0] = i
1048
+ for j in range(len(h)+1):
1049
+ d[0][j] = j
1050
+ for i in range(1, len(r)+1):
1051
+ for j in range(1, len(h)+1):
1052
+ if r[i-1] == h[j-1]:
1053
+ d[i][j] = d[i-1][j-1]
1054
+ else:
1055
+ substitution = d[i-1][j-1] + 1
1056
+ insertion = d[i][j-1] + 1
1057
+ deletion = d[i-1][j] + 1
1058
+ d[i][j] = min(substitution, insertion, deletion)
1059
+ wer_value = d[len(r)][len(h)] / float(len(r)) if len(r) > 0 else 0.0
1060
+ return min(wer_value, 1.0)
1061
+
1062
+ model.train()
1063
+ step = 0
1064
+ while step < max_steps:
1065
+ for batch in train_loader:
1066
+ if step >= max_steps:
1067
+ break
1068
+ x = batch["spectrogram"].to(model.device)
1069
+ f0 = batch["f0"].to(model.device)
1070
+ input_ids = [tokenizer.encode(t) for t in batch["transcription"]]
1071
+ max_len = max(len(ids) for ids in input_ids)
1072
+ input_ids = [ids + [tokenizer.pad_token_id] * (max_len - len(ids)) for ids in input_ids]
1073
+ input_ids = torch.tensor(input_ids, dtype=torch.long, device=model.device)
1074
+ labels = input_ids.clone()
1075
+ out = model(input_ids=input_ids, spectrogram=x, f0=f0, labels=labels)
1076
+ loss = out["loss"]
1077
+ loss.backward()
1078
+ optimizer.step()
1079
+ scheduler.step()
1080
+ optimizer.zero_grad()
1081
+ if step % 100 == 0:
1082
+ current_lr = scheduler.get_last_lr()[0]
1083
+ print(f"Step {step}: Train loss: {loss.item():.4f} | LR: {current_lr:.6f}")
1084
+ step += 1
1085
+
1086
+ model.eval()
1087
+ total_wer = 0
1088
+ n = 0
1089
+ with torch.no_grad():
1090
+ for batch in test_loader:
1091
+ x = batch["spectrogram"].to(model.device)
1092
+ f0 = batch["f0"].to(model.device)
1093
+ pred_ids = model.generate(spectrogram=x, f0=f0, tokenizer=tokenizer, max_length=32)
1094
+ pred_text = tokenizer.batch_decode(pred_ids.tolist())
1095
+ ref_text = batch["transcription"]
1096
+ print(f"REF: {ref_text[0]}")
1097
+ print(f"PRED: {pred_text[0]}")
1098
+ w = wer(ref_text[0], pred_text[0])
1099
+ print(f"WER: {w:.2f}")
1100
+ total_wer += w
1101
+ n += 1
1102
+ print(f"\nAverage WER: {total_wer/n:.2f}")