Sin2pi commited on
Commit
d064036
·
verified ·
1 Parent(s): 3e0f32e

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -1759
model.py DELETED
@@ -1,1759 +0,0 @@
1
-
2
- import os
3
- import pyworld as pw
4
- import math
5
- import warnings
6
- import time
7
- import random
8
- import logging
9
- import gzip
10
- import base64
11
- import torch
12
- import torchaudio
13
- import torch.nn.functional as F
14
- import torch.nn.init as init
15
- from torch import nn, Tensor
16
- from torch.utils.data import Dataset, DataLoader
17
- import numpy as np
18
- from einops import rearrange
19
- import matplotlib.pyplot as plt
20
- from typing import Optional, Dict, Union, List, Tuple, Any
21
- from functools import partial
22
- from datetime import datetime
23
- from datasets import load_dataset, Audio
24
- from torch.utils.tensorboard import SummaryWriter
25
- import tqdm
26
- from tqdm import tqdm
27
- import evaluate
28
- from dataclasses import dataclass
29
- import aiohttp
30
- torch.backends.cudnn.allow_tf32 = True
31
- torch.backends.cuda.matmul.allow_tf32 = True
32
- torch.set_float32_matmul_precision('high')
33
-
34
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
- dtype = torch.float32
36
-
37
- warnings.filterwarnings("ignore")
38
- logging.basicConfig(level=logging.ERROR)
39
-
40
- extractor = None
41
- tokenizer = None
42
- optimizer = None
43
- scheduler = None
44
- model = None
45
-
46
- @dataclass
47
- class Dimensions:
48
- vocab: int
49
- text_ctx: int
50
- text_dims: int
51
- text_head: int
52
- text_idx: int
53
- mels: int
54
- aud_ctx: int
55
- aud_dims: int
56
- aud_head: int
57
- aud_idx: int
58
- act: str
59
- debug: List[str]
60
- cross_attn: bool
61
- features: List[str]
62
-
63
- def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
64
- title="", markers=None, marker_labels=None,
65
- show_voiced_regions=True, show_energy=False):
66
- num_plots = sum([x is not None, w is not None, p is not None, per is not None])
67
- if num_plots == 0:
68
- raise ValueError("No data to plot. Please provide at least one input tensor.")
69
- time_spans = []
70
-
71
- if w is not None:
72
- w_np = w[sample_idx].detach().cpu().numpy()
73
- if w_np.ndim > 1:
74
- w_np = w_np.squeeze()
75
- time_spans.append(len(w_np) / sr)
76
- if x is not None:
77
- x_np = x[sample_idx].detach().cpu().numpy()
78
- if x_np.shape[0] < x_np.shape[1]:
79
- x_np = x_np.T
80
- time_spans.append(x_np.shape[0] * hop_length / sr)
81
- if p is not None:
82
- p_np = p[sample_idx].detach().cpu().numpy()
83
- if p_np.ndim > 1:
84
- p_np = p_np.squeeze()
85
- time_spans.append(len(p_np) * hop_length / sr)
86
- if per is not None:
87
- per_np = per[sample_idx].detach().cpu().numpy()
88
- if per_np.ndim > 1:
89
- per_np = per_np.squeeze()
90
- time_spans.append(len(per_np) * hop_length / sr)
91
- max_time = max(time_spans) if time_spans else 0
92
- fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
93
- if num_plots == 1:
94
- axs = [axs]
95
- if show_voiced_regions and per is not None:
96
- per_np = per[sample_idx].detach().cpu().numpy()
97
- if per_np.ndim > 1:
98
- per_np = per_np.squeeze()
99
- t_per = np.arange(len(per_np)) * hop_length / sr
100
- threshold = 0.5
101
- for ax in axs:
102
- for i in range(len(per_np)-1):
103
- if per_np[i] > threshold:
104
- ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
105
- current_ax = 0
106
- if w is not None:
107
- w_np = w[sample_idx].detach().cpu().numpy()
108
- if w_np.ndim > 1:
109
- w_np = w_np.squeeze()
110
- t = np.arange(len(w_np)) / sr
111
- axs[current_ax].plot(t, w_np, color="tab:blue")
112
-
113
- if show_energy:
114
- frame_length = hop_length
115
- hop_length_energy = hop_length // 2
116
- energy = []
117
- for i in range(0, len(w_np)-frame_length, hop_length_energy):
118
- frame = w_np[i:i+frame_length]
119
- energy.append(np.sqrt(np.mean(frame**2)))
120
- energy = np.array(energy)
121
- energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
122
- t_energy = np.arange(len(energy)) * hop_length_energy / sr
123
- axs[current_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
124
- axs[current_ax].legend(loc='upper right')
125
- axs[current_ax].set_title("Waveform")
126
- axs[current_ax].set_ylabel("Amplitude")
127
- axs[current_ax].set_xlim([0, max_time])
128
- axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
129
- current_ax += 1
130
-
131
- if x is not None:
132
- x_np = x[sample_idx].detach().cpu().numpy()
133
- if x_np.shape[0] < x_np.shape[1]:
134
- x_np = x_np.T
135
- im = axs[current_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
136
- extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
137
- axs[current_ax].set_title("Spectrogram")
138
- axs[current_ax].set_ylabel("Mel Bin")
139
- axs[current_ax].set_xlim([0, max_time])
140
- axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
141
- current_ax += 1
142
-
143
- if p is not None:
144
- p_np = p[sample_idx].detach().cpu().numpy()
145
- if p_np.ndim > 1:
146
- p_np = p_np.squeeze()
147
- t_p = np.arange(len(p_np)) * hop_length / sr
148
- axs[current_ax].plot(t_p, p_np, color="tab:green")
149
- axs[current_ax].set_title("Pitch")
150
- axs[current_ax].set_ylabel("Frequency (Hz)")
151
- axs[current_ax].set_xlim([0, max_time])
152
- axs[current_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
153
- axs[current_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
154
- current_ax += 1
155
-
156
- if per is not None:
157
- per_np = per[sample_idx].detach().cpu().numpy()
158
- if per_np.ndim > 1:
159
- per_np = per_np.squeeze()
160
- t_per = np.arange(len(per_np)) * hop_length / sr
161
- axs[current_ax].plot(t_per, per_np, color="tab:red")
162
- axs[current_ax].set_title("Period (Voice Activity)")
163
- axs[current_ax].set_ylabel("periodocity")
164
- axs[current_ax].set_xlim([0, max_time])
165
- axs[current_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
166
- axs[current_ax].set_ylim([-0.05, 1.05])
167
- axs[current_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
168
-
169
- if markers is not None:
170
- for i, t in enumerate(markers):
171
- label = marker_labels[i] if marker_labels and i < len(marker_labels) else None
172
- for ax in axs:
173
- ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
174
- if marker_labels:
175
- axs[0].legend(loc='upper right', fontsize='small')
176
- axs[-1].set_xlabel("Time (s)")
177
- fig.suptitle(title, fontsize=16)
178
- plt.tight_layout(rect=[0, 0, 1, 0.97])
179
- plt.show()
180
- return fig
181
-
182
- def dict_to(d, device, dtype=dtype):
183
- """Because PyTorch should have this built-in but doesn't"""
184
- return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
185
- for k, v in d.items()}
186
-
187
- def exists(v):
188
- return v is not None
189
-
190
- def default(v, b):
191
- return v if exists(v) else b
192
-
193
- class Conv1d(nn.Conv1d):
194
- def _conv_forward(
195
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
196
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
197
-
198
- class Conv2d(nn.Conv2d):
199
- def _conv_forward(
200
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
201
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
202
-
203
- class Linear(nn.Module):
204
- def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
205
- super(Linear, self).__init__()
206
- self.linear = nn.Linear(in_features, out_features, bias=bias)
207
- init.xavier_uniform_(self.linear.weight)
208
- if bias:
209
- init.zeros_(self.linear.bias)
210
- def forward(self, x: Tensor) -> Tensor:
211
- return self.linear(x)
212
-
213
- class RMSNorm(nn.Module):
214
- def __init__(self, dims: Union[int, Tensor, List, Tuple],
215
- eps = 1e-8, elementwise_affine = True):
216
- super(RMSNorm, self).__init__()
217
- if isinstance(dims, int):
218
- self.normalized_shape = (dims,)
219
- else:
220
- self.normalized_shape = tuple(dims)
221
- self.eps = eps
222
- self.elementwise_affine = elementwise_affine
223
- if self.elementwise_affine:
224
- self.weight = nn.Parameter(torch.empty(self.normalized_shape))
225
- init.ones_(self.weight)
226
- else:
227
- self.register_parameter("weight", None)
228
- def forward(self, x):
229
- return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
230
-
231
- def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
232
- weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
233
- eps: float = 1e-5) -> Tensor:
234
- return F.layer_norm(x, normalized_shape, weight, bias, eps)
235
-
236
- def get_device():
237
- return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
238
-
239
- def get_dtype():
240
- return torch.float32 if torch.cuda.is_available() else torch.float64
241
-
242
- def tox():
243
- return {"device": get_device(), "dtype": get_dtype()}
244
-
245
- def sinusoids(length, channels, max_timescale=10000):
246
- assert channels % 2 == 0
247
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
248
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
249
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
250
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
251
-
252
- class rotary(nn.Module):
253
- def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False):
254
- super(rotary, self).__init__()
255
-
256
- self.use_pbias = use_pbias
257
- self.dims = dims
258
- self.head = head
259
- self.head_dim = dims // head
260
- self.radii = radii
261
- self.dim = self.head_dim
262
- self.debug = debug
263
- self.counter = 0
264
- self.last_theta = None
265
-
266
- self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
267
- self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
268
- freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
269
- self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
270
-
271
- def return_f0(self, f0=None):
272
- if f0 is not None:
273
- self.f0 = f0
274
- self.update_base(f0)
275
- return f0.squeeze(0).to(device, dtype)
276
- elif hasattr(self, 'f0') and self.f0 is not None:
277
- return self.f0.squeeze(0).to(device, dtype)
278
- return None
279
-
280
- def update_base(self, f0):
281
- f0 = f0.squeeze(0).to(device, dtype)
282
- theta = f0.mean() + 1e-8
283
- freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
284
- self.freqs.data.copy_(freqs)
285
- self.theta.data.copy_(theta)
286
-
287
- def get_pitch_bias(self, f0):
288
- if f0 is None:
289
- return None
290
- f0_flat = f0.squeeze().float()
291
- f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
292
- f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
293
- f0_norm.unsqueeze(1)))
294
- return f0_sim.unsqueeze(0).unsqueeze(0)
295
-
296
- def f0proj(self, f0):
297
- if f0.ndim == 3:
298
- f0 = f0.squeeze(0)
299
- self.f0_proj = nn.Linear(1, self.head_dim // 2, device=device, dtype=dtype)
300
- f0 = f0.to(device, dtype)
301
- f0 = self.f0_proj(f0.unsqueeze(-1))
302
- if f0.ndim == 3:
303
- f0 = f0.squeeze(0)
304
- return f0.to(device=device, dtype=dtype)
305
-
306
- def synth_f0(self, f0, ctx):
307
- if f0.dim() == 1:
308
- length = f0.shape[0]
309
- if length == ctx:
310
- return f0
311
- frames = length / ctx
312
- idx = torch.arange(ctx, device=f0.device)
313
- return f0[idx]
314
-
315
- def align_f0(self, ctx, f0):
316
- f0 = self.f0proj(f0)
317
- if f0.dim() == 3:
318
- batch, length, dims = f0.shape
319
- if length == ctx:
320
- return f0
321
- frames = length / ctx
322
- idx = torch.arange(ctx, device=f0.device)
323
- idx = (idx * frames).long().clamp(0, length - 1)
324
- return f0[:, idx, :]
325
- if f0.dim() == 1:
326
- length = f0.shape[0]
327
- if length == ctx:
328
- return f0
329
- frames = length / ctx
330
- idx = torch.arange(ctx, device=f0.device)
331
- idx = (idx * frames).long().clamp(0, length - 1)
332
- return f0[idx]
333
- else:
334
- length, dims = f0.shape
335
- if length == ctx:
336
- return f0
337
- frames = length / ctx
338
- idx = torch.arange(ctx, device=f0.device)
339
- idx = (idx * frames).long().clamp(0, length - 1)
340
- return f0[idx, :]
341
-
342
- def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
343
- f0 = enc.get("f0") if enc is not None else None
344
- if isinstance(x, int):
345
- ctx = x
346
- elif isinstance(x, torch.Tensor) and x.ndim == 2:
347
- batch, ctx = x.shape
348
- elif isinstance(x, torch.Tensor) and x.ndim == 3:
349
- batch, ctx, dims = x.shape
350
- else:
351
- batch, head, ctx, head_dim = x.shape
352
- t = torch.arange(ctx, device=device, dtype=dtype)
353
-
354
- if f0 is not None and f0.dim() == 2:
355
- if f0.shape[0] == 1:
356
- f0 = f0.squeeze(0)
357
- else:
358
- f0 = f0.view(-1)
359
-
360
- if f0 is not None and layer == "encoder":
361
- f0_mean = f0.mean()
362
- theta = f0_mean + self.theta
363
- else:
364
- theta = self.theta
365
- freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
366
- self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
367
-
368
- freqs = t[:, None] * freqs[None, :]
369
- if self.radii and f0 is not None:
370
- radius = f0.to(device, dtype)
371
- L = radius.shape[0]
372
- if L != ctx:
373
- F = L / ctx
374
- idx = torch.arange(ctx, device=f0.device)
375
- idx = (idx * F).long().clamp(0, L - 1)
376
- radius = radius[idx]
377
- rad = radius
378
- radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
379
- radius = torch.sigmoid(radius)
380
- else:
381
- radius = torch.ones_like(freqs)
382
- freqs = torch.polar(radius, freqs)
383
-
384
- if "radius" in self.debug and self.counter % 100 == 0:
385
- theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
386
- print(f" [{layer}] [Radius] {radius.shape} {radius.mean():.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None}")
387
-
388
- if "rot3" in self.debug and self.counter % 100 == 0:
389
- theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
390
- print(f" [{layer}] [f0] {f0.shape if f0 is not None else None} [Theta] {theta_value:.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
391
-
392
- if "rot3" in self.debug and self.counter % 100 == 0:
393
- print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
394
-
395
- if "theta" in self.debug and self.counter % 100 == 0:
396
- if self.last_theta is None or abs(self.last_theta - theta.item()) > 1.0:
397
- self.last_theta = theta.item()
398
- print(f"[Theta] {self.last_theta:.2f}")
399
-
400
- self.counter += 1
401
- return freqs.unsqueeze(0)
402
-
403
- @staticmethod
404
- def apply_rotary(x, freqs):
405
- x1 = x[..., :freqs.shape[-1]*2]
406
- x2 = x[..., freqs.shape[-1]*2:]
407
- orig_shape = x1.shape
408
- if x1.ndim == 2:
409
- x1 = x1.unsqueeze(0)
410
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
411
- x1 = torch.view_as_complex(x1) * freqs
412
- x1 = torch.view_as_real(x1).flatten(-2)
413
- x1 = x1.view(orig_shape)
414
- return torch.cat([x1.type_as(x), x2], dim=-1)
415
-
416
- class MultiheadA(nn.Module):
417
- _seen = set()
418
- rbf = False
419
- def __init__(self, dims: int, head: int, rotary_emb: bool = True,
420
- zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
421
- super(MultiheadA, self).__init__()
422
-
423
- self.dims = dims
424
- self.head = head
425
- self.head_dim = dims // head
426
- self.debug = debug
427
- self.counter = 0
428
-
429
- self.q = Linear(dims, dims).to(device, dtype)
430
- self.k = Linear(dims, dims, bias=False).to(device, dtype)
431
- self.v = Linear(dims, dims).to(device, dtype)
432
- self.o = Linear(dims, dims).to(device, dtype)
433
-
434
- self.pad_token = 0
435
- self.rotary_emb = rotary_emb
436
- self.minz = minz
437
- self.maxz = maxz
438
- self.zero_val = zero_val
439
- self.optim_attn = optim_attn
440
- self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
441
-
442
- if rotary_emb:
443
- self.rope = rotary(
444
- dims=dims,
445
- head=head,
446
- debug=debug,
447
- radii=True,
448
- )
449
- else:
450
- self.rope = None
451
-
452
- def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
453
- scale = (self.dims // self.head) ** -0.25
454
- dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
455
- if rbf_ratio <= 0.0:
456
- return dot_scores
457
- q_norm = q.pow(2).sum(dim=-1, keepdim=True)
458
- k_norm = k.pow(2).sum(dim=-1, keepdim=True)
459
- qk = torch.matmul(q, k.transpose(-1, -2))
460
- dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
461
- rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
462
- return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
463
-
464
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
465
- x = x.to(device, dtype)
466
- if xa is not None:
467
- xa = xa.to(device, dtype)
468
-
469
- batch, ctx, dims = x.shape
470
- scale = (self.dims // self.head) ** -0.25
471
-
472
- z = default(xa, x).to(device, dtype)
473
- q = self.q(x)
474
- k = self.k(z)
475
- v = self.v(z)
476
- qlen = q.shape[1]
477
- klen = k.shape[1]
478
-
479
- if self.rotary_emb:
480
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
481
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
482
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
483
- qlen = q.shape[2]
484
- klen = k.shape[2]
485
-
486
- q = self.rope.apply_rotary(q, (self.rope(qlen, enc=enc, layer=layer)))
487
- k = self.rope.apply_rotary(k, (self.rope(klen, enc=enc, layer=layer)))
488
- else:
489
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
490
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
491
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
492
- batch, head, ctx, head_dim = q.shape
493
-
494
- if self.rbf:
495
- qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
496
-
497
- qk = (q * scale) @ (k * scale).transpose(-1, -2)
498
- if self.rope.use_pbias:
499
- f0 = enc.get("f0", None) if enc is not None else None
500
- pbias = self.rope.use_pbias(f0)
501
- if pbias is not None:
502
- qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
503
- token_ids = k[:, :, :, 0]
504
- zscale = torch.ones_like(token_ids, device=device, dtype=dtype)
505
- fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
506
- zscale[token_ids.float() == self.pad_token] = fzero
507
-
508
- if mask is not None:
509
- mask = mask[:q.shape[2], :q.shape[2]]
510
- qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
511
- qk = qk * zscale.unsqueeze(-2)
512
- w = F.softmax(qk, dim=-1).to(q.dtype)
513
- wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
514
-
515
- if "multihead" in self.debug and self.counter % 100 == 0:
516
- print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
517
- self.counter += 1
518
- return self.o(wv), qk.detach()
519
-
520
- class t_gate(nn.Module):
521
- def __init__(self, dims, num_types=4):
522
- super().__init__()
523
- self.gate_projections = nn.ModuleList([
524
- nn.Sequential(Linear(dims, 1), nn.Sigmoid())
525
- for _ in range(num_types)])
526
- self.type_classifier = nn.Sequential(
527
- Linear(dims, num_types),
528
- nn.Softmax(dim=-1))
529
- def forward(self, x):
530
- type_probs = self.type_classifier(x)
531
- gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
532
- comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
533
- return comb_gate
534
-
535
- class m_gate(nn.Module):
536
- def __init__(self, dims, mem_size=64):
537
- super().__init__()
538
- self.m_key = nn.Parameter(torch.randn(mem_size, dims))
539
- self.m_val = nn.Parameter(torch.randn(mem_size, 1))
540
- self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
541
-
542
- def forward(self, x):
543
- d_gate = torch.sigmoid(self.gate_proj(x))
544
- attention = torch.matmul(x, self.m_key.transpose(0, 1))
545
- attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
546
- m_gate = torch.matmul(attention, self.m_val)
547
- m_gate = torch.sigmoid(m_gate)
548
- return 0.5 * (d_gate + m_gate)
549
-
550
- class c_gate(nn.Module):
551
- def __init__(self, dims):
552
- super().__init__()
553
- self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
554
- self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
555
- self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
556
- self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
557
- self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
558
- self.integ = Linear(dims*5, dims)
559
-
560
- def forward(self, x, features):
561
- s_feat = features.get("spectrogram", x)
562
- w_feat = features.get("waveform", x)
563
- p_feat = features.get("pitch", x)
564
- e_feat = features.get("envelope", x)
565
- ph_feat = features.get("phase", x)
566
- s = self.s_gate(x) * s_feat
567
- w = self.w_gate(x) * w_feat
568
- p = self.p_gate(x) * p_feat
569
- e = self.e_gate(x) * e_feat
570
- ph = self.ph_gate(x) * ph_feat
571
- comb = torch.cat([s, w, p, e, ph], dim=-1)
572
- return self.integ(comb)
573
-
574
- class Residual(nn.Module):
575
- _seen = set()
576
- def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
577
- tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
578
- super().__init__()
579
-
580
- self.dims = dims
581
- self.head = head
582
- self.ctx = ctx
583
- self.head_dim = dims // head
584
- self.cross_attn = cross_attn
585
- self.features = features
586
- self.debug = debug
587
- self.counter = 0
588
- self.dropout = 0.01
589
-
590
- self.t_gate = tgate
591
- self.m_gate = mgate
592
- self.c_gate = cgate
593
- self.skip_gates=True
594
-
595
- self.blend = nn.Parameter(torch.tensor(0.5))
596
-
597
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
598
- "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
599
- "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
600
- "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
601
- act_fn = act_map.get(act, nn.GELU())
602
-
603
- self.attna = MultiheadA(dims=dims, head=head, rotary_emb=True, debug=debug)
604
- self.attnb = (MultiheadA(dims=dims, head=head, rotary_emb=True, debug=debug) if cross_attn else None)
605
-
606
- mlp = dims * 4
607
- self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
608
-
609
- self.t_gate = t_gate(dims=dims, num_types=4) if t_gate else None
610
- self.m_gate = m_gate(dims=dims, mem_size=mem_size) if m_gate else None
611
- self.c_gate = c_gate(dims=dims) if cgate else None
612
-
613
- self.lna = RMSNorm(dims)
614
- self.lnb = RMSNorm(dims) if cross_attn else None
615
- self.lnc = RMSNorm(dims)
616
-
617
- if not any([t_gate, m_gate, c_gate]):
618
- self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
619
-
620
- def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
621
- x = x.to(device, dtype)
622
- if xa is not None:
623
- xa = xa.to(device, dtype)
624
-
625
- bln = self.blend
626
- x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
627
-
628
- if self.attnb and xa is not None:
629
- c = self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
630
- b = torch.sigmoid(bln)
631
- x = b * x + (1 - b) * c
632
-
633
- normx = self.lnc(x)
634
- mlp_out = self.mlp(normx)
635
-
636
- if self.skip_gates:
637
- x = x + mlp_out
638
- else:
639
- if self.t_gate:
640
- gate = self.t_gate(normx)
641
- x = x + gate * mlp_out
642
-
643
- elif self.m_gate:
644
- gate = self.m_gate(normx)
645
- x = x + gate * mlp_out
646
-
647
- elif self.c_gate:
648
- gate_output = self.c_gate(normx, self.features)
649
- x = x + gate_output
650
-
651
- else:
652
- if hasattr(self, 'mlp_gate'):
653
- mlp_gate = self.mlp_gate(normx)
654
- x = x + mlp_gate * mlp_out
655
- else:
656
- x = x + mlp_out
657
-
658
- if "residual" in self.debug and self.counter % 100 == 0:
659
- print(f"Step {self.counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
660
- if self.t_gate:
661
- print(f"Step {self.counter}: Using t_gate: {self.t_gate}")
662
- elif self.m_gate:
663
- print(f"Step {self.counter}: Using m_gate: {self.m_gate}")
664
- elif self.c_gate:
665
- print(f"Step {self.counter}: Using c_gate: {self.c_gate}")
666
- else:
667
- print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
668
- self.counter += 1
669
-
670
- return x
671
-
672
- class FEncoder(nn.Module):
673
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
674
- super().__init__()
675
-
676
- self.head = head
677
- self.head_dim = dims // head
678
- self.dropout = 0.01
679
- self.use_rope = use_rope
680
- self.dims = dims
681
-
682
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
683
- act_fn = act_map.get(act, nn.GELU())
684
-
685
- self.encoder = nn.Sequential(
686
- Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
687
- Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
688
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
689
-
690
- if use_rope:
691
- if spec_shape is not None:
692
- self.rope = rotary(
693
- dims=self.head_dim,
694
- use_2d_axial=True,
695
- spec_shape=spec_shape, debug=[])
696
- else:
697
- self.rope = rotary(
698
- dims=self.head_dim,
699
- use_2d_axial=False, debug=[])
700
- else:
701
- self.rope = None
702
- self.positional = lambda length: sinusoids(length, dims)
703
-
704
- self.norm = RMSNorm(dims)
705
- self._norm = RMSNorm(dims)
706
-
707
- def apply_rope_to_features(self, x, layer=None, feature_type="audio"):
708
- if feature_type in ["envelope", "phase"]:
709
- feature_type = "spectrogram"
710
- batch, ctx, dims = x.shape
711
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
712
- if feature_type == "spectrogram" and hasattr(self.rope, 'use_2d_axial') and self.rope.use_2d_axial:
713
- rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
714
- else:
715
- rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
716
- x = self.rope.apply_rotary(x, rope_freqs)
717
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
718
- return x
719
-
720
- def forward(self, x, enc=None, layer=None, feature_type="audio"):
721
- x = self.encoder(x).permute(0, 2, 1)
722
- if self.use_rope:
723
- x = self.apply_rope_to_features(x, layer=layer, feature_type=feature_type)
724
- else:
725
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
726
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
727
- x = self._norm(x)
728
- return x
729
-
730
- class WEncoder(nn.Module):
731
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
732
- super().__init__()
733
-
734
- self.head = head
735
- self.head_dim = dims // head
736
- self.dropout = 0.01
737
- self.use_rope = use_rope
738
- self.dims = dims
739
-
740
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
741
- act_fn = act_map.get(act, nn.GELU())
742
-
743
- self.downsample = nn.Sequential(
744
- Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
745
- Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn,
746
- Conv1d(dims//4, dims, kernel_size=9, stride=5, padding=4), act_fn)
747
-
748
- self.encoder = nn.Sequential(
749
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn,
750
- Conv1d(dims, dims, kernel_size=1), act_fn)
751
- if use_rope:
752
- self.rope = rotary(
753
- dims=self.head_dim,
754
- use_2d_axial=False,
755
- theta=50.0, debug=[])
756
- else:
757
- self.rope = None
758
- self.positional = lambda length: sinusoids(length, dims)
759
- self.norm = RMSNorm(dims)
760
-
761
- def apply_rope_to_features(self, x, layer=None):
762
- if not self.use_rope or self.rope is None:
763
- return x
764
- batch, ctx, dims = x.shape
765
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
766
- rope_freqs = self.rope(ctx, layer=layer, input_type="waveform")
767
- x = self.rope.apply_rotary(x, rope_freqs)
768
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
769
- return x
770
-
771
- def forward(self, x, enc=None, layer=None, feature_type="waveform"):
772
- x = self.downsample(x)
773
- x = self.encoder(x)
774
- x = x.permute(0, 2, 1)
775
- if self.use_rope:
776
- x = self.apply_rope_to_features(x, layer=layer)
777
- else:
778
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
779
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
780
- return self.norm(x)
781
-
782
- class PEncoder(nn.Module):
783
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
784
- super().__init__()
785
-
786
- self.head = head
787
- self.head_dim = dims // head
788
- self.dropout = 0.01
789
- self.use_rope = use_rope
790
- self.dims = dims
791
-
792
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
793
- act_fn = act_map.get(act, nn.GELU())
794
-
795
- self.encoder = nn.Sequential(
796
- Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
797
- Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
798
- Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2), act_fn)
799
-
800
- if use_rope:
801
- self.rope = rotary(
802
- dims=self.head_dim,
803
- use_2d_axial=False,
804
- theta=100.0, debug=[])
805
- else:
806
- self.rope = None
807
- self.positional = lambda length: sinusoids(length, dims)
808
- self.norm = RMSNorm(dims)
809
-
810
- def apply_rope_to_features(self, x, layer=None):
811
- if not self.use_rope or self.rope is None:
812
- return x
813
- batch, ctx, dims = x.shape
814
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
815
- rope_freqs = self.rope(ctx, layer=layer, input_type="pitch")
816
- x = self.rope.apply_rotary(x, rope_freqs)
817
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
818
- return x
819
-
820
- def forward(self, x, enc=None, layer=None, feature_type="pitch"):
821
- x = self.encoder(x).permute(0, 2, 1)
822
- if self.use_rope:
823
- x = self.apply_rope_to_features(x, layer=layer)
824
- else:
825
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
826
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
827
- x = self.norm(x)
828
- return x
829
-
830
- class AudioEncoder(nn.Module):
831
- _seen = set()
832
- def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
833
- super(AudioEncoder, self).__init__()
834
-
835
- self.dims = dims
836
- self.head = head
837
- self.ctx = ctx
838
- self.head_dim = dims // head
839
- self.debug = debug
840
- self.counter = 0
841
- self.features = features
842
- self.dropout = 0.01
843
-
844
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),"tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
845
- act_fn = act_map.get(act, nn.GELU())
846
-
847
- if features == ["spectrogram", "waveform", "pitch"]:
848
- cgate=True
849
- else:
850
- cgate = False
851
-
852
- self.blocks = nn.ModuleDict({
853
- "spectrogram": nn.ModuleList(
854
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
855
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "spectrogram" in features else None
856
- ),
857
- "waveform": nn.ModuleList(
858
- [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
859
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "waveform" in features else None
860
- ),
861
- "pitch": nn.ModuleList(
862
- [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
863
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
864
- ),
865
- "envelope": nn.ModuleList(
866
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
867
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "envelope" in features else None
868
- ),
869
- "phase": nn.ModuleList(
870
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
871
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None
872
- )
873
- })
874
-
875
- def forward(self, enc, layer="encoder"):
876
- enc = dict_to(enc, device, dtype)
877
-
878
- if self.counter < 1:
879
- s = enc.get("spectrogram")
880
- w = enc.get("waveform")
881
- p = default(enc.get("pitch"), enc.get("f0"))
882
- plot_waveform(x=s, w=w, p=p, hop_length=128)
883
-
884
- out = {}
885
- out.update(enc)
886
-
887
- for f in self.features:
888
- if f in enc and f in self.blocks:
889
- x = enc[f]
890
- for block in self.blocks[f]:
891
- x = block(x, enc=enc, layer=layer)
892
- out[f] = x
893
-
894
- if "encoder" in self.debug and self.counter % 100 == 0:
895
- shapes = {k: v.shape for k, v in enc.items()}
896
- print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
897
- self.counter += 1
898
- return out
899
-
900
- class TextDecoder(nn.Module):
901
- def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
902
- debug: List[str], features: List[str]):
903
- super(TextDecoder, self).__init__()
904
-
905
- self.ctx = ctx
906
- self.dims = dims
907
- self.head = head
908
- self.head_dim = dims // head
909
- self.debug = debug
910
- self.counter = 0
911
- self.dropout = 0.01
912
- self.features = features
913
-
914
- self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
915
- with torch.no_grad():
916
- self.token.weight[0].zero_()
917
- self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
918
-
919
- self.block = nn.ModuleList([
920
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
921
- for _ in range(layer)])
922
-
923
- self.blocks = nn.ModuleDict({
924
- f: nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
925
- for _ in range(layer)]) for f in features})
926
-
927
- self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
928
- self.ln_dec = RMSNorm(dims)
929
-
930
- mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
931
- self.register_buffer("mask", mask, persistent=False)
932
-
933
- def forward(self, x, enc, order=None, layer='decoder', sequential=False) -> Tensor:
934
- enc = dict_to(enc, device, dtype)
935
- x = x.to(device)
936
- bln = self.blend
937
-
938
- if order is None:
939
- order = self.features
940
-
941
- mask = self.mask[:x.shape[1], :x.shape[1]]
942
- x = self.token(x) + self.positional[:x.shape[1]]
943
- x = F.dropout(x, p=self.dropout, training=self.training)
944
-
945
- for block in self.block:
946
- x = block(x, xa=None, mask=mask, enc=None, layer=layer)
947
-
948
- for f in order:
949
- if f in enc:
950
- xa = enc[f]
951
- for block in self.blocks[f]:
952
- out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
953
-
954
- if sequential:
955
- x = out
956
- else:
957
- a = torch.sigmoid(bln[f])
958
- x = a * out + (1 - a) * x
959
-
960
- if "decoder" in self.debug and self.counter % 100 == 0:
961
- print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
962
- self.counter += 1
963
-
964
- x = self.ln_dec(x)
965
- return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
966
-
967
- class Echo(nn.Module):
968
- def __init__(self, param: Dimensions):
969
- super().__init__()
970
- self.param = param
971
- self.count = 0
972
-
973
- self.encoder = AudioEncoder(
974
- mels=param.mels,
975
- ctx=param.aud_ctx,
976
- dims=param.aud_dims,
977
- head=param.aud_head,
978
- layer=param.aud_idx,
979
- act=param.act,
980
- debug=param.debug,
981
- features=param.features,
982
- )
983
-
984
- self.decoder = TextDecoder(
985
- vocab=param.vocab,
986
- ctx=param.text_ctx,
987
- dims=param.text_dims,
988
- head=param.text_head,
989
- layer=param.text_idx,
990
- cross_attn=param.cross_attn,
991
- debug=param.debug,
992
- features=param.features,
993
- )
994
-
995
- all_head = torch.zeros(self.param.text_idx, self.param.text_head, dtype=torch.bool)
996
- all_head[self.param.text_idx // 2 :] = True
997
- self.register_buffer("alignment_head", all_head.to_sparse(), persistent=False)
998
-
999
- def update_base(self, f0):
1000
- for name, module in self.encoder.named_modules():
1001
- if isinstance(module, (rotary)):
1002
- module.update_base(f0)
1003
-
1004
- for name, module in self.decoder.named_modules():
1005
- if isinstance(module, (rotary)):
1006
- module.update_base(f0)
1007
-
1008
- def set_alignment_head(self, dump: bytes):
1009
- array = np.frombuffer(
1010
- gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
1011
- mask = torch.from_numpy(array).reshape(
1012
- self.param.text_idx, self.param.text_head)
1013
- self.register_buffer("alignment_head", mask.to_sparse(), persistent=False)
1014
-
1015
- def embed_audio(self, spectrogram: torch.Tensor):
1016
- return self.encoder(spectrogram)
1017
-
1018
- def logits(self,input_ids: torch.Tensor, encoder_output: torch.Tensor):
1019
- return self.decoder(input_ids, encoder_output)
1020
-
1021
- def forward(self,
1022
- decoder_input_ids=None,
1023
- labels=None,
1024
- waveform: Optional[torch.Tensor]=None,
1025
- input_ids=None,
1026
- spectrogram: torch.Tensor=None,
1027
- pitch: Optional[torch.Tensor]=None,
1028
- f0: Optional[torch.Tensor]=None,
1029
- f0d: Optional[torch.Tensor]=None,
1030
- envelope: Optional[torch.Tensor]=None,
1031
- phase: Optional[torch.Tensor]=None,
1032
- ) -> Dict[str, torch.Tensor]:
1033
-
1034
- decoder_input_ids = input_ids
1035
- encoder_inputs = {}
1036
- if spectrogram is not None:
1037
- encoder_inputs["spectrogram"] = spectrogram
1038
- if waveform is not None:
1039
- encoder_inputs["waveform"] = waveform
1040
- if pitch is not None:
1041
- encoder_inputs["pitch"] = pitch
1042
- if envelope is not None:
1043
- encoder_inputs["envelope"] = envelope
1044
- if phase is not None:
1045
- encoder_inputs["phase"] = phase
1046
- if f0 is not None:
1047
- encoder_inputs["f0"] = f0
1048
-
1049
- encoder_outputs = self.encoder(encoder_inputs)
1050
- logits = self.decoder(input_ids, encoder_outputs)
1051
-
1052
- loss = None
1053
- if labels is not None:
1054
- loss = F.cross_entropy(
1055
- logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
1056
-
1057
- self.count += 1
1058
- return {
1059
- "logits": logits,
1060
- "loss": loss,
1061
- }
1062
-
1063
- @property
1064
- def device(self):
1065
- return next(self.parameters()).device
1066
- @property
1067
- def dtype(self):
1068
- return next(self.parameters()).dtype
1069
-
1070
- def _init_weights(self, module):
1071
- std = 0.02
1072
- self.init_counts = {
1073
- "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
1074
- "Conv2d": 0, "SEBlock": 0, "TextDecoder": 0, "AudioEncoder": 0,
1075
- "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
1076
- "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
1077
- "WEncoder": 0, "PEncoder": 0}
1078
-
1079
- for name, module in self.named_modules():
1080
- if isinstance(module, RMSNorm):
1081
- nn.init.ones_(module.weight)
1082
- self.init_counts["RMSNorm"] += 1
1083
- elif isinstance(module, nn.Linear):
1084
- if module.weight is not None:
1085
- nn.init.xavier_uniform_(module.weight)
1086
- if module.bias is not None:
1087
- nn.init.zeros_(module.bias)
1088
- self.init_counts["Linear"] += 1
1089
- elif isinstance(module, Conv1d):
1090
- nn.init.normal_(module.weight, mean=0.0, std=std)
1091
- if module.bias is not None:
1092
- nn.init.zeros_(module.bias)
1093
- self.init_counts["Conv1d"] += 1
1094
- elif isinstance(module, Conv2d):
1095
- nn.init.normal_(module.weight, mean=0.0, std=std)
1096
- if module.bias is not None:
1097
- nn.init.zeros_(module.bias)
1098
- self.init_counts["Conv2d"] += 1
1099
- elif isinstance(module, MultiheadA):
1100
-
1101
- self.init_counts["MultiheadA"] += 1
1102
- elif isinstance(module, TextDecoder):
1103
- self.init_counts["TextDecoder"] += 1
1104
- elif isinstance(module, AudioEncoder):
1105
- self.init_counts["AudioEncoder"] += 1
1106
- elif isinstance(module, Residual):
1107
- self.init_counts["Residual"] += 1
1108
-
1109
- def init_weights(self):
1110
- print("Initializing model weights...")
1111
- self.apply(self._init_weights)
1112
- print("Initialization summary:")
1113
- for module_type, count in self.init_counts.items():
1114
- if count > 0:
1115
- print(f"{module_type}: {count}")
1116
-
1117
- def register_gradient_hooks(self):
1118
- for name, param in self.named_parameters():
1119
- if param.requires_grad:
1120
- if "encoder" in name:
1121
- param.register_hook(lambda grad, n=name: self._print_encoder_grad(n, grad))
1122
- elif "decoder" in name:
1123
- param.register_hook(lambda grad, n=name: self._print_decoder_grad(n, grad))
1124
-
1125
- print("Gradient debugging hooks registered")
1126
- return self
1127
-
1128
- def _print_encoder_grad(self, name, grad):
1129
- if grad is not None and self.count == 10:
1130
- norm = grad.median().item()
1131
- print(f"ENCODER GRAD: {name} = {norm:.6f}")
1132
-
1133
- return None
1134
-
1135
- def _print_decoder_grad(self, name, grad):
1136
- if grad is not None and self.count == 10:
1137
- norm = grad.median().item()
1138
- print(f"DECODER GRAD: {name} = {norm:.6f}")
1139
- return None
1140
-
1141
- def resetcounter(self):
1142
- self.counter = 0
1143
- print("Counter reset to 0.")
1144
-
1145
- def ctx_to_samples(audio_ctx, hop_length):
1146
- samples_token = hop_length * 2
1147
- n_samples = audio_ctx * samples_token
1148
- return n_samples
1149
-
1150
- def load_wave(wave_data, sample_rate):
1151
- if isinstance(wave_data, str):
1152
- waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1153
- elif isinstance(wave_data, dict):
1154
- waveform = torch.tensor(data=wave_data["array"]).float()
1155
- sr = wave_data["sampling_rate"]
1156
- else:
1157
- raise TypeError("Invalid wave_data format.")
1158
-
1159
- if sr != sample_rate:
1160
- original_length = waveform.shape[1]
1161
- target_length = int(original_length * (sample_rate / sr))
1162
-
1163
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1164
- waveform = resampler(waveform)
1165
-
1166
- return waveform
1167
-
1168
- def pad(array, target_length, axis=-1, dtype: torch.dtype = torch.float32):
1169
- if isinstance(array, np.ndarray):
1170
- array = torch.from_numpy(array).to(dtype)
1171
- if torch.is_tensor(array):
1172
- if array.shape[axis] > target_length:
1173
- array = array.index_select(
1174
- dim=axis,
1175
- index=torch.arange(
1176
- end=target_length, device=array.device, dtype=torch.long
1177
- ),
1178
- )
1179
- if array.shape[axis] < target_length:
1180
- pad_widths = [(0, 0)] * array.ndim
1181
- pad_widths[axis] = (0, target_length - array.shape[axis])
1182
- array = F.pad(
1183
- input=array, pad=[pad for sizes in pad_widths[::-1] for pad in sizes]
1184
- )
1185
- array = array.to(dtype=dtype)
1186
- else:
1187
- raise TypeError(
1188
- f"Unsupported input type: {type(array)}. Expected torch.Tensor or np.ndarray."
1189
- )
1190
- return array
1191
-
1192
- def exact_div(x, y):
1193
- assert x % y == 0
1194
- return x // y
1195
-
1196
- metrics = evaluate.load(path="wer")
1197
-
1198
- def hilbert_transform(x):
1199
- N = x.shape[-1]
1200
- xf = torch.fft.rfft(x)
1201
- h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
1202
- if N % 2 == 0:
1203
- h[0] = h[N//2] = 1
1204
- h[1:N//2] = 2
1205
- else:
1206
- h[0] = 1
1207
- h[1:(N+1)//2] = 2
1208
- return torch.fft.irfft(xf * h, n=N)
1209
-
1210
- def analytic_signal(x):
1211
- return x + 1j * hilbert_transform(x)
1212
-
1213
- def hilbert_transform_2d(x, dim=-1):
1214
- N = x.shape[dim]
1215
- if dim == -1 or dim == len(x.shape) - 1:
1216
- xf = torch.fft.rfft(x)
1217
- else:
1218
- xf = torch.fft.rfft(x, dim=dim)
1219
- h_shape = [1] * len(x.shape)
1220
- h_shape[dim] = N // 2 + 1
1221
- h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
1222
- if dim == -1 or dim == len(x.shape) - 1:
1223
- if N % 2 == 0:
1224
- h[..., 0] = h[..., -1] = 1
1225
- h[..., 1:-1] = 2
1226
- else:
1227
- h[..., 0] = 1
1228
- h[..., 1:] = 2
1229
- else:
1230
- pass
1231
- return torch.fft.irfft(xf * h, n=N, dim=dim)
1232
-
1233
- def hilbert_transform_true_2d(x):
1234
- xf = torch.fft.rfft2(x)
1235
- h1, h2 = torch.meshgrid(
1236
- torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
1237
- torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
1238
- indexing='ij')
1239
- h = -1j / (math.pi * (h1 + 1j*h2))
1240
- h[0, 0] = 0
1241
- return torch.fft.irfft2(xf * h.to(x.device))
1242
-
1243
- def process_spectrogram_with_hilbert(spec):
1244
- analytic = spec + 1j * hilbert_transform(spec)
1245
- envelope = torch.abs(analytic)
1246
- phase = torch.angle(analytic)
1247
- return envelope, phase
1248
-
1249
- @dataclass
1250
- class DataCollator:
1251
- tokenizer: Any
1252
-
1253
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1254
- all_keys = set()
1255
- for f in features:
1256
- all_keys.update(f.keys())
1257
- batch = {}
1258
- pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1259
- bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1260
-
1261
- for key in all_keys:
1262
- if key == "label":
1263
- labels_list = [f["label"] for f in features]
1264
- max_len = max(len(l) for l in labels_list)
1265
- all_ids, all_labels = [], []
1266
- for label in labels_list:
1267
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1268
- decoder_input = [bos_token_id] + label_list
1269
- label_eos = label_list + [pad_token_id]
1270
- input_len = max_len + 1 - len(decoder_input)
1271
- label_len = max_len + 1 - len(label_eos)
1272
- padded_input = decoder_input + [pad_token_id] * input_len
1273
- padded_labels = label_eos + [pad_token_id] * label_len
1274
- all_ids.append(padded_input)
1275
- all_labels.append(padded_labels)
1276
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1277
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1278
-
1279
- elif key in ["spectrogram", "waveform", "pitch", "f0", "envelope", "phase"]:
1280
- items = [f[key] for f in features if key in f]
1281
- max_len = max(item.shape[-1] for item in items)
1282
- padded = []
1283
- for item in items:
1284
- pad_width = max_len - item.shape[-1]
1285
- if pad_width > 0:
1286
- pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1287
- else:
1288
- pad_item = item
1289
- padded.append(pad_item)
1290
- batch[key] = torch.stack(padded)
1291
- if key == "spectrogram":
1292
- batch["spectrogram"] = batch[key]
1293
- return batch
1294
-
1295
- def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
1296
- hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
1297
- pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
1298
- norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
1299
-
1300
- dtype = torch.float32
1301
- device = torch.device("cuda:0")
1302
- audio = batch["audio"]
1303
- sampling_rate = audio["sampling_rate"]
1304
- sr = audio["sampling_rate"]
1305
- wav = load_wave(wave_data=audio, sample_rate=sr)
1306
-
1307
- if spectrogram:
1308
- transform = torchaudio.transforms.MelSpectrogram(
1309
- f_max=fmax,
1310
- f_min=fmin,
1311
- n_mels=n_mels,
1312
- sample_rate=sr,
1313
- n_fft=n_fft,
1314
- hop_length=hop_length,
1315
- norm=norm,
1316
- normalized=normalized,
1317
- power=power,
1318
- center=center,
1319
- mel_scale=mel_scale,
1320
- window_fn=window_fn,
1321
- pad_mode=pad_mode)
1322
-
1323
- mel_spectrogram = transform(wav)
1324
- log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1325
- log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1326
- spec = (log_mel + 4.0) / 4.0
1327
- spec = torch.tensor(spec)
1328
- batch["spectrogram"] = spec
1329
-
1330
- if hilbert:
1331
- envelope_list = []
1332
- phase_list = []
1333
-
1334
- for ch_idx in range(spec.shape[0]):
1335
- envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
1336
- envelope_list.append(envelope)
1337
- phase_list.append(phase)
1338
-
1339
- batch["envelope"] = torch.stack(envelope_list)
1340
- batch["phase"] = torch.stack(phase_list)
1341
-
1342
- wav_1d = wav.unsqueeze(0)
1343
-
1344
- if waveforms:
1345
- batch["waveform"] = wav_1d
1346
-
1347
- if pitch:
1348
- wav_np = wav.numpy().astype(np.float64)
1349
- f0, t = pw.dio(wav_np, sampling_rate,
1350
- frame_period=hop_length/sampling_rate*1000)
1351
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1352
- f0 = torch.from_numpy(f0)
1353
- batch["pitch"] = f0.unsqueeze(0)
1354
-
1355
- if frequency:
1356
- wav_np = wav.numpy().astype(np.float64)
1357
- f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000)
1358
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1359
- f0 = torch.from_numpy(f0)
1360
- batch["f0"] = f0
1361
-
1362
- if spectrogram and waveforms and pitch:
1363
- spec_mean = batch["spectrogram"].mean()
1364
- spec_std = batch["spectrogram"].std() + 1e-6
1365
- batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
1366
-
1367
- wav_mean = batch["waveform"].mean()
1368
- wav_std = batch["waveform"].std() + 1e-6
1369
- batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
1370
-
1371
- if batch["pitch"].max() > 1.0:
1372
- pitch_min = 50.0
1373
- pitch_max = 500.0
1374
- batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
1375
-
1376
- batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1377
- return batch
1378
-
1379
- def compute_metrics(pred, tokenizer):
1380
- pred_ids = pred["predictions"]
1381
- label_ids = pred["label_ids"]
1382
- if isinstance(pred_ids, tuple):
1383
- pred_ids = pred_ids[0]
1384
- else:
1385
- pred_ids = pred_ids
1386
- if pred_ids.ndim == 3:
1387
- pred_ids = np.argmax(pred_ids, axis=-1)
1388
- label_ids[label_ids == -100] = tokenizer.pad_token_id
1389
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1390
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1391
- wer = metrics.compute(predictions=pred_str, references=label_str)
1392
- return {"wer": wer}
1393
-
1394
- logger = logging.getLogger(__name__)
1395
-
1396
- def create_model(param: Dimensions) -> Echo:
1397
- model = Echo(param).to('cuda')
1398
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1399
- total_params = sum(p.numel() for p in model.parameters())
1400
- logger.info(f"Trainable parameters: {trainable_params:,}")
1401
- logger.info(f"Total parameters: {total_params:,}")
1402
- print(f"Trainable parameters: {trainable_params:,}")
1403
- print(f"Total parameters: {total_params:,}")
1404
-
1405
- return model
1406
-
1407
- def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
1408
- from tokenizers import Tokenizer
1409
- tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
1410
- orig_encode = tokenizer.encode
1411
- def enc(text, add_special_tokens=True):
1412
- ids = orig_encode(text).ids
1413
- if not add_special_tokens:
1414
- sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1415
- ids = [id for id in ids if id not in sp_ids]
1416
- return ids
1417
-
1418
- def bdec(ids_list, skip_special_tokens=True):
1419
- results = []
1420
- for ids in ids_list:
1421
- if not isinstance(ids, list):
1422
- ids = ids.tolist()
1423
- if skip_special_tokens:
1424
- ids = [id for id in ids if id not in [0, 1, 2]]
1425
- results.append(tokenizer.decode(ids))
1426
- return results
1427
- def save_pretrained(save_dir):
1428
- os.makedirs(save_dir, exist_ok=True)
1429
- tokenizer.save(f"{save_dir}/tokenizer.json")
1430
- tokenizer.encode = enc
1431
- tokenizer.batch_decode = bdec
1432
- tokenizer.save_pretrained = save_pretrained
1433
- tokenizer.pad_token_id = 0
1434
- tokenizer.bos_token_id = 1
1435
- tokenizer.eos_token_id = 2
1436
- return tokenizer
1437
-
1438
- def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
1439
-
1440
- if sanity_check:
1441
-
1442
- dataset = load_dataset(
1443
- "./librispeech_asr.py", "clean", "train.100",
1444
- storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}},
1445
- token=token, trust_remote_code=True, streaming=False)
1446
-
1447
- dataset = dataset.rename_column("text", "transcription")
1448
- dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1449
-
1450
- dataset = dataset["test"].take(10)
1451
- dataset = dataset.select_columns(["audio", "transcription"])
1452
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1453
- dataset = dataset.map(function=prepare_fn, remove_columns=["audio", "transcription"]).with_format(type="torch")
1454
- train_dataset = dataset
1455
- test_dataset = dataset
1456
- else:
1457
- cache_dir = "./processed_datasets"
1458
- os.makedirs(cache_dir, exist_ok=True)
1459
- cache_file_train = os.path.join(cache_dir, "train.arrow")
1460
- cache_file_test = os.path.join(cache_dir, "test.arrow")
1461
-
1462
- if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
1463
- from datasets import Dataset
1464
- train_dataset = Dataset.load_from_disk(cache_file_train)
1465
- test_dataset = Dataset.load_from_disk(cache_file_test)
1466
- return train_dataset, test_dataset
1467
-
1468
- if dataset_config is None:
1469
- dataset_config = {
1470
- "spectrogram": True,
1471
- "waveforms": True,
1472
- "pitch": True,
1473
- "frequency": True,
1474
- "downsamples": True,
1475
- "hop_length": 128,
1476
- "fmin": 50,
1477
- "fmax": 2000,
1478
- "n_mels": 128,
1479
- "n_fft": 1024,
1480
- "sampling_rate": 16000,
1481
- }
1482
-
1483
- dataset = load_dataset(
1484
- "./librispeech_asr.py", "clean", "train.100",
1485
- storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}},
1486
- token=token, trust_remote_code=True, streaming=False)
1487
-
1488
- dataset = dataset.rename_column("text", "transcription")
1489
- dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1490
-
1491
- def filter_func(x):
1492
- return (0 < len(x["transcription"]) < 512 and
1493
- len(x["audio"]["array"]) > 0 and
1494
- len(x["audio"]["array"]) < 1500 * 160)
1495
-
1496
- dataset = dataset.filter(filter_func)
1497
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1498
-
1499
- train_dataset = dataset["train.100"].take(10000)
1500
- test_dataset = dataset["test"].take(1000)
1501
- train_dataset = train_dataset.map(
1502
- function=prepare_fn,
1503
- remove_columns=["audio", "transcription"]
1504
- ).with_format(type="torch")
1505
-
1506
- test_dataset = test_dataset.map(
1507
- function=prepare_fn,
1508
- remove_columns=["audio", "transcription"]
1509
- ).with_format(type="torch")
1510
-
1511
- train_dataset.save_to_disk(cache_file_train)
1512
- test_dataset.save_to_disk(cache_file_test)
1513
-
1514
- return train_dataset, test_dataset
1515
-
1516
- @dataclass
1517
- class DataCollator:
1518
- tokenizer: Any
1519
-
1520
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1521
- all_keys = set()
1522
- for f in features:
1523
- all_keys.update(f.keys())
1524
- batch = {}
1525
- pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1526
- bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1527
-
1528
- for key in all_keys:
1529
- if key == "label":
1530
- labels_list = [f["label"] for f in features]
1531
- max_len = max(len(l) for l in labels_list)
1532
- all_ids, all_labels = [], []
1533
- for label in labels_list:
1534
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1535
- decoder_input = [bos_token_id] + label_list
1536
- label_eos = label_list + [pad_token_id]
1537
- input_len = max_len + 1 - len(decoder_input)
1538
- label_len = max_len + 1 - len(label_eos)
1539
- padded_input = decoder_input + [pad_token_id] * input_len
1540
- padded_labels = label_eos + [pad_token_id] * label_len
1541
- all_ids.append(padded_input)
1542
- all_labels.append(padded_labels)
1543
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1544
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1545
- elif key in ["spectrogram", "waveform", "pitch", "f0", "envelope", "phase"]:
1546
- items = [f[key] for f in features if key in f]
1547
- max_len = max(item.shape[-1] for item in items)
1548
- padded = []
1549
- for item in items:
1550
- pad_width = max_len - item.shape[-1]
1551
- if pad_width > 0:
1552
- pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1553
- else:
1554
- pad_item = item
1555
- padded.append(pad_item)
1556
- batch[key] = torch.stack(padded)
1557
- if key == "spectrogram":
1558
- batch["spectrogram"] = batch[key]
1559
- return batch
1560
-
1561
- def train_and_evaluate(
1562
- model, tokenizer, train_loader, eval_loader, optimizer, scheduler, loss_fn,
1563
- max_steps=10000, device='cuda', accumulation_steps=1, clear_cache=True,
1564
- log_interval=10, eval_interval=100, save_interval=1000,
1565
- checkpoint_dir="checkpoint_dir", log_dir="log_dir"
1566
- ):
1567
- model.to(device)
1568
- global_step = 0
1569
- scaler = torch.GradScaler()
1570
- writer = SummaryWriter(log_dir=log_dir)
1571
- train_iterator = iter(train_loader)
1572
- total_loss = 0
1573
- step_in_report = 0
1574
- dataset_epochs = 0
1575
-
1576
- progress_bar = tqdm(total=max_steps, desc="Training Progress", leave=True, colour='green')
1577
-
1578
- model.train()
1579
- optimizer.zero_grad()
1580
-
1581
- while global_step < max_steps:
1582
- try:
1583
- batch = next(train_iterator)
1584
- except StopIteration:
1585
- train_iterator = iter(train_loader)
1586
- batch = next(train_iterator)
1587
- dataset_epochs += 1
1588
- print(f"Starting dataset epoch {dataset_epochs}")
1589
-
1590
- if step_in_report > 0:
1591
- avg_loss = total_loss / step_in_report
1592
- logging.info(f"Dataset iteration complete - Steps: {global_step}, Avg Loss: {avg_loss:.4f}")
1593
- total_loss = 0
1594
- step_in_report = 0
1595
-
1596
- start_time = time.time()
1597
-
1598
- batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
1599
-
1600
- with torch.autocast(device_type="cuda"):
1601
- output = model(**batch) if hasattr(model, '__call__') else model.forward(**batch)
1602
- logits = output["logits"] if isinstance(output, dict) and "logits" in output else output
1603
- labels = batch["labels"]
1604
- active_logits = logits.view(-1, logits.size(-1))
1605
- active_labels = labels.view(-1)
1606
- active_mask = active_labels != 0
1607
- active_logits = active_logits[active_mask]
1608
- active_labels = active_labels[active_mask]
1609
- loss = loss_fn(active_logits, active_labels)
1610
- total_loss += loss.item()
1611
- loss = loss / accumulation_steps
1612
-
1613
- scaler.scale(loss).backward()
1614
-
1615
- if (global_step + 1) % accumulation_steps == 0:
1616
- scaler.unscale_(optimizer)
1617
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
1618
- scaler.step(optimizer)
1619
- scaler.update()
1620
- optimizer.zero_grad()
1621
- if clear_cache:
1622
- torch.cuda.empty_cache()
1623
-
1624
- end_time = time.time()
1625
- samples_per_sec = batch["spectrogram"].size(0) / (end_time - start_time)
1626
-
1627
- if global_step % log_interval == 0:
1628
- writer.add_scalar(tag='Loss/train', scalar_value=total_loss / (global_step + 1), global_step=global_step)
1629
- lr = scheduler.get_last_lr()[0]
1630
- writer.add_scalar(tag='LearningRate', scalar_value=lr, global_step=global_step)
1631
- writer.add_scalar(tag='SamplesPerSec', scalar_value=samples_per_sec, global_step=global_step)
1632
-
1633
- if global_step % eval_interval == 0:
1634
- model.eval()
1635
- eval_start_time = time.time()
1636
- eval_loss = 0
1637
- all_predictions = []
1638
- all_labels = []
1639
- batch_count = 0
1640
- total_samples = 0
1641
-
1642
- with torch.no_grad():
1643
- for eval_batch in eval_loader:
1644
- eval_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in eval_batch.items()}
1645
- output = model(**eval_batch) if hasattr(model, '__call__') else model.forward(**eval_batch)
1646
- logits = output["logits"] if isinstance(output, dict) and "logits" in output else output
1647
- labels = eval_batch["labels"]
1648
- batch_size = logits.size(0)
1649
- total_samples += batch_size
1650
- loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
1651
- eval_loss += loss.item()
1652
- all_predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy().tolist())
1653
- all_labels.extend(labels.cpu().numpy().tolist())
1654
- batch_count += 1
1655
-
1656
- eval_time = time.time() - eval_start_time
1657
- loss_avg = eval_loss / batch_count if batch_count > 0 else 0
1658
- predictions = {"predictions": np.array(all_predictions, dtype=object), "label_ids": np.array(all_labels, dtype=object)}
1659
- metrics = compute_metrics(pred=predictions, tokenizer=tokenizer)
1660
-
1661
- writer.add_scalar('Loss/eval', loss_avg, global_step)
1662
- writer.add_scalar('WER', metrics['wer'], global_step)
1663
- writer.add_scalar('EvalSamples', total_samples, global_step)
1664
- writer.add_scalar('EvalTimeSeconds', eval_time, global_step)
1665
-
1666
- lr = scheduler.get_last_lr()[0]
1667
- print(f"• STEP:{global_step} • samp:{samples_per_sec:.1f} • WER:{metrics['wer']:.2f}% • Loss:{loss_avg:.4f} • LR:{lr:.8f}")
1668
- logging.info(f"EVALUATION STEP {global_step} - WER: {metrics['wer']:.2f}%, Loss: {loss_avg:.4f}, LR: {lr:.8f}")
1669
- model.train()
1670
-
1671
- if global_step % save_interval == 0:
1672
- checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{global_step}.pt')
1673
- torch.save(model.state_dict(), checkpoint_path)
1674
- logging.info(f"Model saved at step {global_step} to {checkpoint_path}")
1675
-
1676
- lr = scheduler.get_last_lr()[0]
1677
- scheduler.step()
1678
- global_step += 1
1679
- step_in_report += 1
1680
-
1681
- avg_loss = total_loss / (global_step + 1)
1682
- postfix_dict = {
1683
- 'loss': f'{avg_loss:.4f}',
1684
- 'lr': f'{lr:.6f}',
1685
- 'samp': f'{samples_per_sec:.1f}'
1686
- }
1687
- progress_bar.set_postfix(postfix_dict, refresh=True)
1688
- progress_bar.update(1)
1689
-
1690
- final_model_path = os.path.join(checkpoint_dir, 'final_model.pt')
1691
- torch.save(model.state_dict(), final_model_path)
1692
- print(f"Training completed after {global_step} steps. Final model saved to {final_model_path}")
1693
- writer.close()
1694
- progress_bar.close()
1695
-
1696
- def get_optimizer(model, lr=5e-4, weight_decay=0.01):
1697
- return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-6, betas=(0.9, 0.98))
1698
-
1699
- def get_scheduler(optimizer, total_steps=10000):
1700
- return torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.25, total_iters=total_steps, last_epoch=-1)
1701
-
1702
- def get_loss_fn():
1703
- return torch.nn.CrossEntropyLoss(ignore_index=0)
1704
-
1705
- def main():
1706
- token = ""
1707
- log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
1708
- os.makedirs(log_dir, exist_ok=True)
1709
- tokenizer = setup_tokenizer(token)
1710
-
1711
- param = Dimensions(
1712
- mels=128, aud_ctx=1500, aud_head=4, aud_dims=512, aud_idx=4,
1713
- vocab=40000, text_ctx=512, text_head=4, text_dims=512, text_idx=4,
1714
- act="swish", debug={}, cross_attn=True, features=["spectrogram"]
1715
- )
1716
-
1717
- dataset_config = {
1718
- "spectrogram": True, "waveforms": False, "pitch": False, "downsamples": False,
1719
- "frequency": True, "hilbert": False, "hop_length": 128, "fmin": 150, "fmax": 2000,
1720
- "n_mels": 128, "n_fft": 1024, "sampling_rate": 16000, "pad_mode": "constant",
1721
- "center": True, "power": 2.0, "window_fn": torch.hann_window, "mel_scale": "htk",
1722
- "norm": None, "normalized": False
1723
- }
1724
-
1725
- model = create_model(param)
1726
- train_dataset, test_dataset = prepare_datasets(
1727
- tokenizer=tokenizer, token=token, sanity_check=False, dataset_config=dataset_config
1728
- )
1729
-
1730
- collator = DataCollator(tokenizer=tokenizer)
1731
- train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collator, num_workers=0)
1732
- eval_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collator, num_workers=0)
1733
-
1734
- optimizer = get_optimizer(model)
1735
- scheduler = get_scheduler(optimizer)
1736
- loss_fn = get_loss_fn()
1737
-
1738
- train_and_evaluate(
1739
- model=model,
1740
- tokenizer=tokenizer,
1741
- train_loader=train_loader,
1742
- eval_loader=eval_loader,
1743
- optimizer=optimizer,
1744
- scheduler=scheduler,
1745
- loss_fn=loss_fn,
1746
- max_steps=10000,
1747
- device='cuda',
1748
- accumulation_steps=1,
1749
- clear_cache=False,
1750
- log_interval=10,
1751
- eval_interval=500,
1752
- save_interval=10000,
1753
- checkpoint_dir="./checkpoints",
1754
- log_dir=log_dir
1755
- )
1756
-
1757
- if __name__ == "__main__":
1758
- main()
1759
-