Update model.py
Browse files
model.py
CHANGED
@@ -12,7 +12,6 @@ import torch.nn.functional as F
|
|
12 |
import torch.nn.init as init
|
13 |
from torch import nn, Tensor
|
14 |
import numpy as np
|
15 |
-
from einops import rearrange
|
16 |
import matplotlib.pyplot as plt
|
17 |
from typing import Optional, Dict, Union, List, Tuple, Any
|
18 |
from functools import partial
|
@@ -242,37 +241,34 @@ def get_dtype():
|
|
242 |
def tox():
|
243 |
return {"device": get_device(), "dtype": get_dtype()}
|
244 |
|
245 |
-
def sinusoids(length,
|
246 |
-
assert
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
return torch.cat([torch.sin(
|
251 |
-
|
252 |
class rotary(nn.Module):
|
253 |
def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False, spec_shape=None):
|
254 |
super(rotary, self).__init__()
|
255 |
|
256 |
-
self.pitch_scale = 0.1
|
257 |
-
self.use_pbias = use_pbias
|
258 |
-
self.spec_shape = spec_shape
|
259 |
self.dims = dims
|
260 |
self.head = head
|
261 |
self.head_dim = dims // head
|
262 |
-
self.
|
263 |
-
self.theta = theta
|
264 |
self.max_ctx = max_ctx
|
|
|
|
|
|
|
|
|
|
|
265 |
self.debug = debug
|
266 |
self.counter = 0
|
267 |
self.last_theta = None
|
268 |
-
dim = self.head_dim
|
269 |
-
self.dim = dim
|
270 |
|
271 |
-
self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
|
272 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
273 |
-
|
274 |
-
|
275 |
-
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
276 |
self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
|
277 |
|
278 |
def return_f0(self, f0=None):
|
@@ -291,12 +287,11 @@ class rotary(nn.Module):
|
|
291 |
self.theta.data.copy_(theta)
|
292 |
|
293 |
def get_pitch_bias(self, f0):
|
294 |
-
|
295 |
-
return None
|
296 |
f0_flat = f0.squeeze().float()
|
297 |
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
298 |
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
299 |
-
f0_norm.unsqueeze(1))
|
300 |
return f0_sim.unsqueeze(0).unsqueeze(0)
|
301 |
|
302 |
def f0proj(self, f0):
|
@@ -311,7 +306,7 @@ class rotary(nn.Module):
|
|
311 |
|
312 |
def synth_f0(self, f0, ctx):
|
313 |
f0 = self.f0proj(f0)
|
314 |
-
|
315 |
if f0.dim() == 1:
|
316 |
length = f0.shape[0]
|
317 |
if length == ctx:
|
@@ -319,7 +314,7 @@ class rotary(nn.Module):
|
|
319 |
frames = length / ctx
|
320 |
idx = torch.arange(ctx, device=f0.device)
|
321 |
# return torch.arange(1, ctx+1, device=f0.device, dtype=torch.float)
|
322 |
-
return f0[
|
323 |
|
324 |
def align_f0(self, ctx, f0):
|
325 |
# f0 = self.return_f0()
|
@@ -351,6 +346,14 @@ class rotary(nn.Module):
|
|
351 |
|
352 |
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
353 |
f0 = default(enc.get("pitch"), enc.get("f0")) if enc is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
if isinstance(x, int):
|
356 |
ctx = x
|
@@ -359,48 +362,41 @@ class rotary(nn.Module):
|
|
359 |
else:
|
360 |
batch, head, ctx, head_dim = x.shape
|
361 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
|
|
362 |
if f0 is not None:
|
363 |
f0_mean = f0.mean()
|
364 |
-
theta = f0_mean +
|
365 |
else:
|
366 |
theta = self.theta
|
367 |
-
|
368 |
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
|
369 |
self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
|
|
|
|
|
|
370 |
|
371 |
freqs = t[:, None] * freqs[None, :]
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
if
|
376 |
-
|
|
|
|
|
|
|
|
|
377 |
else:
|
378 |
-
radius = freqs
|
379 |
-
freqs = torch.polar(
|
380 |
-
|
381 |
-
if "
|
382 |
-
print(f"{layer}
|
383 |
-
print(f"freqs mean: {freqs.mean():.2f} inv_freq mean: {self.inv_freq.mean():.2f} theta: {self.theta.item():.2f} radius mean: {radius.mean():.2f} radius shape: {radius.shape} ctx: {ctx}")
|
384 |
-
|
385 |
-
if "rotary_detail" in self.debug and self._counter == 5:
|
386 |
-
print(f"\n==== Detailed RoPE Analysis ====")
|
387 |
-
print(f"Layer: {layer}, Context Length: {ctx}")
|
388 |
-
print(f"F0 stats: mean={self.theta.item():.2f}")
|
389 |
-
print(f"inv_freq range: [{self.inv_freq.min().item():.4f}, {self.inv_freq.max().item():.4f}]")
|
390 |
-
|
391 |
-
if self.radii:
|
392 |
-
print(f"Radius Shape: {radius.shape}, Mean: {radius.mean().item():.4f}")
|
393 |
-
print(f"Radius[0]: {radius[0][:5].cpu().numpy()}")
|
394 |
-
print(f"Radius[mid]: {radius[ctx//2][:5].cpu().numpy()}")
|
395 |
-
print(f"Radius[end]: {radius[-1][:5].cpu().numpy()}")
|
396 |
-
|
397 |
-
print(f"Final freqs shape: {freqs.shape}")
|
398 |
-
print(f"Freqs[0]: {freqs[0][:5].cpu().detach().numpy()}")
|
399 |
-
print(f"Freqs[mid]: {freqs[ctx//2][:5].cpu().detach().numpy()}")
|
400 |
-
print(f"Freqs[end]: {freqs[-1][:5].cpu().detach().numpy()}")
|
401 |
-
print("================================\n")
|
402 |
|
403 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
404 |
return freqs.unsqueeze(0)
|
405 |
|
406 |
@staticmethod
|
@@ -416,88 +412,8 @@ class rotary(nn.Module):
|
|
416 |
x1 = x1.view(orig_shape)
|
417 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
418 |
|
419 |
-
# class FocusA(nn.Module):
|
420 |
-
# def __init__(self, dims, head, max_dist=None, win_size=32, max_span=32, temp_scale=0.01, iterations=2):
|
421 |
-
# super().__init__()
|
422 |
-
# self.dims = dims
|
423 |
-
# self.head = head
|
424 |
-
# self.max_dist = max_dist
|
425 |
-
# self.win_size = win_size
|
426 |
-
# self.max_span = max_span
|
427 |
-
# self.temp_scale = temp_scale
|
428 |
-
# self.iterations = iterations
|
429 |
-
|
430 |
-
# self.span_predictor = nn.Linear(dims, 1)
|
431 |
-
|
432 |
-
# self.attn_l = nn.MultiheadAttention(embed_dim=dims, num_heads=head)
|
433 |
-
# self.attn_g = nn.MultiheadAttention(embed_dim=dims, num_heads=head)
|
434 |
-
|
435 |
-
# self.ln_l = nn.LayerNorm(dims)
|
436 |
-
# self.ln_g = nn.LayerNorm(dims)
|
437 |
-
# self.projection = nn.Linear(2 * dims, dims)
|
438 |
-
|
439 |
-
# def _focus(self, que, key, val, span_scale):
|
440 |
-
# attn_out = que
|
441 |
-
# span_len = max(1, int(self.max_span * span_scale.mean().item()))
|
442 |
-
# span_len = min(span_len, que.size(1), key.size(1), val.size(1))
|
443 |
-
|
444 |
-
# for _ in range(self.iterations):
|
445 |
-
# temp = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
|
446 |
-
# q = que / temp
|
447 |
-
# k = key / temp
|
448 |
-
# v = val / temp
|
449 |
-
# output, _ = self.attn_l(q, k, v)
|
450 |
-
# que = que + output
|
451 |
-
# return que
|
452 |
-
|
453 |
-
# def _window(self, x, win_size, span_len, span_scale):
|
454 |
-
# batch_size, ctx, dims = x.size()
|
455 |
-
# output = torch.zeros_like(x)
|
456 |
-
|
457 |
-
# for i in range(0, ctx, win_size // 2):
|
458 |
-
# end = min(i + win_size, ctx)
|
459 |
-
# que = x[:, i:end]
|
460 |
-
# start = max(0, i - span_len)
|
461 |
-
# end_con = min(i + win_size + span_len, ctx)
|
462 |
-
# con = x[:, start:end_con]
|
463 |
-
# win_out = self._focus(que, con, con, span_scale)
|
464 |
-
|
465 |
-
# if i > 0:
|
466 |
-
# start_over = i
|
467 |
-
# end_over = min(i + win_size // 2, ctx)
|
468 |
-
# blend = torch.linspace(0, 1, end_over - start_over).view(1, -1, 1)
|
469 |
-
# blend = blend.to(x.device)
|
470 |
-
# output[:, start_over:end_over] = (
|
471 |
-
# (1 - blend) * output[:, start_over:end_over] +
|
472 |
-
# blend * win_out[:, :end_over-start_over])
|
473 |
-
# if end_over < end:
|
474 |
-
# output[:, end_over:end] = win_out[:, end_over-i:end-i]
|
475 |
-
# else:
|
476 |
-
# output[:, i:end] = win_out
|
477 |
-
# return output
|
478 |
-
|
479 |
-
# def forward(self, x, mask=None):
|
480 |
-
# l_x = self.ln_l(x)
|
481 |
-
# g_x = self.ln_g(x)
|
482 |
-
# g_out, g_attn = self.attn_g(g_x, g_x, g_x, need_weights=True)
|
483 |
-
# g_focus = g_attn.sum(dim=1)
|
484 |
-
# f_score = g_focus.max(dim=-1)[0]
|
485 |
-
# b_scale = torch.sigmoid(self.span_predictor(x.mean(dim=1)))
|
486 |
-
# var = (f_score - f_score.mean(dim=1, keepdim=True)).abs()
|
487 |
-
# a_span = b_scale * (1.0 + 0.5 * var.mean(dim=1, keepdim=True))
|
488 |
-
|
489 |
-
# l_out = self._window(
|
490 |
-
# l_x,
|
491 |
-
# win_size=self.win_size,
|
492 |
-
# span_len=max(1, int(self.max_span * a_span.mean().item())),
|
493 |
-
# span_scale=a_span
|
494 |
-
# )
|
495 |
-
|
496 |
-
# combined = torch.cat([l_out, g_out], dim=-1)
|
497 |
-
# return self.projection(combined)
|
498 |
-
|
499 |
class MultiheadA(nn.Module):
|
500 |
-
|
501 |
rbf = False
|
502 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
503 |
zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
|
@@ -507,7 +423,7 @@ class MultiheadA(nn.Module):
|
|
507 |
self.head = head
|
508 |
self.head_dim = dims // head
|
509 |
self.debug = debug
|
510 |
-
self.
|
511 |
|
512 |
self.q = Linear(dims, dims).to(device, dtype)
|
513 |
self.k = Linear(dims, dims, bias=False).to(device, dtype)
|
@@ -543,7 +459,7 @@ class MultiheadA(nn.Module):
|
|
543 |
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
|
544 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
545 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
546 |
-
|
547 |
def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
|
548 |
x = x.to(device, dtype)
|
549 |
if xa is not None:
|
@@ -595,9 +511,9 @@ class MultiheadA(nn.Module):
|
|
595 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
596 |
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
597 |
|
598 |
-
if "multihead" in self.debug and self.
|
599 |
print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
|
600 |
-
self.
|
601 |
return self.o(wv), qk.detach()
|
602 |
|
603 |
class t_gate(nn.Module):
|
@@ -667,7 +583,7 @@ class Residual(nn.Module):
|
|
667 |
self.cross_attn = cross_attn
|
668 |
self.features = features
|
669 |
self.debug = debug
|
670 |
-
self.
|
671 |
self.dropout = 0.01
|
672 |
|
673 |
self.t_gate = tgate
|
@@ -734,17 +650,18 @@ class Residual(nn.Module):
|
|
734 |
else:
|
735 |
x = x + mlp_out
|
736 |
|
737 |
-
if "residual" in self.debug and self.
|
738 |
-
print(f"Step {self.
|
739 |
if self.t_gate:
|
740 |
-
print(f"Step {self.
|
741 |
elif self.m_gate:
|
742 |
-
print(f"Step {self.
|
743 |
elif self.c_gate:
|
744 |
-
print(f"Step {self.
|
745 |
else:
|
746 |
-
print(f"Step {self.
|
747 |
-
self.
|
|
|
748 |
return x
|
749 |
|
750 |
class FEncoder(nn.Module):
|
@@ -915,7 +832,7 @@ class AudioEncoder(nn.Module):
|
|
915 |
self.ctx = ctx
|
916 |
self.head_dim = dims // head
|
917 |
self.debug = debug
|
918 |
-
self.
|
919 |
self.features = features
|
920 |
self.dropout = 0.01
|
921 |
|
@@ -943,38 +860,42 @@ class AudioEncoder(nn.Module):
|
|
943 |
"envelope": nn.ModuleList(
|
944 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
945 |
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
|
946 |
-
|
947 |
),
|
948 |
"phase": nn.ModuleList(
|
949 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
950 |
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
|
951 |
for _ in range(layer)] if "phase" in features else None
|
952 |
-
|
|
|
953 |
|
954 |
-
def forward(self, enc, layer="encoder"):
|
955 |
enc = dict_to(enc, device, dtype)
|
956 |
-
|
957 |
-
if self.
|
958 |
s = enc.get("spectrogram")
|
959 |
w = enc.get("waveform")
|
960 |
p = default(enc.get("pitch"), enc.get("f0"))
|
961 |
plot_waveform(x=s, w=w, p=p, hop_length=128)
|
962 |
|
963 |
-
|
|
|
|
|
|
|
|
|
964 |
|
965 |
-
for f in
|
966 |
if f in enc and f in self.blocks:
|
967 |
x = enc[f]
|
968 |
for block in self.blocks[f]:
|
969 |
x = block(x, enc=enc, layer=layer)
|
970 |
-
|
971 |
|
972 |
-
if "encoder" in self.debug and self.
|
973 |
-
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
return xa
|
978 |
|
979 |
class TextDecoder(nn.Module):
|
980 |
def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
|
@@ -986,7 +907,7 @@ class TextDecoder(nn.Module):
|
|
986 |
self.head = head
|
987 |
self.head_dim = dims // head
|
988 |
self.debug = debug
|
989 |
-
self.
|
990 |
self.dropout = 0.01
|
991 |
self.sequential = sequential
|
992 |
self.features = features
|
@@ -1010,8 +931,8 @@ class TextDecoder(nn.Module):
|
|
1010 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
1011 |
self.register_buffer("mask", mask, persistent=False)
|
1012 |
|
1013 |
-
def forward(self, x,
|
1014 |
-
|
1015 |
x = x.to(device)
|
1016 |
bln = self.blend
|
1017 |
|
@@ -1026,17 +947,17 @@ class TextDecoder(nn.Module):
|
|
1026 |
x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
|
1027 |
|
1028 |
for f in order:
|
1029 |
-
if f in
|
1030 |
-
|
1031 |
for block in self.blocks[f]:
|
1032 |
-
out = block(x=x, xa=
|
1033 |
|
1034 |
a = torch.sigmoid(bln[f])
|
1035 |
x = a * out + (1 - a) * x
|
1036 |
|
1037 |
-
if "decoder" in self.debug and self.
|
1038 |
-
print(f"Step {self.
|
1039 |
-
self.
|
1040 |
|
1041 |
x = self.ln_dec(x)
|
1042 |
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
@@ -1076,14 +997,14 @@ class Echo(nn.Module):
|
|
1076 |
def update_base(self, f0):
|
1077 |
for name, module in self.encoder.named_modules():
|
1078 |
if isinstance(module, (rotary)):
|
1079 |
-
module.return_f0(f0)
|
1080 |
module.update_base(f0)
|
|
|
1081 |
|
1082 |
for name, module in self.decoder.named_modules():
|
1083 |
if isinstance(module, (rotary)):
|
1084 |
-
module.return_f0(f0)
|
1085 |
module.update_base(f0)
|
1086 |
-
|
|
|
1087 |
def set_alignment_head(self, dump: bytes):
|
1088 |
array = np.frombuffer(
|
1089 |
gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
|
@@ -1125,9 +1046,10 @@ class Echo(nn.Module):
|
|
1125 |
if f0 is not None:
|
1126 |
encoder_inputs["f0"] = f0
|
1127 |
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
|
|
1131 |
|
1132 |
encoder_outputs = self.encoder(encoder_inputs)
|
1133 |
logits = self.decoder(input_ids, encoder_outputs)
|
@@ -1225,22 +1147,11 @@ class Echo(nn.Module):
|
|
1225 |
print(f"DECODER GRAD: {name} = {norm:.6f}")
|
1226 |
return None
|
1227 |
|
1228 |
-
def
|
1229 |
-
self.
|
1230 |
print("Counter reset to 0.")
|
1231 |
|
1232 |
metric = evaluate.load(path="wer")
|
1233 |
-
|
1234 |
-
def align_f0(f0, ctx):
|
1235 |
-
ctx = torch.tensor(ctx)
|
1236 |
-
bat, length = f0.shape
|
1237 |
-
if length == ctx:
|
1238 |
-
return f0
|
1239 |
-
frames = length / ctx
|
1240 |
-
idx = torch.arange(ctx, device=f0.device)
|
1241 |
-
idx = (idx * frames).long()
|
1242 |
-
batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
|
1243 |
-
return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
|
1244 |
|
1245 |
@dataclass
|
1246 |
class DataCollator:
|
@@ -1486,14 +1397,14 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=
|
|
1486 |
f0, t = pw.dio(wav_np, sampling_rate,
|
1487 |
frame_period=hop_length/sampling_rate*1000)
|
1488 |
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
1489 |
-
f0 = torch.from_numpy(f0)
|
1490 |
batch["pitch"] = f0
|
1491 |
|
1492 |
if frequency:
|
1493 |
wav_np = wav.numpy().astype(np.float64)
|
1494 |
f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000)
|
1495 |
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
1496 |
-
f0 = torch.from_numpy(f0)
|
1497 |
batch["f0"] = f0
|
1498 |
|
1499 |
if spectrogram and waveforms and pitch:
|
@@ -1708,9 +1619,7 @@ def get_training_args(
|
|
1708 |
gradient_accumulation_steps=1,
|
1709 |
eval_accumulation_steps=1,
|
1710 |
eval_strategy="steps",
|
1711 |
-
save_strategy="
|
1712 |
-
include_tokens_per_second=True,
|
1713 |
-
include_num_input_tokens_seen=True,
|
1714 |
max_steps=max_steps,
|
1715 |
save_steps=save_steps,
|
1716 |
eval_steps=eval_steps,
|
@@ -1784,13 +1693,13 @@ def main():
|
|
1784 |
text_dims=512,
|
1785 |
text_idx=4,
|
1786 |
act="swish",
|
1787 |
-
debug={
|
1788 |
cross_attn=True,
|
1789 |
features = ["spectrogram"]
|
1790 |
)
|
1791 |
|
1792 |
-
sanity_check =
|
1793 |
-
|
1794 |
training_args = sanity(sanity_check)
|
1795 |
dataset_config = {
|
1796 |
"spectrogram": True,
|
@@ -1814,7 +1723,7 @@ def main():
|
|
1814 |
"normalized": False}
|
1815 |
|
1816 |
model = create_model(param)
|
1817 |
-
|
1818 |
global global_model
|
1819 |
global_model = model
|
1820 |
|
@@ -1843,3 +1752,6 @@ def main():
|
|
1843 |
if __name__ == "__main__":
|
1844 |
main()
|
1845 |
|
|
|
|
|
|
|
|
12 |
import torch.nn.init as init
|
13 |
from torch import nn, Tensor
|
14 |
import numpy as np
|
|
|
15 |
import matplotlib.pyplot as plt
|
16 |
from typing import Optional, Dict, Union, List, Tuple, Any
|
17 |
from functools import partial
|
|
|
241 |
def tox():
|
242 |
return {"device": get_device(), "dtype": get_dtype()}
|
243 |
|
244 |
+
def sinusoids(length, channels, max_timescale=10000):
|
245 |
+
assert channels % 2 == 0
|
246 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
247 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
248 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
249 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
250 |
+
|
251 |
class rotary(nn.Module):
|
252 |
def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False, spec_shape=None):
|
253 |
super(rotary, self).__init__()
|
254 |
|
|
|
|
|
|
|
255 |
self.dims = dims
|
256 |
self.head = head
|
257 |
self.head_dim = dims // head
|
258 |
+
self.dim = self.head_dim
|
|
|
259 |
self.max_ctx = max_ctx
|
260 |
+
self.theta = theta
|
261 |
+
self.radii = radii
|
262 |
+
self.pitch_scale = 0.1
|
263 |
+
self.use_pbias = use_pbias
|
264 |
+
self.spec_shape = spec_shape
|
265 |
self.debug = debug
|
266 |
self.counter = 0
|
267 |
self.last_theta = None
|
|
|
|
|
268 |
|
269 |
+
# self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
|
270 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
271 |
+
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
|
|
|
|
272 |
self.freqs = nn.Parameter(torch.tensor(freqs, device=device, dtype=dtype), requires_grad=True)
|
273 |
|
274 |
def return_f0(self, f0=None):
|
|
|
287 |
self.theta.data.copy_(theta)
|
288 |
|
289 |
def get_pitch_bias(self, f0):
|
290 |
+
f0 = self.return_f0()
|
|
|
291 |
f0_flat = f0.squeeze().float()
|
292 |
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
293 |
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
294 |
+
f0_norm.unsqueeze(1)))
|
295 |
return f0_sim.unsqueeze(0).unsqueeze(0)
|
296 |
|
297 |
def f0proj(self, f0):
|
|
|
306 |
|
307 |
def synth_f0(self, f0, ctx):
|
308 |
f0 = self.f0proj(f0)
|
309 |
+
|
310 |
if f0.dim() == 1:
|
311 |
length = f0.shape[0]
|
312 |
if length == ctx:
|
|
|
314 |
frames = length / ctx
|
315 |
idx = torch.arange(ctx, device=f0.device)
|
316 |
# return torch.arange(1, ctx+1, device=f0.device, dtype=torch.float)
|
317 |
+
return f0[idx]
|
318 |
|
319 |
def align_f0(self, ctx, f0):
|
320 |
# f0 = self.return_f0()
|
|
|
346 |
|
347 |
def forward(self, x=None, enc=None, layer=None, input_type="audio") -> Tensor:
|
348 |
f0 = default(enc.get("pitch"), enc.get("f0")) if enc is not None else None
|
349 |
+
if f0 is not None and f0.dim() == 2:
|
350 |
+
if f0.shape[0] == 1:
|
351 |
+
f0 = f0.squeeze(0)
|
352 |
+
else:
|
353 |
+
f0 = f0.view(-1)
|
354 |
+
|
355 |
+
if "rot1" in self.debug and self.counter % 100 == 0:
|
356 |
+
print(f"Rotary forward: {x if x is not None else None}, f0: {f0.shape if f0 is not None else None}")
|
357 |
|
358 |
if isinstance(x, int):
|
359 |
ctx = x
|
|
|
362 |
else:
|
363 |
batch, head, ctx, head_dim = x.shape
|
364 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
365 |
+
|
366 |
if f0 is not None:
|
367 |
f0_mean = f0.mean()
|
368 |
+
theta = f0_mean + self.theta
|
369 |
else:
|
370 |
theta = self.theta
|
|
|
371 |
freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
|
372 |
self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
373 |
+
|
374 |
+
if "rot2" in self.debug and self.counter % 100 == 0:
|
375 |
+
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
376 |
|
377 |
freqs = t[:, None] * freqs[None, :]
|
378 |
+
if self.radii and f0 is not None:
|
379 |
+
radius = f0.to(device, dtype)
|
380 |
+
L = radius.shape[0]
|
381 |
+
if L != ctx:
|
382 |
+
F = L / ctx
|
383 |
+
idx = torch.arange(ctx, device=f0.device)
|
384 |
+
idx = (idx * F).long().clamp(0, L - 1)
|
385 |
+
radius = radius[idx]
|
386 |
+
radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
|
387 |
else:
|
388 |
+
radius = torch.ones_like(freqs)
|
389 |
+
freqs = torch.polar(radius, freqs)
|
390 |
+
|
391 |
+
if "rot3" in self.debug and self.counter % 100 == 0:
|
392 |
+
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
+
if "theta" in self.debug and self.counter % 100 == 0:
|
395 |
+
if self.last_theta is None or abs(self.last_theta - theta.item()) > 1.0:
|
396 |
+
self.last_theta = theta.item()
|
397 |
+
print(f"[Theta] {self.last_theta:.2f}")
|
398 |
+
|
399 |
+
self.counter += 1
|
400 |
return freqs.unsqueeze(0)
|
401 |
|
402 |
@staticmethod
|
|
|
412 |
x1 = x1.view(orig_shape)
|
413 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
class MultiheadA(nn.Module):
|
416 |
+
_seen = set()
|
417 |
rbf = False
|
418 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
419 |
zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
|
|
|
423 |
self.head = head
|
424 |
self.head_dim = dims // head
|
425 |
self.debug = debug
|
426 |
+
self.counter = 0
|
427 |
|
428 |
self.q = Linear(dims, dims).to(device, dtype)
|
429 |
self.k = Linear(dims, dims, bias=False).to(device, dtype)
|
|
|
459 |
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
|
460 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
461 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
462 |
+
|
463 |
def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
|
464 |
x = x.to(device, dtype)
|
465 |
if xa is not None:
|
|
|
511 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
512 |
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
513 |
|
514 |
+
if "multihead" in self.debug and self.counter % 100 == 0:
|
515 |
print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
|
516 |
+
self.counter += 1
|
517 |
return self.o(wv), qk.detach()
|
518 |
|
519 |
class t_gate(nn.Module):
|
|
|
583 |
self.cross_attn = cross_attn
|
584 |
self.features = features
|
585 |
self.debug = debug
|
586 |
+
self.counter = 0
|
587 |
self.dropout = 0.01
|
588 |
|
589 |
self.t_gate = tgate
|
|
|
650 |
else:
|
651 |
x = x + mlp_out
|
652 |
|
653 |
+
if "residual" in self.debug and self.counter % 100 == 0:
|
654 |
+
print(f"Step {self.counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
|
655 |
if self.t_gate:
|
656 |
+
print(f"Step {self.counter}: Using t_gate: {self.t_gate}")
|
657 |
elif self.m_gate:
|
658 |
+
print(f"Step {self.counter}: Using m_gate: {self.m_gate}")
|
659 |
elif self.c_gate:
|
660 |
+
print(f"Step {self.counter}: Using c_gate: {self.c_gate}")
|
661 |
else:
|
662 |
+
print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
|
663 |
+
self.counter += 1
|
664 |
+
|
665 |
return x
|
666 |
|
667 |
class FEncoder(nn.Module):
|
|
|
832 |
self.ctx = ctx
|
833 |
self.head_dim = dims // head
|
834 |
self.debug = debug
|
835 |
+
self.counter = 0
|
836 |
self.features = features
|
837 |
self.dropout = 0.01
|
838 |
|
|
|
860 |
"envelope": nn.ModuleList(
|
861 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
862 |
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
|
863 |
+
for _ in range(layer)] if "envelope" in features else None
|
864 |
),
|
865 |
"phase": nn.ModuleList(
|
866 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
867 |
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate)
|
868 |
for _ in range(layer)] if "phase" in features else None
|
869 |
+
)
|
870 |
+
})
|
871 |
|
872 |
+
def forward(self, enc, order=None, layer="encoder"):
|
873 |
enc = dict_to(enc, device, dtype)
|
874 |
+
|
875 |
+
if self.counter < 1:
|
876 |
s = enc.get("spectrogram")
|
877 |
w = enc.get("waveform")
|
878 |
p = default(enc.get("pitch"), enc.get("f0"))
|
879 |
plot_waveform(x=s, w=w, p=p, hop_length=128)
|
880 |
|
881 |
+
if order is None:
|
882 |
+
order = self.features
|
883 |
+
|
884 |
+
out = {}
|
885 |
+
out.update(enc)
|
886 |
|
887 |
+
for f in order:
|
888 |
if f in enc and f in self.blocks:
|
889 |
x = enc[f]
|
890 |
for block in self.blocks[f]:
|
891 |
x = block(x, enc=enc, layer=layer)
|
892 |
+
out[f] = x
|
893 |
|
894 |
+
if "encoder" in self.debug and self.counter % 100 == 0:
|
895 |
+
shapes = {k: v.shape for k, v in enc.items()}
|
896 |
+
print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}, order: {order}")
|
897 |
+
self.counter += 1
|
898 |
+
return out
|
|
|
899 |
|
900 |
class TextDecoder(nn.Module):
|
901 |
def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
|
|
|
907 |
self.head = head
|
908 |
self.head_dim = dims // head
|
909 |
self.debug = debug
|
910 |
+
self.counter = 0
|
911 |
self.dropout = 0.01
|
912 |
self.sequential = sequential
|
913 |
self.features = features
|
|
|
931 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
932 |
self.register_buffer("mask", mask, persistent=False)
|
933 |
|
934 |
+
def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
|
935 |
+
enc = dict_to(enc, device, dtype)
|
936 |
x = x.to(device)
|
937 |
bln = self.blend
|
938 |
|
|
|
947 |
x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
|
948 |
|
949 |
for f in order:
|
950 |
+
if f in enc:
|
951 |
+
xa = enc[f]
|
952 |
for block in self.blocks[f]:
|
953 |
+
out = block(x=x, xa=xa, mask=None, enc=enc, layer=layer)
|
954 |
|
955 |
a = torch.sigmoid(bln[f])
|
956 |
x = a * out + (1 - a) * x
|
957 |
|
958 |
+
if "decoder" in self.debug and self.counter % 100 == 0:
|
959 |
+
print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
|
960 |
+
self.counter += 1
|
961 |
|
962 |
x = self.ln_dec(x)
|
963 |
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
|
|
997 |
def update_base(self, f0):
|
998 |
for name, module in self.encoder.named_modules():
|
999 |
if isinstance(module, (rotary)):
|
|
|
1000 |
module.update_base(f0)
|
1001 |
+
module.return_f0(f0)
|
1002 |
|
1003 |
for name, module in self.decoder.named_modules():
|
1004 |
if isinstance(module, (rotary)):
|
|
|
1005 |
module.update_base(f0)
|
1006 |
+
module.return_f0(f0)
|
1007 |
+
|
1008 |
def set_alignment_head(self, dump: bytes):
|
1009 |
array = np.frombuffer(
|
1010 |
gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
|
|
|
1046 |
if f0 is not None:
|
1047 |
encoder_inputs["f0"] = f0
|
1048 |
|
1049 |
+
|
1050 |
+
# if f0 is not None:
|
1051 |
+
# f0 = f0.squeeze(0)
|
1052 |
+
# self.update_base(f0)
|
1053 |
|
1054 |
encoder_outputs = self.encoder(encoder_inputs)
|
1055 |
logits = self.decoder(input_ids, encoder_outputs)
|
|
|
1147 |
print(f"DECODER GRAD: {name} = {norm:.6f}")
|
1148 |
return None
|
1149 |
|
1150 |
+
def resetcounter(self):
|
1151 |
+
self.counter = 0
|
1152 |
print("Counter reset to 0.")
|
1153 |
|
1154 |
metric = evaluate.load(path="wer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1155 |
|
1156 |
@dataclass
|
1157 |
class DataCollator:
|
|
|
1397 |
f0, t = pw.dio(wav_np, sampling_rate,
|
1398 |
frame_period=hop_length/sampling_rate*1000)
|
1399 |
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
1400 |
+
f0 = torch.from_numpy(f0)
|
1401 |
batch["pitch"] = f0
|
1402 |
|
1403 |
if frequency:
|
1404 |
wav_np = wav.numpy().astype(np.float64)
|
1405 |
f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000)
|
1406 |
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
1407 |
+
f0 = torch.from_numpy(f0)
|
1408 |
batch["f0"] = f0
|
1409 |
|
1410 |
if spectrogram and waveforms and pitch:
|
|
|
1619 |
gradient_accumulation_steps=1,
|
1620 |
eval_accumulation_steps=1,
|
1621 |
eval_strategy="steps",
|
1622 |
+
save_strategy="steps",
|
|
|
|
|
1623 |
max_steps=max_steps,
|
1624 |
save_steps=save_steps,
|
1625 |
eval_steps=eval_steps,
|
|
|
1693 |
text_dims=512,
|
1694 |
text_idx=4,
|
1695 |
act="swish",
|
1696 |
+
debug={},
|
1697 |
cross_attn=True,
|
1698 |
features = ["spectrogram"]
|
1699 |
)
|
1700 |
|
1701 |
+
sanity_check = False
|
1702 |
+
|
1703 |
training_args = sanity(sanity_check)
|
1704 |
dataset_config = {
|
1705 |
"spectrogram": True,
|
|
|
1723 |
"normalized": False}
|
1724 |
|
1725 |
model = create_model(param)
|
1726 |
+
|
1727 |
global global_model
|
1728 |
global_model = model
|
1729 |
|
|
|
1752 |
if __name__ == "__main__":
|
1753 |
main()
|
1754 |
|
1755 |
+
|
1756 |
+
|
1757 |
+
|