Sin2pi commited on
Commit
83bbb30
·
verified ·
1 Parent(s): 80c70a4

Delete model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +0 -1625
model_hf.py DELETED
@@ -1,1625 +0,0 @@
1
- import os
2
- import pyworld as pw
3
- import math
4
- import warnings
5
- import logging
6
- import torch
7
- import torchaudio
8
- import torch.nn.functional as F
9
- import torch.nn.init as init
10
- from torch import nn, Tensor
11
- import numpy as np
12
- import matplotlib.pyplot as plt
13
- from typing import Optional, Dict, Union, List, Tuple, Any
14
- from functools import partial
15
- from datetime import datetime
16
- from datasets import load_dataset, Audio
17
- from transformers.trainer_seq2seq import Seq2SeqTrainer
18
- from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
19
- import transformers
20
- from dataclasses import dataclass
21
- from opimizer import MaxFactor
22
- torch.backends.cudnn.allow_tf32 = True
23
- torch.backends.cuda.matmul.allow_tf32 = True
24
- torch.set_float32_matmul_precision('high')
25
- transformers.utils.logging.set_verbosity_error()
26
-
27
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
- dtype = torch.float32
29
-
30
- warnings.filterwarnings("ignore")
31
- logging.basicConfig(level=logging.ERROR)
32
-
33
- PATH = 'E:/hf'
34
- os.environ['HF_HOME'] = PATH
35
- os.environ['HF_DATASETS_CACHE'] = PATH
36
- os.environ['TORCH_HOME'] = PATH
37
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
38
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
39
-
40
- @dataclass
41
- class Dimensions:
42
- vocab: int
43
- text_ctx: int
44
- text_dims: int
45
- text_head: int
46
- text_idx: int
47
- mels: int
48
- aud_ctx: int
49
- aud_dims: int
50
- aud_head: int
51
- aud_idx: int
52
- act: str
53
- debug: List[str]
54
- cross_attn: bool
55
- features: List[str]
56
-
57
- def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
58
- title="", markers=None, marker_labels=None,
59
- show_voiced_regions=True, show_energy=False):
60
- num_plots = sum([x is not None, w is not None, p is not None, per is not None])
61
- if num_plots == 0:
62
- raise ValueError("No data to plot. Please provide at least one input tensor.")
63
- t_spans = []
64
-
65
- if w is not None:
66
- w_np = w[sample_idx].detach().cpu().numpy()
67
- if w_np.ndim > 1:
68
- w_np = w_np.squeeze()
69
- t_spans.append(len(w_np) / sr)
70
- if x is not None:
71
- x_np = x[sample_idx].detach().cpu().numpy()
72
- if x_np.shape[0] < x_np.shape[1]:
73
- x_np = x_np.T
74
- t_spans.append(x_np.shape[0] * hop_length / sr)
75
- if p is not None:
76
- p_np = p[sample_idx].detach().cpu().numpy()
77
- if p_np.ndim > 1:
78
- p_np = p_np.squeeze()
79
- t_spans.append(len(p_np) * hop_length / sr)
80
- if per is not None:
81
- per_np = per[sample_idx].detach().cpu().numpy()
82
- if per_np.ndim > 1:
83
- per_np = per_np.squeeze()
84
- t_spans.append(len(per_np) * hop_length / sr)
85
- max_t = max(t_spans) if t_spans else 0
86
- fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
87
- if num_plots == 1:
88
- axs = [axs]
89
- if show_voiced_regions and per is not None:
90
- per_np = per[sample_idx].detach().cpu().numpy()
91
- if per_np.ndim > 1:
92
- per_np = per_np.squeeze()
93
- t_per = np.arange(len(per_np)) * hop_length / sr
94
- threshold = 0.5
95
- for ax in axs:
96
- for i in range(len(per_np)-1):
97
- if per_np[i] > threshold:
98
- ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
99
- cu_ax = 0
100
- if w is not None:
101
- w_np = w[sample_idx].detach().cpu().numpy()
102
- if w_np.ndim > 1:
103
- w_np = w_np.squeeze()
104
- t = np.arange(len(w_np)) / sr
105
- axs[cu_ax].plot(t, w_np, color="tab:blue")
106
-
107
- if show_energy:
108
- frame_length = hop_length
109
- hop_length_energy = hop_length // 2
110
- energy = []
111
- for i in range(0, len(w_np)-frame_length, hop_length_energy):
112
- frame = w_np[i:i+frame_length]
113
- energy.append(np.sqrt(np.mean(frame**2)))
114
- energy = np.array(energy)
115
- energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
116
- t_energy = np.arange(len(energy)) * hop_length_energy / sr
117
- axs[cu_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
118
- axs[cu_ax].legend(loc='upper right')
119
- axs[cu_ax].set_title("Waveform")
120
- axs[cu_ax].set_ylabel("Amplitude")
121
- axs[cu_ax].set_xlim([0, max_t])
122
- axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
123
- cu_ax += 1
124
-
125
- if x is not None:
126
- x_np = x[sample_idx].detach().cpu().numpy()
127
- if x_np.shape[0] < x_np.shape[1]:
128
- x_np = x_np.T
129
- axs[cu_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
130
- extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
131
- axs[cu_ax].set_title("Spectrogram")
132
- axs[cu_ax].set_ylabel("Mel Bin")
133
- axs[cu_ax].set_xlim([0, max_t])
134
- axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
135
- cu_ax += 1
136
-
137
- if p is not None:
138
- p_np = p[sample_idx].detach().cpu().numpy()
139
- if p_np.ndim > 1:
140
- p_np = p_np.squeeze()
141
- t_p = np.arange(len(p_np)) * hop_length / sr
142
- axs[cu_ax].plot(t_p, p_np, color="tab:green")
143
- axs[cu_ax].set_title("Pitch")
144
- axs[cu_ax].set_ylabel("Frequency (Hz)")
145
- axs[cu_ax].set_xlim([0, max_t])
146
- axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
147
- axs[cu_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
148
- cu_ax += 1
149
-
150
- if per is not None:
151
- per_np = per[sample_idx].detach().cpu().numpy()
152
- if per_np.ndim > 1:
153
- per_np = per_np.squeeze()
154
- t_per = np.arange(len(per_np)) * hop_length / sr
155
- axs[cu_ax].plot(t_per, per_np, color="tab:red")
156
- axs[cu_ax].set_title("Period (Voice Activity)")
157
- axs[cu_ax].set_ylabel("periodocity")
158
- axs[cu_ax].set_xlim([0, max_t])
159
- axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
160
- axs[cu_ax].set_ylim([-0.05, 1.05])
161
- axs[cu_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
162
-
163
- if markers is not None:
164
- for i, t in enumerate(markers):
165
- label = marker_labels[i] if marker_labels and i < len(marker_labels) else None
166
- for ax in axs:
167
- ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
168
- if marker_labels:
169
- axs[0].legend(loc='upper right', fontsize='small')
170
- axs[-1].set_xlabel("t (s)")
171
- fig.suptitle(title, fontsize=16)
172
- plt.tight_layout(rect=[0, 0, 1, 0.97])
173
- plt.show()
174
- return fig
175
-
176
- def dict_to(d, device, dtype=dtype):
177
- """Because PyTorch should have this built-in but doesn't"""
178
- return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
179
- for k, v in d.items()}
180
-
181
- def exists(v):
182
- return v is not None
183
-
184
- def default(v, b):
185
- return v if exists(v) else b
186
-
187
- class Conv1d(nn.Conv1d):
188
- def _conv_forward(
189
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
190
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
191
-
192
- class Conv2d(nn.Conv2d):
193
- def _conv_forward(
194
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
195
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
196
-
197
- class Linear(nn.Module):
198
- def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
199
- super(Linear, self).__init__()
200
- self.linear = nn.Linear(in_features, out_features, bias=bias)
201
- init.xavier_uniform_(self.linear.weight)
202
- if bias:
203
- init.zeros_(self.linear.bias)
204
- def forward(self, x: Tensor) -> Tensor:
205
- return self.linear(x)
206
-
207
- class RMSNorm(nn.Module):
208
- def __init__(self, dims: Union[int, Tensor, List, Tuple],
209
- eps = 1e-8, elementwise_affine = True):
210
- super(RMSNorm, self).__init__()
211
- if isinstance(dims, int):
212
- self.normalized_shape = (dims,)
213
- else:
214
- self.normalized_shape = tuple(dims)
215
- self.eps = eps
216
- self.elementwise_affine = elementwise_affine
217
- if self.elementwise_affine:
218
- self.weight = nn.Parameter(torch.empty(self.normalized_shape))
219
- init.ones_(self.weight)
220
- else:
221
- self.register_parameter("weight", None)
222
- def forward(self, x):
223
- return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
224
-
225
- def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
226
- weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
227
- eps: float = 1e-5) -> Tensor:
228
- return F.layer_norm(x, normalized_shape, weight, bias, eps)
229
-
230
- def get_device():
231
- return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
232
-
233
- def get_dtype():
234
- return torch.float32 if torch.cuda.is_available() else torch.float64
235
-
236
- def tox():
237
- return {"device": get_device(), "dtype": get_dtype()}
238
-
239
- def sinusoids(length, channels, max_tscale=10000):
240
- assert channels % 2 == 0
241
- log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
242
- inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
243
- scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
244
- return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
245
-
246
- class rotary(nn.Module):
247
- def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=True, debug: List[str] = [], use_pbias=False):
248
- super(rotary, self).__init__()
249
-
250
- self.use_pbias = use_pbias
251
- self.dims = dims
252
- self.head = head
253
- self.head_dim = dims // head
254
- self.radii = radii
255
- self.dim = self.head_dim
256
- self.debug = debug
257
- self.counter = 0
258
- self.last_theta = None
259
-
260
- self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
261
- self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
262
-
263
- def theta_freqs(self, theta):
264
- freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
265
- freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
266
- return freqs
267
-
268
- def inverse_mel_scale_scalar(mel_freq: float) -> float:
269
- return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
270
-
271
- def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
272
- return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
273
-
274
- def mel_scale_scalar(freq: float) -> float:
275
- return 1127.0 * math.log(1.0 + freq / 700.0)
276
-
277
- def mel_scale(freq: Tensor) -> Tensor:
278
- return 1127.0 * (1.0 + freq / 700.0).log()
279
-
280
- def return_f0(self, f0=None):
281
- if f0 is not None:
282
- self.f0 = f0
283
- self.update_base(f0)
284
- return f0.squeeze(0).to(device, dtype)
285
- elif hasattr(self, 'f0') and self.f0 is not None:
286
- return self.f0.squeeze(0).to(device, dtype)
287
- return None
288
-
289
- def get_pitch_bias(self, f0):
290
- if f0 is None:
291
- return None
292
- f0_flat = f0.squeeze().float()
293
- f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
294
- f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
295
- f0_norm.unsqueeze(1)))
296
- return f0_sim.unsqueeze(0).unsqueeze(0)
297
-
298
- def f0proj(self, f0):
299
- if f0.ndim == 3:
300
- f0 = f0.squeeze(0)
301
- self.f0_proj = nn.Linear(1, self.head_dim // 2, device=device, dtype=dtype)
302
- f0 = f0.to(device, dtype)
303
- f0 = self.f0_proj(f0.unsqueeze(-1))
304
- if f0.ndim == 3:
305
- f0 = f0.squeeze(0)
306
- return f0.to(device=device, dtype=dtype)
307
-
308
- def align_f0(self, ctx, f0):
309
- f0 = self.f0proj(f0)
310
- if f0.dim() == 3:
311
- batch, length, dims = f0.shape
312
- if length == ctx:
313
- return f0
314
- frames = length / ctx
315
- idx = torch.arange(ctx, device=f0.device)
316
- idx = (idx * frames).long().clamp(0, length - 1)
317
- return f0[:, idx, :]
318
- if f0.dim() == 1:
319
- length = f0.shape[0]
320
- if length == ctx:
321
- return f0
322
- frames = length / ctx
323
- idx = torch.arange(ctx, device=f0.device)
324
- idx = (idx * frames).long().clamp(0, length - 1)
325
- return f0[idx]
326
- else:
327
- length, dims = f0.shape
328
- if length == ctx:
329
- return f0
330
- frames = length / ctx
331
- idx = torch.arange(ctx, device=f0.device)
332
- idx = (idx * frames).long().clamp(0, length - 1)
333
- return f0[idx, :]
334
-
335
- def forward(self, x=None, enc=None, layer=None, feature_type="audio") -> Tensor:
336
- f0 = enc.get("f0") if enc is not None else None
337
- if isinstance(x, int):
338
- ctx = x
339
- elif isinstance(x, torch.Tensor) and x.ndim == 2:
340
- batch, ctx = x.shape
341
- elif isinstance(x, torch.Tensor) and x.ndim == 3:
342
- batch, ctx, dims = x.shape
343
- else:
344
- batch, head, ctx, head_dim = x.shape
345
- t = torch.arange(ctx, device=device, dtype=dtype)
346
-
347
- if f0 is not None and f0.dim() == 2:
348
- if f0.shape[0] == 1:
349
- f0 = f0.squeeze(0)
350
- else:
351
- f0 = f0.view(-1)
352
-
353
- if f0 is not None:
354
- f0_mean = f0.mean()
355
- theta = f0_mean + self.theta
356
- else:
357
- theta = self.theta
358
-
359
- freqs = self.theta_freqs(theta)
360
-
361
- freqs = t[:, None] * freqs[None, :]
362
-
363
- if self.radii and f0 is not None:
364
- radius = f0.to(device, dtype)
365
- L = radius.shape[0]
366
- if L != ctx:
367
- F = L / ctx
368
- idx = torch.arange(ctx, device=f0.device)
369
- idx = (idx * F).long().clamp(0, L - 1)
370
- radius = radius[idx]
371
- freqs = torch.polar(radius.unsqueeze(-1).expand_as(freqs), freqs)
372
- else:
373
- freqs = torch.polar(torch.ones_like(freqs), freqs)
374
-
375
- if "radius" in self.debug and self.counter % 100 == 0:
376
- theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
377
- print(f" [{layer}] [Radius] {radius.shape} {radius.mean():.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
378
-
379
- if "theta" in self.debug and self.counter % 100 == 0:
380
- if self.last_theta is None or abs(self.last_theta - theta.item()) > 1.0:
381
- self.last_theta = theta.item()
382
- print(f"[Theta] {self.last_theta:.2f}")
383
-
384
- self.counter += 1
385
- return freqs.unsqueeze(0)
386
-
387
- @staticmethod
388
- def apply_rotary(x, freqs):
389
- x1 = x[..., :freqs.shape[-1]*2]
390
- x2 = x[..., freqs.shape[-1]*2:]
391
- orig_shape = x1.shape
392
- if x1.ndim == 2:
393
- x1 = x1.unsqueeze(0)
394
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
395
- x1 = torch.view_as_complex(x1) * freqs
396
- x1 = torch.view_as_real(x1).flatten(-2)
397
- x1 = x1.view(orig_shape)
398
- return torch.cat([x1.type_as(x), x2], dim=-1)
399
-
400
- class MultiheadA(nn.Module):
401
- _seen = set()
402
- rbf = False
403
- def __init__(self, dims: int, head: int, rotary_emb: bool = True,
404
- zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
405
- super(MultiheadA, self).__init__()
406
-
407
- self.dims = dims
408
- self.head = head
409
- self.head_dim = dims // head
410
- self.debug = debug
411
- self.counter = 0
412
-
413
- self.q = nn.Linear(dims, dims).to(device, dtype)
414
- self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
415
- self.v = nn.Linear(dims, dims).to(device, dtype)
416
- self.o = nn.Linear(dims, dims).to(device, dtype)
417
-
418
- self.pad_token = 0
419
- self.rotary_emb = rotary_emb
420
- self.minz = minz
421
- self.maxz = maxz
422
- self.zero_val = zero_val
423
- self.optim_attn = optim_attn
424
- self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
425
-
426
- if rotary_emb:
427
- self.rope = rotary(
428
- dims=dims,
429
- head=head,
430
- debug=debug,
431
- radii=True,
432
- )
433
- else:
434
- self.rope = None
435
-
436
- def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
437
- q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
438
- k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
439
- qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
440
- qk_cosine = qk_cosine + mask
441
- weights = F.softmax(qk_cosine, dim=-1)
442
- out = torch.matmul(weights, v)
443
- return out
444
-
445
- def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
446
- scale = (self.dims // self.head) ** -0.25
447
- dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
448
- if rbf_ratio <= 0.0:
449
- return dot_scores
450
- q_norm = q.pow(2).sum(dim=-1, keepdim=True)
451
- k_norm = k.pow(2).sum(dim=-1, keepdim=True)
452
- qk = torch.matmul(q, k.transpose(-1, -2))
453
- dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
454
- rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
455
- return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
456
-
457
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
458
-
459
- x = x.to(device, dtype)
460
- if xa is not None:
461
- xa = xa.to(device, dtype)
462
- scale = (self.dims // self.head) ** -0.25
463
-
464
- z = default(xa, x).to(device, dtype)
465
- q = self.q(x)
466
- k = self.k(z)
467
- v = self.v(z)
468
-
469
- if self.rotary_emb:
470
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
471
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
472
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
473
- q2 = q.shape[2]
474
- k2 = k.shape[2]
475
-
476
- q = self.rope.apply_rotary(q, (self.rope(q2, enc=enc, layer=layer)))
477
- k = self.rope.apply_rotary(k, (self.rope(k2, enc=enc, layer=layer)))
478
- else:
479
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
480
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
481
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
482
- batch, head, ctx, head_dim = q.shape
483
-
484
- if self.rbf:
485
- qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
486
-
487
- qk = (q * scale) @ (k * scale).transpose(-1, -2)
488
- if self.rope.use_pbias:
489
- f0 = enc.get("f0", None) if enc is not None else None
490
- pbias = self.rope.use_pbias(f0)
491
- if pbias is not None:
492
- qk = qk + pbias[:,:,:q2,:q2]
493
- token_ids = k[:, :, :, 0]
494
- zscale = torch.ones_like(token_ids)
495
- fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
496
- zscale[token_ids.float() == self.pad_token] = fzero
497
-
498
- if mask is not None:
499
- mask = mask[:q2, :q2]
500
- qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
501
- qk = qk * zscale.unsqueeze(-2)
502
- w = F.softmax(qk, dim=-1).to(q.dtype)
503
- wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
504
-
505
- if "multihead" in self.debug and self.counter % 100 == 0:
506
- print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
507
- self.counter += 1
508
- return self.o(wv), qk
509
-
510
- class t_gate(nn.Module):
511
- def __init__(self, dims, num_types=4):
512
- super().__init__()
513
- self.gate_projections = nn.ModuleList([
514
- nn.Sequential(Linear(dims, 1), nn.Sigmoid())
515
- for _ in range(num_types)])
516
- self.type_classifier = nn.Sequential(
517
- Linear(dims, num_types),
518
- nn.Softmax(dim=-1))
519
- def forward(self, x):
520
- type_probs = self.type_classifier(x)
521
- gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
522
- comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
523
- return comb_gate
524
-
525
- class m_gate(nn.Module):
526
- def __init__(self, dims, mem_size=64):
527
- super().__init__()
528
- self.m_key = nn.Parameter(torch.randn(mem_size, dims))
529
- self.m_val = nn.Parameter(torch.randn(mem_size, 1))
530
- self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
531
-
532
- def forward(self, x):
533
- d_gate = torch.sigmoid(self.gate_proj(x))
534
- attention = torch.matmul(x, self.m_key.transpose(0, 1))
535
- attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
536
- m_gate = torch.matmul(attention, self.m_val)
537
- m_gate = torch.sigmoid(m_gate)
538
- return 0.5 * (d_gate + m_gate)
539
-
540
- class c_gate(nn.Module):
541
- def __init__(self, dims):
542
- super().__init__()
543
- self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
544
- self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
545
- self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
546
- self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
547
- self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
548
- self.integ = Linear(dims*5, dims)
549
-
550
- def forward(self, x, features):
551
- s_feat = features.get("spectrogram", x)
552
- w_feat = features.get("waveform", x)
553
- p_feat = features.get("pitch", x)
554
- e_feat = features.get("envelope", x)
555
- ph_feat = features.get("phase", x)
556
- s = self.s_gate(x) * s_feat
557
- w = self.w_gate(x) * w_feat
558
- p = self.p_gate(x) * p_feat
559
- e = self.e_gate(x) * e_feat
560
- ph = self.ph_gate(x) * ph_feat
561
- comb = torch.cat([s, w, p, e, ph], dim=-1)
562
- return self.integ(comb)
563
-
564
- class Residual(nn.Module):
565
- _seen = set()
566
- def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
567
- tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
568
- super().__init__()
569
-
570
- self.dims = dims
571
- self.head = head
572
- self.ctx = ctx
573
- self.head_dim = dims // head
574
- self.cross_attn = cross_attn
575
- self.features = features
576
- self.debug = debug
577
- self.counter = 0
578
- self.dropout = 0.01
579
-
580
- self.t_gate = tgate
581
- self.m_gate = mgate
582
- self.c_gate = cgate
583
- self.do_blend = "no_blend" not in self.debug
584
- self.blend = nn.Parameter(torch.tensor(0.5))
585
- self.skip_gates = True if "skip_gates" in self.debug else False
586
-
587
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
588
- "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
589
- "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
590
- "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
591
- act_fn = act_map.get(act, nn.GELU())
592
-
593
- self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug)
594
- self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None)
595
-
596
- mlp = dims * 4
597
- self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
598
-
599
- self.t_gate = t_gate(dims=dims, num_types=4) if t_gate else None
600
- self.m_gate = m_gate(dims=dims, mem_size=mem_size) if m_gate else None
601
- self.c_gate = c_gate(dims=dims) if cgate else None
602
-
603
- self.lna = RMSNorm(dims)
604
- self.lnb = RMSNorm(dims) if cross_attn else None
605
- self.lnc = RMSNorm(dims)
606
-
607
- if not any([t_gate, m_gate, c_gate]):
608
- self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
609
-
610
- def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
611
-
612
- x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
613
- xb = x
614
- if self.attnb and xa is not None:
615
- x = x + self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
616
-
617
- if self.do_blend:
618
- b = torch.sigmoid(self.blend)
619
- x = b * xb + (1 - b) * x
620
-
621
- if self.skip_gates:
622
- x = x + self.mlp(self.lnc(x))
623
- else:
624
- normx = self.lnc(x)
625
- mlp_out = self.mlp(normx)
626
-
627
- if self.t_gate:
628
- gate = self.t_gate(normx)
629
- x = x + gate * mlp_out
630
-
631
- elif self.m_gate:
632
- gate = self.m_gate(normx)
633
- x = x + gate * mlp_out
634
-
635
- elif self.c_gate:
636
- gate_output = self.c_gate(normx, self.features)
637
- x = x + gate_output
638
-
639
- else:
640
- if hasattr(self, 'mlp_gate'):
641
- mlp_gate = self.mlp_gate(normx)
642
- x = x + mlp_gate * mlp_out
643
- else:
644
- x = x + mlp_out
645
-
646
- if "residual" in self.debug and self.counter % 100 == 0:
647
- print(f"Step {self.counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
648
- if self.t_gate:
649
- print(f"Step {self.counter}: Using t_gate: {self.t_gate}")
650
- elif self.m_gate:
651
- print(f"Step {self.counter}: Using m_gate: {self.m_gate}")
652
- elif self.c_gate:
653
- print(f"Step {self.counter}: Using c_gate: {self.c_gate}")
654
- else:
655
- print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
656
- self.counter += 1
657
- return x
658
-
659
- class FEncoder(nn.Module):
660
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
661
- super().__init__()
662
-
663
- self.head = head
664
- self.head_dim = dims // head
665
- self.dropout = 0.01
666
- self.use_rope = use_rope
667
- self.dims = dims
668
-
669
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
670
- act_fn = act_map.get(act, nn.GELU())
671
-
672
- self.encoder = nn.Sequential(
673
- Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
674
- Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
675
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
676
-
677
- if use_rope:
678
- if spec_shape is not None:
679
- self.rope = rotary(
680
- dims=self.head_dim,
681
- use_2d_axial=True,
682
- spec_shape=spec_shape, debug=[])
683
- else:
684
- self.rope = rotary(
685
- dims=self.head_dim,
686
- use_2d_axial=False, debug=[])
687
- else:
688
- self.rope = None
689
- self.positional = lambda length: sinusoids(length, dims)
690
-
691
- self.norm = RMSNorm(dims)
692
- self._norm = RMSNorm(dims)
693
-
694
- def apply_rope_to_features(self, x, layer=None, feature_type="audio"):
695
- if feature_type in ["envelope", "phase"]:
696
- feature_type = "spectrogram"
697
- batch, ctx, dims = x.shape
698
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
699
- if feature_type == "spectrogram" and hasattr(self.rope, 'use_2d_axial') and self.rope.use_2d_axial:
700
- rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
701
- else:
702
- rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
703
- x = self.rope.apply_rotary(x, rope_freqs)
704
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
705
- return x
706
-
707
- def forward(self, x, enc=None, layer=None, feature_type="audio"):
708
- x = self.encoder(x).permute(0, 2, 1)
709
- if self.use_rope:
710
- x = self.apply_rope_to_features(x, layer=layer, feature_type=feature_type)
711
- else:
712
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
713
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
714
- x = self._norm(x)
715
- return x
716
-
717
- class WEncoder(nn.Module):
718
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
719
- super().__init__()
720
-
721
- self.head = head
722
- self.head_dim = dims // head
723
- self.dropout = 0.01
724
- self.use_rope = use_rope
725
- self.dims = dims
726
-
727
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
728
- act_fn = act_map.get(act, nn.GELU())
729
-
730
- self.downsample = nn.Sequential(
731
- Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
732
- Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn,
733
- Conv1d(dims//4, dims, kernel_size=9, stride=5, padding=4), act_fn)
734
-
735
- self.encoder = nn.Sequential(
736
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn,
737
- Conv1d(dims, dims, kernel_size=1), act_fn)
738
- if use_rope:
739
- self.rope = rotary(
740
- dims=self.head_dim,
741
- use_2d_axial=False,
742
- theta=50.0, debug=[])
743
- else:
744
- self.rope = None
745
- self.positional = lambda length: sinusoids(length, dims)
746
- self.norm = RMSNorm(dims)
747
-
748
- def apply_rope_to_features(self, x, layer=None):
749
- if not self.use_rope or self.rope is None:
750
- return x
751
- batch, ctx, dims = x.shape
752
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
753
- rope_freqs = self.rope(ctx, layer=layer, input_type="waveform")
754
- x = self.rope.apply_rotary(x, rope_freqs)
755
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
756
- return x
757
-
758
- def forward(self, x, enc=None, layer=None, feature_type="waveform"):
759
- x = self.downsample(x)
760
- x = self.encoder(x)
761
- x = x.permute(0, 2, 1)
762
- if self.use_rope:
763
- x = self.apply_rope_to_features(x, layer=layer)
764
- else:
765
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
766
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
767
- return self.norm(x)
768
-
769
- class PEncoder(nn.Module):
770
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
771
- super().__init__()
772
-
773
- self.head = head
774
- self.head_dim = dims // head
775
- self.dropout = 0.01
776
- self.use_rope = use_rope
777
- self.dims = dims
778
-
779
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
780
- act_fn = act_map.get(act, nn.GELU())
781
-
782
- self.encoder = nn.Sequential(
783
- Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
784
- Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
785
- Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2), act_fn)
786
-
787
- if use_rope:
788
- self.rope = rotary(
789
- dims=self.head_dim,
790
- use_2d_axial=False,
791
- theta=100.0, debug=[])
792
- else:
793
- self.rope = None
794
- self.positional = lambda length: sinusoids(length, dims)
795
- self.norm = RMSNorm(dims)
796
-
797
- def apply_rope_to_features(self, x, layer=None):
798
- if not self.use_rope or self.rope is None:
799
- return x
800
- batch, ctx, dims = x.shape
801
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
802
- rope_freqs = self.rope(ctx, layer=layer, input_type="pitch")
803
- x = self.rope.apply_rotary(x, rope_freqs)
804
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
805
- return x
806
-
807
- def forward(self, x, enc=None, layer=None, feature_type="pitch"):
808
- x = self.encoder(x).permute(0, 2, 1)
809
- if self.use_rope:
810
- x = self.apply_rope_to_features(x, layer=layer)
811
- else:
812
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
813
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
814
- x = self.norm(x)
815
- return x
816
-
817
- class AudioEncoder(nn.Module):
818
- _seen = set()
819
- def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
820
- super(AudioEncoder, self).__init__()
821
-
822
- self.dims = dims
823
- self.head = head
824
- self.ctx = ctx
825
- self.head_dim = dims // head
826
- self.debug = debug
827
- self.counter = 0
828
- self.features = features
829
- self.dropout = 0.01
830
-
831
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),"tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
832
- act_fn = act_map.get(act, nn.GELU())
833
-
834
- if features == ["spectrogram", "waveform", "pitch"]:
835
- cgate=True
836
- else:
837
- cgate = False
838
-
839
- self.blocks = nn.ModuleDict({
840
-
841
- "spectrogram": nn.ModuleList(
842
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
843
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
844
- if "spectrogram" in features else None),
845
-
846
- "waveform": nn.ModuleList(
847
- [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
848
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
849
- if "waveform" in features else None),
850
-
851
- "pitch": nn.ModuleList(
852
- [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
853
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
854
- if "pitch" in features else None),
855
-
856
- "envelope": nn.ModuleList(
857
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
858
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
859
- if "envelope" in features else None),
860
-
861
- "phase": nn.ModuleList(
862
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
863
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
864
- if "phase" in features else None),
865
- })
866
-
867
- def forward(self, enc, layer="encoder"):
868
- enc = dict_to(enc, device, dtype)
869
- out = {}
870
- out.update(enc)
871
-
872
- for f in self.features:
873
- if f in enc and f in self.blocks:
874
- x = enc[f]
875
- for block in self.blocks[f]:
876
- x = block(x, enc=enc, layer=layer)
877
- out[f] = x
878
-
879
- if self.counter < 1 and "encoder" in self.debug:
880
- s = enc.get("spectrogram")
881
- w = enc.get("waveform")
882
- p = default(enc.get("pitch"), enc.get("f0"))
883
- plot_waveform(x=s, w=w, p=p, hop_length=128)
884
- shapes = {k: v.shape for k, v in enc.items()}
885
- print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
886
- self.counter += 1
887
- return out
888
-
889
- class TextDecoder(nn.Module):
890
- def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
891
- debug: List[str], features: List[str]):
892
- super(TextDecoder, self).__init__()
893
-
894
- self.ctx = ctx
895
- self.dims = dims
896
- self.head = head
897
- self.head_dim = dims // head
898
- self.debug = debug
899
- self.counter = 0
900
- self.dropout = 0.01
901
- self.features = features
902
- self.do_blend = "no_blend" not in self.debug
903
- self.sequential = "sequential" in self.debug
904
-
905
- self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
906
- with torch.no_grad():
907
- self.token.weight[0].zero_()
908
- self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
909
-
910
- self.block = nn.ModuleList([
911
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
912
- for _ in range(layer)])
913
-
914
- self.blocks = nn.ModuleDict({
915
- f: nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
916
- for _ in range(layer)]) for f in features})
917
-
918
- self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
919
- self.ln_dec = RMSNorm(dims)
920
-
921
- mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
922
- self.register_buffer("mask", mask, persistent=False)
923
-
924
- def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
925
-
926
- if order is None:
927
- order = self.features
928
-
929
- mask = self.mask[:x.shape[1], :x.shape[1]]
930
- x = self.token(x) + self.positional[:x.shape[1]]
931
- x = F.dropout(x, p=self.dropout, training=self.training)
932
-
933
- for block in self.block:
934
- x = block(x, xa=None, mask=mask, enc=None, layer=layer)
935
-
936
- for f in order:
937
- if f in enc:
938
- xa = enc[f]
939
- for block in self.blocks[f]:
940
- out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
941
-
942
- if self.sequential:
943
- x = out
944
- else:
945
- a = torch.sigmoid(self.blend[f])
946
- x = a * out + (1 - a) * x
947
-
948
- if self.counter < 1 and "decoder" in self.debug:
949
- shapes = {k: v.shape for k, v in enc.items()}
950
- print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}: shapes: {shapes}")
951
- self.counter += 1
952
-
953
- x = self.ln_dec(x)
954
- return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
955
-
956
- class Echo(nn.Module):
957
- def __init__(self, param: Dimensions):
958
- super().__init__()
959
- self.param = param
960
-
961
- self.encoder = AudioEncoder(
962
- mels=param.mels,
963
- ctx=param.aud_ctx,
964
- dims=param.aud_dims,
965
- head=param.aud_head,
966
- layer=param.aud_idx,
967
- act=param.act,
968
- debug=param.debug,
969
- features=param.features,
970
- )
971
-
972
- self.decoder = TextDecoder(
973
- vocab=param.vocab,
974
- ctx=param.text_ctx,
975
- dims=param.text_dims,
976
- head=param.text_head,
977
- layer=param.text_idx,
978
- cross_attn=param.cross_attn,
979
- debug=param.debug,
980
- features=param.features,
981
- )
982
-
983
- def forward(self,
984
- labels=None,
985
- input_ids=None,
986
- waveform: Optional[torch.Tensor]=None,
987
- spectrogram: torch.Tensor=None,
988
- pitch: Optional[torch.Tensor]=None,
989
- f0: Optional[torch.Tensor]=None,
990
- envelope: Optional[torch.Tensor]=None,
991
- phase: Optional[torch.Tensor]=None,
992
- ) -> Dict[str, torch.Tensor]:
993
-
994
- encoder_inputs = {}
995
- if spectrogram is not None:
996
- encoder_inputs["spectrogram"] = spectrogram
997
- if waveform is not None:
998
- encoder_inputs["waveform"] = waveform
999
- if pitch is not None:
1000
- encoder_inputs["pitch"] = pitch
1001
- if envelope is not None:
1002
- encoder_inputs["envelope"] = envelope
1003
- if phase is not None:
1004
- encoder_inputs["phase"] = phase
1005
- if f0 is not None:
1006
- encoder_inputs["f0"] = f0
1007
-
1008
- encoder_outputs = self.encoder(encoder_inputs)
1009
- logits = self.decoder(input_ids, encoder_outputs)
1010
-
1011
- loss = None
1012
- if labels is not None:
1013
- loss = F.cross_entropy(
1014
- logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
1015
-
1016
- return {"logits": logits, "loss": loss}
1017
-
1018
- @property
1019
- def device(self):
1020
- return next(self.parameters()).device
1021
- @property
1022
- def dtype(self):
1023
- return next(self.parameters()).dtype
1024
-
1025
- def _init_weights(self, module):
1026
- std = 0.02
1027
- self.init_counts = {
1028
- "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
1029
- "Conv2d": 0, "SEBlock": 0, "TextDecoder": 0, "AudioEncoder": 0,
1030
- "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
1031
- "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
1032
- "WEncoder": 0, "PEncoder": 0}
1033
-
1034
- for name, module in self.named_modules():
1035
- if isinstance(module, RMSNorm):
1036
- nn.init.ones_(module.weight)
1037
- self.init_counts["RMSNorm"] += 1
1038
- elif isinstance(module, nn.Linear):
1039
- if module.weight is not None:
1040
- nn.init.xavier_uniform_(module.weight)
1041
- if module.bias is not None:
1042
- nn.init.zeros_(module.bias)
1043
- self.init_counts["Linear"] += 1
1044
- elif isinstance(module, Conv1d):
1045
- nn.init.normal_(module.weight, mean=0.0, std=std)
1046
- if module.bias is not None:
1047
- nn.init.zeros_(module.bias)
1048
- self.init_counts["Conv1d"] += 1
1049
- elif isinstance(module, Conv2d):
1050
- nn.init.normal_(module.weight, mean=0.0, std=std)
1051
- if module.bias is not None:
1052
- nn.init.zeros_(module.bias)
1053
- self.init_counts["Conv2d"] += 1
1054
- elif isinstance(module, MultiheadA):
1055
-
1056
- self.init_counts["MultiheadA"] += 1
1057
- elif isinstance(module, TextDecoder):
1058
- self.init_counts["TextDecoder"] += 1
1059
- elif isinstance(module, AudioEncoder):
1060
- self.init_counts["AudioEncoder"] += 1
1061
- elif isinstance(module, Residual):
1062
- self.init_counts["Residual"] += 1
1063
-
1064
- def init_weights(self):
1065
- print("Initializing model weights...")
1066
- self.apply(self._init_weights)
1067
- print("Initialization summary:")
1068
- for module_type, count in self.init_counts.items():
1069
- if count > 0:
1070
- print(f"{module_type}: {count}")
1071
-
1072
- @dataclass
1073
- class DataCollator:
1074
- tokenizer: Any
1075
-
1076
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1077
- all_keys = set()
1078
- for f in features:
1079
- all_keys.update(f.keys())
1080
- batch = {}
1081
- pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1082
- bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1083
- eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
1084
-
1085
- for key in all_keys:
1086
- if key == "label":
1087
- labels_list = [f["label"] for f in features]
1088
- max_len = max(len(l) for l in labels_list)
1089
- all_ids, all_labels = [], []
1090
- for label in labels_list:
1091
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1092
- decoder_input = [bos_token_id] + label_list
1093
- label_eos = label_list + [eos_token_id]
1094
- input_len = max_len + 1 - len(decoder_input)
1095
- label_len = max_len + 1 - len(label_eos)
1096
- padded_input = decoder_input + [pad_token_id] * input_len
1097
- padded_labels = label_eos + [pad_token_id] * label_len
1098
- all_ids.append(padded_input)
1099
- all_labels.append(padded_labels)
1100
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1101
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1102
- elif key in ["spectrogram", "waveform", "pitch", "f0", "envelope", "phase"]:
1103
- items = [f[key] for f in features if key in f]
1104
- max_len = max(item.shape[-1] for item in items)
1105
- padded = []
1106
- for item in items:
1107
- pad_width = max_len - item.shape[-1]
1108
- if pad_width > 0:
1109
- pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1110
- else:
1111
- pad_item = item
1112
- padded.append(pad_item)
1113
- batch[key] = torch.stack(padded)
1114
- if key == "spectrogram":
1115
- batch["spectrogram"] = batch[key]
1116
- return batch
1117
-
1118
- def hilbert_transform(x):
1119
- N = x.shape[-1]
1120
- xf = torch.fft.rfft(x)
1121
- h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
1122
- if N % 2 == 0:
1123
- h[0] = h[N//2] = 1
1124
- h[1:N//2] = 2
1125
- else:
1126
- h[0] = 1
1127
- h[1:(N+1)//2] = 2
1128
- return torch.fft.irfft(xf * h, n=N)
1129
-
1130
- def analytic_signal(x):
1131
- return x + 1j * hilbert_transform(x)
1132
-
1133
- def hilbert_transform_2d(x, dim=-1):
1134
- N = x.shape[dim]
1135
- if dim == -1 or dim == len(x.shape) - 1:
1136
- xf = torch.fft.rfft(x)
1137
- else:
1138
- xf = torch.fft.rfft(x, dim=dim)
1139
- h_shape = [1] * len(x.shape)
1140
- h_shape[dim] = N // 2 + 1
1141
- h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
1142
- if dim == -1 or dim == len(x.shape) - 1:
1143
- if N % 2 == 0:
1144
- h[..., 0] = h[..., -1] = 1
1145
- h[..., 1:-1] = 2
1146
- else:
1147
- h[..., 0] = 1
1148
- h[..., 1:] = 2
1149
- else:
1150
- pass
1151
- return torch.fft.irfft(xf * h, n=N, dim=dim)
1152
-
1153
- def hilbert_transform_true_2d(x):
1154
- xf = torch.fft.rfft2(x)
1155
- h1, h2 = torch.meshgrid(
1156
- torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
1157
- torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
1158
- indexing='ij')
1159
- h = -1j / (math.pi * (h1 + 1j*h2))
1160
- h[0, 0] = 0
1161
- return torch.fft.irfft2(xf * h.to(x.device))
1162
-
1163
- def process_spectrogram_with_hilbert(spec):
1164
- analytic = spec + 1j * hilbert_transform(spec)
1165
- envelope = torch.abs(analytic)
1166
- phase = torch.angle(analytic)
1167
- return envelope, phase
1168
-
1169
- def load_wave(wave_data, sample_rate):
1170
- if isinstance(wave_data, str):
1171
- waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1172
- elif isinstance(wave_data, dict):
1173
- waveform = torch.tensor(data=wave_data["array"]).float()
1174
- sr = wave_data["sampling_rate"]
1175
- else:
1176
- raise TypeError("Invalid wave_data format.")
1177
-
1178
- if waveform.dim() == 1:
1179
- waveform = waveform.unsqueeze(0)
1180
-
1181
- if sr != sample_rate:
1182
- original_length = waveform.shape[1]
1183
- target_length = int(original_length * (sample_rate / sr))
1184
-
1185
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1186
- waveform = resampler(waveform)
1187
-
1188
- return waveform.flatten()
1189
-
1190
- def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
1191
- hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
1192
- pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
1193
- norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
1194
-
1195
- audio = batch["audio"]
1196
- sampling_rate = audio["sampling_rate"]
1197
- sr = audio["sampling_rate"]
1198
- wav = load_wave(wave_data=audio, sample_rate=sr)
1199
-
1200
- if spectrogram:
1201
- transform = torchaudio.transforms.MelSpectrogram(
1202
- f_max=fmax,
1203
- f_min=fmin,
1204
- n_mels=n_mels,
1205
- sample_rate=sr,
1206
- n_fft=n_fft,
1207
- hop_length=hop_length,
1208
- norm=norm,
1209
- normalized=normalized,
1210
- power=power,
1211
- center=center,
1212
- mel_scale=mel_scale,
1213
- window_fn=window_fn,
1214
- pad_mode=pad_mode)
1215
-
1216
- mel_spectrogram = transform(wav)
1217
- log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1218
- log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1219
- spec = (log_mel + 4.0) / 4.0
1220
- spec = torch.tensor(spec)
1221
- batch["spectrogram"] = spec
1222
-
1223
- if hilbert:
1224
- batch["envelope"] = process_spectrogram_with_hilbert(spec)[0]
1225
- batch["phase"] = process_spectrogram_with_hilbert(spec)[1]
1226
-
1227
- wav_1d = wav.unsqueeze(0)
1228
-
1229
- if waveforms:
1230
- batch["waveform"] = wav_1d
1231
-
1232
- if pitch:
1233
- wav_np = wav.numpy().astype(np.float64)
1234
- f0, t = pw.dio(wav_np, sampling_rate,
1235
- frame_period=hop_length/sampling_rate*1000)
1236
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1237
- f0 = torch.from_numpy(f0)
1238
- batch["pitch"] = f0.unsqueeze(0)
1239
-
1240
- if frequency:
1241
- wav_np = wav.numpy().astype(np.float64)
1242
- f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000)
1243
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1244
- f0 = torch.from_numpy(f0)
1245
- batch["f0"] = f0
1246
-
1247
- if spectrogram and waveforms and pitch:
1248
- spec_mean = batch["spectrogram"].mean()
1249
- spec_std = batch["spectrogram"].std() + 1e-6
1250
- batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
1251
-
1252
- wav_mean = batch["waveform"].mean()
1253
- wav_std = batch["waveform"].std() + 1e-6
1254
- batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
1255
-
1256
- if batch["pitch"].max() > 1.0:
1257
- pitch_min = 50.0
1258
- pitch_max = 500.0
1259
- batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
1260
-
1261
- batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1262
- return batch
1263
-
1264
- def levenshtein(reference_words, hypothesis_words):
1265
-
1266
- m, n = len(reference_words), len(hypothesis_words)
1267
- dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)]
1268
-
1269
- for i in range(m+1):
1270
- dist_matrix[i][0] = i
1271
- for j in range(n+1):
1272
- dist_matrix[0][j] = j
1273
-
1274
- for i in range(1, m+1):
1275
- for j in range(1, n+1):
1276
- if reference_words[i-1] == hypothesis_words[j-1]:
1277
- dist_matrix[i][j] = dist_matrix[i-1][j-1]
1278
- else:
1279
- substitution = dist_matrix[i-1][j-1] + 1
1280
- insertion = dist_matrix[i][j-1] + 1
1281
- deletion = dist_matrix[i-1][j] + 1
1282
- dist_matrix[i][j] = min(substitution, insertion, deletion)
1283
-
1284
- return dist_matrix[m][n]
1285
-
1286
- def wer_batch(references, hypotheses):
1287
-
1288
- total_errors = 0
1289
- total_words = 0
1290
- for ref, hyp in zip(references, hypotheses):
1291
- ref_words = ref.lower().split()
1292
- errors = levenshtein(ref_words, hyp.lower().split())
1293
- total_errors += errors
1294
- total_words += len(ref_words)
1295
- return (total_errors / total_words) * 100 if total_words > 0 else 0.0
1296
-
1297
- def compute_metrics(pred, compute_result: bool = True, print_pred: bool = False, num_samples: int = 0, tokenizer = None, model = None):
1298
-
1299
- pred_ids = pred.predictions
1300
- label_ids = pred.label_ids
1301
-
1302
- if isinstance(pred_ids, tuple):
1303
- pred_ids = pred_ids[0]
1304
- else:
1305
- pred_ids = pred_ids
1306
- if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
1307
- if not isinstance(pred_ids, torch.Tensor):
1308
- pred_ids = torch.tensor(pred_ids)
1309
- pred_ids = pred_ids.argmax(dim=-1)
1310
-
1311
- pred_ids = pred_ids.tolist()
1312
- label_ids = label_ids.tolist()
1313
-
1314
- pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1315
- label_ids = [[pad_token_id if token == -100 else token for token in seq] for seq in label_ids]
1316
-
1317
- if print_pred:
1318
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1319
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1320
- for i in range(min(num_samples, len(pred_str))):
1321
- print(f"Preds: {pred_str[i]}")
1322
- print(f"Label: {label_str[i]}")
1323
- print(f"Preds: {pred_ids[i]}")
1324
- print(f"Label: {label_ids[i]}")
1325
- print("--------------------------------")
1326
-
1327
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1328
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1329
- wer = wer_batch(label_str, pred_str)
1330
-
1331
- if model is None:
1332
- global global_model
1333
- if 'global_model' in globals():
1334
- model = global_model
1335
-
1336
- if model is not None:
1337
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1338
- if trainable_params > 0:
1339
- efficiency_score = (100 - wer) / trainable_params
1340
- else:
1341
- print("Warning: Zero trainable parameters detected")
1342
- efficiency_score = 0.0
1343
- else:
1344
- print("Warning: Model not available for parameter counting")
1345
- trainable_params = 0.0
1346
- efficiency_score = 0.0
1347
-
1348
- if hasattr(wer, "item"):
1349
- wer = wer.item()
1350
-
1351
- metrics = {
1352
- "wer": float(wer),
1353
- "trainable_params_M": float(trainable_params),
1354
- "efficiency_score": float(efficiency_score),
1355
- }
1356
- return metrics
1357
-
1358
- logger = logging.getLogger(__name__)
1359
-
1360
- def create_model(param: Dimensions) -> Echo:
1361
- model = Echo(param).to('cuda')
1362
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1363
- total_params = sum(p.numel() for p in model.parameters())
1364
- logger.info(f"Trainable parameters: {trainable_params:,}")
1365
- logger.info(f"Total parameters: {total_params:,}")
1366
- print(f"Trainable parameters: {trainable_params:,}")
1367
- print(f"Total parameters: {total_params:,}")
1368
-
1369
- return model
1370
-
1371
- def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
1372
- from tokenizers import Tokenizer
1373
- tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
1374
- orig_encode = tokenizer.encode
1375
- def enc(text, add_special_tokens=True):
1376
- ids = orig_encode(text).ids
1377
- if not add_special_tokens:
1378
- sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1379
- ids = [id for id in ids if id not in sp_ids]
1380
- return ids
1381
-
1382
- def bdec(ids_list, skip_special_tokens=True):
1383
- results = []
1384
- for ids in ids_list:
1385
- if skip_special_tokens:
1386
- ids = [id for id in ids if id not in [0, 1, 2]]
1387
- results.append(tokenizer.decode(ids))
1388
- return results
1389
-
1390
- def save_pretrained(save_dir):
1391
- os.makedirs(save_dir, exist_ok=True)
1392
- tokenizer.save(f"{save_dir}/tokenizer.json")
1393
- tokenizer.encode = enc
1394
- tokenizer.batch_decode = bdec
1395
- tokenizer.save_pretrained = save_pretrained
1396
- tokenizer.pad_token_id = 0
1397
- tokenizer.bos_token_id = 1
1398
- tokenizer.eos_token_id = 2
1399
- return tokenizer
1400
-
1401
- def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
1402
- if dataset_config is None:
1403
- dataset_config = {
1404
- "spectrogram": True,
1405
- "waveforms": True,
1406
- "pitch": True,
1407
- "frequency": True,
1408
- "downsamples": True,
1409
- "hop_length": 128,
1410
- "fmin": 50,
1411
- "fmax": 2000,
1412
- "n_mels": 128,
1413
- "n_fft": 1024,
1414
- "sampling_rate": 16000,
1415
- }
1416
-
1417
- dataset = load_dataset(
1418
- "mozilla-foundation/common_voice_17_0",
1419
- "en",
1420
- token=token,
1421
- trust_remote_code=True,
1422
- streaming=True)
1423
-
1424
- dataset = dataset.rename_column("sentence", "transcription")
1425
- dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1426
-
1427
- if sanity_check:
1428
- dataset = dataset["test"].take(10)
1429
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1430
- dataset = dataset.map(function=prepare_fn, remove_columns=["audio", "transcription"]).with_format(type="torch")
1431
- train_dataset = dataset
1432
- test_dataset = dataset
1433
- else:
1434
-
1435
- def filter_func(x):
1436
- return (0 < len(x["transcription"]) < 512 and
1437
- len(x["audio"]["array"]) > 0 and
1438
- len(x["audio"]["array"]) < 1500 * 160)
1439
-
1440
- dataset = dataset.filter(filter_func)
1441
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1442
- train_dataset = dataset["train"].take(10000).shuffle()
1443
- test_dataset = dataset["test"].take(1000).shuffle()
1444
-
1445
- train_dataset = train_dataset.map(
1446
- function=prepare_fn,
1447
- remove_columns=["audio", "transcription"]
1448
- ).with_format(type="torch")
1449
-
1450
- test_dataset = test_dataset.map(
1451
- function=prepare_fn,
1452
- remove_columns=["audio", "transcription"]
1453
- ).with_format(type="torch")
1454
-
1455
- return train_dataset, test_dataset
1456
-
1457
- def get_training_args(
1458
- log_dir: str,
1459
- batch_eval_metrics: bool = False,
1460
- max_steps: int = 10,
1461
- save_steps: int = 1000,
1462
- eval_steps: int = 1,
1463
- warmup_steps: int = 0,
1464
- num_train_epochs: int = 1,
1465
- logging_steps: int = 1,
1466
- eval_on_start: bool = False,
1467
- learning_rate: float = 1e-4,
1468
- weight_decay: float = 0.01,
1469
- max_grad_norm: float = 1.0,
1470
- ) -> Seq2SeqTrainingArguments:
1471
-
1472
- return Seq2SeqTrainingArguments(
1473
- output_dir=log_dir,
1474
- per_device_train_batch_size=1,
1475
- per_device_eval_batch_size=1,
1476
- gradient_accumulation_steps=1,
1477
- eval_accumulation_steps=1,
1478
- eval_strategy="steps",
1479
- save_strategy="no",
1480
- max_steps=max_steps,
1481
- save_steps=save_steps,
1482
- eval_steps=eval_steps,
1483
- warmup_steps=warmup_steps,
1484
- num_train_epochs=num_train_epochs,
1485
- logging_steps=logging_steps,
1486
- logging_dir=log_dir,
1487
- logging_strategy="steps",
1488
- report_to=["tensorboard"],
1489
- push_to_hub=False,
1490
- disable_tqdm=False,
1491
- save_total_limit=1,
1492
- label_names=["labels"],
1493
- optim="adamw_torch",
1494
- adam_beta1=0.9,
1495
- adam_beta2=0.999,
1496
- adam_epsilon=1e-8,
1497
- lr_scheduler_type="cosine",
1498
- learning_rate=learning_rate,
1499
- weight_decay=weight_decay,
1500
- save_safetensors=False,
1501
- eval_on_start=eval_on_start,
1502
- batch_eval_metrics=batch_eval_metrics,
1503
- max_grad_norm=max_grad_norm,
1504
- )
1505
-
1506
- def main():
1507
-
1508
- token = ""
1509
- log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H_%M_%S'))
1510
- os.makedirs(name=log_dir, exist_ok=True)
1511
- tokenizer = setup_tokenizer(token)
1512
-
1513
- def sanity(sanity: bool):
1514
-
1515
- if sanity:
1516
- training_args = get_training_args(
1517
- log_dir,
1518
- batch_eval_metrics = False,
1519
- max_steps = 10,
1520
- save_steps = 0,
1521
- eval_steps = 1,
1522
- warmup_steps = 0,
1523
- logging_steps = 1,
1524
- eval_on_start = True,
1525
- learning_rate = 5e-6,
1526
- weight_decay = 0.01,
1527
- max_grad_norm = 0.6,
1528
- )
1529
- else:
1530
- training_args = get_training_args(
1531
- log_dir,
1532
- batch_eval_metrics = False,
1533
- max_steps = 10000,
1534
- save_steps = 10000,
1535
- eval_steps = 1000,
1536
- warmup_steps = 1000,
1537
- logging_steps = 100,
1538
- eval_on_start = False,
1539
- learning_rate = 2.5e-4,
1540
- weight_decay = 0.01,
1541
- max_grad_norm = 0.6,
1542
- )
1543
-
1544
- return training_args
1545
-
1546
- param = Dimensions(
1547
- mels=128,
1548
- aud_ctx=1500,
1549
- aud_head=4,
1550
- aud_dims=512,
1551
- aud_idx=4,
1552
- vocab=40000,
1553
- text_ctx=512,
1554
- text_head=4,
1555
- text_dims=512,
1556
- text_idx=4,
1557
- act="swish",
1558
- debug={"encoder"},
1559
- cross_attn=True,
1560
- features = ["spectrogram"],
1561
- )
1562
-
1563
- sanity_check = False
1564
-
1565
- training_args = sanity(sanity_check)
1566
- dataset_config = {
1567
- "spectrogram": True,
1568
- "waveforms": False,
1569
- "pitch": False,
1570
- "downsamples": False,
1571
- "frequency": True,
1572
- "hilbert": False,
1573
- "hop_length": 128,
1574
- "fmin": 150,
1575
- "fmax": 2000,
1576
- "n_mels": 128,
1577
- "n_fft": 1024,
1578
- "sampling_rate": 16000,
1579
- "pad_mode": "constant",
1580
- "center": True,
1581
- "power": 1.0,
1582
- "window_fn": torch.hann_window,
1583
- "mel_scale": "htk",
1584
- "norm": None,
1585
- "normalized": False}
1586
-
1587
- model = create_model(param)
1588
-
1589
- global global_model
1590
- global_model = model
1591
-
1592
- metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1,
1593
- tokenizer=tokenizer, model=model)
1594
-
1595
- print(f"{'Sanity check' if sanity_check else 'Training'} mode")
1596
- train_dataset, test_dataset = prepare_datasets(
1597
- tokenizer=tokenizer,
1598
- token=token,
1599
- sanity_check=sanity_check,
1600
- dataset_config=dataset_config)
1601
-
1602
- # optimizer = MaxFactor(model.parameters(), lr=0.025, beta2_decay=-0.8, eps=(1e-10, 1e-7), d=1.0,
1603
- # weight_decay=0.025, gamma=0.99, max=False)
1604
-
1605
- # optimizer = torch.optim.AdamW(model.parameters(), lr=0.00025, eps=1e-8, weight_decay=0.025, betas=(0.9, 0.999),
1606
- # amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
1607
-
1608
- # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=0.0, last_epoch=-1)
1609
-
1610
- trainer = Seq2SeqTrainer(
1611
- args=training_args,
1612
- model=model,
1613
- train_dataset=train_dataset,
1614
- eval_dataset=test_dataset,
1615
- data_collator=DataCollator(tokenizer=tokenizer),
1616
- compute_metrics=metrics_fn,
1617
- # optimizers=(optimizer, scheduler)
1618
- )
1619
-
1620
- model.init_weights()
1621
- trainer.train()
1622
-
1623
- if __name__ == "__main__":
1624
- main()
1625
-