Update model.py
Browse files
model.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import pyworld as pw
|
2 |
import os
|
3 |
import math
|
@@ -34,20 +35,6 @@ dtype = torch.float32
|
|
34 |
warnings.filterwarnings("ignore")
|
35 |
logging.basicConfig(level=logging.ERROR)
|
36 |
|
37 |
-
from rich.traceback import install
|
38 |
-
install(show_locals=True)
|
39 |
-
|
40 |
-
import pretty_errors
|
41 |
-
pretty_errors.configure(
|
42 |
-
separator_character = '*',
|
43 |
-
filename_display = pretty_errors.FILENAME_EXTENDED,
|
44 |
-
line_number_first = True,
|
45 |
-
display_link = True,
|
46 |
-
lines_before = 5,
|
47 |
-
lines_after = 2,
|
48 |
-
line_color = pretty_errors.RED + '> ' + pretty_errors.default_config.line_color,
|
49 |
-
code_color = ' ' + pretty_errors.default_config.line_color,
|
50 |
-
)
|
51 |
|
52 |
extractor = None
|
53 |
tokenizer = None
|
@@ -256,93 +243,12 @@ def get_dtype():
|
|
256 |
def tox():
|
257 |
return {"device": get_device(), "dtype": get_dtype()}
|
258 |
|
259 |
-
def sinusoids(length,
|
260 |
-
assert
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
return torch.cat([torch.sin(
|
265 |
-
|
266 |
-
def align_f0(f0, ctx):
|
267 |
-
b, l = f0.shape
|
268 |
-
if l == ctx:
|
269 |
-
return f0.squeeze(0).float()
|
270 |
-
frames_per_token = l / ctx
|
271 |
-
idx = torch.arange(ctx, device=device, dtype=dtype)
|
272 |
-
src_idx = (idx * frames_per_token).long().clamp(0, l-1)
|
273 |
-
batch_idx = torch.arange(b, device=device, dtype=dtype).unsqueeze(1)
|
274 |
-
f0 = f0[batch_idx, src_idx]
|
275 |
-
return f0.squeeze(0).float()
|
276 |
-
|
277 |
-
def align_f0(f0, target_length, method='nearest', device=device, dtype=dtype):
|
278 |
-
if device is None:
|
279 |
-
device = f0.device
|
280 |
-
if dtype is None:
|
281 |
-
dtype = f0.dtype
|
282 |
-
original_shape = f0.shape
|
283 |
-
squeeze_batch = False
|
284 |
-
reshape_back = None
|
285 |
-
|
286 |
-
if f0.dim() == 1:
|
287 |
-
f0 = f0.unsqueeze(0)
|
288 |
-
squeeze_batch = True
|
289 |
-
elif f0.dim() == 2:
|
290 |
-
pass
|
291 |
-
elif f0.dim() == 3:
|
292 |
-
batch_size, ctx, length = f0.shape
|
293 |
-
f0 = f0.view(-1, length)
|
294 |
-
reshape_back = (batch_size, ctx)
|
295 |
-
else:
|
296 |
-
raise ValueError(f"F0 tensor must be 1D, 2D, or 3D, got {f0.dim()}D")
|
297 |
-
batch_size, current_length = f0.shape
|
298 |
-
if current_length == target_length:
|
299 |
-
result = f0
|
300 |
-
elif method == 'nearest':
|
301 |
-
frames_per_token = current_length / target_length
|
302 |
-
target_indices = torch.arange(target_length, device=device, dtype=torch.float32)
|
303 |
-
source_indices = (target_indices * frames_per_token).long().clamp(0, current_length - 1)
|
304 |
-
batch_indices = torch.arange(batch_size, device=device, dtype=torch.long).unsqueeze(1)
|
305 |
-
result = f0[batch_indices, source_indices]
|
306 |
-
else:
|
307 |
-
import torch.nn.functional as F
|
308 |
-
f0_for_interp = f0.unsqueeze(1)
|
309 |
-
mode_map = {'linear': 'linear', 'cubic': 'bicubic'}
|
310 |
-
if method not in mode_map:
|
311 |
-
raise ValueError(f"Method '{method}' not supported. Use 'nearest', 'linear', or 'cubic'")
|
312 |
-
|
313 |
-
result = F.interpolate(
|
314 |
-
f0_for_interp.float(),
|
315 |
-
size=target_length,
|
316 |
-
mode=mode_map[method],
|
317 |
-
align_corners=False
|
318 |
-
).squeeze(1)
|
319 |
-
|
320 |
-
if reshape_back is not None:
|
321 |
-
result = result.view(reshape_back[0], reshape_back[1], target_length)
|
322 |
-
elif squeeze_batch:
|
323 |
-
result = result.squeeze(0)
|
324 |
-
return result.to(dtype)
|
325 |
-
|
326 |
-
# def update_base(self, f0):
|
327 |
-
# f0 = f0.to(device, dtype)
|
328 |
-
# f0_mean = f0.mean() + 1e-8
|
329 |
-
|
330 |
-
# # Standard RoPE calculation (keep this)
|
331 |
-
# theta_freqs = 1.0 / (f0_mean ** (torch.arange(0, self.dim, 2, device=device, dtype=dtype)[:(self.dim // 2)].float() / self.dim))
|
332 |
-
|
333 |
-
# # Direct f0-adapted mel scale (new part)
|
334 |
-
# center_freq = f0_mean
|
335 |
-
# min_freq = center_freq * 0.25 # Lower bound
|
336 |
-
# max_freq = center_freq * 4.0 # Upper bound
|
337 |
-
|
338 |
-
# # Direct mel calculation centered on f0
|
339 |
-
# mel_min = 2595 * torch.log10(1 + min_freq/700)
|
340 |
-
# mel_max = 2595 * torch.log10(1 + max_freq/700)
|
341 |
-
# mel_freqs = 700 * (torch.pow(10, torch.linspace(mel_min, mel_max, self.dim//2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
342 |
-
|
343 |
-
# # Use a weighted combination
|
344 |
-
# self.inv_freq.data.copy_(0.5 * theta_freqs + 0.5 * mel_freqs)
|
345 |
-
# self.theta.data.copy_(f0_mean)
|
346 |
|
347 |
class rotary(nn.Module):
|
348 |
def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [],
|
@@ -366,15 +272,23 @@ class rotary(nn.Module):
|
|
366 |
theta = torch.tensor(theta, device=device, dtype=dtype)
|
367 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
368 |
self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
|
369 |
-
inv_freq = (theta /
|
370 |
self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
|
371 |
|
372 |
def update_base(self, f0):
|
373 |
f0 = f0.squeeze(0).to(device, dtype)
|
374 |
theta = f0.mean() + 1e-8
|
375 |
-
inv_freq = (theta /
|
376 |
self.inv_freq.data.copy_(inv_freq)
|
377 |
-
self.theta.data.copy_(theta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
|
379 |
def get_pitch_bias(self, f0):
|
380 |
if f0 is None:
|
@@ -386,14 +300,26 @@ class rotary(nn.Module):
|
|
386 |
return f0_sim.unsqueeze(0).unsqueeze(0)
|
387 |
|
388 |
def f0proj(self, f0):
|
389 |
-
|
390 |
-
f0 = f0.
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
|
|
393 |
|
394 |
-
def align_f0(self,
|
|
|
395 |
f0 = self.f0proj(f0)
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
if f0.dim() == 1:
|
398 |
length = f0.shape[0]
|
399 |
if length == ctx:
|
@@ -410,64 +336,48 @@ class rotary(nn.Module):
|
|
410 |
idx = torch.arange(ctx, device=f0.device)
|
411 |
idx = (idx * frames).long().clamp(0, length - 1)
|
412 |
return f0[idx, :]
|
413 |
-
|
414 |
-
|
415 |
-
# R = torch.eye(dims).to(theta.device)
|
416 |
-
# R[i, i] = torch.cos(theta)
|
417 |
-
# R[i, j] = -torch.sin(theta)
|
418 |
-
# R[j, i] = torch.sin(theta)
|
419 |
-
# R[j, j] = torch.cos(theta)
|
420 |
-
# R = torch.eye(dims).to(theta.device) - 2 * torch.outer(R, R) / torch.dot(R, R)
|
421 |
-
# return R
|
422 |
-
|
423 |
-
# def orthogonal_regularization_term(self):
|
424 |
-
# loss = torch.tensor(0.0, device=self.r_matrix.device)
|
425 |
-
# if self.r_matrix.requires_grad:
|
426 |
-
# product = torch.matmul(self.r_matrix, self.r_matrix.t())
|
427 |
-
# identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)
|
428 |
-
# loss = ((product - identity) ** 2).sum()
|
429 |
-
# return self.orthogonal_reg_weight * loss
|
430 |
-
|
431 |
-
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
432 |
-
f0 = enc.get("f0", None) if enc is not None else None
|
433 |
-
|
434 |
if isinstance(x, int):
|
435 |
ctx = x
|
436 |
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
437 |
batch, ctx, dims = x.shape
|
438 |
else:
|
439 |
batch, head, ctx, head_dim = x.shape
|
440 |
-
|
441 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
442 |
-
|
443 |
-
if f0 is not None:
|
444 |
-
freqs = self.inv_freq
|
445 |
-
f0_mean = f0.mean()
|
446 |
-
theta = f0_mean + 1e-8
|
447 |
-
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
448 |
-
|
449 |
-
if "rotary1" in self.debug:
|
450 |
-
print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
|
451 |
-
else:
|
452 |
-
freqs = self.inv_freq
|
453 |
freqs = t[:, None] * freqs[None, :]
|
454 |
|
455 |
-
# sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
|
456 |
-
|
457 |
if self.radii:
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
radius = freqs
|
462 |
-
if "rotary2" in self.debug:
|
463 |
-
print(f"{layer} radius: {radius} ctx: {ctx}")
|
464 |
else:
|
465 |
radius = freqs
|
466 |
-
freqs = torch.polar(torch.ones_like(radius), freqs
|
467 |
|
468 |
if "rotary3" in self.debug:
|
469 |
-
print(f"{layer}
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
self._counter += 1
|
472 |
return freqs.unsqueeze(0)
|
473 |
|
@@ -484,8 +394,89 @@ class rotary(nn.Module):
|
|
484 |
x1 = x1.view(orig_shape)
|
485 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
class MultiheadA(nn.Module):
|
488 |
-
|
489 |
rbf = False
|
490 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
491 |
zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
|
@@ -531,7 +522,7 @@ class MultiheadA(nn.Module):
|
|
531 |
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
|
532 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
533 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
534 |
-
|
535 |
def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
|
536 |
x = x.to(device, dtype)
|
537 |
if xa is not None:
|
@@ -733,7 +724,6 @@ class Residual(nn.Module):
|
|
733 |
else:
|
734 |
print(f"Step {self._counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
|
735 |
self._counter += 1
|
736 |
-
|
737 |
return x
|
738 |
|
739 |
class FEncoder(nn.Module):
|
@@ -930,42 +920,40 @@ class AudioEncoder(nn.Module):
|
|
930 |
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
|
931 |
),
|
932 |
"envelope": nn.ModuleList(
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
})
|
943 |
|
944 |
def forward(self, enc, layer="encoder"):
|
945 |
enc = dict_to(enc, device, dtype)
|
946 |
-
|
947 |
if self._counter < 1:
|
948 |
s = enc.get("spectrogram")
|
949 |
w = enc.get("waveform")
|
950 |
p = default(enc.get("pitch"), enc.get("f0"))
|
951 |
plot_waveform(x=s, w=w, p=p, hop_length=128)
|
952 |
|
953 |
-
|
954 |
-
out.update(enc)
|
955 |
|
956 |
for f in self.features:
|
957 |
if f in enc and f in self.blocks:
|
958 |
x = enc[f]
|
959 |
for block in self.blocks[f]:
|
960 |
x = block(x, enc=enc, layer=layer)
|
961 |
-
|
962 |
|
963 |
if "encoder" in self.debug and self._counter % 100 == 0:
|
964 |
names = list(x.keys())
|
965 |
shapes = {k: v.shape for k, v in x.items()}
|
966 |
print(f"Step {self._counter}: mode: {names}: shapes: {shapes}")
|
967 |
self._counter += 1
|
968 |
-
return
|
969 |
|
970 |
class TextDecoder(nn.Module):
|
971 |
def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
|
@@ -1001,8 +989,8 @@ class TextDecoder(nn.Module):
|
|
1001 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
1002 |
self.register_buffer("mask", mask, persistent=False)
|
1003 |
|
1004 |
-
def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
|
1005 |
-
|
1006 |
x = x.to(device)
|
1007 |
bln = self.blend
|
1008 |
|
@@ -1017,10 +1005,10 @@ class TextDecoder(nn.Module):
|
|
1017 |
x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
|
1018 |
|
1019 |
for f in order:
|
1020 |
-
if f in
|
1021 |
-
|
1022 |
for block in self.blocks[f]:
|
1023 |
-
out = block(x=x, xa=
|
1024 |
|
1025 |
a = torch.sigmoid(bln[f])
|
1026 |
x = a * out + (1 - a) * x
|
@@ -1068,7 +1056,13 @@ class Echo(nn.Module):
|
|
1068 |
for name, module in self.encoder.named_modules():
|
1069 |
if isinstance(module, (rotary)):
|
1070 |
module.update_base(f0)
|
|
|
1071 |
|
|
|
|
|
|
|
|
|
|
|
1072 |
def set_alignment_head(self, dump: bytes):
|
1073 |
array = np.frombuffer(
|
1074 |
gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
|
@@ -1109,12 +1103,10 @@ class Echo(nn.Module):
|
|
1109 |
encoder_inputs["phase"] = phase
|
1110 |
if f0 is not None:
|
1111 |
encoder_inputs["f0"] = f0
|
1112 |
-
|
1113 |
-
encoder_inputs["f0d"] = f0d
|
1114 |
-
|
1115 |
if f0 is not None:
|
1116 |
f0 = f0.squeeze(0)
|
1117 |
-
self.update_base(f0)
|
1118 |
|
1119 |
encoder_outputs = self.encoder(encoder_inputs)
|
1120 |
logits = self.decoder(input_ids, encoder_outputs)
|
@@ -1695,7 +1687,9 @@ def get_training_args(
|
|
1695 |
gradient_accumulation_steps=1,
|
1696 |
eval_accumulation_steps=1,
|
1697 |
eval_strategy="steps",
|
1698 |
-
save_strategy="
|
|
|
|
|
1699 |
max_steps=max_steps,
|
1700 |
save_steps=save_steps,
|
1701 |
eval_steps=eval_steps,
|
@@ -1769,19 +1763,20 @@ def main():
|
|
1769 |
text_dims=512,
|
1770 |
text_idx=4,
|
1771 |
act="swish",
|
1772 |
-
debug={"
|
1773 |
cross_attn=True,
|
1774 |
features = ["spectrogram"]
|
1775 |
)
|
1776 |
|
1777 |
-
sanity_check =
|
|
|
1778 |
training_args = sanity(sanity_check)
|
1779 |
dataset_config = {
|
1780 |
"spectrogram": True,
|
1781 |
"waveforms": False,
|
1782 |
"pitch": False,
|
1783 |
"downsamples": False,
|
1784 |
-
"frequency":
|
1785 |
"hilbert": False,
|
1786 |
"hop_length": 128,
|
1787 |
"fmin": 150,
|
@@ -1798,7 +1793,7 @@ def main():
|
|
1798 |
"normalized": False}
|
1799 |
|
1800 |
model = create_model(param)
|
1801 |
-
|
1802 |
global global_model
|
1803 |
global_model = model
|
1804 |
|
@@ -1827,9 +1822,5 @@ def main():
|
|
1827 |
if __name__ == "__main__":
|
1828 |
main()
|
1829 |
|
1830 |
-
|
1831 |
-
|
1832 |
-
tb = program.TensorBoard()
|
1833 |
-
tb.configure(argv=[None, '--logdir', log_dir])
|
1834 |
-
url = tb.launch()
|
1835 |
-
print(f"TensorBoard started at {url}")
|
|
|
1 |
+
|
2 |
import pyworld as pw
|
3 |
import os
|
4 |
import math
|
|
|
35 |
warnings.filterwarnings("ignore")
|
36 |
logging.basicConfig(level=logging.ERROR)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
extractor = None
|
40 |
tokenizer = None
|
|
|
243 |
def tox():
|
244 |
return {"device": get_device(), "dtype": get_dtype()}
|
245 |
|
246 |
+
def sinusoids(length, num_chan, max=10000):
|
247 |
+
assert num_chan % 2 == 0
|
248 |
+
time_x = np.log(max) / (num_chan // 2 - 1)
|
249 |
+
inv_time = torch.exp(-time_x * torch.arange(num_chan // 2))
|
250 |
+
s_time = torch.arange(length)[:, np.newaxis] * inv_time[np.newaxis, :]
|
251 |
+
return torch.cat([torch.sin(s_time), torch.cos(s_time)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
class rotary(nn.Module):
|
254 |
def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [],
|
|
|
272 |
theta = torch.tensor(theta, device=device, dtype=dtype)
|
273 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
274 |
self.radius = nn.Parameter(torch.ones(radius, device=device, dtype=dtype), requires_grad=True)
|
275 |
+
inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
276 |
self.inv_freq = nn.Parameter(torch.tensor(inv_freq, device=device, dtype=dtype), requires_grad=True)
|
277 |
|
278 |
def update_base(self, f0):
|
279 |
f0 = f0.squeeze(0).to(device, dtype)
|
280 |
theta = f0.mean() + 1e-8
|
281 |
+
inv_freq = (theta / 140.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
282 |
self.inv_freq.data.copy_(inv_freq)
|
283 |
+
self.theta.data.copy_(theta)
|
284 |
+
|
285 |
+
def return_f0(self, f0=None):
|
286 |
+
if f0 is not None:
|
287 |
+
self.f0 = f0
|
288 |
+
return f0.squeeze(0).to(device, dtype)
|
289 |
+
elif hasattr(self, 'f0') and self.f0 is not None:
|
290 |
+
return self.f0.squeeze(0).to(device, dtype)
|
291 |
+
return None
|
292 |
|
293 |
def get_pitch_bias(self, f0):
|
294 |
if f0 is None:
|
|
|
300 |
return f0_sim.unsqueeze(0).unsqueeze(0)
|
301 |
|
302 |
def f0proj(self, f0):
|
303 |
+
if f0.ndim == 3:
|
304 |
+
f0 = f0.squeeze(0)
|
305 |
+
self.f0_proj = nn.Linear(1, self.head_dim // 2, device=device, dtype=dtype)
|
306 |
+
f0 = f0.to(device, dtype)
|
307 |
+
f0 = self.f0_proj(f0.unsqueeze(-1))
|
308 |
+
if f0.ndim == 3:
|
309 |
+
f0 = f0.squeeze(0)
|
310 |
+
return f0.to(device=device, dtype=dtype)
|
311 |
|
312 |
+
def align_f0(self, ctx):
|
313 |
+
f0 = self.return_f0()
|
314 |
f0 = self.f0proj(f0)
|
315 |
+
if f0.dim() == 3:
|
316 |
+
batch, length, dims = f0.shape
|
317 |
+
if length == ctx:
|
318 |
+
return f0
|
319 |
+
frames = length / ctx
|
320 |
+
idx = torch.arange(ctx, device=f0.device)
|
321 |
+
idx = (idx * frames).long().clamp(0, length - 1)
|
322 |
+
return f0[:, idx, :]
|
323 |
if f0.dim() == 1:
|
324 |
length = f0.shape[0]
|
325 |
if length == ctx:
|
|
|
336 |
idx = torch.arange(ctx, device=f0.device)
|
337 |
idx = (idx * frames).long().clamp(0, length - 1)
|
338 |
return f0[idx, :]
|
339 |
+
|
340 |
+
def forward(self, x=None, f0=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
if isinstance(x, int):
|
342 |
ctx = x
|
343 |
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
344 |
batch, ctx, dims = x.shape
|
345 |
else:
|
346 |
batch, head, ctx, head_dim = x.shape
|
|
|
347 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
348 |
+
freqs = self.inv_freq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
freqs = t[:, None] * freqs[None, :]
|
350 |
|
|
|
|
|
351 |
if self.radii:
|
352 |
+
radius = self.align_f0(ctx)
|
353 |
+
if "rotary2" in self.debug:
|
354 |
+
print(f"{layer} radius: {radius} ctx: {ctx}")
|
|
|
|
|
|
|
355 |
else:
|
356 |
radius = freqs
|
357 |
+
freqs = torch.polar(torch.ones_like(radius), freqs)
|
358 |
|
359 |
if "rotary3" in self.debug:
|
360 |
+
print(f"{layer} count {self._counter} f0: {f0.shape if f0 is not None else None} freqs: {freqs.shape} radius: {radius.shape} ctx: {ctx}")
|
361 |
+
print(f"freqs mean: {freqs.mean():.2f} inv_freq mean: {self.inv_freq.mean():.2f} theta: {self.theta.item():.2f} radius mean: {radius.mean():.2f} radius shape: {radius.shape} ctx: {ctx}")
|
362 |
+
|
363 |
+
if "rotary_detail" in self.debug:
|
364 |
+
print(f"\n==== Detailed RoPE Analysis ====")
|
365 |
+
print(f"Layer: {layer}, Context Length: {ctx}")
|
366 |
+
print(f"F0 stats: mean={self.theta.item():.2f}")
|
367 |
+
print(f"inv_freq range: [{self.inv_freq.min().item():.4f}, {self.inv_freq.max().item():.4f}]")
|
368 |
+
|
369 |
+
if self.radii:
|
370 |
+
print(f"Radius Shape: {radius.shape}, Mean: {radius.mean().item():.4f}")
|
371 |
+
print(f"Radius[0]: {radius[0][:5].cpu().numpy()}")
|
372 |
+
print(f"Radius[mid]: {radius[ctx//2][:5].cpu().numpy()}")
|
373 |
+
print(f"Radius[end]: {radius[-1][:5].cpu().numpy()}")
|
374 |
+
|
375 |
+
print(f"Final freqs shape: {freqs.shape}")
|
376 |
+
print(f"Freqs[0]: {freqs[0][:5].cpu().numpy()}")
|
377 |
+
print(f"Freqs[mid]: {freqs[ctx//2][:5].cpu().numpy()}")
|
378 |
+
print(f"Freqs[end]: {freqs[-1][:5].cpu().numpy()}")
|
379 |
+
print("================================\n")
|
380 |
+
|
381 |
self._counter += 1
|
382 |
return freqs.unsqueeze(0)
|
383 |
|
|
|
394 |
x1 = x1.view(orig_shape)
|
395 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
396 |
|
397 |
+
|
398 |
+
# class FocusA(nn.Module):
|
399 |
+
# def __init__(self, dims, head, max_dist=None, win_size=32, max_span=32, temp_scale=0.01, iterations=2):
|
400 |
+
# super().__init__()
|
401 |
+
# self.dims = dims
|
402 |
+
# self.head = head
|
403 |
+
# self.max_dist = max_dist
|
404 |
+
# self.win_size = win_size
|
405 |
+
# self.max_span = max_span
|
406 |
+
# self.temp_scale = temp_scale
|
407 |
+
# self.iterations = iterations
|
408 |
+
|
409 |
+
# self.span_predictor = nn.Linear(dims, 1)
|
410 |
+
|
411 |
+
# self.attn_l = nn.MultiheadAttention(embed_dim=dims, num_heads=head)
|
412 |
+
# self.attn_g = nn.MultiheadAttention(embed_dim=dims, num_heads=head)
|
413 |
+
|
414 |
+
# self.ln_l = nn.LayerNorm(dims)
|
415 |
+
# self.ln_g = nn.LayerNorm(dims)
|
416 |
+
# self.projection = nn.Linear(2 * dims, dims)
|
417 |
+
|
418 |
+
# def _focus(self, que, key, val, span_scale):
|
419 |
+
# attn_out = que
|
420 |
+
# span_len = max(1, int(self.max_span * span_scale.mean().item()))
|
421 |
+
# span_len = min(span_len, que.size(1), key.size(1), val.size(1))
|
422 |
+
|
423 |
+
# for _ in range(self.iterations):
|
424 |
+
# temp = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
|
425 |
+
# q = que / temp
|
426 |
+
# k = key / temp
|
427 |
+
# v = val / temp
|
428 |
+
# output, _ = self.attn_l(q, k, v)
|
429 |
+
# que = que + output
|
430 |
+
# return que
|
431 |
+
|
432 |
+
# def _window(self, x, win_size, span_len, span_scale):
|
433 |
+
# batch_size, ctx, dims = x.size()
|
434 |
+
# output = torch.zeros_like(x)
|
435 |
+
|
436 |
+
# for i in range(0, ctx, win_size // 2):
|
437 |
+
# end = min(i + win_size, ctx)
|
438 |
+
# que = x[:, i:end]
|
439 |
+
# start = max(0, i - span_len)
|
440 |
+
# end_con = min(i + win_size + span_len, ctx)
|
441 |
+
# con = x[:, start:end_con]
|
442 |
+
# win_out = self._focus(que, con, con, span_scale)
|
443 |
+
|
444 |
+
# if i > 0:
|
445 |
+
# start_over = i
|
446 |
+
# end_over = min(i + win_size // 2, ctx)
|
447 |
+
# blend = torch.linspace(0, 1, end_over - start_over).view(1, -1, 1)
|
448 |
+
# blend = blend.to(x.device)
|
449 |
+
# output[:, start_over:end_over] = (
|
450 |
+
# (1 - blend) * output[:, start_over:end_over] +
|
451 |
+
# blend * win_out[:, :end_over-start_over])
|
452 |
+
# if end_over < end:
|
453 |
+
# output[:, end_over:end] = win_out[:, end_over-i:end-i]
|
454 |
+
# else:
|
455 |
+
# output[:, i:end] = win_out
|
456 |
+
# return output
|
457 |
+
|
458 |
+
# def forward(self, x, mask=None):
|
459 |
+
# l_x = self.ln_l(x)
|
460 |
+
# g_x = self.ln_g(x)
|
461 |
+
# g_out, g_attn = self.attn_g(g_x, g_x, g_x, need_weights=True)
|
462 |
+
# g_focus = g_attn.sum(dim=1)
|
463 |
+
# f_score = g_focus.max(dim=-1)[0]
|
464 |
+
# b_scale = torch.sigmoid(self.span_predictor(x.mean(dim=1)))
|
465 |
+
# var = (f_score - f_score.mean(dim=1, keepdim=True)).abs()
|
466 |
+
# a_span = b_scale * (1.0 + 0.5 * var.mean(dim=1, keepdim=True))
|
467 |
+
|
468 |
+
# l_out = self._window(
|
469 |
+
# l_x,
|
470 |
+
# win_size=self.win_size,
|
471 |
+
# span_len=max(1, int(self.max_span * a_span.mean().item())),
|
472 |
+
# span_scale=a_span
|
473 |
+
# )
|
474 |
+
|
475 |
+
# combined = torch.cat([l_out, g_out], dim=-1)
|
476 |
+
# return self.projection(combined)
|
477 |
+
|
478 |
class MultiheadA(nn.Module):
|
479 |
+
|
480 |
rbf = False
|
481 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
482 |
zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
|
|
|
522 |
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
|
523 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
524 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
525 |
+
|
526 |
def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
|
527 |
x = x.to(device, dtype)
|
528 |
if xa is not None:
|
|
|
724 |
else:
|
725 |
print(f"Step {self._counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
|
726 |
self._counter += 1
|
|
|
727 |
return x
|
728 |
|
729 |
class FEncoder(nn.Module):
|
|
|
920 |
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
|
921 |
),
|
922 |
"envelope": nn.ModuleList(
|
923 |
+
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
924 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
|
925 |
+
for _ in range(layer)] if "envelope" in features else None
|
926 |
+
),
|
927 |
+
"phase": nn.ModuleList(
|
928 |
+
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
929 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
|
930 |
+
for _ in range(layer)] if "phase" in features else None
|
931 |
+
)})
|
|
|
932 |
|
933 |
def forward(self, enc, layer="encoder"):
|
934 |
enc = dict_to(enc, device, dtype)
|
935 |
+
|
936 |
if self._counter < 1:
|
937 |
s = enc.get("spectrogram")
|
938 |
w = enc.get("waveform")
|
939 |
p = default(enc.get("pitch"), enc.get("f0"))
|
940 |
plot_waveform(x=s, w=w, p=p, hop_length=128)
|
941 |
|
942 |
+
xa = {}
|
|
|
943 |
|
944 |
for f in self.features:
|
945 |
if f in enc and f in self.blocks:
|
946 |
x = enc[f]
|
947 |
for block in self.blocks[f]:
|
948 |
x = block(x, enc=enc, layer=layer)
|
949 |
+
xa[f] = x
|
950 |
|
951 |
if "encoder" in self.debug and self._counter % 100 == 0:
|
952 |
names = list(x.keys())
|
953 |
shapes = {k: v.shape for k, v in x.items()}
|
954 |
print(f"Step {self._counter}: mode: {names}: shapes: {shapes}")
|
955 |
self._counter += 1
|
956 |
+
return xa
|
957 |
|
958 |
class TextDecoder(nn.Module):
|
959 |
def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
|
|
|
989 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
990 |
self.register_buffer("mask", mask, persistent=False)
|
991 |
|
992 |
+
def forward(self, x, xa, enc=None, order=None, layer='decoder') -> Tensor:
|
993 |
+
xa = dict_to(xa, device, dtype)
|
994 |
x = x.to(device)
|
995 |
bln = self.blend
|
996 |
|
|
|
1005 |
x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
|
1006 |
|
1007 |
for f in order:
|
1008 |
+
if f in xa:
|
1009 |
+
ax = xa[f]
|
1010 |
for block in self.blocks[f]:
|
1011 |
+
out = block(x=x, xa=ax, mask=None, enc=enc, layer=layer)
|
1012 |
|
1013 |
a = torch.sigmoid(bln[f])
|
1014 |
x = a * out + (1 - a) * x
|
|
|
1056 |
for name, module in self.encoder.named_modules():
|
1057 |
if isinstance(module, (rotary)):
|
1058 |
module.update_base(f0)
|
1059 |
+
module.return_f0(f0)
|
1060 |
|
1061 |
+
for name, module in self.decoder.named_modules():
|
1062 |
+
if isinstance(module, (rotary)):
|
1063 |
+
module.update_base(f0)
|
1064 |
+
module.return_f0(f0)
|
1065 |
+
|
1066 |
def set_alignment_head(self, dump: bytes):
|
1067 |
array = np.frombuffer(
|
1068 |
gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
|
|
|
1103 |
encoder_inputs["phase"] = phase
|
1104 |
if f0 is not None:
|
1105 |
encoder_inputs["f0"] = f0
|
1106 |
+
|
|
|
|
|
1107 |
if f0 is not None:
|
1108 |
f0 = f0.squeeze(0)
|
1109 |
+
self.update_base(f0)
|
1110 |
|
1111 |
encoder_outputs = self.encoder(encoder_inputs)
|
1112 |
logits = self.decoder(input_ids, encoder_outputs)
|
|
|
1687 |
gradient_accumulation_steps=1,
|
1688 |
eval_accumulation_steps=1,
|
1689 |
eval_strategy="steps",
|
1690 |
+
save_strategy="no",
|
1691 |
+
include_tokens_per_second=True,
|
1692 |
+
include_num_input_tokens_seen=True,
|
1693 |
max_steps=max_steps,
|
1694 |
save_steps=save_steps,
|
1695 |
eval_steps=eval_steps,
|
|
|
1763 |
text_dims=512,
|
1764 |
text_idx=4,
|
1765 |
act="swish",
|
1766 |
+
debug={"rotary_detail"},
|
1767 |
cross_attn=True,
|
1768 |
features = ["spectrogram"]
|
1769 |
)
|
1770 |
|
1771 |
+
sanity_check = True
|
1772 |
+
|
1773 |
training_args = sanity(sanity_check)
|
1774 |
dataset_config = {
|
1775 |
"spectrogram": True,
|
1776 |
"waveforms": False,
|
1777 |
"pitch": False,
|
1778 |
"downsamples": False,
|
1779 |
+
"frequency": True,
|
1780 |
"hilbert": False,
|
1781 |
"hop_length": 128,
|
1782 |
"fmin": 150,
|
|
|
1793 |
"normalized": False}
|
1794 |
|
1795 |
model = create_model(param)
|
1796 |
+
|
1797 |
global global_model
|
1798 |
global_model = model
|
1799 |
|
|
|
1822 |
if __name__ == "__main__":
|
1823 |
main()
|
1824 |
|
1825 |
+
|
1826 |
+
|
|
|
|
|
|
|
|