Sin2pi commited on
Commit
294a87a
·
verified ·
1 Parent(s): 22ef493

Create model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +1741 -0
model_hf.py ADDED
@@ -0,0 +1,1741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ PATH = 'E:/hf'
3
+ os.environ['HF_HOME'] = PATH
4
+ os.environ['HF_DATASETS_CACHE'] = PATH
5
+ import pyworld as pw
6
+ import math
7
+ import warnings
8
+ import logging
9
+ import gzip
10
+ import base64
11
+ import torch
12
+ import torchaudio
13
+ import torch.nn.functional as F
14
+ import torch.nn.init as init
15
+ from torch import nn, Tensor
16
+ import numpy as np
17
+ from einops import rearrange
18
+ import matplotlib.pyplot as plt
19
+ from typing import Optional, Dict, Union, List, Tuple, Any
20
+ from functools import partial
21
+ from datetime import datetime
22
+ from datasets import load_dataset, Audio
23
+ from transformers.trainer_seq2seq import Seq2SeqTrainer
24
+ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
25
+ import transformers
26
+ import evaluate
27
+ from dataclasses import dataclass
28
+ import aiohttp
29
+ torch.backends.cudnn.allow_tf32 = True
30
+ torch.backends.cuda.matmul.allow_tf32 = True
31
+ torch.set_float32_matmul_precision('high')
32
+ transformers.utils.logging.set_verbosity_error()
33
+
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ dtype = torch.float32
36
+
37
+ warnings.filterwarnings("ignore")
38
+ logging.basicConfig(level=logging.ERROR)
39
+
40
+ @dataclass
41
+ class Dimensions:
42
+ vocab: int
43
+ text_ctx: int
44
+ text_dims: int
45
+ text_head: int
46
+ text_idx: int
47
+ mels: int
48
+ aud_ctx: int
49
+ aud_dims: int
50
+ aud_head: int
51
+ aud_idx: int
52
+ act: str
53
+ debug: List[str]
54
+ cross_attn: bool
55
+ features: List[str]
56
+
57
+ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
58
+ title="", markers=None, marker_labels=None,
59
+ show_voiced_regions=True, show_energy=False):
60
+ num_plots = sum([x is not None, w is not None, p is not None, per is not None])
61
+ if num_plots == 0:
62
+ raise ValueError("No data to plot. Please provide at least one input tensor.")
63
+ time_spans = []
64
+
65
+ if w is not None:
66
+ w_np = w[sample_idx].detach().cpu().numpy()
67
+ if w_np.ndim > 1:
68
+ w_np = w_np.squeeze()
69
+ time_spans.append(len(w_np) / sr)
70
+ if x is not None:
71
+ x_np = x[sample_idx].detach().cpu().numpy()
72
+ if x_np.shape[0] < x_np.shape[1]:
73
+ x_np = x_np.T
74
+ time_spans.append(x_np.shape[0] * hop_length / sr)
75
+ if p is not None:
76
+ p_np = p[sample_idx].detach().cpu().numpy()
77
+ if p_np.ndim > 1:
78
+ p_np = p_np.squeeze()
79
+ time_spans.append(len(p_np) * hop_length / sr)
80
+ if per is not None:
81
+ per_np = per[sample_idx].detach().cpu().numpy()
82
+ if per_np.ndim > 1:
83
+ per_np = per_np.squeeze()
84
+ time_spans.append(len(per_np) * hop_length / sr)
85
+ max_time = max(time_spans) if time_spans else 0
86
+ fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
87
+ if num_plots == 1:
88
+ axs = [axs]
89
+ if show_voiced_regions and per is not None:
90
+ per_np = per[sample_idx].detach().cpu().numpy()
91
+ if per_np.ndim > 1:
92
+ per_np = per_np.squeeze()
93
+ t_per = np.arange(len(per_np)) * hop_length / sr
94
+ threshold = 0.5
95
+ for ax in axs:
96
+ for i in range(len(per_np)-1):
97
+ if per_np[i] > threshold:
98
+ ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
99
+ current_ax = 0
100
+ if w is not None:
101
+ w_np = w[sample_idx].detach().cpu().numpy()
102
+ if w_np.ndim > 1:
103
+ w_np = w_np.squeeze()
104
+ t = np.arange(len(w_np)) / sr
105
+ axs[current_ax].plot(t, w_np, color="tab:blue")
106
+
107
+ if show_energy:
108
+ frame_length = hop_length
109
+ hop_length_energy = hop_length // 2
110
+ energy = []
111
+ for i in range(0, len(w_np)-frame_length, hop_length_energy):
112
+ frame = w_np[i:i+frame_length]
113
+ energy.append(np.sqrt(np.mean(frame**2)))
114
+ energy = np.array(energy)
115
+ energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
116
+ t_energy = np.arange(len(energy)) * hop_length_energy / sr
117
+ axs[current_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
118
+ axs[current_ax].legend(loc='upper right')
119
+ axs[current_ax].set_title("Waveform")
120
+ axs[current_ax].set_ylabel("Amplitude")
121
+ axs[current_ax].set_xlim([0, max_time])
122
+ axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
123
+ current_ax += 1
124
+
125
+ if x is not None:
126
+ x_np = x[sample_idx].detach().cpu().numpy()
127
+ if x_np.shape[0] < x_np.shape[1]:
128
+ x_np = x_np.T
129
+ im = axs[current_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
130
+ extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
131
+ axs[current_ax].set_title("Spectrogram")
132
+ axs[current_ax].set_ylabel("Mel Bin")
133
+ axs[current_ax].set_xlim([0, max_time])
134
+ axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
135
+ current_ax += 1
136
+
137
+ if p is not None:
138
+ p_np = p[sample_idx].detach().cpu().numpy()
139
+ if p_np.ndim > 1:
140
+ p_np = p_np.squeeze()
141
+ t_p = np.arange(len(p_np)) * hop_length / sr
142
+ axs[current_ax].plot(t_p, p_np, color="tab:green")
143
+ axs[current_ax].set_title("Pitch")
144
+ axs[current_ax].set_ylabel("Frequency (Hz)")
145
+ axs[current_ax].set_xlim([0, max_time])
146
+ axs[current_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
147
+ axs[current_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
148
+ current_ax += 1
149
+
150
+ if per is not None:
151
+ per_np = per[sample_idx].detach().cpu().numpy()
152
+ if per_np.ndim > 1:
153
+ per_np = per_np.squeeze()
154
+ t_per = np.arange(len(per_np)) * hop_length / sr
155
+ axs[current_ax].plot(t_per, per_np, color="tab:red")
156
+ axs[current_ax].set_title("Period (Voice Activity)")
157
+ axs[current_ax].set_ylabel("periodocity")
158
+ axs[current_ax].set_xlim([0, max_time])
159
+ axs[current_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
160
+ axs[current_ax].set_ylim([-0.05, 1.05])
161
+ axs[current_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
162
+
163
+ if markers is not None:
164
+ for i, t in enumerate(markers):
165
+ label = marker_labels[i] if marker_labels and i < len(marker_labels) else None
166
+ for ax in axs:
167
+ ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
168
+ if marker_labels:
169
+ axs[0].legend(loc='upper right', fontsize='small')
170
+ axs[-1].set_xlabel("Time (s)")
171
+ fig.suptitle(title, fontsize=16)
172
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
173
+ plt.show()
174
+ return fig
175
+
176
+ def dict_to(d, device, dtype=dtype):
177
+ """Because PyTorch should have this built-in but doesn't"""
178
+ return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
179
+ for k, v in d.items()}
180
+
181
+ def exists(v):
182
+ return v is not None
183
+
184
+ def default(v, b):
185
+ return v if exists(v) else b
186
+
187
+ class Conv1d(nn.Conv1d):
188
+ def _conv_forward(
189
+ self, x: Tensor, weight: Tensor, bias) -> Tensor:
190
+ return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
191
+
192
+ class Conv2d(nn.Conv2d):
193
+ def _conv_forward(
194
+ self, x: Tensor, weight: Tensor, bias) -> Tensor:
195
+ return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
196
+
197
+ class Linear(nn.Module):
198
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
199
+ super(Linear, self).__init__()
200
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
201
+ init.xavier_uniform_(self.linear.weight)
202
+ if bias:
203
+ init.zeros_(self.linear.bias)
204
+ def forward(self, x: Tensor) -> Tensor:
205
+ return self.linear(x)
206
+
207
+ class RMSNorm(nn.Module):
208
+ def __init__(self, dims: Union[int, Tensor, List, Tuple],
209
+ eps = 1e-8, elementwise_affine = True):
210
+ super(RMSNorm, self).__init__()
211
+ if isinstance(dims, int):
212
+ self.normalized_shape = (dims,)
213
+ else:
214
+ self.normalized_shape = tuple(dims)
215
+ self.eps = eps
216
+ self.elementwise_affine = elementwise_affine
217
+ if self.elementwise_affine:
218
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape))
219
+ init.ones_(self.weight)
220
+ else:
221
+ self.register_parameter("weight", None)
222
+ def forward(self, x):
223
+ return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
224
+
225
+ def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
226
+ weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
227
+ eps: float = 1e-5) -> Tensor:
228
+ return F.layer_norm(x, normalized_shape, weight, bias, eps)
229
+
230
+ def get_device():
231
+ return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
232
+
233
+ def get_dtype():
234
+ return torch.float32 if torch.cuda.is_available() else torch.float64
235
+
236
+ def tox():
237
+ return {"device": get_device(), "dtype": get_dtype()}
238
+
239
+ def sinusoids(length, channels, max_timescale=10000):
240
+ assert channels % 2 == 0
241
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
242
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
243
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
244
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
245
+
246
+ class rotary(nn.Module):
247
+ def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False):
248
+ super(rotary, self).__init__()
249
+
250
+ self.use_pbias = use_pbias
251
+ self.dims = dims
252
+ self.head = head
253
+ self.head_dim = dims // head
254
+ self.radii = radii
255
+ self.dim = self.head_dim
256
+ self.debug = debug
257
+ self.counter = 0
258
+ self.last_theta = None
259
+
260
+ self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
261
+ self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
262
+
263
+ def theta_freqs(self, theta):
264
+ freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
265
+ freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
266
+ return freqs
267
+
268
+ def inverse_mel_scale_scalar(mel_freq: float) -> float:
269
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
270
+
271
+ def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
272
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
273
+
274
+ def mel_scale_scalar(freq: float) -> float:
275
+ return 1127.0 * math.log(1.0 + freq / 700.0)
276
+
277
+ def mel_scale(freq: Tensor) -> Tensor:
278
+ return 1127.0 * (1.0 + freq / 700.0).log()
279
+
280
+ def return_f0(self, f0=None):
281
+ if f0 is not None:
282
+ self.f0 = f0
283
+ self.update_base(f0)
284
+ return f0.squeeze(0).to(device, dtype)
285
+ elif hasattr(self, 'f0') and self.f0 is not None:
286
+ return self.f0.squeeze(0).to(device, dtype)
287
+ return None
288
+
289
+ def get_pitch_bias(self, f0):
290
+ if f0 is None:
291
+ return None
292
+ f0_flat = f0.squeeze().float()
293
+ f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
294
+ f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
295
+ f0_norm.unsqueeze(1)))
296
+ return f0_sim.unsqueeze(0).unsqueeze(0)
297
+
298
+ def f0proj(self, f0):
299
+ if f0.ndim == 3:
300
+ f0 = f0.squeeze(0)
301
+ self.f0_proj = nn.Linear(1, self.head_dim // 2, device=device, dtype=dtype)
302
+ f0 = f0.to(device, dtype)
303
+ f0 = self.f0_proj(f0.unsqueeze(-1))
304
+ if f0.ndim == 3:
305
+ f0 = f0.squeeze(0)
306
+ return f0.to(device=device, dtype=dtype)
307
+
308
+ def synth_f0(self, f0, ctx):
309
+ # f0 = self.f0proj(f0)
310
+ if f0.dim() == 1:
311
+ length = f0.shape[0]
312
+ if length == ctx:
313
+ return f0
314
+ frames = length / ctx
315
+ idx = torch.arange(ctx, device=f0.device)
316
+ # return torch.arange(1, ctx+1, device=f0.device, dtype=torch.float)
317
+ return f0[idx]
318
+
319
+ def align_f0(self, ctx, f0):
320
+ f0 = self.f0proj(f0)
321
+ if f0.dim() == 3:
322
+ batch, length, dims = f0.shape
323
+ if length == ctx:
324
+ return f0
325
+ frames = length / ctx
326
+ idx = torch.arange(ctx, device=f0.device)
327
+ idx = (idx * frames).long().clamp(0, length - 1)
328
+ return f0[:, idx, :]
329
+ if f0.dim() == 1:
330
+ length = f0.shape[0]
331
+ if length == ctx:
332
+ return f0
333
+ frames = length / ctx
334
+ idx = torch.arange(ctx, device=f0.device)
335
+ idx = (idx * frames).long().clamp(0, length - 1)
336
+ return f0[idx]
337
+ else:
338
+ length, dims = f0.shape
339
+ if length == ctx:
340
+ return f0
341
+ frames = length / ctx
342
+ idx = torch.arange(ctx, device=f0.device)
343
+ idx = (idx * frames).long().clamp(0, length - 1)
344
+ return f0[idx, :]
345
+
346
+ def forward(self, x=None, enc=None, layer=None, feature_type="audio") -> Tensor:
347
+ f0 = enc.get("f0") if enc is not None else None
348
+ if isinstance(x, int):
349
+ ctx = x
350
+ elif isinstance(x, torch.Tensor) and x.ndim == 2:
351
+ batch, ctx = x.shape
352
+ elif isinstance(x, torch.Tensor) and x.ndim == 3:
353
+ batch, ctx, dims = x.shape
354
+ else:
355
+ batch, head, ctx, head_dim = x.shape
356
+ t = torch.arange(ctx, device=device, dtype=dtype)
357
+
358
+ if f0 is not None and f0.dim() == 2:
359
+ if f0.shape[0] == 1:
360
+ f0 = f0.squeeze(0)
361
+ else:
362
+ f0 = f0.view(-1)
363
+
364
+ if f0 is not None and layer == "encoder":
365
+ f0_mean = f0.mean()
366
+ theta = f0_mean + self.theta
367
+ else:
368
+ theta = self.theta
369
+ freqs = self.theta_freqs(theta)
370
+
371
+ freqs = t[:, None] * freqs[None, :]
372
+ if self.radii and f0 is not None and layer == "encoder":
373
+ radius = f0.to(device, dtype)
374
+ L = radius.shape[0]
375
+ if L != ctx:
376
+ F = L / ctx
377
+ idx = torch.arange(ctx, device=f0.device)
378
+ idx = (idx * F).long().clamp(0, L - 1)
379
+ radius = radius[idx]
380
+ rad = radius
381
+ radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
382
+ radius = torch.sigmoid(radius)
383
+ else:
384
+ radius = torch.ones_like(freqs)
385
+ freqs = torch.polar(radius, freqs)
386
+
387
+ if "radius" in self.debug and self.counter % 100 == 0:
388
+ theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
389
+ print(f" [{layer}] [Radius] {radius.shape} {radius.mean():.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
390
+
391
+ if "theta" in self.debug and self.counter % 100 == 0:
392
+ if self.last_theta is None or abs(self.last_theta - theta.item()) > 1.0:
393
+ self.last_theta = theta.item()
394
+ print(f"[Theta] {self.last_theta:.2f}")
395
+
396
+ self.counter += 1
397
+ return freqs.unsqueeze(0)
398
+
399
+ @staticmethod
400
+ def apply_rotary(x, freqs):
401
+ x1 = x[..., :freqs.shape[-1]*2]
402
+ x2 = x[..., freqs.shape[-1]*2:]
403
+ orig_shape = x1.shape
404
+ if x1.ndim == 2:
405
+ x1 = x1.unsqueeze(0)
406
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
407
+ x1 = torch.view_as_complex(x1) * freqs
408
+ x1 = torch.view_as_real(x1).flatten(-2)
409
+ x1 = x1.view(orig_shape)
410
+ return torch.cat([x1.type_as(x), x2], dim=-1)
411
+
412
+ class MultiheadA(nn.Module):
413
+ _seen = set()
414
+ rbf = False
415
+ def __init__(self, dims: int, head: int, rotary_emb: bool = True,
416
+ zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
417
+ super(MultiheadA, self).__init__()
418
+
419
+ self.dims = dims
420
+ self.head = head
421
+ self.head_dim = dims // head
422
+ self.debug = debug
423
+ self.counter = 0
424
+
425
+ self.q = Linear(dims, dims).to(device, dtype)
426
+ self.k = Linear(dims, dims, bias=False).to(device, dtype)
427
+ self.v = Linear(dims, dims).to(device, dtype)
428
+ self.o = Linear(dims, dims).to(device, dtype)
429
+
430
+ self.pad_token = 0
431
+ self.rotary_emb = rotary_emb
432
+ self.minz = minz
433
+ self.maxz = maxz
434
+ self.zero_val = zero_val
435
+ self.optim_attn = optim_attn
436
+ self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
437
+
438
+ if rotary_emb:
439
+ self.rope = rotary(
440
+ dims=dims,
441
+ head=head,
442
+ debug=debug,
443
+ radii=True,
444
+ )
445
+ else:
446
+ self.rope = None
447
+
448
+ def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
449
+ scale = (self.dims // self.head) ** -0.25
450
+ dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
451
+ if rbf_ratio <= 0.0:
452
+ return dot_scores
453
+ q_norm = q.pow(2).sum(dim=-1, keepdim=True)
454
+ k_norm = k.pow(2).sum(dim=-1, keepdim=True)
455
+ qk = torch.matmul(q, k.transpose(-1, -2))
456
+ dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
457
+ rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
458
+ return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
459
+
460
+ def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
461
+ x = x.to(device, dtype)
462
+ if xa is not None:
463
+ xa = xa.to(device, dtype)
464
+
465
+ batch, ctx, dims = x.shape
466
+ scale = (self.dims // self.head) ** -0.25
467
+
468
+ z = default(xa, x).to(device, dtype)
469
+ q = self.q(x)
470
+ k = self.k(z)
471
+ v = self.v(z)
472
+ qlen = q.shape[1]
473
+ klen = k.shape[1]
474
+
475
+ if self.rotary_emb:
476
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
477
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
478
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
479
+ qlen = q.shape[2]
480
+ klen = k.shape[2]
481
+
482
+ q = self.rope.apply_rotary(q, (self.rope(qlen, enc=enc, layer=layer)))
483
+ k = self.rope.apply_rotary(k, (self.rope(klen, enc=enc, layer=layer)))
484
+ else:
485
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
486
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
487
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
488
+ batch, head, ctx, head_dim = q.shape
489
+
490
+ if self.rbf:
491
+ qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
492
+
493
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
494
+ if self.rope.use_pbias:
495
+ f0 = enc.get("f0", None) if enc is not None else None
496
+ pbias = self.rope.use_pbias(f0)
497
+ if pbias is not None:
498
+ qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
499
+ token_ids = k[:, :, :, 0]
500
+ zscale = torch.ones_like(token_ids)
501
+ fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
502
+ zscale[token_ids.float() == self.pad_token] = fzero
503
+
504
+ if mask is not None:
505
+ mask = mask[:q.shape[2], :q.shape[2]]
506
+ qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
507
+ qk = qk * zscale.unsqueeze(-2)
508
+ w = F.softmax(qk, dim=-1).to(q.dtype)
509
+ wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
510
+
511
+ if "multihead" in self.debug and self.counter % 100 == 0:
512
+ print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
513
+ self.counter += 1
514
+ return self.o(wv), qk.detach()
515
+
516
+ class t_gate(nn.Module):
517
+ def __init__(self, dims, num_types=4):
518
+ super().__init__()
519
+ self.gate_projections = nn.ModuleList([
520
+ nn.Sequential(Linear(dims, 1), nn.Sigmoid())
521
+ for _ in range(num_types)])
522
+ self.type_classifier = nn.Sequential(
523
+ Linear(dims, num_types),
524
+ nn.Softmax(dim=-1))
525
+ def forward(self, x):
526
+ type_probs = self.type_classifier(x)
527
+ gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
528
+ comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
529
+ return comb_gate
530
+
531
+ class m_gate(nn.Module):
532
+ def __init__(self, dims, mem_size=64):
533
+ super().__init__()
534
+ self.m_key = nn.Parameter(torch.randn(mem_size, dims))
535
+ self.m_val = nn.Parameter(torch.randn(mem_size, 1))
536
+ self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
537
+
538
+ def forward(self, x):
539
+ d_gate = torch.sigmoid(self.gate_proj(x))
540
+ attention = torch.matmul(x, self.m_key.transpose(0, 1))
541
+ attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
542
+ m_gate = torch.matmul(attention, self.m_val)
543
+ m_gate = torch.sigmoid(m_gate)
544
+ return 0.5 * (d_gate + m_gate)
545
+
546
+ class c_gate(nn.Module):
547
+ def __init__(self, dims):
548
+ super().__init__()
549
+ self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
550
+ self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
551
+ self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
552
+ self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
553
+ self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
554
+ self.integ = Linear(dims*5, dims)
555
+
556
+ def forward(self, x, features):
557
+ s_feat = features.get("spectrogram", x)
558
+ w_feat = features.get("waveform", x)
559
+ p_feat = features.get("pitch", x)
560
+ e_feat = features.get("envelope", x)
561
+ ph_feat = features.get("phase", x)
562
+ s = self.s_gate(x) * s_feat
563
+ w = self.w_gate(x) * w_feat
564
+ p = self.p_gate(x) * p_feat
565
+ e = self.e_gate(x) * e_feat
566
+ ph = self.ph_gate(x) * ph_feat
567
+ comb = torch.cat([s, w, p, e, ph], dim=-1)
568
+ return self.integ(comb)
569
+
570
+ class Residual(nn.Module):
571
+ _seen = set()
572
+ def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
573
+ tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
574
+ super().__init__()
575
+
576
+ self.dims = dims
577
+ self.head = head
578
+ self.ctx = ctx
579
+ self.head_dim = dims // head
580
+ self.cross_attn = cross_attn
581
+ self.features = features
582
+ self.debug = debug
583
+ self.counter = 0
584
+ self.dropout = 0.01
585
+
586
+ self.t_gate = tgate
587
+ self.m_gate = mgate
588
+ self.c_gate = cgate
589
+ self.skip_gates=True
590
+
591
+ self.blend = nn.Parameter(torch.tensor(0.5))
592
+
593
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
594
+ "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
595
+ "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
596
+ "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
597
+ act_fn = act_map.get(act, nn.GELU())
598
+
599
+ self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug)
600
+ self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None)
601
+
602
+ mlp = dims * 4
603
+ self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
604
+
605
+ self.t_gate = t_gate(dims=dims, num_types=4) if t_gate else None
606
+ self.m_gate = m_gate(dims=dims, mem_size=mem_size) if m_gate else None
607
+ self.c_gate = c_gate(dims=dims) if cgate else None
608
+
609
+ self.lna = RMSNorm(dims)
610
+ self.lnb = RMSNorm(dims) if cross_attn else None
611
+ self.lnc = RMSNorm(dims)
612
+
613
+ if not any([t_gate, m_gate, c_gate]):
614
+ self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
615
+
616
+ def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
617
+ x = x.to(device, dtype)
618
+ if xa is not None:
619
+ xa = xa.to(device, dtype)
620
+
621
+ bln = self.blend
622
+ x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
623
+
624
+ if self.attnb and xa is not None:
625
+ c = self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
626
+ b = torch.sigmoid(bln)
627
+ x = b * x + (1 - b) * c
628
+
629
+ normx = self.lnc(x)
630
+ mlp_out = self.mlp(normx)
631
+
632
+ if self.skip_gates:
633
+ x = x + mlp_out
634
+
635
+ else:
636
+
637
+ if self.t_gate:
638
+ gate = self.t_gate(normx)
639
+ x = x + gate * mlp_out
640
+
641
+ elif self.m_gate:
642
+ gate = self.m_gate(normx)
643
+ x = x + gate * mlp_out
644
+
645
+ elif self.c_gate:
646
+ gate_output = self.c_gate(normx, self.features)
647
+ x = x + gate_output
648
+
649
+ else:
650
+ if hasattr(self, 'mlp_gate'):
651
+ mlp_gate = self.mlp_gate(normx)
652
+ x = x + mlp_gate * mlp_out
653
+ else:
654
+ x = x + mlp_out
655
+
656
+ if "residual" in self.debug and self.counter % 100 == 0:
657
+ print(f"Step {self.counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
658
+ if self.t_gate:
659
+ print(f"Step {self.counter}: Using t_gate: {self.t_gate}")
660
+ elif self.m_gate:
661
+ print(f"Step {self.counter}: Using m_gate: {self.m_gate}")
662
+ elif self.c_gate:
663
+ print(f"Step {self.counter}: Using c_gate: {self.c_gate}")
664
+ else:
665
+ print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
666
+ self.counter += 1
667
+
668
+ return x
669
+
670
+ class FEncoder(nn.Module):
671
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
672
+ super().__init__()
673
+
674
+ self.head = head
675
+ self.head_dim = dims // head
676
+ self.dropout = 0.01
677
+ self.use_rope = use_rope
678
+ self.dims = dims
679
+
680
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
681
+ act_fn = act_map.get(act, nn.GELU())
682
+
683
+ self.encoder = nn.Sequential(
684
+ Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
685
+ Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
686
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
687
+
688
+ if use_rope:
689
+ if spec_shape is not None:
690
+ self.rope = rotary(
691
+ dims=self.head_dim,
692
+ use_2d_axial=True,
693
+ spec_shape=spec_shape, debug=[])
694
+ else:
695
+ self.rope = rotary(
696
+ dims=self.head_dim,
697
+ use_2d_axial=False, debug=[])
698
+ else:
699
+ self.rope = None
700
+ self.positional = lambda length: sinusoids(length, dims)
701
+
702
+ self.norm = RMSNorm(dims)
703
+ self._norm = RMSNorm(dims)
704
+
705
+ def apply_rope_to_features(self, x, layer=None, feature_type="audio"):
706
+ if feature_type in ["envelope", "phase"]:
707
+ feature_type = "spectrogram"
708
+ batch, ctx, dims = x.shape
709
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
710
+ if feature_type == "spectrogram" and hasattr(self.rope, 'use_2d_axial') and self.rope.use_2d_axial:
711
+ rope_freqs = self.rope(ctx, layer=layer, input_type="spectrogram")
712
+ else:
713
+ rope_freqs = self.rope(ctx, layer=layer, input_type="audio")
714
+ x = self.rope.apply_rotary(x, rope_freqs)
715
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
716
+ return x
717
+
718
+ def forward(self, x, enc=None, layer=None, feature_type="audio"):
719
+ x = self.encoder(x).permute(0, 2, 1)
720
+ if self.use_rope:
721
+ x = self.apply_rope_to_features(x, layer=layer, feature_type=feature_type)
722
+ else:
723
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
724
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
725
+ x = self._norm(x)
726
+ return x
727
+
728
+ class WEncoder(nn.Module):
729
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
730
+ super().__init__()
731
+
732
+ self.head = head
733
+ self.head_dim = dims // head
734
+ self.dropout = 0.01
735
+ self.use_rope = use_rope
736
+ self.dims = dims
737
+
738
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
739
+ act_fn = act_map.get(act, nn.GELU())
740
+
741
+ self.downsample = nn.Sequential(
742
+ Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
743
+ Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn,
744
+ Conv1d(dims//4, dims, kernel_size=9, stride=5, padding=4), act_fn)
745
+
746
+ self.encoder = nn.Sequential(
747
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn,
748
+ Conv1d(dims, dims, kernel_size=1), act_fn)
749
+ if use_rope:
750
+ self.rope = rotary(
751
+ dims=self.head_dim,
752
+ use_2d_axial=False,
753
+ theta=50.0, debug=[])
754
+ else:
755
+ self.rope = None
756
+ self.positional = lambda length: sinusoids(length, dims)
757
+ self.norm = RMSNorm(dims)
758
+
759
+ def apply_rope_to_features(self, x, layer=None):
760
+ if not self.use_rope or self.rope is None:
761
+ return x
762
+ batch, ctx, dims = x.shape
763
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
764
+ rope_freqs = self.rope(ctx, layer=layer, input_type="waveform")
765
+ x = self.rope.apply_rotary(x, rope_freqs)
766
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
767
+ return x
768
+
769
+ def forward(self, x, enc=None, layer=None, feature_type="waveform"):
770
+ x = self.downsample(x)
771
+ x = self.encoder(x)
772
+ x = x.permute(0, 2, 1)
773
+ if self.use_rope:
774
+ x = self.apply_rope_to_features(x, layer=layer)
775
+ else:
776
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
777
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
778
+ return self.norm(x)
779
+
780
+ class PEncoder(nn.Module):
781
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
782
+ super().__init__()
783
+
784
+ self.head = head
785
+ self.head_dim = dims // head
786
+ self.dropout = 0.01
787
+ self.use_rope = use_rope
788
+ self.dims = dims
789
+
790
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
791
+ act_fn = act_map.get(act, nn.GELU())
792
+
793
+ self.encoder = nn.Sequential(
794
+ Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
795
+ Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
796
+ Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2), act_fn)
797
+
798
+ if use_rope:
799
+ self.rope = rotary(
800
+ dims=self.head_dim,
801
+ use_2d_axial=False,
802
+ theta=100.0, debug=[])
803
+ else:
804
+ self.rope = None
805
+ self.positional = lambda length: sinusoids(length, dims)
806
+ self.norm = RMSNorm(dims)
807
+
808
+ def apply_rope_to_features(self, x, layer=None):
809
+ if not self.use_rope or self.rope is None:
810
+ return x
811
+ batch, ctx, dims = x.shape
812
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
813
+ rope_freqs = self.rope(ctx, layer=layer, input_type="pitch")
814
+ x = self.rope.apply_rotary(x, rope_freqs)
815
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
816
+ return x
817
+
818
+ def forward(self, x, enc=None, layer=None, feature_type="pitch"):
819
+ x = self.encoder(x).permute(0, 2, 1)
820
+ if self.use_rope:
821
+ x = self.apply_rope_to_features(x, layer=layer)
822
+ else:
823
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
824
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
825
+ x = self.norm(x)
826
+ return x
827
+
828
+ class AudioEncoder(nn.Module):
829
+ _seen = set()
830
+ def __init__(self, mels: int, ctx: int, dims: int, head: int, layer: int, debug: List[str], features: List[str], act: str = "gelu"):
831
+ super(AudioEncoder, self).__init__()
832
+
833
+ self.dims = dims
834
+ self.head = head
835
+ self.ctx = ctx
836
+ self.head_dim = dims // head
837
+ self.debug = debug
838
+ self.counter = 0
839
+ self.features = features
840
+ self.dropout = 0.01
841
+
842
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),"tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
843
+ act_fn = act_map.get(act, nn.GELU())
844
+
845
+ if features == ["spectrogram", "waveform", "pitch"]:
846
+ cgate=True
847
+ else:
848
+ cgate = False
849
+
850
+ self.blocks = nn.ModuleDict({
851
+ "spectrogram": nn.ModuleList(
852
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
853
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "spectrogram" in features else None
854
+ ),
855
+ "waveform": nn.ModuleList(
856
+ [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
857
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "waveform" in features else None
858
+ ),
859
+ "pitch": nn.ModuleList(
860
+ [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
861
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
862
+ ),
863
+ "envelope": nn.ModuleList(
864
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
865
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "envelope" in features else None
866
+ ),
867
+ "phase": nn.ModuleList(
868
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
869
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None
870
+ )
871
+ })
872
+
873
+ def forward(self, enc, layer="encoder"):
874
+ enc = dict_to(enc, device, dtype)
875
+
876
+ if self.counter < 1:
877
+ s = enc.get("spectrogram")
878
+ w = enc.get("waveform")
879
+ p = default(enc.get("pitch"), enc.get("f0"))
880
+ plot_waveform(x=s, w=w, p=p, hop_length=128)
881
+
882
+ out = {}
883
+ out.update(enc)
884
+
885
+ for f in self.features:
886
+ if f in enc and f in self.blocks:
887
+ x = enc[f]
888
+ for block in self.blocks[f]:
889
+ x = block(x, enc=enc, layer=layer)
890
+ out[f] = x
891
+
892
+ if "encoder" in self.debug and self.counter % 100 == 0:
893
+ shapes = {k: v.shape for k, v in enc.items()}
894
+ print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
895
+ self.counter += 1
896
+ return out
897
+
898
+ class TextDecoder(nn.Module):
899
+ def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
900
+ debug: List[str], features: List[str]):
901
+ super(TextDecoder, self).__init__()
902
+
903
+ self.ctx = ctx
904
+ self.dims = dims
905
+ self.head = head
906
+ self.head_dim = dims // head
907
+ self.debug = debug
908
+ self.counter = 0
909
+ self.dropout = 0.01
910
+ self.features = features
911
+
912
+ self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
913
+ with torch.no_grad():
914
+ self.token.weight[0].zero_()
915
+ self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
916
+
917
+ self.block = nn.ModuleList([
918
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
919
+ for _ in range(layer)])
920
+
921
+ self.blocks = nn.ModuleDict({
922
+ f: nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
923
+ for _ in range(layer)]) for f in features})
924
+
925
+ self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
926
+ self.ln_dec = RMSNorm(dims)
927
+
928
+ mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
929
+ self.register_buffer("mask", mask, persistent=False)
930
+
931
+ def forward(self, x, enc, order=None, layer='decoder', sequential=False) -> Tensor:
932
+ enc = dict_to(enc, device, dtype)
933
+ x = x.to(device)
934
+ bln = self.blend
935
+
936
+ if order is None:
937
+ order = self.features
938
+
939
+ mask = self.mask[:x.shape[1], :x.shape[1]]
940
+ x = self.token(x) + self.positional[:x.shape[1]]
941
+ x = F.dropout(x, p=self.dropout, training=self.training)
942
+
943
+ for block in self.block:
944
+ x = block(x, xa=None, mask=mask, enc=None, layer=layer)
945
+
946
+ for f in order:
947
+ if f in enc:
948
+ xa = enc[f]
949
+ for block in self.blocks[f]:
950
+ out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
951
+
952
+ if sequential:
953
+ x = out
954
+ else:
955
+ a = torch.sigmoid(bln[f])
956
+ x = a * out + (1 - a) * x
957
+
958
+ if "decoder" in self.debug and self.counter % 100 == 0:
959
+ print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
960
+ self.counter += 1
961
+
962
+ x = self.ln_dec(x)
963
+ return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
964
+
965
+ class Echo(nn.Module):
966
+ def __init__(self, param: Dimensions):
967
+ super().__init__()
968
+ self.param = param
969
+ self.count = 0
970
+
971
+ self.encoder = AudioEncoder(
972
+ mels=param.mels,
973
+ ctx=param.aud_ctx,
974
+ dims=param.aud_dims,
975
+ head=param.aud_head,
976
+ layer=param.aud_idx,
977
+ act=param.act,
978
+ debug=param.debug,
979
+ features=param.features,
980
+ )
981
+
982
+ self.decoder = TextDecoder(
983
+ vocab=param.vocab,
984
+ ctx=param.text_ctx,
985
+ dims=param.text_dims,
986
+ head=param.text_head,
987
+ layer=param.text_idx,
988
+ cross_attn=param.cross_attn,
989
+ debug=param.debug,
990
+ features=param.features,
991
+ )
992
+
993
+ all_head = torch.zeros(self.param.text_idx, self.param.text_head, dtype=torch.bool)
994
+ all_head[self.param.text_idx // 2 :] = True
995
+ self.register_buffer("alignment_head", all_head.to_sparse(), persistent=False)
996
+
997
+ def update_base(self, f0):
998
+ for name, module in self.encoder.named_modules():
999
+ if isinstance(module, (rotary)):
1000
+ module.update_base(f0)
1001
+
1002
+ for name, module in self.decoder.named_modules():
1003
+ if isinstance(module, (rotary)):
1004
+ module.update_base(f0)
1005
+
1006
+ def set_alignment_head(self, dump: bytes):
1007
+ array = np.frombuffer(
1008
+ gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
1009
+ mask = torch.from_numpy(array).reshape(
1010
+ self.param.text_idx, self.param.text_head)
1011
+ self.register_buffer("alignment_head", mask.to_sparse(), persistent=False)
1012
+
1013
+ def embed_audio(self, spectrogram: torch.Tensor):
1014
+ return self.encoder(spectrogram)
1015
+
1016
+ def logits(self,input_ids: torch.Tensor, encoder_output: torch.Tensor):
1017
+ return self.decoder(input_ids, encoder_output)
1018
+
1019
+ def forward(self,
1020
+ decoder_input_ids=None,
1021
+ labels=None,
1022
+ waveform: Optional[torch.Tensor]=None,
1023
+ input_ids=None,
1024
+ spectrogram: torch.Tensor=None,
1025
+ pitch: Optional[torch.Tensor]=None,
1026
+ f0: Optional[torch.Tensor]=None,
1027
+ f0d: Optional[torch.Tensor]=None,
1028
+ envelope: Optional[torch.Tensor]=None,
1029
+ phase: Optional[torch.Tensor]=None,
1030
+ ) -> Dict[str, torch.Tensor]:
1031
+
1032
+ decoder_input_ids = input_ids
1033
+ encoder_inputs = {}
1034
+ if spectrogram is not None:
1035
+ encoder_inputs["spectrogram"] = spectrogram
1036
+ if waveform is not None:
1037
+ encoder_inputs["waveform"] = waveform
1038
+ if pitch is not None:
1039
+ encoder_inputs["pitch"] = pitch
1040
+ if envelope is not None:
1041
+ encoder_inputs["envelope"] = envelope
1042
+ if phase is not None:
1043
+ encoder_inputs["phase"] = phase
1044
+ if f0 is not None:
1045
+ encoder_inputs["f0"] = f0
1046
+
1047
+ encoder_outputs = self.encoder(encoder_inputs)
1048
+ logits = self.decoder(input_ids, encoder_outputs)
1049
+
1050
+ loss = None
1051
+ if labels is not None:
1052
+ loss = F.cross_entropy(
1053
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
1054
+
1055
+ self.count += 1
1056
+ return {
1057
+ "logits": logits,
1058
+ "loss": loss,
1059
+ }
1060
+
1061
+ @property
1062
+ def device(self):
1063
+ return next(self.parameters()).device
1064
+ @property
1065
+ def dtype(self):
1066
+ return next(self.parameters()).dtype
1067
+
1068
+ def _init_weights(self, module):
1069
+ std = 0.02
1070
+ self.init_counts = {
1071
+ "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
1072
+ "Conv2d": 0, "SEBlock": 0, "TextDecoder": 0, "AudioEncoder": 0,
1073
+ "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
1074
+ "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
1075
+ "WEncoder": 0, "PEncoder": 0}
1076
+
1077
+ for name, module in self.named_modules():
1078
+ if isinstance(module, RMSNorm):
1079
+ nn.init.ones_(module.weight)
1080
+ self.init_counts["RMSNorm"] += 1
1081
+ elif isinstance(module, nn.Linear):
1082
+ if module.weight is not None:
1083
+ nn.init.xavier_uniform_(module.weight)
1084
+ if module.bias is not None:
1085
+ nn.init.zeros_(module.bias)
1086
+ self.init_counts["Linear"] += 1
1087
+ elif isinstance(module, Conv1d):
1088
+ nn.init.normal_(module.weight, mean=0.0, std=std)
1089
+ if module.bias is not None:
1090
+ nn.init.zeros_(module.bias)
1091
+ self.init_counts["Conv1d"] += 1
1092
+ elif isinstance(module, Conv2d):
1093
+ nn.init.normal_(module.weight, mean=0.0, std=std)
1094
+ if module.bias is not None:
1095
+ nn.init.zeros_(module.bias)
1096
+ self.init_counts["Conv2d"] += 1
1097
+ elif isinstance(module, MultiheadA):
1098
+
1099
+ self.init_counts["MultiheadA"] += 1
1100
+ elif isinstance(module, TextDecoder):
1101
+ self.init_counts["TextDecoder"] += 1
1102
+ elif isinstance(module, AudioEncoder):
1103
+ self.init_counts["AudioEncoder"] += 1
1104
+ elif isinstance(module, Residual):
1105
+ self.init_counts["Residual"] += 1
1106
+
1107
+ def init_weights(self):
1108
+ print("Initializing model weights...")
1109
+ self.apply(self._init_weights)
1110
+ print("Initialization summary:")
1111
+ for module_type, count in self.init_counts.items():
1112
+ if count > 0:
1113
+ print(f"{module_type}: {count}")
1114
+
1115
+ def register_gradient_hooks(self):
1116
+ for name, param in self.named_parameters():
1117
+ if param.requires_grad:
1118
+ if "encoder" in name:
1119
+ param.register_hook(lambda grad, n=name: self._print_encoder_grad(n, grad))
1120
+ elif "decoder" in name:
1121
+ param.register_hook(lambda grad, n=name: self._print_decoder_grad(n, grad))
1122
+
1123
+ print("Gradient debugging hooks registered")
1124
+ return self
1125
+
1126
+ def _print_encoder_grad(self, name, grad):
1127
+ if grad is not None and self.count == 10:
1128
+ norm = grad.median().item()
1129
+ print(f"ENCODER GRAD: {name} = {norm:.6f}")
1130
+
1131
+ return None
1132
+
1133
+ def _print_decoder_grad(self, name, grad):
1134
+ if grad is not None and self.count == 10:
1135
+ norm = grad.median().item()
1136
+ print(f"DECODER GRAD: {name} = {norm:.6f}")
1137
+ return None
1138
+
1139
+ def resetcounter(self):
1140
+ self.counter = 0
1141
+ print("Counter reset to 0.")
1142
+
1143
+ metric = evaluate.load(path="wer")
1144
+
1145
+ @dataclass
1146
+ class DataCollator:
1147
+ tokenizer: Any
1148
+ def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1149
+ pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1150
+ bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
1151
+
1152
+ batch = {}
1153
+
1154
+ if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
1155
+ spectrogram_list = [f["spectrogram"] for f in features]
1156
+ max_len_feat = max(f.shape[-1] for f in spectrogram_list)
1157
+ pad_spectrogram = []
1158
+ for feat in spectrogram_list:
1159
+ current_len = feat.shape[-1]
1160
+ padding = max_len_feat - current_len
1161
+ if padding > 0:
1162
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1163
+ else:
1164
+ pad_feat = feat
1165
+ pad_spectrogram.append(pad_feat)
1166
+ batch["spectrogram"] = torch.stack(pad_spectrogram)
1167
+
1168
+ if "waveform" in features[0] and features[0]["waveform"] is not None:
1169
+ waveform_list = [f["waveform"] for f in features]
1170
+ max_len_wav = max(w.shape[-1] for w in waveform_list)
1171
+ pad_waveforms = []
1172
+ for wav in waveform_list:
1173
+ current_len = wav.shape[-1]
1174
+ padding = max_len_wav - current_len
1175
+ if padding > 0:
1176
+ if wav.ndim == 1:
1177
+ wav = wav.unsqueeze(0)
1178
+ pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
1179
+ else:
1180
+ pad_wav = wav
1181
+ pad_waveforms.append(pad_wav)
1182
+ batch["waveform"] = torch.stack(pad_waveforms)
1183
+
1184
+ if "label" in features[0] and features[0]["label"] is not None:
1185
+ labels_list = [f["label"] for f in features]
1186
+ max_len = max(len(l) for l in labels_list)
1187
+ all_ids = []
1188
+ all_labels = []
1189
+
1190
+ for label in labels_list:
1191
+ label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1192
+ decoder_input = [bos_token_id] + label_list
1193
+ label_eos = label_list + [pad_token_id]
1194
+ input_len = max_len + 1 - len(decoder_input)
1195
+ label_len = max_len + 1 - len(label_eos)
1196
+ padded_input = decoder_input + [pad_token_id] * input_len
1197
+ padded_labels = label_eos + [pad_token_id] * label_len
1198
+ all_ids.append(padded_input)
1199
+ all_labels.append(padded_labels)
1200
+ batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1201
+ batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1202
+
1203
+ if "pitch" in features[0] and features[0]["pitch"] is not None:
1204
+ pitch_list = [f["pitch"] for f in features]
1205
+ max_len_pitch = max(e.shape[-1] for e in pitch_list)
1206
+ pad_pitch = []
1207
+ for pitch in pitch_list:
1208
+ current_len = pitch.shape[-1]
1209
+ padding = max_len_pitch - current_len
1210
+ if padding > 0:
1211
+ pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
1212
+ else:
1213
+ pad_pitch_item = pitch
1214
+ pad_pitch.append(pad_pitch_item)
1215
+ batch["pitch"] = torch.stack(pad_pitch)
1216
+
1217
+ if "f0" in features[0] and features[0]["f0"] is not None:
1218
+ f0_list = [f["f0"] for f in features]
1219
+ max_len_f0 = max(f.shape[-1] for f in f0_list)
1220
+ pad_f0 = []
1221
+ for f0 in f0_list:
1222
+ current_len = f0.shape[-1]
1223
+ padding = max_len_f0 - current_len
1224
+ if padding > 0:
1225
+ pad_f0_item = F.pad(f0, (0, padding), mode='constant', value=pad_token_id)
1226
+ else:
1227
+ pad_f0_item = f0
1228
+ pad_f0.append(pad_f0_item)
1229
+ batch["f0"] = torch.stack(pad_f0)
1230
+
1231
+ if "envelope" in features[0] and features[0]["envelope"] is not None:
1232
+ env_list = [f["envelope"] for f in features]
1233
+ max_len = max(f.shape[-1] for f in env_list)
1234
+ pad_env = []
1235
+ for feat in env_list:
1236
+ current_len = feat.shape[-1]
1237
+ padding = max_len - current_len
1238
+ if padding > 0:
1239
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1240
+ else:
1241
+ pad_feat = feat
1242
+ pad_env.append(pad_feat)
1243
+ batch["envelope"] = torch.stack(pad_env)
1244
+
1245
+ if "phase" in features[0] and features[0]["phase"] is not None:
1246
+ ph_list = [f["phase"] for f in features]
1247
+ max_len = max(f.shape[-1] for f in ph_list)
1248
+ pad_ph = []
1249
+ for feat in ph_list:
1250
+ current_len = feat.shape[-1]
1251
+ padding = max_len - current_len
1252
+ if padding > 0:
1253
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1254
+ else:
1255
+ pad_feat = feat
1256
+ pad_ph.append(pad_feat)
1257
+ batch["phase"] = torch.stack(pad_ph)
1258
+ return batch
1259
+
1260
+ def hilbert_transform(x):
1261
+ N = x.shape[-1]
1262
+ xf = torch.fft.rfft(x)
1263
+ h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
1264
+ if N % 2 == 0:
1265
+ h[0] = h[N//2] = 1
1266
+ h[1:N//2] = 2
1267
+ else:
1268
+ h[0] = 1
1269
+ h[1:(N+1)//2] = 2
1270
+ return torch.fft.irfft(xf * h, n=N)
1271
+
1272
+ def analytic_signal(x):
1273
+ return x + 1j * hilbert_transform(x)
1274
+
1275
+ def hilbert_transform_2d(x, dim=-1):
1276
+ N = x.shape[dim]
1277
+ if dim == -1 or dim == len(x.shape) - 1:
1278
+ xf = torch.fft.rfft(x)
1279
+ else:
1280
+ xf = torch.fft.rfft(x, dim=dim)
1281
+ h_shape = [1] * len(x.shape)
1282
+ h_shape[dim] = N // 2 + 1
1283
+ h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
1284
+ if dim == -1 or dim == len(x.shape) - 1:
1285
+ if N % 2 == 0:
1286
+ h[..., 0] = h[..., -1] = 1
1287
+ h[..., 1:-1] = 2
1288
+ else:
1289
+ h[..., 0] = 1
1290
+ h[..., 1:] = 2
1291
+ else:
1292
+ pass
1293
+ return torch.fft.irfft(xf * h, n=N, dim=dim)
1294
+
1295
+ def hilbert_transform_true_2d(x):
1296
+ xf = torch.fft.rfft2(x)
1297
+ h1, h2 = torch.meshgrid(
1298
+ torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
1299
+ torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
1300
+ indexing='ij')
1301
+ h = -1j / (math.pi * (h1 + 1j*h2))
1302
+ h[0, 0] = 0
1303
+ return torch.fft.irfft2(xf * h.to(x.device))
1304
+
1305
+ def process_spectrogram_with_hilbert(spec):
1306
+ analytic = spec + 1j * hilbert_transform(spec)
1307
+ envelope = torch.abs(analytic)
1308
+ phase = torch.angle(analytic)
1309
+ return envelope, phase
1310
+
1311
+ def load_wave(wave_data, sample_rate):
1312
+ if isinstance(wave_data, str):
1313
+ waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1314
+ elif isinstance(wave_data, dict):
1315
+ waveform = torch.tensor(data=wave_data["array"]).float()
1316
+ sr = wave_data["sampling_rate"]
1317
+ else:
1318
+ raise TypeError("Invalid wave_data format.")
1319
+
1320
+ if waveform.dim() == 1:
1321
+ waveform = waveform.unsqueeze(0)
1322
+
1323
+ if sr != sample_rate:
1324
+ original_length = waveform.shape[1]
1325
+ target_length = int(original_length * (sample_rate / sr))
1326
+
1327
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1328
+ waveform = resampler(waveform)
1329
+
1330
+ return waveform.flatten()
1331
+
1332
+ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
1333
+ hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
1334
+ pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
1335
+ norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
1336
+
1337
+ dtype = torch.float32
1338
+ device = torch.device("cuda:0")
1339
+ audio = batch["audio"]
1340
+ sampling_rate = audio["sampling_rate"]
1341
+ sr = audio["sampling_rate"]
1342
+ wav = load_wave(wave_data=audio, sample_rate=sr)
1343
+
1344
+ if spectrogram:
1345
+ transform = torchaudio.transforms.MelSpectrogram(
1346
+ f_max=fmax,
1347
+ f_min=fmin,
1348
+ n_mels=n_mels,
1349
+ sample_rate=sr,
1350
+ n_fft=n_fft,
1351
+ hop_length=hop_length,
1352
+ norm=norm,
1353
+ normalized=normalized,
1354
+ power=power,
1355
+ center=center,
1356
+ mel_scale=mel_scale,
1357
+ window_fn=window_fn,
1358
+ pad_mode=pad_mode)
1359
+
1360
+ mel_spectrogram = transform(wav)
1361
+ log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1362
+ log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1363
+ spec = (log_mel + 4.0) / 4.0
1364
+ spec = torch.tensor(spec)
1365
+ batch["spectrogram"] = spec
1366
+
1367
+ if hilbert:
1368
+ envelope_list = []
1369
+ phase_list = []
1370
+
1371
+ for ch_idx in range(spec.shape[0]):
1372
+ envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
1373
+ envelope_list.append(envelope)
1374
+ phase_list.append(phase)
1375
+
1376
+ batch["envelope"] = torch.stack(envelope_list)
1377
+ batch["phase"] = torch.stack(phase_list)
1378
+
1379
+ wav_1d = wav.unsqueeze(0)
1380
+
1381
+ if waveforms:
1382
+ batch["waveform"] = wav_1d
1383
+
1384
+ if pitch:
1385
+ wav_np = wav.numpy().astype(np.float64)
1386
+ f0, t = pw.dio(wav_np, sampling_rate,
1387
+ frame_period=hop_length/sampling_rate*1000)
1388
+ f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1389
+ f0 = torch.from_numpy(f0)
1390
+ batch["pitch"] = f0.unsqueeze(0)
1391
+
1392
+ if frequency:
1393
+ wav_np = wav.numpy().astype(np.float64)
1394
+ f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000)
1395
+ f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1396
+ f0 = torch.from_numpy(f0)
1397
+ batch["f0"] = f0
1398
+
1399
+ if spectrogram and waveforms and pitch:
1400
+ spec_mean = batch["spectrogram"].mean()
1401
+ spec_std = batch["spectrogram"].std() + 1e-6
1402
+ batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
1403
+
1404
+ wav_mean = batch["waveform"].mean()
1405
+ wav_std = batch["waveform"].std() + 1e-6
1406
+ batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
1407
+
1408
+ if batch["pitch"].max() > 1.0:
1409
+ pitch_min = 50.0
1410
+ pitch_max = 500.0
1411
+ batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
1412
+
1413
+ batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1414
+ return batch
1415
+
1416
+ def compute_metrics(eval_pred, compute_result: bool = True,
1417
+ print_pred: bool = False, num_samples: int = 0, tokenizer=None, model=None):
1418
+
1419
+ pred_logits = eval_pred.predictions
1420
+ label_ids = eval_pred.label_ids
1421
+
1422
+ if hasattr(pred_logits, "cpu"):
1423
+ pred_logits = pred_logits.cpu()
1424
+ else:
1425
+ pred_logits = torch.tensor(pred_logits).cpu()
1426
+ if hasattr(label_ids, "cpu"):
1427
+ label_ids = label_ids.cpu()
1428
+ else:
1429
+ label_ids = torch.tensor(label_ids).cpu()
1430
+
1431
+ if isinstance(pred_logits, tuple):
1432
+ pred_ids = pred_logits[0]
1433
+ else:
1434
+ pred_ids = pred_logits
1435
+ if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
1436
+ if not isinstance(pred_ids, torch.Tensor):
1437
+ pred_ids = torch.tensor(pred_ids)
1438
+ pred_ids = pred_ids.argmax(dim=-1)
1439
+ pred_ids = pred_ids.tolist()
1440
+
1441
+ if hasattr(label_ids, "tolist"):
1442
+ label_ids = label_ids.tolist()
1443
+
1444
+ label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
1445
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1446
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1447
+
1448
+ if print_pred:
1449
+ for i in range(min(num_samples, len(pred_str))):
1450
+ print(f"Preds: {pred_str[i]}")
1451
+ print(f"Label: {label_str[i]}")
1452
+ print("--------------------------------")
1453
+
1454
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1455
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1456
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1457
+
1458
+ if model is None:
1459
+ global global_model
1460
+ if 'global_model' in globals():
1461
+ model = global_model
1462
+
1463
+ if model is not None:
1464
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1465
+ if trainable_params > 0:
1466
+ efficiency_score = (100 - wer) / trainable_params
1467
+ else:
1468
+ print("Warning: Zero trainable parameters detected")
1469
+ efficiency_score = 0.0
1470
+ else:
1471
+ print("Warning: Model not available for parameter counting")
1472
+ trainable_params = 0.0
1473
+ efficiency_score = 0.0
1474
+
1475
+ if hasattr(wer, "item"):
1476
+ wer = wer.item()
1477
+
1478
+ metrics = {
1479
+ "wer": float(wer),
1480
+ "trainable_params_M": float(trainable_params),
1481
+ "efficiency_score": float(efficiency_score),
1482
+ }
1483
+
1484
+ return metrics
1485
+
1486
+ logger = logging.getLogger(__name__)
1487
+
1488
+ def create_model(param: Dimensions) -> Echo:
1489
+ model = Echo(param).to('cuda')
1490
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1491
+ total_params = sum(p.numel() for p in model.parameters())
1492
+ logger.info(f"Trainable parameters: {trainable_params:,}")
1493
+ logger.info(f"Total parameters: {total_params:,}")
1494
+ print(f"Trainable parameters: {trainable_params:,}")
1495
+ print(f"Total parameters: {total_params:,}")
1496
+
1497
+ return model
1498
+
1499
+ def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
1500
+ from tokenizers import Tokenizer
1501
+ tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
1502
+ orig_encode = tokenizer.encode
1503
+ def enc(text, add_special_tokens=True):
1504
+ ids = orig_encode(text).ids
1505
+ if not add_special_tokens:
1506
+ sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1507
+ ids = [id for id in ids if id not in sp_ids]
1508
+ return ids
1509
+ def bdec(ids_list, skip_special_tokens=True):
1510
+ results = []
1511
+ for ids in ids_list:
1512
+ if skip_special_tokens:
1513
+ ids = [id for id in ids if id not in [0, 1, 2]]
1514
+ results.append(tokenizer.decode(ids))
1515
+ return results
1516
+ def save_pretrained(save_dir):
1517
+ os.makedirs(save_dir, exist_ok=True)
1518
+ tokenizer.save(f"{save_dir}/tokenizer.json")
1519
+ tokenizer.encode = enc
1520
+ tokenizer.batch_decode = bdec
1521
+ tokenizer.save_pretrained = save_pretrained
1522
+ tokenizer.pad_token_id = 0
1523
+ tokenizer.bos_token_id = 1
1524
+ tokenizer.eos_token_id = 2
1525
+ return tokenizer
1526
+
1527
+ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
1528
+ if dataset_config is None:
1529
+ dataset_config = {
1530
+ "spectrogram": True,
1531
+ "waveforms": True,
1532
+ "pitch": True,
1533
+ "frequency": True,
1534
+ "downsamples": True,
1535
+ "hop_length": 128,
1536
+ "fmin": 50,
1537
+ "fmax": 2000,
1538
+ "n_mels": 128,
1539
+ "n_fft": 1024,
1540
+ "sampling_rate": 16000,
1541
+ }
1542
+
1543
+ dataset = load_dataset(
1544
+ "google/fleurs",
1545
+ "en_us",
1546
+ token=token,
1547
+ trust_remote_code=True,
1548
+ streaming=False)
1549
+
1550
+ dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1551
+
1552
+ if sanity_check:
1553
+ dataset = dataset["test"]
1554
+ dataset = dataset.select_columns(["audio", "transcription"])
1555
+ prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1556
+ dataset = dataset.map(function=prepare_fn, remove_columns=["audio", "transcription"]).with_format(type="torch")
1557
+ train_dataset = dataset
1558
+ test_dataset = dataset
1559
+ else:
1560
+ def filter_func(x):
1561
+ return (0 < len(x["transcription"]) < 512 and
1562
+ len(x["audio"]["array"]) > 0 and
1563
+ len(x["audio"]["array"]) < 1500 * 160)
1564
+
1565
+ dataset = dataset.filter(filter_func)
1566
+ prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1567
+ train_dataset = dataset["train"].take(10000)
1568
+ test_dataset = dataset["test"].take(1000)
1569
+
1570
+ train_dataset = train_dataset.map(
1571
+ function=prepare_fn,
1572
+ remove_columns=["audio", "transcription"]
1573
+ ).with_format(type="torch")
1574
+
1575
+ test_dataset = test_dataset.map(
1576
+ function=prepare_fn,
1577
+ remove_columns=["audio", "transcription"]
1578
+ ).with_format(type="torch")
1579
+
1580
+ return train_dataset, test_dataset
1581
+
1582
+ def get_training_args(
1583
+ log_dir: str,
1584
+ batch_eval_metrics: bool = False,
1585
+ max_steps: int = 10,
1586
+ save_steps: int = 1000,
1587
+ eval_steps: int = 1,
1588
+ warmup_steps: int = 0,
1589
+ num_train_epochs: int = 1,
1590
+ logging_steps: int = 1,
1591
+ eval_on_start: bool = False,
1592
+ learning_rate: float = 1e-4,
1593
+ weight_decay: float = 0.01,
1594
+ max_grad_norm: float = 1.0,
1595
+ ) -> Seq2SeqTrainingArguments:
1596
+
1597
+ return Seq2SeqTrainingArguments(
1598
+ output_dir=log_dir,
1599
+ per_device_train_batch_size=1,
1600
+ per_device_eval_batch_size=2,
1601
+ gradient_accumulation_steps=1,
1602
+ eval_accumulation_steps=4,
1603
+ eval_strategy="steps",
1604
+ save_strategy="no",
1605
+ max_steps=max_steps,
1606
+ save_steps=save_steps,
1607
+ eval_steps=eval_steps,
1608
+ warmup_steps=warmup_steps,
1609
+ num_train_epochs=num_train_epochs,
1610
+ logging_steps=logging_steps,
1611
+ logging_dir=log_dir,
1612
+ logging_strategy="steps",
1613
+ report_to=["tensorboard"],
1614
+ push_to_hub=False,
1615
+ disable_tqdm=False,
1616
+ save_total_limit=1,
1617
+ label_names=["labels"],
1618
+ optim="adamw_torch",
1619
+ adam_beta1=0.9,
1620
+ adam_beta2=0.999,
1621
+ adam_epsilon=1e-8,
1622
+ lr_scheduler_type="cosine",
1623
+ learning_rate=learning_rate,
1624
+ weight_decay=weight_decay,
1625
+ save_safetensors=False,
1626
+ eval_on_start=eval_on_start,
1627
+ batch_eval_metrics=batch_eval_metrics,
1628
+ max_grad_norm=max_grad_norm,
1629
+ )
1630
+
1631
+ def main():
1632
+
1633
+ token = ""
1634
+ log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H_%M_%S'))
1635
+ os.makedirs(name=log_dir, exist_ok=True)
1636
+ tokenizer = setup_tokenizer(token)
1637
+
1638
+ def sanity(sanity: bool):
1639
+
1640
+ if sanity:
1641
+ training_args = get_training_args(
1642
+ log_dir,
1643
+ batch_eval_metrics = False,
1644
+ max_steps = 10,
1645
+ save_steps = 0,
1646
+ eval_steps = 1,
1647
+ warmup_steps = 0,
1648
+ logging_steps = 1,
1649
+ eval_on_start = False,
1650
+ learning_rate = 5e-6,
1651
+ weight_decay = 0.01,
1652
+ max_grad_norm = 0.6,
1653
+ )
1654
+ else:
1655
+ training_args = get_training_args(
1656
+ log_dir,
1657
+ batch_eval_metrics = False,
1658
+ max_steps = 10000,
1659
+ save_steps = 1005,
1660
+ eval_steps = 1000,
1661
+ warmup_steps = 1000,
1662
+ logging_steps = 100,
1663
+ eval_on_start = False,
1664
+ learning_rate = 2.5e-4,
1665
+ weight_decay = 0.01,
1666
+ max_grad_norm = 0.6,
1667
+ )
1668
+
1669
+ return training_args
1670
+
1671
+ param = Dimensions(
1672
+ mels=128,
1673
+ aud_ctx=1500,
1674
+ aud_head=4,
1675
+ aud_dims=512,
1676
+ aud_idx=4,
1677
+ vocab=40000,
1678
+ text_ctx=512,
1679
+ text_head=4,
1680
+ text_dims=512,
1681
+ text_idx=4,
1682
+ act="swish",
1683
+ debug={},
1684
+ cross_attn=True,
1685
+ features = ["spectrogram"],
1686
+ )
1687
+
1688
+ sanity_check = False
1689
+
1690
+ training_args = sanity(sanity_check)
1691
+ dataset_config = {
1692
+ "spectrogram": True,
1693
+ "waveforms": False,
1694
+ "pitch": False,
1695
+ "downsamples": False,
1696
+ "frequency": True,
1697
+ "hilbert": False,
1698
+ "hop_length": 128,
1699
+ "fmin": 150,
1700
+ "fmax": 2000,
1701
+ "n_mels": 128,
1702
+ "n_fft": 1024,
1703
+ "sampling_rate": 16000,
1704
+ "pad_mode": "constant",
1705
+ "center": True,
1706
+ "power": 2.0,
1707
+ "window_fn": torch.hann_window,
1708
+ "mel_scale": "htk",
1709
+ "norm": None,
1710
+ "normalized": False}
1711
+
1712
+ model = create_model(param)
1713
+
1714
+ global global_model
1715
+ global_model = model
1716
+
1717
+ metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
1718
+ tokenizer=tokenizer, model=model)
1719
+
1720
+ print(f"{'Sanity check' if sanity_check else 'Training'} mode")
1721
+ train_dataset, test_dataset = prepare_datasets(
1722
+ tokenizer=tokenizer,
1723
+ token=token,
1724
+ sanity_check=sanity_check,
1725
+ dataset_config=dataset_config)
1726
+
1727
+ trainer = Seq2SeqTrainer(
1728
+ args=training_args,
1729
+ model=model,
1730
+ train_dataset=train_dataset,
1731
+ eval_dataset=test_dataset,
1732
+ data_collator=DataCollator(tokenizer=tokenizer),
1733
+ compute_metrics=metrics_fn,
1734
+ )
1735
+
1736
+ model.init_weights()
1737
+ trainer.train()
1738
+
1739
+ if __name__ == "__main__":
1740
+ main()
1741
+