Update modelA.py
Browse files
modelA.py
CHANGED
@@ -8,10 +8,15 @@ import torchaudio
|
|
8 |
import torch.nn.functional as F
|
9 |
import torch.nn.init as init
|
10 |
from torch import nn, Tensor
|
11 |
-
|
12 |
-
|
|
|
13 |
import numpy as np
|
14 |
-
from
|
|
|
|
|
|
|
|
|
15 |
import transformers
|
16 |
from dataclasses import dataclass
|
17 |
from opimizer import MaxFactor
|
@@ -20,22 +25,44 @@ torch.backends.cudnn.allow_tf32 = True
|
|
20 |
torch.backends.cuda.matmul.allow_tf32 = True
|
21 |
torch.set_float32_matmul_precision('high')
|
22 |
transformers.utils.logging.set_verbosity_error()
|
|
|
23 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
24 |
dtype = torch.float32
|
|
|
|
|
25 |
logging.basicConfig(level=logging.ERROR)
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
@dataclass
|
28 |
class Dimensions:
|
29 |
vocab: int
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
mels: int
|
35 |
-
aud_ctx: int
|
36 |
-
aud_dims: int
|
37 |
-
aud_head: int
|
38 |
-
aud_idx: int
|
39 |
act: str
|
40 |
debug: List[str]
|
41 |
cross_attn: bool
|
@@ -59,6 +86,132 @@ def get_generation_config(param):
|
|
59 |
use_cache=False,
|
60 |
return_timestamps=False)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def dict_to(d, device, dtype=dtype):
|
63 |
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
|
64 |
for k, v in d.items()}
|
@@ -100,17 +253,17 @@ class RMSNorm(nn.Module):
|
|
100 |
self.eps = eps
|
101 |
self.elementwise_affine = elementwise_affine
|
102 |
if self.elementwise_affine:
|
103 |
-
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
|
104 |
init.ones_(self.weight)
|
105 |
else:
|
106 |
self.register_parameter("weight", None)
|
107 |
def forward(self, x):
|
108 |
-
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
109 |
|
110 |
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
|
111 |
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
|
112 |
eps: float = 1e-5) -> Tensor:
|
113 |
-
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
114 |
|
115 |
def get_device():
|
116 |
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -128,9 +281,8 @@ def sinusoids(length, channels, max_tscale=10000):
|
|
128 |
scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
|
129 |
return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
|
130 |
|
131 |
-
|
132 |
class rotary(nn.Module):
|
133 |
-
def __init__(self, dims, head, max_ctx=1500,
|
134 |
super(rotary, self).__init__()
|
135 |
|
136 |
self.use_pbias = use_pbias
|
@@ -143,96 +295,17 @@ class rotary(nn.Module):
|
|
143 |
self.counter = 0
|
144 |
self.last_theta = None
|
145 |
|
146 |
-
self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
# freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
151 |
-
# freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
|
152 |
-
# return freqs
|
153 |
-
|
154 |
-
# def mel_geodesic_rotary(f0, theta):
|
155 |
-
# mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
|
156 |
-
# fisher_info = torch.var(mel_f0) + 1e-8
|
157 |
-
# adaptive_theta = theta * torch.sqrt(fisher_info)
|
158 |
-
# freqs = self.theta_freqs(adaptive_theta)
|
159 |
-
# return freqs
|
160 |
-
|
161 |
-
# def compute_pitch_fisher_info(f0, window_size=10):
|
162 |
-
# if f0.dim() == 1:
|
163 |
-
# f0 = f0.unsqueeze(0)
|
164 |
-
# mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
|
165 |
-
# fisher_info = torch.nn.functional.avg_pool1d(
|
166 |
-
# mel_f0.unsqueeze(0),
|
167 |
-
# kernel_size=window_size,
|
168 |
-
# stride=1,
|
169 |
-
# padding=window_size//2
|
170 |
-
# ).squeeze(0)
|
171 |
-
# fisher_info = (fisher_info - fisher_info.min()) / (fisher_info.max() - fisher_info.min() + 1e-8)
|
172 |
-
# return fisher_info
|
173 |
-
|
174 |
-
# def compute_advanced_fisher_info(f0, window_size=10):
|
175 |
-
# mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
|
176 |
-
# local_mean = torch.nn.functional.avg_pool1d(
|
177 |
-
# mel_f0.unsqueeze(0), window_size, 1, window_size//2
|
178 |
-
# ).squeeze(0)
|
179 |
-
|
180 |
-
# local_var = torch.nn.functional.avg_pool1d(
|
181 |
-
# (mel_f0 - local_mean).pow(2).unsqueeze(0),
|
182 |
-
# window_size, 1, window_size//2
|
183 |
-
# ).squeeze(0)
|
184 |
-
|
185 |
-
# fisher_info = 1.0 / (local_var + 1e-8)
|
186 |
-
# return fisher_info
|
187 |
-
|
188 |
-
# def test_fisher_info(self, f0):
|
189 |
-
# """Test Fisher information computation.""" # fisher_info = self.compute_pitch_fisher_info(f0)
|
190 |
-
|
191 |
-
# print(f"f0 range: {f0.min():.1f} - {f0.max():.1f}")
|
192 |
-
# print(f"Fisher info range: {fisher_info.min():.3f} - {fisher_info.max():.3f}")
|
193 |
-
# print(f"Fisher info mean: {fisher_info.mean():.3f}")
|
194 |
-
|
195 |
-
# # Visualize: high Fisher info = meaningful pitch changes
|
196 |
-
# return fisher_info
|
197 |
-
|
198 |
-
# def forward(self, x=None, enc=None, layer=None, feature_type="audio"):
|
199 |
-
|
200 |
-
# if f0 is not None:
|
201 |
-
# # Compute Fisher information
|
202 |
-
# fisher_info = self.compute_pitch_fisher_info(f0)
|
203 |
-
|
204 |
-
# # Use Fisher info to weight pitch influence
|
205 |
-
# f0_weighted = f0 * fisher_info
|
206 |
-
|
207 |
-
# # Apply to both theta and radius
|
208 |
-
# f0_mean = f0_weighted.mean()
|
209 |
-
# theta = f0_mean + self.theta
|
210 |
-
|
211 |
-
# if self.radii:
|
212 |
-
# radius = f0_weighted.to(device, dtype)
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
def theta_freqs(self, theta):
|
217 |
-
freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
218 |
-
freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
|
219 |
-
return freqs
|
220 |
|
221 |
-
def mel_scale_scalar(freq: float) -> float:
|
222 |
return 1127.0 * math.log(1.0 + freq / 700.0)
|
223 |
|
224 |
-
def mel_scale(freq: Tensor) -> Tensor:
|
225 |
return 1127.0 * (1.0 + freq / 700.0).log()
|
226 |
|
227 |
-
def return_f0(self, f0=None):
|
228 |
-
if f0 is not None:
|
229 |
-
self.f0 = f0
|
230 |
-
self.update_base(f0)
|
231 |
-
return f0.squeeze(0).to(device, dtype)
|
232 |
-
elif hasattr(self, 'f0') and self.f0 is not None:
|
233 |
-
return self.f0.squeeze(0).to(device, dtype)
|
234 |
-
return None
|
235 |
-
|
236 |
def pitch_bias(self, f0):
|
237 |
if f0 is None:
|
238 |
return None
|
@@ -242,9 +315,31 @@ class rotary(nn.Module):
|
|
242 |
f0_norm.unsqueeze(1)))
|
243 |
return f0_sim.unsqueeze(0).unsqueeze(0)
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
def forward(self, x=None, enc=None, layer=None, feature_type="audio") -> Tensor:
|
247 |
f0 = enc.get("f0") if enc is not None else None
|
|
|
248 |
if isinstance(x, int):
|
249 |
ctx = x
|
250 |
elif isinstance(x, torch.Tensor) and x.ndim == 2:
|
@@ -252,46 +347,33 @@ class rotary(nn.Module):
|
|
252 |
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
253 |
batch, ctx, dims = x.shape
|
254 |
else:
|
255 |
-
batch, head, ctx, head_dim = x.shape
|
256 |
-
t = torch.arange(ctx, device=device, dtype=dtype)
|
257 |
-
|
258 |
-
if f0 is not None and f0.dim() == 2:
|
259 |
-
if f0.shape[0] == 1:
|
260 |
-
f0 = f0.squeeze(0)
|
261 |
-
else:
|
262 |
-
f0 = f0.view(-1)
|
263 |
|
264 |
if f0 is not None:
|
265 |
-
|
266 |
-
|
|
|
267 |
else:
|
268 |
theta = self.theta
|
269 |
|
270 |
freqs = self.theta_freqs(theta)
|
271 |
-
|
272 |
-
freqs = t[:, None] * freqs
|
273 |
-
|
274 |
if self.radii and f0 is not None:
|
275 |
radius = f0.to(device, dtype)
|
276 |
-
|
277 |
-
if L != ctx:
|
278 |
-
F = L / ctx
|
279 |
-
idx = torch.arange(ctx, device=f0.device)
|
280 |
-
idx = (idx * F).long().clamp(0, L - 1)
|
281 |
-
radius = radius[idx]
|
282 |
-
freqs = torch.polar(radius.unsqueeze(-1).expand_as(freqs), freqs)
|
283 |
else:
|
284 |
-
|
285 |
-
|
286 |
-
if "radius" in self.debug and self.counter % 100 == 0:
|
287 |
-
theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
|
288 |
-
print(f" [{layer}] [Radius] {radius.shape} {radius.mean():.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
289 |
|
290 |
-
if "
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
295 |
self.counter += 1
|
296 |
return freqs.unsqueeze(0)
|
297 |
|
@@ -308,12 +390,11 @@ class rotary(nn.Module):
|
|
308 |
x1 = x1.view(orig_shape)
|
309 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
310 |
|
311 |
-
|
312 |
class MultiheadA(nn.Module):
|
313 |
-
|
314 |
rbf = False
|
315 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
316 |
-
zero_val: float = 1e-
|
317 |
super(MultiheadA, self).__init__()
|
318 |
|
319 |
self.dims = dims
|
@@ -345,8 +426,29 @@ class MultiheadA(nn.Module):
|
|
345 |
)
|
346 |
else:
|
347 |
self.rope = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
-
def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
|
350 |
|
351 |
x = x.to(device, dtype)
|
352 |
if xa is not None:
|
@@ -365,8 +467,8 @@ class MultiheadA(nn.Module):
|
|
365 |
q2 = q.shape[2]
|
366 |
k2 = k.shape[2]
|
367 |
|
368 |
-
q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
|
369 |
-
k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
|
370 |
else:
|
371 |
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
372 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
@@ -374,14 +476,25 @@ class MultiheadA(nn.Module):
|
|
374 |
|
375 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
token_ids = k[:, :, :, 0]
|
378 |
zscale = torch.ones_like(token_ids)
|
379 |
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
|
380 |
zscale[token_ids.float() == self.pad_token] = fzero
|
381 |
|
382 |
if mask is not None:
|
383 |
-
mask = mask[:q2, :q2]
|
384 |
-
|
|
|
|
|
|
|
|
|
385 |
qk = qk * zscale.unsqueeze(-2)
|
386 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
387 |
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
@@ -392,8 +505,9 @@ class MultiheadA(nn.Module):
|
|
392 |
return self.o(wv), qk
|
393 |
|
394 |
class t_gate(nn.Module):
|
395 |
-
def __init__(self, dims, num_types=4):
|
396 |
super().__init__()
|
|
|
397 |
self.gate_projections = nn.ModuleList([
|
398 |
nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
399 |
for _ in range(num_types)])
|
@@ -401,19 +515,25 @@ class t_gate(nn.Module):
|
|
401 |
Linear(dims, num_types),
|
402 |
nn.Softmax(dim=-1))
|
403 |
def forward(self, x):
|
|
|
|
|
404 |
type_probs = self.type_classifier(x)
|
405 |
gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
|
406 |
comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
|
407 |
return comb_gate
|
408 |
|
409 |
class m_gate(nn.Module):
|
410 |
-
def __init__(self, dims, mem_size=64):
|
411 |
super().__init__()
|
412 |
-
self.
|
413 |
-
|
414 |
-
|
415 |
-
|
|
|
|
|
416 |
def forward(self, x):
|
|
|
|
|
417 |
d_gate = torch.sigmoid(self.gate_proj(x))
|
418 |
attention = torch.matmul(x, self.m_key.transpose(0, 1))
|
419 |
attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
|
@@ -422,16 +542,20 @@ class m_gate(nn.Module):
|
|
422 |
return 0.5 * (d_gate + m_gate)
|
423 |
|
424 |
class c_gate(nn.Module):
|
425 |
-
def __init__(self, dims):
|
426 |
super().__init__()
|
427 |
-
self.
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
|
|
|
|
433 |
|
434 |
def forward(self, x, features):
|
|
|
|
|
435 |
s_feat = features.get("spectrogram", x)
|
436 |
w_feat = features.get("waveform", x)
|
437 |
p_feat = features.get("pitch", x)
|
@@ -445,9 +569,21 @@ class c_gate(nn.Module):
|
|
445 |
comb = torch.cat([s, w, p, e, ph], dim=-1)
|
446 |
return self.integ(comb)
|
447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
class Residual(nn.Module):
|
449 |
_seen = set()
|
450 |
-
def __init__(self, ctx, dims, head, act,
|
451 |
tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
|
452 |
super().__init__()
|
453 |
|
@@ -455,80 +591,44 @@ class Residual(nn.Module):
|
|
455 |
self.head = head
|
456 |
self.ctx = ctx
|
457 |
self.head_dim = dims // head
|
458 |
-
self.cross_attn = cross_attn
|
459 |
self.features = features
|
460 |
self.debug = debug
|
461 |
self.counter = 0
|
462 |
self.dropout = 0.01
|
463 |
-
|
464 |
-
self.t_gate = tgate
|
465 |
-
self.m_gate = mgate
|
466 |
-
self.c_gate = cgate
|
467 |
-
self.do_blend = "no_blend" not in self.debug
|
468 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug)
|
478 |
-
self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None)
|
479 |
|
480 |
mlp = dims * 4
|
481 |
self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
|
482 |
|
483 |
-
self.t_gate = t_gate(dims=dims, num_types=4
|
484 |
-
self.m_gate = m_gate(dims=dims, mem_size=mem_size
|
485 |
-
self.c_gate = c_gate(dims=dims
|
|
|
486 |
|
487 |
self.lna = RMSNorm(dims)
|
488 |
-
self.lnb = RMSNorm(dims)
|
489 |
self.lnc = RMSNorm(dims)
|
490 |
|
491 |
-
if not any([t_gate, m_gate, c_gate]):
|
492 |
-
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
493 |
-
|
494 |
def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
|
495 |
-
|
496 |
-
x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
|
497 |
-
xb = x
|
498 |
-
if self.attnb and xa is not None:
|
499 |
-
x = x + self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
|
500 |
-
|
501 |
-
if self.do_blend:
|
502 |
-
b = torch.sigmoid(self.blend)
|
503 |
-
x = b * xb + (1 - b) * x
|
504 |
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
elif self.m_gate:
|
516 |
-
gate = self.m_gate(normx)
|
517 |
-
x = x + gate * mlp_out
|
518 |
|
519 |
-
elif self.c_gate:
|
520 |
-
gate_output = self.c_gate(normx, self.features)
|
521 |
-
x = x + gate_output
|
522 |
-
|
523 |
-
else:
|
524 |
-
if hasattr(self, 'mlp_gate'):
|
525 |
-
mlp_gate = self.mlp_gate(normx)
|
526 |
-
x = x + mlp_gate * mlp_out
|
527 |
-
else:
|
528 |
-
x = x + mlp_out
|
529 |
-
|
530 |
-
return x
|
531 |
-
|
532 |
class FEncoder(nn.Module):
|
533 |
def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
|
534 |
super().__init__()
|
@@ -539,8 +639,7 @@ class FEncoder(nn.Module):
|
|
539 |
self.use_rope = use_rope
|
540 |
self.dims = dims
|
541 |
|
542 |
-
|
543 |
-
act_fn = act_map.get(act, nn.GELU())
|
544 |
|
545 |
self.encoder = nn.Sequential(
|
546 |
Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
|
@@ -551,11 +650,13 @@ class FEncoder(nn.Module):
|
|
551 |
if spec_shape is not None:
|
552 |
self.rope = rotary(
|
553 |
dims=self.head_dim,
|
|
|
554 |
use_2d_axial=True,
|
555 |
spec_shape=spec_shape, debug=[])
|
556 |
else:
|
557 |
self.rope = rotary(
|
558 |
dims=self.head_dim,
|
|
|
559 |
use_2d_axial=False, debug=[])
|
560 |
else:
|
561 |
self.rope = None
|
@@ -569,7 +670,7 @@ class FEncoder(nn.Module):
|
|
569 |
feature_type = "spectrogram"
|
570 |
batch, ctx, dims = x.shape
|
571 |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
572 |
-
if feature_type == "spectrogram" and
|
573 |
rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
|
574 |
else:
|
575 |
rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
|
@@ -597,8 +698,7 @@ class WEncoder(nn.Module):
|
|
597 |
self.use_rope = use_rope
|
598 |
self.dims = dims
|
599 |
|
600 |
-
|
601 |
-
act_fn = act_map.get(act, nn.GELU())
|
602 |
|
603 |
self.downsample = nn.Sequential(
|
604 |
Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
|
@@ -611,8 +711,8 @@ class WEncoder(nn.Module):
|
|
611 |
if use_rope:
|
612 |
self.rope = rotary(
|
613 |
dims=self.head_dim,
|
614 |
-
|
615 |
-
|
616 |
else:
|
617 |
self.rope = None
|
618 |
self.positional = lambda length: sinusoids(length, dims)
|
@@ -649,8 +749,7 @@ class PEncoder(nn.Module):
|
|
649 |
self.use_rope = use_rope
|
650 |
self.dims = dims
|
651 |
|
652 |
-
|
653 |
-
act_fn = act_map.get(act, nn.GELU())
|
654 |
|
655 |
self.encoder = nn.Sequential(
|
656 |
Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
|
@@ -660,8 +759,8 @@ class PEncoder(nn.Module):
|
|
660 |
if use_rope:
|
661 |
self.rope = rotary(
|
662 |
dims=self.head_dim,
|
663 |
-
|
664 |
-
|
665 |
else:
|
666 |
self.rope = None
|
667 |
self.positional = lambda length: sinusoids(length, dims)
|
@@ -687,10 +786,10 @@ class PEncoder(nn.Module):
|
|
687 |
x = self.norm(x)
|
688 |
return x
|
689 |
|
690 |
-
class
|
691 |
_seen = set()
|
692 |
-
def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
|
693 |
-
super(
|
694 |
|
695 |
self.dims = dims
|
696 |
self.head = head
|
@@ -700,9 +799,12 @@ class AudioEncoder(nn.Module):
|
|
700 |
self.counter = 0
|
701 |
self.features = features
|
702 |
self.dropout = 0.01
|
|
|
|
|
703 |
|
704 |
-
|
705 |
-
|
|
|
706 |
|
707 |
if features == ["spectrogram", "waveform", "pitch"]:
|
708 |
cgate=True
|
@@ -737,80 +839,55 @@ class AudioEncoder(nn.Module):
|
|
737 |
if "phase" in features else None),
|
738 |
})
|
739 |
|
740 |
-
def forward(self, enc, layer="encoder"):
|
741 |
-
enc = dict_to(enc, device, dtype)
|
742 |
-
out = {}
|
743 |
-
out.update(enc)
|
744 |
-
|
745 |
-
for f in self.features:
|
746 |
-
if f in enc and f in self.blocks:
|
747 |
-
x = enc[f]
|
748 |
-
for block in self.blocks[f]:
|
749 |
-
x = block(x, enc=enc, layer=layer)
|
750 |
-
out[f] = x
|
751 |
-
|
752 |
-
return out
|
753 |
-
|
754 |
-
class TextDecoder(nn.Module):
|
755 |
-
def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
|
756 |
-
debug: List[str], features: List[str]):
|
757 |
-
super(TextDecoder, self).__init__()
|
758 |
-
|
759 |
-
self.ctx = ctx
|
760 |
-
self.dims = dims
|
761 |
-
self.head = head
|
762 |
-
self.head_dim = dims // head
|
763 |
-
self.debug = debug
|
764 |
-
self.counter = 0
|
765 |
-
self.dropout = 0.01
|
766 |
-
self.features = features
|
767 |
-
self.do_blend = "no_blend" not in self.debug
|
768 |
-
self.sequential = "sequential" in self.debug
|
769 |
-
|
770 |
-
self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
|
771 |
-
with torch.no_grad():
|
772 |
-
self.token.weight[0].zero_()
|
773 |
-
self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
|
774 |
-
|
775 |
self.block = nn.ModuleList([
|
776 |
-
Residual(ctx=ctx, dims=dims, head=head, act="gelu",
|
777 |
-
for _ in range(layer)])
|
778 |
-
|
779 |
-
self.blocks = nn.ModuleDict({
|
780 |
-
f: nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
|
781 |
-
for _ in range(layer)]) for f in features})
|
782 |
|
783 |
-
self.blend = nn.
|
784 |
self.ln_dec = RMSNorm(dims)
|
785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
786 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
787 |
self.register_buffer("mask", mask, persistent=False)
|
788 |
|
789 |
-
def forward(self,
|
|
|
790 |
|
791 |
-
|
792 |
-
order = self.features
|
793 |
-
|
794 |
-
mask = self.mask[:x.shape[1], :x.shape[1]]
|
795 |
x = self.token(x) + self.positional[:x.shape[1]]
|
796 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
797 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
798 |
|
799 |
for block in self.block:
|
|
|
800 |
x = block(x, xa=None, mask=mask, enc=None, layer=layer)
|
801 |
|
802 |
-
for f in
|
803 |
if f in enc:
|
804 |
-
|
805 |
-
for block in self.
|
806 |
-
out = block(x
|
807 |
-
|
808 |
if self.sequential:
|
809 |
x = out
|
810 |
else:
|
811 |
-
a = torch.sigmoid(self.blend
|
812 |
x = a * out + (1 - a) * x
|
813 |
-
|
814 |
|
815 |
x = self.ln_dec(x)
|
816 |
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
@@ -820,38 +897,28 @@ class Echo(nn.Module):
|
|
820 |
super().__init__()
|
821 |
self.param = param
|
822 |
|
823 |
-
self.
|
|
|
824 |
mels=param.mels,
|
825 |
-
ctx=param.
|
826 |
-
dims=param.
|
827 |
-
head=param.
|
828 |
-
layer=param.
|
829 |
-
act=param.act,
|
830 |
debug=param.debug,
|
831 |
features=param.features,
|
|
|
832 |
)
|
833 |
|
834 |
-
self.decoder = TextDecoder(
|
835 |
-
vocab=param.vocab,
|
836 |
-
ctx=param.text_ctx,
|
837 |
-
dims=param.text_dims,
|
838 |
-
head=param.text_head,
|
839 |
-
layer=param.text_idx,
|
840 |
-
cross_attn=param.cross_attn,
|
841 |
-
debug=param.debug,
|
842 |
-
features=param.features,
|
843 |
-
)
|
844 |
-
|
845 |
def forward(self,
|
846 |
labels=None,
|
847 |
-
waveform: Optional[torch.Tensor]=None,
|
848 |
input_ids=None,
|
849 |
-
|
|
|
850 |
pitch: Optional[torch.Tensor]=None,
|
851 |
f0: Optional[torch.Tensor]=None,
|
852 |
envelope: Optional[torch.Tensor]=None,
|
853 |
phase: Optional[torch.Tensor]=None,
|
854 |
-
) -> Dict[str, torch.Tensor]:
|
855 |
|
856 |
encoder_inputs = {}
|
857 |
if spectrogram is not None:
|
@@ -866,9 +933,10 @@ class Echo(nn.Module):
|
|
866 |
encoder_inputs["phase"] = phase
|
867 |
if f0 is not None:
|
868 |
encoder_inputs["f0"] = f0
|
|
|
|
|
869 |
|
870 |
-
|
871 |
-
logits = self.decoder(input_ids, encoder_outputs)
|
872 |
|
873 |
loss = None
|
874 |
if labels is not None:
|
@@ -888,7 +956,7 @@ class Echo(nn.Module):
|
|
888 |
std = 0.02
|
889 |
self.init_counts = {
|
890 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
891 |
-
"Conv2d": 0, "SEBlock": 0, "
|
892 |
"Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
|
893 |
"MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
|
894 |
"WEncoder": 0, "PEncoder": 0}
|
@@ -914,12 +982,9 @@ class Echo(nn.Module):
|
|
914 |
nn.init.zeros_(module.bias)
|
915 |
self.init_counts["Conv2d"] += 1
|
916 |
elif isinstance(module, MultiheadA):
|
917 |
-
|
918 |
self.init_counts["MultiheadA"] += 1
|
919 |
-
elif isinstance(module,
|
920 |
-
self.init_counts["
|
921 |
-
elif isinstance(module, AudioEncoder):
|
922 |
-
self.init_counts["AudioEncoder"] += 1
|
923 |
elif isinstance(module, Residual):
|
924 |
self.init_counts["Residual"] += 1
|
925 |
|
@@ -957,10 +1022,11 @@ class Echo(nn.Module):
|
|
957 |
encoder_inputs["phase"] = phase
|
958 |
if f0 is not None:
|
959 |
encoder_inputs["f0"] = f0
|
960 |
-
|
961 |
for i in range(max_length - 1):
|
962 |
with torch.no_grad():
|
963 |
-
|
|
|
964 |
next_token_logits = logits[:, -1, :]
|
965 |
if i < min_length:
|
966 |
next_token_logits[:, eos_token_id] = 0
|
@@ -985,10 +1051,9 @@ class Echo(nn.Module):
|
|
985 |
})
|
986 |
return Config()
|
987 |
|
988 |
-
|
989 |
-
def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
990 |
from tokenizers import Tokenizer
|
991 |
-
tokenizer = Tokenizer.from_file(
|
992 |
orig_encode = tokenizer.encode
|
993 |
def enc(text, add_special_tokens=True):
|
994 |
ids = orig_encode(text).ids
|
@@ -1005,6 +1070,11 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
|
1005 |
ids = ids[1:]
|
1006 |
while ids and ids[-1] in [0, 2]:
|
1007 |
ids = ids[:-1]
|
|
|
|
|
|
|
|
|
|
|
1008 |
results.append(tokenizer.decode(ids))
|
1009 |
return results
|
1010 |
|
@@ -1019,95 +1089,165 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
|
|
1019 |
tokenizer.eos_token_id = 2
|
1020 |
return tokenizer
|
1021 |
|
1022 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1023 |
audio = batch["audio"]
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
-
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
|
1040 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1041 |
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
|
1042 |
f0 = pw.stonemask(wav_np, f0, t, sample_rate)
|
1043 |
-
f0 = torch.from_numpy(f0)
|
1044 |
-
|
1045 |
-
|
|
|
1046 |
return {
|
1047 |
"spectrogram": spec,
|
1048 |
"f0": f0,
|
1049 |
-
"
|
1050 |
-
"
|
|
|
1051 |
}
|
1052 |
|
1053 |
-
def prepare_datasets(tokenizer, token
|
1054 |
-
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
1058 |
-
|
1059 |
-
|
1060 |
-
|
1061 |
-
|
1062 |
-
|
1063 |
-
|
1064 |
-
|
1065 |
-
|
1066 |
-
|
1067 |
-
|
1068 |
-
|
1069 |
-
|
1070 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1071 |
|
1072 |
@dataclass
|
1073 |
class DataCollator:
|
1074 |
tokenizer: Any
|
1075 |
|
1076 |
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
|
|
|
|
|
|
|
1077 |
pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
|
1078 |
bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
|
1079 |
eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
|
1080 |
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
|
|
|
|
|
|
|
|
|
|
1111 |
|
1112 |
def levenshtein(reference_words, hypothesis_words):
|
1113 |
m, n = len(reference_words), len(hypothesis_words)
|
@@ -1137,7 +1277,7 @@ def wer_batch(references, hypotheses):
|
|
1137 |
total_words += len(ref_words)
|
1138 |
return (total_errors / total_words) * 100 if total_words > 0 else 0.0
|
1139 |
|
1140 |
-
def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0):
|
1141 |
pred_ids = pred.predictions
|
1142 |
label_ids = pred.label_ids
|
1143 |
if isinstance(pred_ids, tuple):
|
@@ -1146,21 +1286,25 @@ def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samp
|
|
1146 |
if not isinstance(pred_ids, torch.Tensor):
|
1147 |
pred_ids = torch.tensor(pred_ids)
|
1148 |
pred_ids = pred_ids.argmax(dim=-1)
|
|
|
1149 |
pred_ids = pred_ids.tolist()
|
1150 |
label_ids = label_ids.tolist()
|
1151 |
pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
|
1152 |
label_ids = [[pad_token_id if token == -100 else token for token in seq] for seq in label_ids]
|
1153 |
-
|
1154 |
-
while seq and seq[-1] == pad_token_id:
|
1155 |
-
seq = seq[:-1]
|
1156 |
-
return seq
|
1157 |
-
pred_ids = [strip_trailing(seq, pad_token_id) for seq in pred_ids]
|
1158 |
-
label_ids = [strip_trailing(seq, pad_token_id) for seq in label_ids]
|
1159 |
if print_pred:
|
1160 |
for i in range(min(num_samples, len(pred_ids))):
|
1161 |
-
|
1162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1163 |
print("-" * 40)
|
|
|
1164 |
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
1165 |
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
1166 |
wer = wer_batch(label_str, pred_str)
|
@@ -1170,9 +1314,9 @@ def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samp
|
|
1170 |
else:
|
1171 |
trainable_params = 0.0
|
1172 |
efficiency_score = 0.0
|
|
|
1173 |
return {
|
1174 |
"wer": float(wer),
|
1175 |
-
"trainable_params_M": float(trainable_params),
|
1176 |
"efficiency_score": float(efficiency_score),
|
1177 |
}
|
1178 |
|
@@ -1183,10 +1327,13 @@ def main():
|
|
1183 |
tokenizer = setup_tokenizer(token)
|
1184 |
train_dataset, test_dataset = prepare_datasets(tokenizer, token)
|
1185 |
param = Dimensions(
|
1186 |
-
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
|
|
|
|
|
|
1190 |
model = Echo(param).to('cuda')
|
1191 |
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
1192 |
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
@@ -1202,7 +1349,6 @@ def main():
|
|
1202 |
logging_steps=10,
|
1203 |
logging_dir=log_dir,
|
1204 |
eval_strategy="steps",
|
1205 |
-
|
1206 |
save_strategy="steps",
|
1207 |
report_to=["tensorboard"],
|
1208 |
push_to_hub=False,
|
@@ -1214,17 +1360,28 @@ def main():
|
|
1214 |
batch_eval_metrics=False,
|
1215 |
)
|
1216 |
from functools import partial
|
1217 |
-
metrics_fn = partial(compute_metrics,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1218 |
trainer = Seq2SeqTrainer(
|
1219 |
args=training_args,
|
1220 |
model=model,
|
1221 |
-
train_dataset=train_dataset,
|
1222 |
-
eval_dataset=test_dataset,
|
1223 |
-
data_collator=DataCollator(tokenizer=tokenizer),
|
1224 |
compute_metrics=metrics_fn,
|
|
|
1225 |
)
|
1226 |
model.init_weights()
|
1227 |
trainer.train()
|
1228 |
|
1229 |
if __name__ == "__main__":
|
1230 |
-
main()
|
|
|
|
8 |
import torch.nn.functional as F
|
9 |
import torch.nn.init as init
|
10 |
from torch import nn, Tensor
|
11 |
+
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from typing import Optional, Dict, Union, List, Tuple, Any
|
14 |
import numpy as np
|
15 |
+
from functools import partial
|
16 |
+
from datetime import datetime
|
17 |
+
from datasets import load_dataset, Audio
|
18 |
+
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
19 |
+
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
20 |
import transformers
|
21 |
from dataclasses import dataclass
|
22 |
from opimizer import MaxFactor
|
|
|
25 |
torch.backends.cuda.matmul.allow_tf32 = True
|
26 |
torch.set_float32_matmul_precision('high')
|
27 |
transformers.utils.logging.set_verbosity_error()
|
28 |
+
|
29 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
30 |
dtype = torch.float32
|
31 |
+
|
32 |
+
warnings.filterwarnings("ignore")
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
35 |
+
PATH = 'E:/hf'
|
36 |
+
os.environ['HF_HOME'] = PATH
|
37 |
+
os.environ['HF_DATASETS_CACHE'] = PATH
|
38 |
+
os.environ['TORCH_HOME'] = PATH
|
39 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
40 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
41 |
+
|
42 |
+
def get_activation(act: str) -> nn.Module:
|
43 |
+
"""Get activation function by name."""
|
44 |
+
act_map = {
|
45 |
+
"gelu": nn.GELU(),
|
46 |
+
"relu": nn.ReLU(),
|
47 |
+
"sigmoid": nn.Sigmoid(),
|
48 |
+
"tanh": nn.Tanh(),
|
49 |
+
"swish": nn.SiLU(),
|
50 |
+
"tanhshrink": nn.Tanhshrink(),
|
51 |
+
"softplus": nn.Softplus(),
|
52 |
+
"softshrink": nn.Softshrink(),
|
53 |
+
"leaky_relu": nn.LeakyReLU(),
|
54 |
+
"elu": nn.ELU()
|
55 |
+
}
|
56 |
+
return act_map.get(act, nn.GELU())
|
57 |
+
|
58 |
@dataclass
|
59 |
class Dimensions:
|
60 |
vocab: int
|
61 |
+
ctx: int
|
62 |
+
dims: int
|
63 |
+
head: int
|
64 |
+
layer: int
|
65 |
mels: int
|
|
|
|
|
|
|
|
|
66 |
act: str
|
67 |
debug: List[str]
|
68 |
cross_attn: bool
|
|
|
86 |
use_cache=False,
|
87 |
return_timestamps=False)
|
88 |
|
89 |
+
def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
|
90 |
+
title="", markers=None, marker_labels=None,
|
91 |
+
show_voiced_regions=True, show_energy=False):
|
92 |
+
num_plots = sum([x is not None, w is not None, p is not None, per is not None])
|
93 |
+
if num_plots == 0:
|
94 |
+
raise ValueError("No data to plot. Please provide at least one input tensor.")
|
95 |
+
t_spans = []
|
96 |
+
|
97 |
+
if w is not None:
|
98 |
+
w_np = w[sample_idx].detach().cpu().numpy()
|
99 |
+
if w_np.ndim > 1:
|
100 |
+
w_np = w_np.squeeze()
|
101 |
+
t_spans.append(len(w_np) / sr)
|
102 |
+
if x is not None:
|
103 |
+
x_np = x[sample_idx].detach().cpu().numpy()
|
104 |
+
if x_np.shape[0] < x_np.shape[1]:
|
105 |
+
x_np = x_np.T
|
106 |
+
t_spans.append(x_np.shape[0] * hop_length / sr)
|
107 |
+
if p is not None:
|
108 |
+
p_np = p[sample_idx].detach().cpu().numpy()
|
109 |
+
if p_np.ndim > 1:
|
110 |
+
p_np = p_np.squeeze()
|
111 |
+
t_spans.append(len(p_np) * hop_length / sr)
|
112 |
+
if per is not None:
|
113 |
+
per_np = per[sample_idx].detach().cpu().numpy()
|
114 |
+
if per_np.ndim > 1:
|
115 |
+
per_np = per_np.squeeze()
|
116 |
+
t_spans.append(len(per_np) * hop_length / sr)
|
117 |
+
max_t = max(t_spans) if t_spans else 0
|
118 |
+
fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
|
119 |
+
if num_plots == 1:
|
120 |
+
axs = [axs]
|
121 |
+
if show_voiced_regions and per is not None:
|
122 |
+
per_np = per[sample_idx].detach().cpu().numpy()
|
123 |
+
if per_np.ndim > 1:
|
124 |
+
per_np = per_np.squeeze()
|
125 |
+
t_per = np.arange(len(per_np)) * hop_length / sr
|
126 |
+
threshold = 0.5
|
127 |
+
for ax in axs:
|
128 |
+
for i in range(len(per_np)-1):
|
129 |
+
if per_np[i] > threshold:
|
130 |
+
ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
|
131 |
+
cu_ax = 0
|
132 |
+
if w is not None:
|
133 |
+
w_np = w[sample_idx].detach().cpu().numpy()
|
134 |
+
if w_np.ndim > 1:
|
135 |
+
w_np = w_np.squeeze()
|
136 |
+
t = np.arange(len(w_np)) / sr
|
137 |
+
axs[cu_ax].plot(t, w_np, color="tab:blue")
|
138 |
+
|
139 |
+
if show_energy:
|
140 |
+
frame_length = hop_length
|
141 |
+
hop_length_energy = hop_length // 2
|
142 |
+
energy = []
|
143 |
+
for i in range(0, len(w_np)-frame_length, hop_length_energy):
|
144 |
+
frame = w_np[i:i+frame_length]
|
145 |
+
energy.append(np.sqrt(np.mean(frame**2)))
|
146 |
+
energy = np.array(energy)
|
147 |
+
energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
|
148 |
+
t_energy = np.arange(len(energy)) * hop_length_energy / sr
|
149 |
+
axs[cu_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
|
150 |
+
axs[cu_ax].legend(loc='upper right')
|
151 |
+
axs[cu_ax].set_title("Waveform")
|
152 |
+
axs[cu_ax].set_ylabel("Amplitude")
|
153 |
+
axs[cu_ax].set_xlim([0, max_t])
|
154 |
+
axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
|
155 |
+
cu_ax += 1
|
156 |
+
|
157 |
+
if x is not None:
|
158 |
+
x_np = x[sample_idx].detach().cpu().numpy()
|
159 |
+
if x_np.shape[0] < x_np.shape[1]:
|
160 |
+
x_np = x_np.T
|
161 |
+
axs[cu_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
|
162 |
+
extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
|
163 |
+
axs[cu_ax].set_title("Spectrogram")
|
164 |
+
axs[cu_ax].set_ylabel("Mel Bin")
|
165 |
+
axs[cu_ax].set_xlim([0, max_t])
|
166 |
+
axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
|
167 |
+
cu_ax += 1
|
168 |
+
|
169 |
+
if p is not None:
|
170 |
+
p_np = p[sample_idx].detach().cpu().numpy()
|
171 |
+
if p_np.ndim > 1:
|
172 |
+
p_np = p_np.squeeze()
|
173 |
+
t_p = np.arange(len(p_np)) * hop_length / sr
|
174 |
+
axs[cu_ax].plot(t_p, p_np, color="tab:green")
|
175 |
+
axs[cu_ax].set_title("Pitch")
|
176 |
+
axs[cu_ax].set_ylabel("Frequency (Hz)")
|
177 |
+
axs[cu_ax].set_xlim([0, max_t])
|
178 |
+
axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
|
179 |
+
axs[cu_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
|
180 |
+
cu_ax += 1
|
181 |
+
|
182 |
+
if per is not None:
|
183 |
+
per_np = per[sample_idx].detach().cpu().numpy()
|
184 |
+
if per_np.ndim > 1:
|
185 |
+
per_np = per_np.squeeze()
|
186 |
+
t_per = np.arange(len(per_np)) * hop_length / sr
|
187 |
+
axs[cu_ax].plot(t_per, per_np, color="tab:red")
|
188 |
+
axs[cu_ax].set_title("Period (Voice Activity)")
|
189 |
+
axs[cu_ax].set_ylabel("periodocity")
|
190 |
+
axs[cu_ax].set_xlim([0, max_t])
|
191 |
+
axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
|
192 |
+
axs[cu_ax].set_ylim([-0.05, 1.05])
|
193 |
+
axs[cu_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
|
194 |
+
|
195 |
+
if markers is not None:
|
196 |
+
for i, t in enumerate(markers):
|
197 |
+
label = marker_labels[i] if marker_labels and i < len(marker_labels) else None
|
198 |
+
for ax in axs:
|
199 |
+
ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
|
200 |
+
if marker_labels:
|
201 |
+
axs[0].legend(loc='upper right', fontsize='small')
|
202 |
+
axs[-1].set_xlabel("t (s)")
|
203 |
+
fig.suptitle(title, fontsize=16)
|
204 |
+
plt.tight_layout(rect=[0, 0, 1, 0.97]) # type: ignore
|
205 |
+
plt.show()
|
206 |
+
return fig
|
207 |
+
|
208 |
+
def valid(default_value, *items):
|
209 |
+
"""Get first non-None item"""
|
210 |
+
for item in items:
|
211 |
+
if item is not None:
|
212 |
+
return item
|
213 |
+
return default_value
|
214 |
+
|
215 |
def dict_to(d, device, dtype=dtype):
|
216 |
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
|
217 |
for k, v in d.items()}
|
|
|
253 |
self.eps = eps
|
254 |
self.elementwise_affine = elementwise_affine
|
255 |
if self.elementwise_affine:
|
256 |
+
self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
|
257 |
init.ones_(self.weight)
|
258 |
else:
|
259 |
self.register_parameter("weight", None)
|
260 |
def forward(self, x):
|
261 |
+
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
|
262 |
|
263 |
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
|
264 |
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
|
265 |
eps: float = 1e-5) -> Tensor:
|
266 |
+
return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
|
267 |
|
268 |
def get_device():
|
269 |
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
281 |
scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
|
282 |
return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
|
283 |
|
|
|
284 |
class rotary(nn.Module):
|
285 |
+
def __init__(self, dims, head, max_ctx=1500, radii=True, debug: List[str] = [], use_pbias=False):
|
286 |
super(rotary, self).__init__()
|
287 |
|
288 |
self.use_pbias = use_pbias
|
|
|
295 |
self.counter = 0
|
296 |
self.last_theta = None
|
297 |
|
298 |
+
self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
|
299 |
+
theta = (torch.tensor(10000, device=device, dtype=dtype))
|
300 |
+
self.theta = nn.Parameter(theta, requires_grad=True)
|
301 |
+
self.theta_values = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
+
def mel_scale_scalar(self, freq: float) -> float:
|
304 |
return 1127.0 * math.log(1.0 + freq / 700.0)
|
305 |
|
306 |
+
def mel_scale(self, freq: Tensor) -> Tensor:
|
307 |
return 1127.0 * (1.0 + freq / 700.0).log()
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
def pitch_bias(self, f0):
|
310 |
if f0 is None:
|
311 |
return None
|
|
|
315 |
f0_norm.unsqueeze(1)))
|
316 |
return f0_sim.unsqueeze(0).unsqueeze(0)
|
317 |
|
318 |
+
def theta_freqs(self, theta):
|
319 |
+
if theta.dim() == 0:
|
320 |
+
theta = theta.unsqueeze(0)
|
321 |
+
freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
|
322 |
+
torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
|
323 |
+
self.dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
|
324 |
+
|
325 |
+
return freq
|
326 |
+
|
327 |
+
def _apply_radii(self, freqs, f0, ctx):
|
328 |
+
if self.radii and f0 is not None:
|
329 |
+
radius = f0.to(device, dtype)
|
330 |
+
L = radius.shape[0]
|
331 |
+
if L != ctx:
|
332 |
+
F = L / ctx
|
333 |
+
idx = torch.arange(ctx, device=f0.device)
|
334 |
+
idx = (idx * F).long().clamp(0, L - 1)
|
335 |
+
radius = radius[idx]
|
336 |
+
return torch.polar(radius.unsqueeze(-1), freqs)
|
337 |
+
else:
|
338 |
+
return torch.polar(torch.ones_like(freqs), freqs)
|
339 |
|
340 |
def forward(self, x=None, enc=None, layer=None, feature_type="audio") -> Tensor:
|
341 |
f0 = enc.get("f0") if enc is not None else None
|
342 |
+
|
343 |
if isinstance(x, int):
|
344 |
ctx = x
|
345 |
elif isinstance(x, torch.Tensor) and x.ndim == 2:
|
|
|
347 |
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
348 |
batch, ctx, dims = x.shape
|
349 |
else:
|
350 |
+
batch, head, ctx, head_dim = x.shape # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
if f0 is not None:
|
353 |
+
if f0.dim() == 2:
|
354 |
+
f0 = f0.squeeze(0)
|
355 |
+
theta = f0 + self.theta
|
356 |
else:
|
357 |
theta = self.theta
|
358 |
|
359 |
freqs = self.theta_freqs(theta)
|
360 |
+
t = torch.arange(ctx, device=device, dtype=dtype)
|
361 |
+
freqs = t[:, None] * freqs
|
362 |
+
|
363 |
if self.radii and f0 is not None:
|
364 |
radius = f0.to(device, dtype)
|
365 |
+
freqs = torch.polar(radius.unsqueeze(-1), freqs)
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
else:
|
367 |
+
radius = torch.ones_like(freqs)
|
368 |
+
freqs = torch.polar(radius, freqs)
|
|
|
|
|
|
|
369 |
|
370 |
+
if "radius" in self.debug and self.counter == 10:
|
371 |
+
theta_value = theta.mean()
|
372 |
+
radius_shape = radius.shape if 'radius' in locals() else "N/A"
|
373 |
+
radius_mean = radius.mean() if 'radius' in locals() else 0.0
|
374 |
+
print(f" [{layer}] [Radius] {radius_shape} {radius_mean:.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
375 |
+
print(f" [{layer}] [Radius] {radius}")
|
376 |
+
# self.theta_values.append(theta.item())
|
377 |
self.counter += 1
|
378 |
return freqs.unsqueeze(0)
|
379 |
|
|
|
390 |
x1 = x1.view(orig_shape)
|
391 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
392 |
|
|
|
393 |
class MultiheadA(nn.Module):
|
394 |
+
|
395 |
rbf = False
|
396 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
397 |
+
zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
|
398 |
super(MultiheadA, self).__init__()
|
399 |
|
400 |
self.dims = dims
|
|
|
426 |
)
|
427 |
else:
|
428 |
self.rope = None
|
429 |
+
|
430 |
+
def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
431 |
+
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
432 |
+
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
433 |
+
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
|
434 |
+
qk_cosine = qk_cosine + mask
|
435 |
+
weights = F.softmax(qk_cosine, dim=-1)
|
436 |
+
out = torch.matmul(weights, v)
|
437 |
+
return out
|
438 |
+
|
439 |
+
def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
|
440 |
+
scale = (self.dims // self.head) ** -0.25
|
441 |
+
dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
|
442 |
+
if rbf_ratio <= 0.0:
|
443 |
+
return dot_scores
|
444 |
+
q_norm = q.pow(2).sum(dim=-1, keepdim=True)
|
445 |
+
k_norm = k.pow(2).sum(dim=-1, keepdim=True)
|
446 |
+
qk = torch.matmul(q, k.transpose(-1, -2))
|
447 |
+
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
|
448 |
+
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
449 |
+
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
450 |
|
451 |
+
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
|
452 |
|
453 |
x = x.to(device, dtype)
|
454 |
if xa is not None:
|
|
|
467 |
q2 = q.shape[2]
|
468 |
k2 = k.shape[2]
|
469 |
|
470 |
+
q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer))) # type: ignore
|
471 |
+
k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer))) # type: ignore
|
472 |
else:
|
473 |
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
474 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
|
|
476 |
|
477 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
478 |
|
479 |
+
if self.rbf:
|
480 |
+
qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
481 |
+
if self.use_pbias:
|
482 |
+
pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None) # type: ignore
|
483 |
+
if pbias is not None:
|
484 |
+
qk = qk + pbias[:,:,:q2,:q2]
|
485 |
+
|
486 |
token_ids = k[:, :, :, 0]
|
487 |
zscale = torch.ones_like(token_ids)
|
488 |
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
|
489 |
zscale[token_ids.float() == self.pad_token] = fzero
|
490 |
|
491 |
if mask is not None:
|
492 |
+
# mask = mask[:q2, :q2]#torch.tril(torch.ones(q2, q2, device=q.device))
|
493 |
+
# audio_mask = torch.ones(q2, k2 - q2, device=q.device)
|
494 |
+
# mask = torch.cat([mask, audio_mask], dim=-1)
|
495 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
496 |
+
qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
|
497 |
+
|
498 |
qk = qk * zscale.unsqueeze(-2)
|
499 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
500 |
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
|
|
505 |
return self.o(wv), qk
|
506 |
|
507 |
class t_gate(nn.Module):
|
508 |
+
def __init__(self, dims, num_types=4, enabled=True):
|
509 |
super().__init__()
|
510 |
+
self.enabled = enabled
|
511 |
self.gate_projections = nn.ModuleList([
|
512 |
nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
513 |
for _ in range(num_types)])
|
|
|
515 |
Linear(dims, num_types),
|
516 |
nn.Softmax(dim=-1))
|
517 |
def forward(self, x):
|
518 |
+
if not self.enabled:
|
519 |
+
return None
|
520 |
type_probs = self.type_classifier(x)
|
521 |
gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
|
522 |
comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
|
523 |
return comb_gate
|
524 |
|
525 |
class m_gate(nn.Module):
|
526 |
+
def __init__(self, dims, mem_size=64, enabled=True):
|
527 |
super().__init__()
|
528 |
+
self.enabled = enabled
|
529 |
+
if enabled:
|
530 |
+
self.m_key = nn.Parameter(torch.randn(mem_size, dims))
|
531 |
+
self.m_val = nn.Parameter(torch.randn(mem_size, 1))
|
532 |
+
self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
|
533 |
+
|
534 |
def forward(self, x):
|
535 |
+
if not self.enabled:
|
536 |
+
return None
|
537 |
d_gate = torch.sigmoid(self.gate_proj(x))
|
538 |
attention = torch.matmul(x, self.m_key.transpose(0, 1))
|
539 |
attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
|
|
|
542 |
return 0.5 * (d_gate + m_gate)
|
543 |
|
544 |
class c_gate(nn.Module):
|
545 |
+
def __init__(self, dims, enabled=True):
|
546 |
super().__init__()
|
547 |
+
self.enabled = enabled
|
548 |
+
if enabled:
|
549 |
+
self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
550 |
+
self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
551 |
+
self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
552 |
+
self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
553 |
+
self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
554 |
+
self.integ = Linear(dims*5, dims)
|
555 |
|
556 |
def forward(self, x, features):
|
557 |
+
if not self.enabled:
|
558 |
+
return None
|
559 |
s_feat = features.get("spectrogram", x)
|
560 |
w_feat = features.get("waveform", x)
|
561 |
p_feat = features.get("pitch", x)
|
|
|
569 |
comb = torch.cat([s, w, p, e, ph], dim=-1)
|
570 |
return self.integ(comb)
|
571 |
|
572 |
+
class mlp_gate(nn.Module):
|
573 |
+
def __init__(self, dims, enabled=True):
|
574 |
+
super().__init__()
|
575 |
+
self.enabled = enabled
|
576 |
+
if enabled:
|
577 |
+
self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
578 |
+
|
579 |
+
def forward(self, x):
|
580 |
+
if not self.enabled:
|
581 |
+
return None
|
582 |
+
return self.gate(x)
|
583 |
+
|
584 |
class Residual(nn.Module):
|
585 |
_seen = set()
|
586 |
+
def __init__(self, ctx, dims, head, act, debug: List[str] = [],
|
587 |
tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
|
588 |
super().__init__()
|
589 |
|
|
|
591 |
self.head = head
|
592 |
self.ctx = ctx
|
593 |
self.head_dim = dims // head
|
|
|
594 |
self.features = features
|
595 |
self.debug = debug
|
596 |
self.counter = 0
|
597 |
self.dropout = 0.01
|
598 |
+
|
|
|
|
|
|
|
|
|
599 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
600 |
+
act_fn = get_activation(act)
|
601 |
+
self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
|
602 |
+
|
603 |
+
if not any([tgate, mgate, cgate]):
|
604 |
+
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
605 |
+
else:
|
606 |
+
self.mlp_gate = None
|
|
|
|
|
|
|
607 |
|
608 |
mlp = dims * 4
|
609 |
self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
|
610 |
|
611 |
+
self.t_gate = t_gate(dims=dims, num_types=4*2, enabled=tgate)
|
612 |
+
self.m_gate = m_gate(dims=dims, mem_size=mem_size, enabled=mgate)
|
613 |
+
self.c_gate = c_gate(dims=dims, enabled=cgate)
|
614 |
+
self.mlp_gate = mlp_gate(dims=dims, enabled=not any([tgate, mgate, cgate]))
|
615 |
|
616 |
self.lna = RMSNorm(dims)
|
617 |
+
self.lnb = RMSNorm(dims)
|
618 |
self.lnc = RMSNorm(dims)
|
619 |
|
|
|
|
|
|
|
620 |
def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
|
622 |
+
b = torch.sigmoid(self.blend)
|
623 |
+
ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer)[0]
|
624 |
+
bx = b * ax + (1 - b) * x
|
625 |
+
cx = self.lnb(bx)
|
626 |
+
dx = self.mlp(cx)
|
627 |
+
ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
|
628 |
+
fx = x + ex + dx
|
629 |
+
gx = self.lnc(fx)
|
630 |
+
return gx
|
|
|
|
|
|
|
|
|
631 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
class FEncoder(nn.Module):
|
633 |
def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
|
634 |
super().__init__()
|
|
|
639 |
self.use_rope = use_rope
|
640 |
self.dims = dims
|
641 |
|
642 |
+
act_fn = get_activation(act)
|
|
|
643 |
|
644 |
self.encoder = nn.Sequential(
|
645 |
Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
|
|
|
650 |
if spec_shape is not None:
|
651 |
self.rope = rotary(
|
652 |
dims=self.head_dim,
|
653 |
+
head=self.head,
|
654 |
use_2d_axial=True,
|
655 |
spec_shape=spec_shape, debug=[])
|
656 |
else:
|
657 |
self.rope = rotary(
|
658 |
dims=self.head_dim,
|
659 |
+
head=self.head,
|
660 |
use_2d_axial=False, debug=[])
|
661 |
else:
|
662 |
self.rope = None
|
|
|
670 |
feature_type = "spectrogram"
|
671 |
batch, ctx, dims = x.shape
|
672 |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
673 |
+
if feature_type == "spectrogram" and self.rope is not None:
|
674 |
rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
|
675 |
else:
|
676 |
rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
|
|
|
698 |
self.use_rope = use_rope
|
699 |
self.dims = dims
|
700 |
|
701 |
+
act_fn = get_activation(act)
|
|
|
702 |
|
703 |
self.downsample = nn.Sequential(
|
704 |
Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
|
|
|
711 |
if use_rope:
|
712 |
self.rope = rotary(
|
713 |
dims=self.head_dim,
|
714 |
+
head=self.head,
|
715 |
+
debug=[])
|
716 |
else:
|
717 |
self.rope = None
|
718 |
self.positional = lambda length: sinusoids(length, dims)
|
|
|
749 |
self.use_rope = use_rope
|
750 |
self.dims = dims
|
751 |
|
752 |
+
act_fn = get_activation(act)
|
|
|
753 |
|
754 |
self.encoder = nn.Sequential(
|
755 |
Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
|
|
|
759 |
if use_rope:
|
760 |
self.rope = rotary(
|
761 |
dims=self.head_dim,
|
762 |
+
head=self.head,
|
763 |
+
debug=[])
|
764 |
else:
|
765 |
self.rope = None
|
766 |
self.positional = lambda length: sinusoids(length, dims)
|
|
|
786 |
x = self.norm(x)
|
787 |
return x
|
788 |
|
789 |
+
class SpeechTransformer(nn.Module):
|
790 |
_seen = set()
|
791 |
+
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
|
792 |
+
super(SpeechTransformer, self).__init__()
|
793 |
|
794 |
self.dims = dims
|
795 |
self.head = head
|
|
|
799 |
self.counter = 0
|
800 |
self.features = features
|
801 |
self.dropout = 0.01
|
802 |
+
self.sequential = "sequential" in debug
|
803 |
+
act_fn = get_activation(act)
|
804 |
|
805 |
+
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
806 |
+
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
807 |
+
self.register_buffer("audio_embedding", sinusoids(ctx, dims))
|
808 |
|
809 |
if features == ["spectrogram", "waveform", "pitch"]:
|
810 |
cgate=True
|
|
|
839 |
if "phase" in features else None),
|
840 |
})
|
841 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
842 |
self.block = nn.ModuleList([
|
843 |
+
Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
|
844 |
+
for _ in range(layer)])
|
|
|
|
|
|
|
|
|
845 |
|
846 |
+
self.blend = nn.Parameter(torch.tensor(0.5))
|
847 |
self.ln_dec = RMSNorm(dims)
|
848 |
|
849 |
+
def get_mask(text_ctx, aud_ctx):
|
850 |
+
mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device), diagonal=0)
|
851 |
+
audio_mask = torch.ones(text_ctx, aud_ctx - text_ctx, device=device)
|
852 |
+
full_mask = torch.cat([mask, audio_mask], dim=-1)
|
853 |
+
return full_mask
|
854 |
+
self.register_buffer("mask_ax", get_mask(ctx, ctx), persistent=False)
|
855 |
+
|
856 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
857 |
self.register_buffer("mask", mask, persistent=False)
|
858 |
|
859 |
+
def forward(self, enc, layer="encoder"):
|
860 |
+
enc = dict_to(enc, device, dtype)
|
861 |
|
862 |
+
x = enc.get("input_ids").long()
|
|
|
|
|
|
|
863 |
x = self.token(x) + self.positional[:x.shape[1]]
|
864 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
865 |
+
|
866 |
+
out = {}
|
867 |
+
out.update(enc)
|
868 |
+
|
869 |
+
for f in self.features:
|
870 |
+
if f in enc and f in self.blocks:
|
871 |
+
xa = enc[f]
|
872 |
+
for block in self.blocks[f]: # type: ignore
|
873 |
+
xa = block(xa, enc=enc, layer=layer)
|
874 |
+
out[f] = xa
|
875 |
+
xa = xa + self.audio_embedding[:xa.shape[1]]
|
876 |
|
877 |
for block in self.block:
|
878 |
+
mask = self.mask[:x.shape[1], :x.shape[1]]
|
879 |
x = block(x, xa=None, mask=mask, enc=None, layer=layer)
|
880 |
|
881 |
+
for f in self.features:
|
882 |
if f in enc:
|
883 |
+
mask = self.mask_ax[:x.shape[1], :xa.shape[1]]
|
884 |
+
for block in self.block:
|
885 |
+
out = block(x, xa=xa, mask=mask, enc=None, layer=layer)
|
|
|
886 |
if self.sequential:
|
887 |
x = out
|
888 |
else:
|
889 |
+
a = torch.sigmoid(self.blend)
|
890 |
x = a * out + (1 - a) * x
|
|
|
891 |
|
892 |
x = self.ln_dec(x)
|
893 |
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
|
|
897 |
super().__init__()
|
898 |
self.param = param
|
899 |
|
900 |
+
self.SpeechTransformer = SpeechTransformer(
|
901 |
+
vocab=param.vocab,
|
902 |
mels=param.mels,
|
903 |
+
ctx=param.ctx,
|
904 |
+
dims=param.dims,
|
905 |
+
head=param.head,
|
906 |
+
layer=param.layer,
|
|
|
907 |
debug=param.debug,
|
908 |
features=param.features,
|
909 |
+
act=param.act,
|
910 |
)
|
911 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
912 |
def forward(self,
|
913 |
labels=None,
|
|
|
914 |
input_ids=None,
|
915 |
+
waveform: Optional[torch.Tensor]=None,
|
916 |
+
spectrogram: Optional[torch.Tensor]=None,
|
917 |
pitch: Optional[torch.Tensor]=None,
|
918 |
f0: Optional[torch.Tensor]=None,
|
919 |
envelope: Optional[torch.Tensor]=None,
|
920 |
phase: Optional[torch.Tensor]=None,
|
921 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
922 |
|
923 |
encoder_inputs = {}
|
924 |
if spectrogram is not None:
|
|
|
933 |
encoder_inputs["phase"] = phase
|
934 |
if f0 is not None:
|
935 |
encoder_inputs["f0"] = f0
|
936 |
+
if input_ids is not None:
|
937 |
+
encoder_inputs["input_ids"] = input_ids
|
938 |
|
939 |
+
logits = self.SpeechTransformer(encoder_inputs)
|
|
|
940 |
|
941 |
loss = None
|
942 |
if labels is not None:
|
|
|
956 |
std = 0.02
|
957 |
self.init_counts = {
|
958 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
959 |
+
"Conv2d": 0, "SEBlock": 0, "SpeechTransformer": 0,
|
960 |
"Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
|
961 |
"MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
|
962 |
"WEncoder": 0, "PEncoder": 0}
|
|
|
982 |
nn.init.zeros_(module.bias)
|
983 |
self.init_counts["Conv2d"] += 1
|
984 |
elif isinstance(module, MultiheadA):
|
|
|
985 |
self.init_counts["MultiheadA"] += 1
|
986 |
+
elif isinstance(module, SpeechTransformer):
|
987 |
+
self.init_counts["SpeechTransformer"] += 1
|
|
|
|
|
988 |
elif isinstance(module, Residual):
|
989 |
self.init_counts["Residual"] += 1
|
990 |
|
|
|
1022 |
encoder_inputs["phase"] = phase
|
1023 |
if f0 is not None:
|
1024 |
encoder_inputs["f0"] = f0
|
1025 |
+
|
1026 |
for i in range(max_length - 1):
|
1027 |
with torch.no_grad():
|
1028 |
+
encoder_inputs["input_ids"] = ids
|
1029 |
+
logits = self.SpeechTransformer(encoder_inputs)
|
1030 |
next_token_logits = logits[:, -1, :]
|
1031 |
if i < min_length:
|
1032 |
next_token_logits[:, eos_token_id] = 0
|
|
|
1051 |
})
|
1052 |
return Config()
|
1053 |
|
1054 |
+
def setup_tokenizer(token: str):
|
|
|
1055 |
from tokenizers import Tokenizer
|
1056 |
+
tokenizer = Tokenizer.from_file("./tokenizer.json")
|
1057 |
orig_encode = tokenizer.encode
|
1058 |
def enc(text, add_special_tokens=True):
|
1059 |
ids = orig_encode(text).ids
|
|
|
1070 |
ids = ids[1:]
|
1071 |
while ids and ids[-1] in [0, 2]:
|
1072 |
ids = ids[:-1]
|
1073 |
+
|
1074 |
+
if isinstance(ids, torch.Tensor):
|
1075 |
+
ids = ids.tolist()
|
1076 |
+
elif isinstance(ids, np.ndarray):
|
1077 |
+
ids = ids.tolist()
|
1078 |
results.append(tokenizer.decode(ids))
|
1079 |
return results
|
1080 |
|
|
|
1089 |
tokenizer.eos_token_id = 2
|
1090 |
return tokenizer
|
1091 |
|
1092 |
+
def load_wave(wave_data, sample_rate):
|
1093 |
+
if isinstance(wave_data, str):
|
1094 |
+
waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
|
1095 |
+
elif isinstance(wave_data, dict):
|
1096 |
+
waveform = torch.tensor(data=wave_data["array"]).float()
|
1097 |
+
sr = wave_data["sampling_rate"]
|
1098 |
+
else:
|
1099 |
+
raise TypeError("Invalid wave_data format.")
|
1100 |
+
|
1101 |
+
return waveform
|
1102 |
+
|
1103 |
+
def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **dataset_config):
|
1104 |
+
|
1105 |
audio = batch["audio"]
|
1106 |
+
sr = audio["sampling_rate"]
|
1107 |
+
wav = load_wave(wave_data=audio, sample_rate=sr)
|
1108 |
+
|
1109 |
+
dataset_config = {
|
1110 |
+
"hop_length": 256,
|
1111 |
+
"f_min": 150,
|
1112 |
+
"f_max": 2000,
|
1113 |
+
"n_mels": 128,
|
1114 |
+
"n_fft": 1024,
|
1115 |
+
"sample_rate": 16000,
|
1116 |
+
"pad_mode": "constant",
|
1117 |
+
"center": True,
|
1118 |
+
"power": 1.0,
|
1119 |
+
"window_fn": torch.hann_window,
|
1120 |
+
"mel_scale": "htk",
|
1121 |
+
"norm": None,
|
1122 |
+
"normalized": False}
|
1123 |
+
|
1124 |
+
transform = torchaudio.transforms.MelSpectrogram(
|
1125 |
+
**dataset_config
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
mel_spectrogram = transform(wav)
|
1129 |
+
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
1130 |
+
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
1131 |
+
spec = (log_mel + 4.0) / 4.0
|
1132 |
+
spec = torch.tensor(spec)
|
1133 |
+
# batch["spectrogram"] = spec
|
1134 |
+
|
1135 |
+
wav_np = wav.numpy().astype(np.float64)
|
1136 |
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
|
1137 |
f0 = pw.stonemask(wav_np, f0, t, sample_rate)
|
1138 |
+
f0 = torch.from_numpy(f0)
|
1139 |
+
|
1140 |
+
labels = tokenizer.encode(batch["transcription"])
|
1141 |
+
|
1142 |
return {
|
1143 |
"spectrogram": spec,
|
1144 |
"f0": f0,
|
1145 |
+
"labels": labels,
|
1146 |
+
# "waveform": wav,
|
1147 |
+
# "pitch": f0,
|
1148 |
}
|
1149 |
|
1150 |
+
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
|
1151 |
+
|
1152 |
+
if sanity_check:
|
1153 |
+
test = load_dataset(
|
1154 |
+
"google/fleurs", "en_us", token=token, split="test[:10]", trust_remote_code=True
|
1155 |
+
).cast_column("audio", Audio(sample_rate=sample_rate))
|
1156 |
+
|
1157 |
+
dataset = test.map(
|
1158 |
+
lambda x: extract_features(x, tokenizer, **dataset_config),
|
1159 |
+
remove_columns=test.column_names)
|
1160 |
+
dataset = dataset(remove_columns=["audio", "transcription"]).with_format(type="torch")
|
1161 |
+
train_dataset = dataset
|
1162 |
+
test_dataset = dataset
|
1163 |
+
else:
|
1164 |
+
|
1165 |
+
cache_dir = "./processed_datasets"
|
1166 |
+
os.makedirs(cache_dir, exist_ok=True)
|
1167 |
+
cache_file_train = os.path.join(cache_dir, "train.arrow")
|
1168 |
+
cache_file_test = os.path.join(cache_dir, "test.arrow")
|
1169 |
+
|
1170 |
+
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
|
1171 |
+
from datasets import Dataset
|
1172 |
+
train_dataset = Dataset.load_from_disk(cache_file_train)
|
1173 |
+
test_dataset = Dataset.load_from_disk(cache_file_test)
|
1174 |
+
return train_dataset, test_dataset
|
1175 |
+
|
1176 |
+
def filter_func(x):
|
1177 |
+
return (0 < len(x["transcription"]) < 512 and
|
1178 |
+
len(x["audio"]["array"]) > 0 and
|
1179 |
+
len(x["audio"]["array"]) < 1500 * 160)
|
1180 |
+
|
1181 |
+
raw_train = load_dataset(
|
1182 |
+
"google/fleurs", "en_us", token=token, split="train[:1000]", trust_remote_code=True)
|
1183 |
+
raw_test = load_dataset(
|
1184 |
+
"google/fleurs", "en_us", token=token, split="test[:100]", trust_remote_code=True)
|
1185 |
+
|
1186 |
+
raw_train = raw_train.filter(filter_func)
|
1187 |
+
raw_test = raw_test.filter(filter_func)
|
1188 |
+
|
1189 |
+
raw_train = raw_train.cast_column("audio", Audio(sampling_rate=sample_rate))
|
1190 |
+
raw_test = raw_test.cast_column("audio", Audio(sampling_rate=sample_rate))
|
1191 |
+
|
1192 |
+
train_dataset = raw_train.map(
|
1193 |
+
lambda x: extract_features(x, tokenizer, **dataset_config),
|
1194 |
+
remove_columns=raw_train.column_names)
|
1195 |
+
test_dataset = raw_test.map(
|
1196 |
+
lambda x: extract_features(x, tokenizer, **dataset_config),
|
1197 |
+
remove_columns=raw_test.column_names)
|
1198 |
+
|
1199 |
+
train_dataset.save_to_disk(cache_file_train)
|
1200 |
+
test_dataset.save_to_disk(cache_file_test)
|
1201 |
+
return train_dataset, test_dataset
|
1202 |
|
1203 |
@dataclass
|
1204 |
class DataCollator:
|
1205 |
tokenizer: Any
|
1206 |
|
1207 |
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
1208 |
+
all_keys = set()
|
1209 |
+
for f in features:
|
1210 |
+
all_keys.update(f.keys())
|
1211 |
+
batch = {}
|
1212 |
pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
|
1213 |
bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
|
1214 |
eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
|
1215 |
|
1216 |
+
for key in all_keys:
|
1217 |
+
if key == "labels":
|
1218 |
+
labels_list = [f["labels"] for f in features]
|
1219 |
+
max_len = max(len(l) for l in labels_list)
|
1220 |
+
all_ids, all_labels = [], []
|
1221 |
+
for label in labels_list:
|
1222 |
+
label_list = label.tolist() if isinstance(label, torch.Tensor) else label
|
1223 |
+
decoder_input = [bos_token_id] + label_list
|
1224 |
+
label_eos = label_list + [eos_token_id]
|
1225 |
+
input_len = max_len + 1 - len(decoder_input)
|
1226 |
+
label_len = max_len + 1 - len(label_eos)
|
1227 |
+
padded_input = decoder_input + [pad_token_id] * input_len
|
1228 |
+
padded_labels = label_eos + [pad_token_id] * label_len
|
1229 |
+
all_ids.append(padded_input)
|
1230 |
+
all_labels.append(padded_labels)
|
1231 |
+
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
|
1232 |
+
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
|
1233 |
+
|
1234 |
+
elif key in ["spectrogram", "waveform", "pitch", "f0", "envelope", "phase"]:
|
1235 |
+
|
1236 |
+
items = [f[key] for f in features if key in f]
|
1237 |
+
items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
|
1238 |
+
max_len = max(item.shape[-1] for item in items)
|
1239 |
+
padded = []
|
1240 |
+
for item in items:
|
1241 |
+
pad_width = max_len - item.shape[-1]
|
1242 |
+
if pad_width > 0:
|
1243 |
+
pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
|
1244 |
+
else:
|
1245 |
+
pad_item = item
|
1246 |
+
padded.append(pad_item)
|
1247 |
+
batch[key] = torch.stack(padded)
|
1248 |
+
if key == "spectrogram":
|
1249 |
+
batch["spectrogram"] = batch[key]
|
1250 |
+
return batch
|
1251 |
|
1252 |
def levenshtein(reference_words, hypothesis_words):
|
1253 |
m, n = len(reference_words), len(hypothesis_words)
|
|
|
1277 |
total_words += len(ref_words)
|
1278 |
return (total_errors / total_words) * 100 if total_words > 0 else 0.0
|
1279 |
|
1280 |
+
def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, optimizer=None, scheduler=None):
|
1281 |
pred_ids = pred.predictions
|
1282 |
label_ids = pred.label_ids
|
1283 |
if isinstance(pred_ids, tuple):
|
|
|
1286 |
if not isinstance(pred_ids, torch.Tensor):
|
1287 |
pred_ids = torch.tensor(pred_ids)
|
1288 |
pred_ids = pred_ids.argmax(dim=-1)
|
1289 |
+
|
1290 |
pred_ids = pred_ids.tolist()
|
1291 |
label_ids = label_ids.tolist()
|
1292 |
pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
|
1293 |
label_ids = [[pad_token_id if token == -100 else token for token in seq] for seq in label_ids]
|
1294 |
+
|
|
|
|
|
|
|
|
|
|
|
1295 |
if print_pred:
|
1296 |
for i in range(min(num_samples, len(pred_ids))):
|
1297 |
+
|
1298 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
|
1299 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
|
1300 |
+
|
1301 |
+
print(f"Pred tokens: {pred_ids[i]}")
|
1302 |
+
print(f"Label tokens: {label_ids[i]}")
|
1303 |
+
print(f"Pred: '{pred_str[i]}'")
|
1304 |
+
print(f"Label: '{label_str[i]}'")
|
1305 |
+
|
1306 |
print("-" * 40)
|
1307 |
+
|
1308 |
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
1309 |
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
1310 |
wer = wer_batch(label_str, pred_str)
|
|
|
1314 |
else:
|
1315 |
trainable_params = 0.0
|
1316 |
efficiency_score = 0.0
|
1317 |
+
|
1318 |
return {
|
1319 |
"wer": float(wer),
|
|
|
1320 |
"efficiency_score": float(efficiency_score),
|
1321 |
}
|
1322 |
|
|
|
1327 |
tokenizer = setup_tokenizer(token)
|
1328 |
train_dataset, test_dataset = prepare_datasets(tokenizer, token)
|
1329 |
param = Dimensions(
|
1330 |
+
vocab=40000, ctx=2048, dims=512, head=4, layer=4,
|
1331 |
+
mels=128, act="swish",
|
1332 |
+
debug={},
|
1333 |
+
cross_attn=True,
|
1334 |
+
features=["spectrogram"]
|
1335 |
+
)
|
1336 |
+
|
1337 |
model = Echo(param).to('cuda')
|
1338 |
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
1339 |
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
1349 |
logging_steps=10,
|
1350 |
logging_dir=log_dir,
|
1351 |
eval_strategy="steps",
|
|
|
1352 |
save_strategy="steps",
|
1353 |
report_to=["tensorboard"],
|
1354 |
push_to_hub=False,
|
|
|
1360 |
batch_eval_metrics=False,
|
1361 |
)
|
1362 |
from functools import partial
|
1363 |
+
metrics_fn = partial(compute_metrics,
|
1364 |
+
print_pred=True,
|
1365 |
+
num_samples=1,
|
1366 |
+
tokenizer=tokenizer, model=model)
|
1367 |
+
|
1368 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00025, eps=1e-8, weight_decay=0.025, betas=(0.9, 0.999),
|
1369 |
+
amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
|
1370 |
+
|
1371 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
|
1372 |
+
|
1373 |
trainer = Seq2SeqTrainer(
|
1374 |
args=training_args,
|
1375 |
model=model,
|
1376 |
+
train_dataset=train_dataset, # type: ignore
|
1377 |
+
eval_dataset=test_dataset, # type: ignore
|
1378 |
+
data_collator=DataCollator(tokenizer=tokenizer), # type: ignore
|
1379 |
compute_metrics=metrics_fn,
|
1380 |
+
optimizers=(optimizer, scheduler) # type: ignore
|
1381 |
)
|
1382 |
model.init_weights()
|
1383 |
trainer.train()
|
1384 |
|
1385 |
if __name__ == "__main__":
|
1386 |
+
main()
|
1387 |
+
|