Sin2pi commited on
Commit
c12de10
·
verified ·
1 Parent(s): b62df1e

Delete echoutils.py

Browse files
Files changed (1) hide show
  1. echoutils.py +0 -1593
echoutils.py DELETED
@@ -1,1593 +0,0 @@
1
- import torch
2
- import os
3
- import pyworld as pw
4
- import numpy as np
5
- import torchaudio
6
- import torch.nn.functional as F
7
- from datasets import load_dataset
8
- from datasets import Audio
9
- from dataclasses import dataclass
10
- from typing import Any, List, Dict
11
- import math
12
- import matplotlib.pyplot as plt
13
- import torch.nn as nn
14
- import torch.nn.init as init
15
- from torch import Tensor
16
- from typing import Any, List, Dict, Optional, Union, Tuple
17
- from torch.nn.functional import scaled_dot_product_attention
18
-
19
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
- dtype = torch.float32
21
-
22
- # def shape(tensor: torch.Tensor, head: int, head_dim: int, batch: int, ctx: int):
23
- # return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
24
-
25
- # def reshape_to_output(attn_output, head: int, head_dim: int, batch: int, ctx: int, dims: int):
26
- # return attn_output.permute(0, 2, 1, 3).reshape(batch, ctx, dims).contiguous()
27
-
28
- def shape(self, tensor: torch.Tensor, ctx: int, batch: int):
29
- return tensor.view(batch, ctx, self.head, self.head_dim).transpose(1, 2).contiguous()
30
-
31
- def reshape_to_output(self, attn_output, batch, ctx):
32
- return attn_output.permute(0, 2, 1, 3).reshape(batch, ctx, self.dims).contiguous()
33
-
34
- def create_attention_mask(batch_size, ctx, is_causal=True, padding_mask=None, device=None):
35
- if is_causal:
36
- mask = torch.triu(torch.ones((ctx, ctx), device=device), diagonal=0)
37
- mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, ctx, ctx)
38
- else:
39
- mask = torch.zeros((batch_size, 1, ctx, ctx), device=device)
40
- if padding_mask is not None:
41
- padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).bool()
42
- mask = mask | (~padding_mask)
43
- return mask
44
-
45
- def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
46
- q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
47
- k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
48
- qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
49
- qk_cosine = qk_cosine + mask
50
- weights = F.softmax(qk_cosine, dim=-1)
51
- out = torch.matmul(weights, v)
52
- return out
53
-
54
- def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0):
55
- dot_scores = torch.matmul(q, k.transpose(-1, -2))
56
- if rbf_ratio <= 0.0:
57
- return dot_scores
58
- q_norm = q.pow(2).sum(dim=-1, keepdim=True)
59
- k_norm = k.pow(2).sum(dim=-1, keepdim=True)
60
- qk = torch.matmul(q, k.transpose(-1, -2))
61
- dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
62
- rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
63
- return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
64
-
65
- def sliding_window_mask(q_len, k_len, window, device):
66
- # mask[i, j] = 1 if j in [i-window+1, i], else 0
67
- idxs = torch.arange(q_len, device=device).unsqueeze(1)
68
- jdxs = torch.arange(k_len, device=device).unsqueeze(0)
69
- mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs)
70
- return mask.float() # shape: (q_len, k_len)
71
-
72
- def mask_win(text_ctx, aud_ctx):
73
- mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0)
74
- audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype))
75
- full_mask = torch.cat([mask, audio_mask], dim=-1)
76
- return full_mask
77
-
78
- def maskc(ctx, device):
79
- return torch.tril(torch.ones(ctx, ctx, device=device, dtype=dtype), diagonal=0)
80
-
81
- def qkv_init(dims: int, head: int):
82
- head_dim = dims // head
83
- scale = head_dim ** -0.5
84
- q = nn.Linear(dims, dims)
85
- k = nn.Linear(dims, dims, bias=False)
86
- v = nn.Linear(dims, dims)
87
- o = nn.Linear(dims, dims)
88
- return q, k, v, o, scale
89
-
90
- def create_qkv(q, k, v, x, xa=None, head=8):
91
- head_dim = q.out_features // head
92
- scale = head_dim ** -0.5
93
- q = q(x) * scale
94
- k = k(xa if xa is not None else x) * scale
95
- v = v(xa if xa is not None else x)
96
- batch, ctx, _ = q.shape
97
- def _shape(tensor):
98
- return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
99
- return _shape(q), _shape(k), _shape(v)
100
-
101
- def calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True):
102
- # q, k, v = create_qkv(q, k, v, dims, head)
103
-
104
- batch, head, ctx, dims = q.shape
105
- attn_mask = None
106
- if mask is not None:
107
- if mask.dim() <= 3:
108
- attn_mask = create_attention_mask(
109
- batch_size=batch,
110
- ctx=ctx,
111
- is_causal=is_causal,
112
- padding_mask=mask if mask.dim() > 1 else None,
113
- device=device)
114
- else:
115
- attn_mask = mask
116
- scaled_q = q
117
- if temperature != 1.0 and temperature > 0:
118
- scaled_q = q * (1.0 / temperature)**.5
119
- a = scaled_dot_product_attention(scaled_q, k, v, attn_mask=attn_mask, is_causal=is_causal if attn_mask is None else False)
120
- out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
121
- return out, None
122
-
123
- class KVCache(nn.Module):
124
- def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
125
- super().__init__()
126
- cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
127
- self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
128
- self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
129
-
130
- def update(self, input_pos, k_val, v_val):
131
- # input_pos: [S], k_val: [B, H, S, D]
132
- assert input_pos.shape[0] == k_val.shape[2]
133
-
134
- k_out = self.k_cache
135
- v_out = self.v_cache
136
- k_out[:, :, input_pos] = k_val # pyright: ignore[reportIndexIssue]
137
- v_out[:, :, input_pos] = v_val # pyright: ignore[reportIndexIssue]
138
-
139
- return k_out, v_out
140
-
141
- def mel_scale_scalar(freq: float) -> float:
142
- return 1127.0 * math.log(1.0 + freq / 700.0)
143
-
144
- def mel_scale(freq: Tensor) -> Tensor:
145
- return 1127.0 * (1.0 + freq / 700.0).log()
146
-
147
- def trace_x(func):
148
- def wrapper(*args, **kwargs):
149
- print(f"Calling {func.__name__}")
150
- result = func(*args, **kwargs)
151
- if isinstance(result, torch.Tensor):
152
- print(f" {func.__name__} returned shape: {result.shape}")
153
- return result
154
- return wrapper
155
-
156
- def track_x(new_x, operation=""):
157
- """ track_x(x, "x") """
158
- x_id = [id(new_x)]
159
- if new_x is None:
160
- return new_x
161
- current_id = id(new_x)
162
- if current_id != x_id[0]:
163
- print(f"x FLOW: {x_id[0]} → {current_id} in {operation}")
164
- x_id[0] = current_id
165
- else:
166
- print(f"x REUSE: {current_id} in {operation}")
167
- return new_x
168
-
169
- def track_xa(new_xa, operation=""):
170
- """ track_xa(xa, "xa - decoder") """
171
- xa_id = [id(new_xa)] if new_xa is not None else [None]
172
- if new_xa is None:
173
- return new_xa
174
- current_id = id(new_xa)
175
- if current_id != xa_id[0]:
176
- print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}")
177
- xa_id[0] = current_id # pyright: ignore[reportArgumentType, reportCallIssue]
178
- else:
179
- print(f"xa REUSE: {current_id} in {operation}")
180
- return new_xa
181
-
182
- def get_activation(act: str) -> nn.Module:
183
- """Get activation function by name."""
184
- act_map = {
185
- "gelu": nn.GELU(),
186
- "relu": nn.ReLU(),
187
- "sigmoid": nn.Sigmoid(),
188
- "tanh": nn.Tanh(),
189
- "swish": nn.SiLU(),
190
- "tanhshrink": nn.Tanhshrink(),
191
- "softplus": nn.Softplus(),
192
- "softshrink": nn.Softshrink(),
193
- "leaky_relu": nn.LeakyReLU(),
194
- "elu": nn.ELU()
195
- }
196
- return act_map.get(act, nn.GELU())
197
-
198
- def get_generation_config(param):
199
- return GenerationConfig( # type: ignore
200
- max_length=param.text_ctx,
201
- pad_token_id=getattr(param, "pad_token_id", 0),
202
- bos_token_id=getattr(param, "bos_token_id", 1),
203
- eos_token_id=getattr(param, "eos_token_id", 2),
204
- do_sample=False,
205
- num_beams=1,
206
- early_stopping=False,
207
- length_penalty=1.0,
208
- no_repeat_ngram_size=0,
209
- repetition_penalty=1.0,
210
- temperature=1.0,
211
- decoder_start_token_id=1,
212
- is_multilingual=False,
213
- use_cache=False,
214
- return_timestamps=False)
215
-
216
- # class rotary(nn.Module):
217
- # def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
218
-
219
- # super(rotary, self).__init__()
220
- # self.use_pbias = use_pbias
221
- # self.dims = dims
222
- # self.head = head
223
- # self.head_dim = dims // head
224
- # self.radii = radii
225
- # self.debug = debug
226
- # self.counter = 0
227
- # self.last_theta = None
228
- # self.axial = axial
229
-
230
- # self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
231
- # theta = (torch.tensor(10000, device=device, dtype=dtype))
232
- # self.theta = nn.Parameter(theta, requires_grad=True)
233
- # self.theta_values = []
234
-
235
- # if axial and spec_shape is not None:
236
- # time_frames, freq_bins = spec_shape
237
- # self.time_frames = time_frames
238
- # self.freq_bins = freq_bins
239
-
240
- # time_theta = 50.0
241
- # time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
242
- # self.register_buffer('time_freqs', time_freqs)
243
-
244
- # freq_theta = 100.0
245
- # freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
246
- # self.register_buffer('freq_freqs', freq_freqs)
247
-
248
- # def pitch_bias(self, f0):
249
- # if f0 is None:
250
- # return None
251
- # f0_flat = f0.squeeze().float()
252
- # f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
253
- # f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
254
- # f0_norm.unsqueeze(1)))
255
- # return f0_sim.unsqueeze(0).unsqueeze(0)
256
-
257
- # def theta_freqs(self, theta):
258
- # if theta.dim() == 0:
259
- # theta = theta.unsqueeze(0)
260
- # freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
261
- # torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
262
- # self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
263
- # return freq
264
-
265
- # def _apply_radii(self, freqs, f0, ctx):
266
- # if self.radii and f0 is not None:
267
- # radius = f0.to(device, dtype)
268
- # L = radius.shape[0]
269
- # if L != ctx:
270
- # feature = L / ctx
271
- # idx = torch.arange(ctx, device=f0.device)
272
- # idx = (idx * feature).long().clamp(0, L - 1)
273
- # radius = radius[idx]
274
- # return torch.polar(radius.unsqueeze(-1), freqs), radius
275
- # else:
276
- # return torch.polar(radius.unsqueeze(-1), freqs), radius
277
- # else:
278
- # return torch.polar(torch.ones_like(freqs), freqs), None
279
-
280
- # def check_f0(self, f0, f0t, ctx):
281
- # if f0 is not None and f0.shape[1] == ctx:
282
- # return f0
283
- # elif f0t is not None and f0t.shape[1] == ctx:
284
- # return f0t
285
- # else:
286
- # return None
287
-
288
- # def axial_freqs(self, ctx):
289
- # if not self.axial:
290
- # return None
291
- # time_frames = self.time_frames
292
- # freq_bins = self.freq_bins
293
-
294
- # t = torch.arange(ctx, device=device, dtype=dtype)
295
- # t_x = (t % time_frames).float()
296
- # t_y = torch.div(t, time_frames, rounding_mode='floor').float()
297
- # freqs_x = torch.outer(t_x, self.time_freqs)
298
- # freqs_y = torch.outer(t_y, self.freq_freqs)
299
- # freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
300
- # freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
301
- # return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
302
-
303
- # def forward(self, x=None, feats=None, feature=None, layer=None) -> Tensor:
304
- # ctx=x
305
- # f0 = feats.get("f0") if feats is not None else None
306
- # f0t = feats.get("f0t") if feats is not None else None
307
-
308
- # f0 = self.check_f0(f0, f0t, ctx)
309
- # if f0 is not None:
310
- # # if f0.dim() == 2:
311
- # # f0 = f0.squeeze(0)
312
- # theta = f0 + self.theta
313
- # else:
314
- # theta = self.theta
315
- # freqs = self.theta_freqs(theta)
316
- # t = torch.arange(ctx, device=device, dtype=dtype) # type: ignore
317
- # freqs = t[:, None] * freqs
318
- # freqs, radius = self._apply_radii(freqs, f0, ctx)
319
-
320
- # if self.axial and feature == "spectrogram":
321
- # freqs_2d = self.axial_freqs(ctx)
322
- # if freqs_2d is not None:
323
- # return freqs_2d.unsqueeze(0)
324
-
325
- # if "radius" in self.debug and self.counter == 10:
326
- # print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
327
- # self.counter += 1
328
- # return freqs.unsqueeze(0)
329
-
330
- # @staticmethod
331
- # def split(X: Tensor):
332
- # half_dim = X.shape[-1] // 2
333
- # return X[..., :half_dim], X[..., half_dim:]
334
-
335
- # @staticmethod
336
- # def apply_rotary(x, freqs):
337
- # x1 = x[..., :freqs.shape[-1]*2]
338
- # x2 = x[..., freqs.shape[-1]*2:]
339
- # orig_shape = x1.shape
340
- # if x1.ndim == 2:
341
- # x1 = x1.unsqueeze(0)
342
- # x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
343
- # x1 = torch.view_as_complex(x1) * freqs
344
- # x1 = torch.view_as_real(x1).flatten(-2)
345
- # x1 = x1.view(orig_shape)
346
- # return torch.cat([x1.type_as(x), x2], dim=-1)
347
-
348
-
349
- # class feature_encoder(nn.Module):
350
- # def __init__(self, mels, input_dims, dims, head, layer, act, features, feature=None, use_rope=False, spec_shape=None, debug=[], attend_feature=False, target_length=None):
351
- # """
352
- # Feature encoder for audio processing.
353
- # """
354
- # super().__init__()
355
-
356
- # self.dims = dims
357
- # self.head = head
358
- # self.head_dim = dims // head
359
- # self.dropout = 0.01
360
- # self.use_rope = use_rope
361
- # self.attend_feature = attend_feature
362
- # self.target_length = target_length
363
- # self.feature = feature
364
-
365
- # self.debug = debug
366
- # act_fn = get_activation(act)
367
-
368
- # if self.attend_feature:
369
- # self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
370
- # self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims))
371
- # else:
372
- # self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
373
- # self.mlp = None
374
-
375
- # self.spectrogram = nn.Sequential(
376
- # Conv1d(mels, dims, kernel_size=3), act_fn,
377
- # Conv1d(dims, dims, kernel_size=3), act_fn,
378
- # Conv1d(dims, dims, kernel_size=3, groups=dims), act_fn)
379
-
380
- # self.waveform = nn.Sequential(
381
- # Conv1d(1, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
382
- # Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
383
- # Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
384
-
385
- # self.pitch = nn.Sequential(
386
- # Conv1d(1, dims, kernel_size=7, stride=1, padding=3), act_fn,
387
- # Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
388
- # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
389
-
390
- # if use_rope:
391
- # # if spec_shape is not None:
392
- # self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
393
- # self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
394
- # else:
395
- # self.rope = None
396
- # self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
397
- # self.norm = RMSNorm(dims)
398
-
399
- # def rope(self, x, xa=None, mask=None, feats=None, feature=None, layer=None):
400
- # if isinstance(x, int):
401
- # ctx = x
402
- # elif isinstance(x, torch.Tensor):
403
- # ctx = x.shape[1] if x.dim() > 1 else x.shape[0]
404
- # batch, ctx, dims = x.shape[0], ctx, x.shape[-1]
405
-
406
- # x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
407
- # freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
408
- # x = self.rope.apply_rotary(x, freqs) # pyright: ignore[reportOptionalSubscript, reportAttributeAccessIssue]
409
- # x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
410
- # return x
411
-
412
- # def mel_scalar(self, freq: float) -> float:
413
- # return 1127.0 * math.log(1.0 + freq / 700.0)
414
-
415
- # def forward(self, x, xa=None, mask=None, feats=None, feature=None, layer=None, max_tscale=36000):
416
- # target_length = x.shape[1] if self.target_length is None else self.target_length
417
-
418
- # if feature == "pitch":
419
- # xp = x.clone()
420
- # enc_dict = feats if feats is not None else {}
421
- # enc_dict = dict(enc_dict)
422
- # enc_dict["f0"] = xp
423
- # # xp = self.mel_scalar(xp.mean())
424
- # # print(f"Using pitch scalar: {xp}")
425
- # # max_tscale = xp*300
426
- # # print(f"Using max_tscale: {max_tscale}")
427
- # feats = enc_dict
428
- # if x.dim() == 2:
429
- # x = x.unsqueeze(0)
430
- # x = self.pitch(x).permute(0, 2, 1)
431
-
432
- # if feature == "phase":
433
- # if x.dim() == 2:
434
- # x = x.unsqueeze(0)
435
- # x = self.pitch(x).permute(0, 2, 1)
436
-
437
- # if feature == "waveform":
438
- # if x.dim() == 2:
439
- # x = x.unsqueeze(0)
440
- # x = self.waveform(x).permute(0, 2, 1)
441
- # if target_length and x.shape[1] != self.target_length:
442
- # x = F.adaptive_avg_pool1d(x.transpose(1, 2), target_length).transpose(1, 2)
443
-
444
- # if feature == "harmonics":
445
- # if x.dim() == 2:
446
- # x = x.unsqueeze(0)
447
- # x = self.spectrogram(x).permute(0, 2, 1)
448
-
449
- # if feature == "aperiodic":
450
- # if x.dim() == 2:
451
- # x = x.unsqueeze(0)
452
- # x = self.spectrogram(x).permute(0, 2, 1)
453
-
454
- # if feature == "spectrogram":
455
- # if x.dim() == 2:
456
- # x = x.unsqueeze(0)
457
- # x = self.spectrogram(x).permute(0, 2, 1)
458
-
459
- # if self.use_rope:
460
- # x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
461
- # x = self.rope(x=x, xa=None, mask=None, feats=feats, feature=feature, layer=layer)
462
- # else:
463
- # max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
464
- # x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
465
- # x = nn.functional.dropout(x, p=self.dropout, training=self.training)
466
- # x = self.norm(x)
467
-
468
- # if self.attend_feature:
469
- # xa = feats[feature] # pyright: ignore[reportOptionalSubscript]
470
- # if xa is not None:
471
- # q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
472
- # out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
473
- # x = x + out
474
-
475
- # x = nn.functional.dropout(x, p=self.dropout, training=self.training)
476
- # x = self.norm(x)
477
- # return x
478
-
479
- class OneShot(nn.Module):
480
- def __init__(self, dims: int, head: int, scale: float = 0.3, features: Optional[List[str]] = None):
481
- super().__init__()
482
- if features is None:
483
- features = ["spectrogram", "waveform", "pitch", "aperiodic", "harmonics"]
484
- self.head = head
485
- self.head_dim = dims // head
486
- self.scale = 1.0 // len(features) if features else scale
487
-
488
- self.q = Linear(dims, dims)
489
- self.k = Linear(dims, dims)
490
-
491
- def forward(self, x: Tensor, xa: Tensor, feature=None) -> Tensor | None:
492
- B, L, D = x.shape
493
- K = xa.size(1)
494
- q = self.q(x).view(B, L, self.head, self.head_dim).transpose(1,2)
495
- k = self.k(xa).view(B, K, self.head, self.head_dim).transpose(1,2)
496
- bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.head_dim)
497
- return bias
498
-
499
- class curiosity(nn.Module):
500
- def __init__(self, d, h, bias=True):
501
- super().__init__()
502
- self.h = h
503
- self.dh = d // h
504
- self.qkv = nn.Linear(d, d * 3, bias=bias)
505
- self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
506
- self.o = nn.Linear(d, d, bias=bias)
507
- self.g = nn.Parameter(torch.zeros(h))
508
-
509
- def split(self, x):
510
- b, t, _ = x.shape
511
- return x.view(b, t, self.h, self.dh).transpose(1, 2)
512
-
513
- def merge(self, x):
514
- b, h, t, dh = x.shape
515
- return x.transpose(1, 2).contiguous().view(b, t, h * dh)
516
-
517
- def forward(self, x, xa, mask=None):
518
- q, k, v = self.qkv(x).chunk(3, -1)
519
- qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
520
- q, k, v = map(self.split, (q, k, v))
521
- qa, ka, va = map(self.split, (qa, ka, va))
522
- dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
523
- dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
524
- if mask is not None: dots = dots.masked_fill(mask, -9e15)
525
- p = dots.softmax(-1)
526
- pa = dots_aux.softmax(-1)
527
- h_main = p @ v
528
- h_aux = pa @ va
529
- g = torch.sigmoid(self.g).view(1, -1, 1, 1)
530
- out = self.merge(h_main * (1 - g) + h_aux * g)
531
- return self.o(out)
532
-
533
- class PositionalEncoding(nn.Module):
534
- def __init__(self, dims, ctx):
535
- super(PositionalEncoding, self).__init__()
536
- self.dims = dims
537
- self.ctx = ctx
538
- self.pe = self.get_positional_encoding(max_ctx=ctx)
539
-
540
- def get_positional_encoding(self, max_ctx):
541
- pe = torch.zeros(max_ctx, self.dims)
542
- position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
543
- div_term = torch.exp(
544
- torch.arange(0, self.dims, 2, dtype=torch.float32)
545
- * (-math.log(10000.0) / self.dims)
546
- )
547
- pe[:, 0::2] = torch.sin(position * div_term)
548
- pe[:, 1::2] = torch.cos(position * div_term)
549
- pe = pe.unsqueeze(0)
550
- return pe.to(device)
551
-
552
- def forward(self, x):
553
- ctx = x.size(1)
554
- pe = self.pe[:, :ctx, :]
555
- x = x * math.sqrt(self.dims)
556
- x = x + pe
557
- return x
558
-
559
-
560
- def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
561
- title="", markers=None, marker_labels=None,
562
- show_voiced_regions=True, show_energy=False):
563
- num_plots = sum([x is not None, w is not None, p is not None, per is not None])
564
- if num_plots == 0:
565
- raise ValueError("No data to plot. Please provide at least one input tensor.")
566
- t_spans = []
567
-
568
- if w is not None:
569
- w_np = w[sample_idx].detach().cpu().numpy()
570
- if w_np.ndim > 1:
571
- w_np = w_np.squeeze()
572
- t_spans.append(len(w_np) / sr)
573
- if x is not None:
574
- x_np = x[sample_idx].detach().cpu().numpy()
575
- if x_np.shape[0] < x_np.shape[1]:
576
- x_np = x_np.T
577
- t_spans.append(x_np.shape[0] * hop_length / sr)
578
- if p is not None:
579
- p_np = p[sample_idx].detach().cpu().numpy()
580
- if p_np.ndim > 1:
581
- p_np = p_np.squeeze()
582
- t_spans.append(len(p_np) * hop_length / sr)
583
- if per is not None:
584
- per_np = per[sample_idx].detach().cpu().numpy()
585
- if per_np.ndim > 1:
586
- per_np = per_np.squeeze()
587
- t_spans.append(len(per_np) * hop_length / sr)
588
- max_t = max(t_spans) if t_spans else 0
589
- fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
590
- if num_plots == 1:
591
- axs = [axs]
592
- if show_voiced_regions and per is not None:
593
- per_np = per[sample_idx].detach().cpu().numpy()
594
- if per_np.ndim > 1:
595
- per_np = per_np.squeeze()
596
- t_per = np.arange(len(per_np)) * hop_length / sr
597
- threshold = 0.5
598
- for ax in axs:
599
- for i in range(len(per_np)-1):
600
- if per_np[i] > threshold:
601
- ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
602
- cu_ax = 0
603
- if w is not None:
604
- w_np = w[sample_idx].detach().cpu().numpy()
605
- if w_np.ndim > 1:
606
- w_np = w_np.squeeze()
607
- t = np.arange(len(w_np)) / sr
608
- axs[cu_ax].plot(t, w_np, color="tab:blue")
609
-
610
- if show_energy:
611
- frame_length = hop_length
612
- hop_length_energy = hop_length // 2
613
- energy = []
614
- for i in range(0, len(w_np)-frame_length, hop_length_energy):
615
- frame = w_np[i:i+frame_length]
616
- energy.append(np.sqrt(np.mean(frame**2)))
617
- energy = np.array(energy)
618
- energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
619
- t_energy = np.arange(len(energy)) * hop_length_energy / sr
620
- axs[cu_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
621
- axs[cu_ax].legend(loc='upper right')
622
- axs[cu_ax].set_title("Waveform")
623
- axs[cu_ax].set_ylabel("Amplitude")
624
- axs[cu_ax].set_xlim([0, max_t])
625
- axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
626
- cu_ax += 1
627
-
628
- if x is not None:
629
- x_np = x[sample_idx].detach().cpu().numpy()
630
- if x_np.shape[0] < x_np.shape[1]:
631
- x_np = x_np.T
632
- axs[cu_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
633
- extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
634
- axs[cu_ax].set_title("Spectrogram")
635
- axs[cu_ax].set_ylabel("Mel Bin")
636
- axs[cu_ax].set_xlim([0, max_t])
637
- axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
638
- cu_ax += 1
639
-
640
- if p is not None:
641
- p_np = p[sample_idx].detach().cpu().numpy()
642
- if p_np.ndim > 1:
643
- p_np = p_np.squeeze()
644
- t_p = np.arange(len(p_np)) * hop_length / sr
645
- axs[cu_ax].plot(t_p, p_np, color="tab:green")
646
- axs[cu_ax].set_title("Pitch")
647
- axs[cu_ax].set_ylabel("Frequency (Hz)")
648
- axs[cu_ax].set_xlim([0, max_t])
649
- axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
650
- axs[cu_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
651
- cu_ax += 1
652
-
653
- if per is not None:
654
- per_np = per[sample_idx].detach().cpu().numpy()
655
- if per_np.ndim > 1:
656
- per_np = per_np.squeeze()
657
- t_per = np.arange(len(per_np)) * hop_length / sr
658
- axs[cu_ax].plot(t_per, per_np, color="tab:red")
659
- axs[cu_ax].set_title("Period (Voice Activity)")
660
- axs[cu_ax].set_ylabel("periodocity")
661
- axs[cu_ax].set_xlim([0, max_t])
662
- axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
663
- axs[cu_ax].set_ylim([-0.05, 1.05])
664
- axs[cu_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
665
-
666
- if markers is not None:
667
- for i, t in enumerate(markers):
668
- label = marker_labels[i] if marker_labels and i < len(marker_labels) else None
669
- for ax in axs:
670
- ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
671
- if marker_labels:
672
- axs[0].legend(loc='upper right', fontsize='small')
673
- axs[-1].set_xlabel("t (s)")
674
- fig.suptitle(title, fontsize=16)
675
- plt.tight_layout(rect=[0, 0, 1, 0.97]) # type: ignore
676
- plt.show()
677
- return fig
678
-
679
- def valid(default_value, *items):
680
- """Get first non-None item"""
681
- for item in items:
682
- if item is not None:
683
- return item
684
- return default_value
685
-
686
- def dict_to(d, device, dtype=dtype):
687
- return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
688
- for k, v in d.items()}
689
-
690
- def exists(v):
691
- return v is not None
692
-
693
- def default(v, b):
694
- return v if exists(v) else b
695
-
696
- class Conv1d(nn.Conv1d):
697
- def _conv_forward(
698
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
699
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
700
-
701
- class Conv2d(nn.Conv2d):
702
- def _conv_forward(
703
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
704
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
705
-
706
- class Linear(nn.Module):
707
- def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
708
- super(Linear, self).__init__()
709
- self.linear = nn.Linear(in_features, out_features, bias=bias)
710
- init.xavier_uniform_(self.linear.weight)
711
- if bias:
712
- init.zeros_(self.linear.bias)
713
- def forward(self, x: Tensor) -> Tensor:
714
- return self.linear(x)
715
-
716
- class RMSNorm(nn.Module):
717
- def __init__(self, dims: Union[int, Tensor, List, Tuple],
718
- eps = 1e-8, elementwise_affine = True):
719
- super(RMSNorm, self).__init__()
720
- if isinstance(dims, int):
721
- self.normalized_shape = (dims,)
722
- else:
723
- self.normalized_shape = tuple(dims)
724
- self.eps = eps
725
- self.elementwise_affine = elementwise_affine
726
- if self.elementwise_affine:
727
- self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
728
- init.ones_(self.weight)
729
- else:
730
- self.register_parameter("weight", None)
731
- def forward(self, x):
732
- return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
733
-
734
- def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
735
- weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
736
- eps: float = 1e-5) -> Tensor:
737
- return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
738
-
739
- def get_device():
740
- return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
741
-
742
- def get_dtype():
743
- return torch.float32 if torch.cuda.is_available() else torch.float64
744
-
745
- def tox():
746
- return {"device": get_device(), "dtype": get_dtype()}
747
-
748
- class Sinusoids(nn.Module):
749
- def __init__(self, ctx: int, dims: int):
750
- super().__init__()
751
-
752
- position = torch.arange(start=0, end=ctx, dtype=dtype).unsqueeze(dim=1)
753
- div_term = torch.exp(input=torch.arange(start=0, end=dims, step=2, dtype=dtype) * -(math.log(10000.0) / dims))
754
- features = torch.zeros(ctx, dims)
755
- features[:, 0::2] = torch.sin(position * div_term)
756
- features[:, 1::2] = torch.cos(position* div_term)
757
- self.register_buffer('sinusoid', tensor=features)
758
- self.positional_embeddings = nn.Parameter(self.sinusoid.clone()) # type: ignore
759
- def forward(self, positions):
760
- position_embeddings = self.positional_embeddings[positions]
761
- return position_embeddings
762
-
763
- def sinusoids(length, channels, max_tscale=10000):
764
- assert channels % 2 == 0
765
- log_tscale_increment = torch.log(torch.tensor(float(max_tscale))) / (channels // 2 - 1)
766
- inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2, device=device, dtype=torch.float32))
767
- scaled_t = torch.arange(length, device=device, dtype=torch.float32).unsqueeze(1) * inv_tscales.unsqueeze(0)
768
- return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
769
-
770
- class SelfCriticalRL(nn.Module):
771
- def __init__(self, model, tokenizer, reward_fn):
772
- super().__init__()
773
- self.model = model
774
- self.tokenizer = tokenizer
775
- self.reward_fn = reward_fn
776
-
777
- def forward(self, input_ids, features, labels=None, max_len=128, feature_name="spectrogram"):
778
-
779
- with torch.no_grad():
780
- greedy_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len)
781
- greedy_text = [self.tokenizer.decode(ids) for ids in greedy_ids]
782
- sampled_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len, do_sample=True, top_k=5)
783
- sampled_text = [self.tokenizer.decode(ids) for ids in sampled_ids]
784
-
785
- rewards = []
786
- baseline = []
787
- for s, g, ref in zip(sampled_text, greedy_text, labels): # type: ignore
788
- ref_text = self.tokenizer.decode(ref)
789
- rewards.append(self.reward_fn(s, ref_text))
790
- baseline.append(self.reward_fn(g, ref_text))
791
- rewards = torch.tensor(rewards, device=device, dtype=torch.float)
792
- baseline = torch.tensor(baseline, device=device, dtype=torch.float)
793
- advantage = rewards - baseline
794
- logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"] # logits: [batch, sampled_seq_len, vocab_size]
795
- log_probs = F.log_softmax(logits, dim=-1)
796
- log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1)
797
- log_probs_sum = log_probs_seq.sum(dim=1)
798
- loss = -(advantage * log_probs_sum).mean()
799
- return loss
800
-
801
- class SelfTrainingModule(nn.Module):
802
- def __init__(self, model, tokenizer, quality_fn=None, threshold=0.8):
803
- super().__init__()
804
- self.model = model
805
- self.tokenizer = tokenizer
806
- self.quality_fn = quality_fn
807
- self.threshold = threshold
808
-
809
- def generate_pseudo_labels(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
810
- with torch.no_grad():
811
- pred_ids = self.model.generate(input_ids=unlabeled_batch, **{feature_name: features}, max_length=max_len)
812
-
813
- if self.quality_fn is not None:
814
- quality_scores = self.quality_fn(pred_ids, self.model, features)
815
- mask = quality_scores > self.threshold
816
- pred_ids = pred_ids[mask]
817
- return pred_ids
818
-
819
- def forward(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
820
- pseudo_labels = self.generate_pseudo_labels(unlabeled_batch, features, max_len, feature_name=feature_name)
821
- logits = self.model(input_ids=unlabeled_batch, **{feature_name: features}, labels=pseudo_labels)["logits"]
822
- loss = nn.functional.cross_entropy(
823
- logits.view(-1, logits.shape[-1]), pseudo_labels.view(-1), ignore_index=0)
824
- return loss
825
-
826
- def confidence_indicator(pred_ids, model, features):
827
- with torch.no_grad():
828
- logits = model(input_ids=pred_ids, **features)["logits"]
829
- probs = torch.softmax(logits, dim=-1)
830
- max_probs, _ = probs.max(dim=-1)
831
- return max_probs.mean(dim=1)
832
-
833
- def wer_reward(hyp, ref):
834
-
835
- hyp_words = hyp.split()
836
- ref_words = ref.split()
837
- d = [[0] * (len(ref_words)+1) for _ in range(len(hyp_words)+1)]
838
- for i in range(len(hyp_words)+1):
839
- d[i][0] = i
840
- for j in range(len(ref_words)+1):
841
- d[0][j] = j
842
- for i in range(1, len(hyp_words)+1):
843
- for j in range(1, len(ref_words)+1):
844
- if hyp_words[i-1] == ref_words[j-1]:
845
- d[i][j] = d[i-1][j-1]
846
- else:
847
- d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
848
- wer = d[-1][-1] / max(1, len(ref_words))
849
- return -wer # negative WER as reward
850
-
851
- def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
852
- if isinstance(ids, torch.Tensor):
853
- ids = ids.tolist()
854
- return [int(id) for id in ids if id != -100 and id != pad_token_id and id != bos_token_id and id != eos_token_id]
855
-
856
- def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
857
- return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids]
858
-
859
- def setup_tokenizer(dir: str):
860
- from tokenizers import Tokenizer
861
- tokenizer = Tokenizer.from_file(f"{dir}")
862
- orig_encode = tokenizer.encode
863
- orig_decode = tokenizer.decode
864
-
865
- def enc(text, add_special_tokens=True):
866
- ids = orig_encode(text).ids
867
- if not add_special_tokens:
868
- sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
869
- ids = [id for id in ids if id not in sp_ids]
870
- return ids
871
-
872
- def bdec(ids_list, pad_token_id=0, bos_token_id=1, eos_token_id=2, skip_special_tokens=True):
873
- results = []
874
- if isinstance(ids_list, torch.Tensor):
875
- ids_list = ids_list.tolist()
876
- elif isinstance(ids_list, np.ndarray):
877
- ids_list = ids_list.tolist()
878
- for ids in ids_list:
879
- ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)]
880
- results.append(orig_decode(ids))
881
- return results
882
-
883
- def dec(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
884
- ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)]
885
- return orig_decode(ids)
886
-
887
- def save_pretrained(save_dir):
888
- os.makedirs(save_dir, exist_ok=True)
889
- tokenizer.save(f"{save_dir}/tokenizer.json")
890
-
891
- tokenizer.encode = enc
892
- tokenizer.batch_decode = bdec
893
- tokenizer.decode = dec
894
- tokenizer.save_pretrained = save_pretrained
895
- tokenizer.pad_token_id = 0
896
- tokenizer.bos_token_id = 1
897
- tokenizer.eos_token_id = 2
898
- return tokenizer
899
-
900
- def tokenize_pitch(pitch_features, target_length):
901
- pitch_len = pitch_features.shape[-1]
902
- token_len = target_length
903
- if pitch_len > token_len:
904
- pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len)
905
- else:
906
- pitch_tokens = F.interpolate(pitch_features, token_len)
907
- return pitch_tokens
908
-
909
- def load_wave(wave_data, sample_rate=16000):
910
-
911
- if isinstance(wave_data, str):
912
- waveform, sample_rate = torchaudio.load(uri=wave_data, normalize=False)
913
- elif isinstance(wave_data, dict):
914
- waveform = torch.tensor(data=wave_data["array"]).float()
915
- sample_rate = wave_data["sampling_rate"] # noqa: F841
916
- else:
917
- raise TypeError("Invalid wave_data format.")
918
- return waveform
919
-
920
- def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
921
- import librosa
922
- mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
923
- mel_basis = torch.from_numpy(mel_basis).float()
924
- sp_mel = torch.matmul(sp, mel_basis.T) # (frames, 128)
925
- ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128)
926
- return sp_mel, ap_mel
927
-
928
- def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False):
929
-
930
- # import torchaudio
931
- # import torchaudio.functional
932
- # import torchaudio.transforms
933
-
934
- # torch_windows = {
935
- # 'hann': torch.hann_window,
936
- # 'hamming': torch.hamming_window,
937
- # 'blackman': torch.blackman_window,
938
- # 'bartlett': torch.bartlett_window,
939
- # 'ones': torch.ones,
940
- # None: torch.ones,
941
- # }
942
- # if dummy:
943
- # return {
944
- # "spectrogram": torch.zeros((1, 128, 100)),
945
- # "f0": torch.zeros((1, 100)),
946
- # "f0t": torch.zeros((1, 100)),
947
- # "pitch": torch.zeros((1, 100)),
948
- # "harmonics": torch.zeros((1, 128, 100)),
949
- # "aperiodics": torch.zeros((1, 128, 100)),
950
- # "crepe_time": None,
951
- # "crepe_frequency": None,
952
- # "crepe_confidence": None,
953
- # "crepe_activation": None,
954
- # }
955
-
956
- audio = batch["audio"]
957
- sample_rate = audio["sampling_rate"]
958
- labels = tokenizer.encode(batch["transcription"])
959
- wav = load_wave(wave_data=audio, sample_rate=sample_rate)
960
-
961
- spectrogram_config = {
962
- # "hop_length": 256,
963
- # "f_min": 150,
964
- # "f_max": 2000,
965
- # "n_mels": 128,
966
- # "n_fft": 1024,
967
- "sample_rate": 16000,
968
- # "pad_mode": "constant",
969
- # "center": True,
970
- # "power": 1.0,
971
- # "window_fn": torch.hann_window,
972
- # "mel_scale": "htk",
973
- # "norm": None,
974
- # "normalized": False,
975
- }
976
-
977
- def crepe_predict(wav, sample_rate, viterbi=False):
978
- import torchcrepe
979
- wav = wav.numpy().astype(np.float32)
980
- time, frequency, confidence, activation = torchcrepe.predict(
981
- wav, sample_rate=sample_rate, viterbi=viterbi)
982
- crepe_time = torch.from_numpy(time)
983
- crepe_frequency = torch.from_numpy(frequency)
984
- crepe_confidence = torch.from_numpy(confidence)
985
- crepe_activation = torch.from_numpy(activation)
986
- return crepe_time, crepe_frequency, crepe_confidence, crepe_activation
987
-
988
- if crepe:
989
- crepe_time, crepe_frequency, crepe_confidence, crepe_activation = crepe_predict(wav, sample_rate, viterbi=True)
990
-
991
- else:
992
- crepe_time = None
993
- crepe_frequency = None
994
- crepe_confidence = None
995
- crepe_activation = None
996
-
997
- # def spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
998
- # if isinstance(window_fn, str):
999
- # window_fn = torch_windows[window_fn]
1000
- # if window_fn is None:
1001
- # window_fn = torch.ones(n_fft)
1002
- # if isinstance(window_fn, torch.Tensor):
1003
- # window_fn = window_fn.to(device)
1004
- # return torchaudio.functional.spectrogram(
1005
- # wav, n_fft=n_fft, hop_length=hop_length, win_length=n_fft,
1006
- # window=window_fn, center=True, pad_mode="reflect", power=1.0)
1007
-
1008
- # def mel_spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
1009
- # transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
1010
- # mel_spectrogram = transform(wav)
1011
- # log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1012
- # log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1013
- # spectrogram_tensor = (log_mel + 4.0) / 4.0
1014
- # spectrogram_tensor = torch.tensor(spectrogram_tensor)
1015
- # return spectrogram_tensor
1016
- if spec:
1017
- transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
1018
- mel_spectrogram = transform(wav)
1019
- log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1020
- log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1021
- spectrogram_tensor = (log_mel + 4.0) / 4.0
1022
- spectrogram_tensor = torch.tensor(spectrogram_tensor)
1023
-
1024
-
1025
-
1026
- # if spec:
1027
- # if isinstance(wav, torch.Tensor):
1028
- # wav = wav.to(device)
1029
- # spectrogram_tensor = mel_spectrogram(wav, sample_rate, **spectrogram_config)
1030
- # spectrogram_tensor = spectrogram_tensor.permute(1, 0)
1031
-
1032
-
1033
- def mfcc(wav, sample_rate, n_mels=128, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
1034
- transform = torchaudio.transforms.MFCC(
1035
- sample_rate=sample_rate,
1036
- n_mfcc=n_mels,
1037
- melkwargs={
1038
- "n_fft": n_fft,
1039
- "hop_length": hop_length,
1040
- "window_fn": window_fn,
1041
- "n_mels": n_mels,
1042
- "center": True,
1043
- "pad_mode": "reflect",
1044
- "norm": None,
1045
- "mel_scale": "htk",
1046
- }
1047
- )
1048
- mfcc_tensor = transform(wav)
1049
- return mfcc_tensor
1050
-
1051
-
1052
- def compute_pitch(wav, sample_rate, hop_length=256):
1053
- import pyworld as pw
1054
- wav_np = wav.numpy().astype(np.float64)
1055
- f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length / sample_rate * 1000)
1056
- f0 = pw.stonemask(wav_np, f0, t, sample_rate)
1057
- return f0, t
1058
-
1059
- def compute_harmonics_and_aperiodics(wav, f0, t, sample_rate):
1060
- import pyworld as pw
1061
- wav_np = wav.numpy().astype(np.float64)
1062
- sp = pw.cheaptrick(wav_np, f0, t, sample_rate, fft_size=256)
1063
- ap = pw.d4c(wav_np, f0, t, sample_rate, fft_size=256)
1064
- harmonic_tensor = torch.from_numpy(sp)
1065
- aperiodic_tensor = torch.from_numpy(ap)
1066
- harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
1067
- aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
1068
- harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
1069
- aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
1070
- return harmonic_tensor, aperiodic_tensor
1071
-
1072
-
1073
- if f0 or f0t or pitch or harmonics or aperiodics:
1074
- wavnp = wav.numpy().astype(np.float64)
1075
- f0_np, t = pw.dio(wavnp, sample_rate, frame_period=hop_length / sample_rate * 1000)
1076
- f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
1077
-
1078
- if f0:
1079
- f0_tensor = torch.from_numpy(f0_np)
1080
- else:
1081
- f0_tensor = None
1082
-
1083
- if f0t:
1084
- wav = torch.from_numpy(wavnp)
1085
- t2 = torch.from_numpy(t)
1086
- audio_duration = len(wav) / sample_rate
1087
- T = len(labels)
1088
- tok_dur_sec = audio_duration / T
1089
- token_starts = torch.arange(T) * tok_dur_sec
1090
- token_ends = token_starts + tok_dur_sec
1091
- start_idx = torch.searchsorted(t2, token_starts, side="left")
1092
- end_idx = torch.searchsorted(t2, token_ends, side="right")
1093
- pitch_tok = torch.zeros(T, dtype=torch.float32)
1094
- for i in range(T):
1095
- lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i]) # type: ignore
1096
- segment = f0_np[lo:hi]
1097
- if mode == "mean":
1098
- pitch_tok[i] = segment.mean()
1099
- elif mode == "median":
1100
- pitch_tok[i] = torch.median(segment)
1101
- else:
1102
- pitch_tok[i] = segment[-1]
1103
- pitch_tok[pitch_tok < 100.0] = 0.0
1104
- bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
1105
- f0t_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok])
1106
- f0t_tensor = torch.where(f0t_tensor == 0.0, torch.zeros_like(f0t_tensor), (f0t_tensor - 71.0) / (500.0 - 71.0))
1107
- else:
1108
- f0t_tensor = None
1109
-
1110
- if phase_mod:
1111
- tframe = torch.mean(t2[1:] - t2[:-1])
1112
- phi0 = 0.0
1113
- omega = 2 * torch.pi * f0_tensor # type: ignore
1114
- dphi = omega * tframe
1115
- phi = torch.cumsum(dphi, dim=0) + phi0
1116
- phase = torch.remainder(phi, 2 * torch.pi)
1117
- else:
1118
- phase = None
1119
-
1120
- if pitch:
1121
- p_tensor = compute_pitch(wav, sample_rate, hop_length=hop_length)[0]
1122
- p_tensor = torch.from_numpy(p_tensor)
1123
- p_tensor = p_tensor.unsqueeze(0)
1124
- # p_tensor = torch.from_numpy(f0_np)
1125
- else:
1126
- p_tensor = None
1127
-
1128
- if harmonics or aperiodics:
1129
- spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
1130
- apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
1131
- harmonic_tensor = torch.from_numpy(spnp)
1132
- aperiodic_tensor = torch.from_numpy(apnp)
1133
- harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
1134
- aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
1135
- harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
1136
- aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
1137
- else:
1138
- harmonic_tensor = None
1139
- aperiodic_tensor = None
1140
-
1141
- if waveform:
1142
- wave_tensor = wav
1143
- else:
1144
- wave_tensor = None
1145
-
1146
- if dummy:
1147
- if spectrogram_tensor is not None:
1148
- dummy_tensor = torch.ones_like(spectrogram_tensor)
1149
- elif p_tensor is not None:
1150
- dummy_tensor = torch.ones_like(p_tensor)
1151
- elif f0_tensor is not None:
1152
- dummy_tensor = torch.ones_like(f0_tensor)
1153
- elif f0t_tensor is not None:
1154
- dummy_tensor = torch.ones_like(f0t_tensor)
1155
- else:
1156
- batch_size = 128
1157
- seq_len = 1024
1158
- dummy_tensor = torch.ones(batch_size, seq_len)
1159
- dummy_tensor = dummy_tensor.to(device)
1160
-
1161
- else:
1162
- dummy_tensor = None
1163
-
1164
- if debug:
1165
-
1166
- print(f"['f0']: {f0_tensor.shape if f0 else None}")
1167
- print(f"['f0t']: {f0t_tensor.shape if f0t else None}")
1168
- print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}")
1169
- print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}")
1170
- print(f"['spectrogram']: {spectrogram_tensor.shape if spec else None}")
1171
- print(f"['waveform']: {wave_tensor.shape if waveform else None}")
1172
- print(f"['labels']: {len(labels) if labels else None}")
1173
- print(f"['phase']: {phase.shape if phase else None}")
1174
- print(f"['pitch']: {p_tensor.shape if pitch else None}")
1175
- print(f"['crepe_time']: {crepe_time.shape if crepe else None}")
1176
- print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}")
1177
- print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}")
1178
- print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}")
1179
- print(f"['dummy']: {dummy_tensor.shape if dummy else None}")
1180
-
1181
- return {
1182
- "waveform": wave_tensor if waveform else None,
1183
- "spectrogram": spectrogram_tensor if spec else None,
1184
- "f0": f0_tensor if f0 else None,
1185
- "f0t": f0t_tensor if f0t else None,
1186
- "pitch": p_tensor if pitch else None,
1187
- "harmonic": harmonic_tensor if harmonics else None,
1188
- "aperiodic": aperiodic_tensor if aperiodics else None,
1189
- "labels": labels,
1190
- "phase": phase if phase_mod else None,
1191
- "crepe_time": crepe_time if crepe else None,
1192
- "crepe_frequency": crepe_frequency if crepe else None,
1193
- "crepe_confidence": crepe_confidence if crepe else None,
1194
- "crepe_activation": crepe_activation if crepe else None,
1195
- "dummy": dummy_tensor if dummy else None,
1196
- }
1197
-
1198
- def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
1199
- load_saved=False, save_dataset=False, cache_dir=None, extract_args=None, max_ctx=2048):
1200
-
1201
- if extract_args is None:
1202
- extract_args = {
1203
- "waveform": False,
1204
- "spec": False,
1205
- "f0": False,
1206
- "f0t": False,
1207
- "pitch": False,
1208
- "harmonic": False,
1209
- "aperiodic": False,
1210
- "sample_rate": 16000,
1211
- "hop_length": 256,
1212
- "mode": "mean",
1213
- "debug": False,
1214
- "phase_mod": False,
1215
- "crepe": False,
1216
- "dummy": False,
1217
- }
1218
-
1219
- if load_saved:
1220
- if cache_dir is None:
1221
- cache_dir = "./processed_datasets"
1222
- else:
1223
- cache_dir = cache_dir
1224
-
1225
- os.makedirs(cache_dir, exist_ok=True)
1226
- cache_file_train = os.path.join(cache_dir, "train.arrow")
1227
- cache_file_test = os.path.join(cache_dir, "test.arrow")
1228
-
1229
- if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
1230
- from datasets import Dataset
1231
- train_dataset = Dataset.load_from_disk(cache_file_train)
1232
- test_dataset = Dataset.load_from_disk(cache_file_test)
1233
- return train_dataset, test_dataset
1234
-
1235
- if sanity_check:
1236
- test = load_dataset(
1237
- "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).cast_column("audio", Audio(sampling_rate=sample_rate)).take(1)
1238
-
1239
- dataset = test.map(
1240
- lambda x: extract_features(x, tokenizer, **extract_args),
1241
- remove_columns=test.column_names)
1242
-
1243
- train_dataset = dataset
1244
- test_dataset = dataset
1245
- return train_dataset, test_dataset
1246
-
1247
- else:
1248
-
1249
- def filter_func(x):
1250
- return (0 < len(x["transcription"]) < max_ctx and
1251
- len(x["audio"]["array"]) > 0 and
1252
- len(x["audio"]["array"]) < max_ctx * 160)
1253
-
1254
- raw_train = load_dataset(
1255
- "google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000)
1256
- raw_test = load_dataset(
1257
- "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).take(100)
1258
-
1259
- raw_train = raw_train.filter(filter_func)
1260
- raw_test = raw_test.filter(filter_func)
1261
- raw_train = raw_train.cast_column("audio", Audio(sampling_rate=sample_rate))
1262
- raw_test = raw_test.cast_column("audio", Audio(sampling_rate=sample_rate))
1263
-
1264
- train_dataset = raw_train.map(
1265
- lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_train.column_names)
1266
-
1267
- test_dataset = raw_test.map(
1268
- lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_test.column_names)
1269
- train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
1270
- test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
1271
-
1272
- return train_dataset, test_dataset
1273
-
1274
- def get_feature_encoder(feature: str, mels: int, input_dims: int, dims: int, head: int, layer: int, act=None, features=None) -> nn.Module:
1275
- if feature == "spectrogram":
1276
- return FEncoder(mels=mels, input_dims=input_dims, dims=dims, head=head, layer=layer, act=act, feature=feature, features=features)
1277
- elif feature == "waveform":
1278
- return WEncoder(input_dims, dims, head, layer, act, feature, features)
1279
- elif feature == "pitch":
1280
- return PEncoder(input_dims, dims, head, layer, act, feature, features)
1281
- else:
1282
- raise ValueError(f"Unknown feature type: {feature}")
1283
-
1284
- class FEncoder(nn.Module):
1285
- def __init__(self, mels, input_dims, dims, head, layer, act, feature, features, use_rope=False, spec_shape=None, debug=[]):
1286
- super().__init__()
1287
-
1288
- self.head = head
1289
- self.head_dim = dims // head
1290
- self.dropout = 0.01
1291
- self.use_rope = use_rope
1292
- self.dims = dims
1293
- self.debug = debug
1294
- self.feature = feature
1295
- self.mels = mels
1296
- self.input_dims = input_dims
1297
- act_fn = get_activation(act)
1298
-
1299
- self.encoder = nn.Sequential(
1300
- Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
1301
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
1302
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
1303
-
1304
- if use_rope:
1305
- if spec_shape is not None:
1306
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore
1307
- else:
1308
- self.rope = None
1309
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
1310
- self.norm = RMSNorm(dims)
1311
-
1312
- def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
1313
- batch, ctx, dims = x.shape
1314
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1315
- freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore
1316
- x = self.rope.apply_rotary(x, freqs)# type: ignore
1317
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1318
-
1319
- return x
1320
-
1321
- def forward(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
1322
- x = self.encoder(x).permute(0, 2, 1)
1323
- if self.use_rope:
1324
- x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
1325
- else:
1326
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
1327
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
1328
- print(f"feature encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
1329
- x = self.norm(x)
1330
- return x
1331
-
1332
- class WEncoder(nn.Module): # waveform encoder
1333
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
1334
- super().__init__()
1335
-
1336
- self.head = head
1337
- self.head_dim = dims // head
1338
- self.dropout = 0.01
1339
- self.use_rope = use_rope
1340
- self.dims = dims
1341
- self.debug = debug
1342
- act_fn = get_activation(act)
1343
- self.target_length = None
1344
- self.encoder = nn.Sequential(
1345
- Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
1346
- Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
1347
- Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
1348
-
1349
- if use_rope:
1350
- if spec_shape is not None:
1351
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore
1352
- else:
1353
- self.rope = None
1354
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
1355
- self.norm = RMSNorm(dims)
1356
-
1357
- def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"):
1358
- batch, ctx, dims = x.shape
1359
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1360
- freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore
1361
- x = self.rope.apply_rotary(x, freqs)# type: ignore
1362
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1363
- return x
1364
-
1365
- def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"):
1366
- x = self.encoder(x).permute(0, 2, 1) # (batch, time, dims)
1367
- if self.target_length and x.shape[1] != self.target_length:
1368
- x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
1369
- if self.use_rope:
1370
- x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
1371
- else:
1372
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
1373
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
1374
- print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
1375
- return self.norm(x)
1376
-
1377
- class PEncoder(nn.Module): # pitch encoder
1378
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], one_shot=False, spec_shape=None):
1379
- super().__init__()
1380
-
1381
- self.head = head
1382
- self.head_dim = dims // head
1383
- self.dims = dims
1384
- self.dropout = 0.01
1385
- self.use_rope = use_rope
1386
- self.debug = debug
1387
- act_fn = get_activation(act)
1388
-
1389
- self.attend_pitch = False
1390
-
1391
- if self.attend_pitch:
1392
- self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
1393
- self.mlp = nn.Sequential(
1394
- nn.Linear(dims, dims),
1395
- nn.ReLU(),
1396
- nn.Linear(dims, dims),
1397
- )
1398
- else:
1399
- self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
1400
- self.mlp = None
1401
-
1402
- self.pitch_encoder = nn.Sequential(
1403
- Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
1404
- Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
1405
- Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
1406
-
1407
- # self.spectrogram_encoder = nn.Sequential(
1408
- # Conv1d(input_dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
1409
- # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
1410
- # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
1411
-
1412
- # self.waveform_encoder = nn.Sequential(
1413
- # Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
1414
- # Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
1415
- # Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
1416
-
1417
- if use_rope:
1418
- self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore
1419
- else:
1420
- self.rope = None
1421
- self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
1422
- self.norm = RMSNorm(dims)
1423
-
1424
- def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"):
1425
- batch, ctx, dims = x.shape
1426
- x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1427
- freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore
1428
- x = self.rope.apply_rotary(x, freqs)# type: ignore
1429
- x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1430
- return x
1431
-
1432
- def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"):
1433
- # f0=x
1434
- # freqs = self.rope(f0.shape[1], feats=feats, feature=feature, layer=layer)
1435
- if x.dim() == 2:
1436
- x = x.unsqueeze(0)
1437
- if feature == "pitch":
1438
- x = self.pitch_encoder(x).permute(0, 2, 1)
1439
- # elif feature == "spectrogram":
1440
- # x = self.spectrogram_encoder(x).permute(0, 2, 1)
1441
- # elif feature == "waveform":
1442
- # x = self.waveform_encoder(x).permute(0, 2, 1)
1443
-
1444
- # if self.target_length and x.shape[1] != self.target_length:
1445
- # x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
1446
-
1447
- if self.use_rope:
1448
- x = self.rope_to_feature(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
1449
-
1450
- x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
1451
- if self.mlp is not None:
1452
- x = self.mlp(x)
1453
-
1454
- if self.attend_pitch:
1455
- if xa is not None:
1456
- q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
1457
- out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
1458
-
1459
- x = x + out
1460
-
1461
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
1462
- x = self.norm(x)
1463
- print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
1464
- return x
1465
-
1466
-
1467
- @dataclass
1468
- class DataCollator:
1469
- tokenizer: Any
1470
-
1471
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1472
- all_keys = set()
1473
- for f in features:
1474
- all_keys.update(f.keys())
1475
- batch = {}
1476
- pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1477
- bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1478
- eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
1479
-
1480
- for key in all_keys:
1481
- if key == "labels":
1482
- labels_list = [f["labels"] for f in features]
1483
- max_len = max(len(l) for l in labels_list) # noqa: E741
1484
- all_ids, all_labels = [], []
1485
- for label in labels_list:
1486
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1487
- decoder_input = [bos_token_id] + label_list
1488
- label_eos = label_list + [eos_token_id]
1489
- input_len = max_len + 1 - len(decoder_input)
1490
- label_len = max_len + 1 - len(label_eos)
1491
- padded_input = decoder_input + [pad_token_id] * input_len
1492
- padded_labels = label_eos + [pad_token_id] * label_len
1493
- all_ids.append(padded_input)
1494
- all_labels.append(padded_labels)
1495
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1496
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1497
-
1498
- elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation", "dummy"]:
1499
- items = [f[key] for f in features if key in f]
1500
- items = [item for item in items if item is not None]
1501
- if not items:
1502
- continue
1503
- items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
1504
- max_len = max(item.shape[-1] for item in items)
1505
- padded = []
1506
- for item in items:
1507
- pad_width = max_len - item.shape[-1]
1508
- if pad_width > 0:
1509
- pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1510
- else:
1511
- pad_item = item
1512
- padded.append(pad_item)
1513
- batch[key] = torch.stack(padded)
1514
- # if key == "spectrogram":
1515
- # batch["spectrogram"] = batch[key]
1516
- return batch
1517
-
1518
- def levenshtein(reference_words, hypothesis_words):
1519
- m, n = len(reference_words), len(hypothesis_words)
1520
- dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)]
1521
- for i in range(m+1):
1522
- dist_matrix[i][0] = i
1523
- for j in range(n+1):
1524
- dist_matrix[0][j] = j
1525
- for i in range(1, m+1):
1526
- for j in range(1, n+1):
1527
- if reference_words[i-1] == hypothesis_words[j-1]:
1528
- dist_matrix[i][j] = dist_matrix[i-1][j-1]
1529
- else:
1530
- substitution = dist_matrix[i-1][j-1] + 1
1531
- insertion = dist_matrix[i][j-1] + 1
1532
- deletion = dist_matrix[i-1][j] + 1
1533
- dist_matrix[i][j] = min(substitution, insertion, deletion)
1534
- return dist_matrix[m][n]
1535
-
1536
- def wer_batch(references, hypotheses):
1537
- total_errors = 0
1538
- total_words = 0
1539
- for ref, hyp in zip(references, hypotheses):
1540
- ref_words = ref.lower().split()
1541
- errors = levenshtein(ref_words, hyp.lower().split())
1542
- total_errors += errors
1543
- total_words += len(ref_words)
1544
- return (total_errors / total_words) * 100 if total_words > 0 else 0.0
1545
-
1546
- def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0):
1547
- def clean(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
1548
- if isinstance(ids, torch.Tensor):
1549
- ids = ids.tolist()
1550
- if isinstance(ids[0], (list, torch.Tensor, np.ndarray)):
1551
- return [[int(i) for i in seq if i not in (-100, pad_token_id, bos_token_id, eos_token_id)] for seq in ids]
1552
- else:
1553
- return [int(i) for i in ids if i not in (-100, pad_token_id, bos_token_id, eos_token_id)]
1554
-
1555
- pred_ids = pred.predictions
1556
- label_ids = pred.label_ids
1557
-
1558
- if isinstance(pred_ids, tuple):
1559
- pred_ids = pred_ids[0]
1560
-
1561
- if not isinstance(pred_ids, torch.Tensor):
1562
- pred_ids = torch.tensor(pred_ids)
1563
-
1564
- label_ids = clean(label_ids)
1565
- pred_ids = clean(pred_ids)
1566
- pred_str = tokenizer.batch_decode(pred_ids)
1567
- label_str = tokenizer.batch_decode(label_ids)
1568
-
1569
- if print_pred:
1570
- for i in range(min(num_samples, len(pred_ids))):
1571
-
1572
- print(f"Pred tokens: {pred_ids[i]}")
1573
- print(f"Label tokens: {label_ids[i]}")
1574
- print(f"Pred: '{pred_str[i]}'")
1575
- print(f"Label: '{label_str[i]}'")
1576
- print("-" * 40)
1577
-
1578
- wer = wer_batch(label_str, pred_str)
1579
- if model is not None:
1580
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1581
- efficiency_score = (100 - wer) / trainable_params if trainable_params > 0 else 0.0
1582
- else:
1583
- trainable_params = 0.0
1584
- efficiency_score = 0.0
1585
-
1586
- return {
1587
- "wer": float(wer),
1588
- "efficiency_score": float(efficiency_score),
1589
- }
1590
-
1591
- def preprocess_logits_for_metrics(logits, labels):
1592
- pred_ids = torch.argmax(logits, dim=-1)
1593
- return pred_ids, labels