Update modelA.py
Browse files
modelA.py
CHANGED
@@ -32,6 +32,13 @@ dtype = torch.float32
|
|
32 |
warnings.filterwarnings("ignore")
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def get_activation(act: str) -> nn.Module:
|
36 |
"""Get activation function by name."""
|
37 |
act_map = {
|
@@ -193,7 +200,7 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
|
|
193 |
axs[0].legend(loc='upper right', fontsize='small')
|
194 |
axs[-1].set_xlabel("t (s)")
|
195 |
fig.suptitle(title, fontsize=16)
|
196 |
-
plt.tight_layout(rect=[0, 0, 1, 0.97])
|
197 |
plt.show()
|
198 |
return fig
|
199 |
|
@@ -245,17 +252,17 @@ class RMSNorm(nn.Module):
|
|
245 |
self.eps = eps
|
246 |
self.elementwise_affine = elementwise_affine
|
247 |
if self.elementwise_affine:
|
248 |
-
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
|
249 |
init.ones_(self.weight)
|
250 |
else:
|
251 |
self.register_parameter("weight", None)
|
252 |
def forward(self, x):
|
253 |
-
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
254 |
|
255 |
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
|
256 |
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
|
257 |
eps: float = 1e-5) -> Tensor:
|
258 |
-
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
259 |
|
260 |
def get_device():
|
261 |
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -339,7 +346,7 @@ class rotary(nn.Module):
|
|
339 |
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
340 |
batch, ctx, dims = x.shape
|
341 |
else:
|
342 |
-
batch, head, ctx, head_dim = x.shape
|
343 |
|
344 |
if f0 is not None:
|
345 |
if f0.dim() == 2:
|
@@ -365,7 +372,6 @@ class rotary(nn.Module):
|
|
365 |
radius_mean = radius.mean() if 'radius' in locals() else 0.0
|
366 |
print(f" [{layer}] [Radius] {radius_shape} {radius_mean:.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
367 |
print(f" [{layer}] [Radius] {radius}")
|
368 |
-
# self.theta_values.append(theta.item())
|
369 |
self.counter += 1
|
370 |
return freqs.unsqueeze(0)
|
371 |
|
@@ -386,7 +392,8 @@ class MultiheadA(nn.Module):
|
|
386 |
|
387 |
rbf = False
|
388 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
389 |
-
zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [],
|
|
|
390 |
super(MultiheadA, self).__init__()
|
391 |
|
392 |
self.dims = dims
|
@@ -419,6 +426,16 @@ class MultiheadA(nn.Module):
|
|
419 |
else:
|
420 |
self.rope = None
|
421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
423 |
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
424 |
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
@@ -440,7 +457,7 @@ class MultiheadA(nn.Module):
|
|
440 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
441 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
442 |
|
443 |
-
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
|
444 |
|
445 |
x = x.to(device, dtype)
|
446 |
if xa is not None:
|
@@ -459,8 +476,8 @@ class MultiheadA(nn.Module):
|
|
459 |
q2 = q.shape[2]
|
460 |
k2 = k.shape[2]
|
461 |
|
462 |
-
q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
|
463 |
-
k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
|
464 |
else:
|
465 |
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
466 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
@@ -468,10 +485,26 @@ class MultiheadA(nn.Module):
|
|
468 |
|
469 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
if self.rbf:
|
472 |
qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
473 |
if self.use_pbias:
|
474 |
-
pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None)
|
475 |
if pbias is not None:
|
476 |
qk = qk + pbias[:,:,:q2,:q2]
|
477 |
|
@@ -481,9 +514,6 @@ class MultiheadA(nn.Module):
|
|
481 |
zscale[token_ids.float() == self.pad_token] = fzero
|
482 |
|
483 |
if mask is not None:
|
484 |
-
# mask = mask[:q2, :q2]#torch.tril(torch.ones(q2, q2, device=q.device))
|
485 |
-
# audio_mask = torch.ones(q2, k2 - q2, device=q.device)
|
486 |
-
# mask = torch.cat([mask, audio_mask], dim=-1)
|
487 |
mask = mask.unsqueeze(0).unsqueeze(0)
|
488 |
qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
|
489 |
|
@@ -499,7 +529,7 @@ class MultiheadA(nn.Module):
|
|
499 |
class FocusWindow(nn.Module):
|
500 |
|
501 |
def __init__(self, dims: int, head: int, max_span: int = 512, max_dist: int = 256,
|
502 |
-
feature_type: str = "waveform", debug: List[str] = []):
|
503 |
super().__init__()
|
504 |
self.dims = dims
|
505 |
self.head = head
|
@@ -508,22 +538,19 @@ class FocusWindow(nn.Module):
|
|
508 |
self.max_dist = max_dist
|
509 |
self.feature_type = feature_type
|
510 |
self.debug = debug
|
511 |
-
|
512 |
-
|
513 |
self.threshold = nn.Parameter(torch.tensor(0.01))
|
514 |
self.s_factor = nn.Parameter(torch.tensor(0.1))
|
515 |
self.temp_scale = nn.Parameter(torch.tensor(1.0))
|
516 |
self.sharpen = True
|
517 |
|
518 |
-
# Feature-specific projections
|
519 |
self.q_proj = Linear(dims, dims)
|
520 |
self.k_proj = Linear(dims, dims)
|
521 |
self.v_proj = Linear(dims, dims)
|
522 |
|
523 |
-
# Bias strength controller
|
524 |
self.bias_strength = nn.Parameter(torch.tensor(0.5))
|
525 |
|
526 |
-
# Feature-specific window sizes
|
527 |
self.window_sizes = {
|
528 |
"spectrogram": 128,
|
529 |
"waveform": 256,
|
@@ -532,7 +559,6 @@ class FocusWindow(nn.Module):
|
|
532 |
"phase": 64
|
533 |
}
|
534 |
|
535 |
-
# Feature-specific span lengths
|
536 |
self.span_lengths = {
|
537 |
"spectrogram": 256,
|
538 |
"waveform": 512,
|
@@ -541,16 +567,32 @@ class FocusWindow(nn.Module):
|
|
541 |
"phase": 128
|
542 |
}
|
543 |
|
544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
|
|
|
|
|
546 |
q_energy = torch.norm(q, dim=-1).mean()
|
547 |
k_energy = torch.norm(k, dim=-1).mean()
|
548 |
content_richness = (q_energy + k_energy) / 2
|
549 |
|
550 |
-
# Dynamic max iterations: more interesting content = more iterations
|
551 |
base_iterations = 3
|
552 |
max_iterations = int(base_iterations + content_richness * 12)
|
553 |
-
max_iterations = min(max_iterations, 20)
|
554 |
|
555 |
iteration = 0
|
556 |
prev_attn = torch.zeros_like(q)
|
@@ -570,13 +612,13 @@ class FocusWindow(nn.Module):
|
|
570 |
|
571 |
q_span = q[:, :eff_span, :]
|
572 |
k_span = k[:, :eff_span, :]
|
573 |
-
v_span =
|
574 |
|
575 |
batch, ctx, dims = q_span.size()
|
576 |
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
|
581 |
if self.sharpen:
|
582 |
temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
|
@@ -584,11 +626,11 @@ class FocusWindow(nn.Module):
|
|
584 |
temperature = 0.5 + self.temp_scale * span_scale.mean().item()
|
585 |
|
586 |
scale = (dims // self.head) ** -0.5
|
587 |
-
attn = torch.matmul(
|
588 |
|
589 |
if mask is not None:
|
590 |
if mask.dim() == 4:
|
591 |
-
q_len, k_len =
|
592 |
mask_q_len = min(mask.size(2), q_len)
|
593 |
mask_k_len = min(mask.size(3), k_len)
|
594 |
|
@@ -603,7 +645,7 @@ class FocusWindow(nn.Module):
|
|
603 |
attn = F.softmax(attn, dim=-1)
|
604 |
|
605 |
if mask is not None and mask.dtype == torch.bool:
|
606 |
-
q_len, k_len =
|
607 |
mask_q_len = min(mask.size(2), q_len)
|
608 |
mask_k_len = min(mask.size(3), k_len)
|
609 |
|
@@ -616,8 +658,11 @@ class FocusWindow(nn.Module):
|
|
616 |
|
617 |
attn[:, :, :mask_q_len, :mask_k_len] = attn_to_mask
|
618 |
|
619 |
-
attn_output = torch.matmul(attn,
|
620 |
-
attn_out = attn_output.transpose(1, 2).contiguous().view(batch, ctx,
|
|
|
|
|
|
|
621 |
|
622 |
diff = torch.abs(attn_out - prev_attn).mean()
|
623 |
dynamic_threshold = threshold + s_factor * diff
|
@@ -626,7 +671,6 @@ class FocusWindow(nn.Module):
|
|
626 |
break
|
627 |
|
628 |
prev_attn = attn_out
|
629 |
-
q = q + attn_out
|
630 |
iteration += 1
|
631 |
|
632 |
return attn_out, attn_weights
|
@@ -644,9 +688,9 @@ class FocusWindow(nn.Module):
|
|
644 |
k_start = max(0, start_idx - span_len + win_size)
|
645 |
k_end = min(start_idx + span_len, ctx)
|
646 |
|
647 |
-
q = x[:, start_idx:end_idx, :]
|
648 |
-
k = x[:, k_start:k_end, :]
|
649 |
-
|
650 |
|
651 |
window_mask = None
|
652 |
if mask is not None:
|
@@ -656,32 +700,37 @@ class FocusWindow(nn.Module):
|
|
656 |
if window_mask.size(1) == 1:
|
657 |
window_mask = window_mask.expand(-1, self.head, -1, -1)
|
658 |
|
659 |
-
attn_out, _ = self._focus(
|
660 |
-
q=q, k=k, v=v, span_scale=span_scale, mask=window_mask
|
661 |
-
)
|
662 |
|
663 |
output[:, start_idx:end_idx, :] = attn_out
|
664 |
|
665 |
return output
|
666 |
|
667 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
668 |
q = self.q_proj(x)
|
669 |
k = self.k_proj(x if xa is None else xa)
|
670 |
v = self.v_proj(x if xa is None else xa)
|
671 |
|
672 |
-
# Create span scale based on feature characteristics
|
673 |
if xa is not None:
|
674 |
-
# Feature-specific span scaling
|
675 |
feature_energy = torch.norm(xa, dim=-1).mean(dim=-1, keepdim=True)
|
676 |
span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
|
677 |
else:
|
678 |
span_scale = torch.ones(x.size(0), 1, device=x.device)
|
679 |
|
680 |
-
# Get feature-specific parameters
|
681 |
win_size = self.window_sizes.get(self.feature_type, 128)
|
682 |
span_len = self.span_lengths.get(self.feature_type, 256)
|
683 |
|
684 |
-
# Apply sliding window with focus attention
|
685 |
output = self.slide_win(
|
686 |
x=q,
|
687 |
win_size=win_size,
|
@@ -689,14 +738,133 @@ class FocusWindow(nn.Module):
|
|
689 |
span_scale=span_scale,
|
690 |
mask=mask
|
691 |
)
|
692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
if return_bias:
|
694 |
-
# Return as bias for main attention
|
695 |
bias_strength = torch.sigmoid(self.bias_strength)
|
696 |
return bias_strength * output
|
697 |
else:
|
698 |
return output
|
699 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
700 |
class t_gate(nn.Module):
|
701 |
def __init__(self, dims, num_types=4, enabled=True):
|
702 |
super().__init__()
|
@@ -821,6 +989,7 @@ class Residual(nn.Module):
|
|
821 |
bx = b * ax + (1 - b) * x
|
822 |
cx = self.lnb(bx)
|
823 |
dx = self.mlp(cx)
|
|
|
824 |
ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
|
825 |
fx = x + ex + dx
|
826 |
gx = self.lnc(fx)
|
@@ -1066,7 +1235,7 @@ class SpeechTransformer(nn.Module):
|
|
1066 |
for f in self.features:
|
1067 |
if f in enc and f in self.blocks:
|
1068 |
xa = enc[f]
|
1069 |
-
for block in self.blocks[f]:
|
1070 |
xa = block(xa, enc=enc, layer=layer)
|
1071 |
out[f] = xa
|
1072 |
xa = xa + self.audio_embedding[:xa.shape[1]]
|
@@ -1327,7 +1496,6 @@ def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **data
|
|
1327 |
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
1328 |
spec = (log_mel + 4.0) / 4.0
|
1329 |
spec = torch.tensor(spec)
|
1330 |
-
# batch["spectrogram"] = spec
|
1331 |
|
1332 |
wav_np = wav.numpy().astype(np.float64)
|
1333 |
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
|
@@ -1340,8 +1508,6 @@ def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **data
|
|
1340 |
"spectrogram": spec,
|
1341 |
"f0": f0,
|
1342 |
"labels": labels,
|
1343 |
-
# "waveform": wav,
|
1344 |
-
# "pitch": f0,
|
1345 |
}
|
1346 |
|
1347 |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
|
@@ -1569,11 +1735,11 @@ def main():
|
|
1569 |
trainer = Seq2SeqTrainer(
|
1570 |
args=training_args,
|
1571 |
model=model,
|
1572 |
-
train_dataset=train_dataset,
|
1573 |
-
eval_dataset=test_dataset,
|
1574 |
-
data_collator=DataCollator(tokenizer=tokenizer),
|
1575 |
compute_metrics=metrics_fn,
|
1576 |
-
optimizers=(optimizer, scheduler)
|
1577 |
)
|
1578 |
model.init_weights()
|
1579 |
trainer.train()
|
|
|
32 |
warnings.filterwarnings("ignore")
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
35 |
+
PATH = 'E:/hf'
|
36 |
+
os.environ['HF_HOME'] = PATH
|
37 |
+
os.environ['HF_DATASETS_CACHE'] = PATH
|
38 |
+
os.environ['TORCH_HOME'] = PATH
|
39 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
40 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
41 |
+
|
42 |
def get_activation(act: str) -> nn.Module:
|
43 |
"""Get activation function by name."""
|
44 |
act_map = {
|
|
|
200 |
axs[0].legend(loc='upper right', fontsize='small')
|
201 |
axs[-1].set_xlabel("t (s)")
|
202 |
fig.suptitle(title, fontsize=16)
|
203 |
+
plt.tight_layout(rect=[0, 0, 1, 0.97])
|
204 |
plt.show()
|
205 |
return fig
|
206 |
|
|
|
252 |
self.eps = eps
|
253 |
self.elementwise_affine = elementwise_affine
|
254 |
if self.elementwise_affine:
|
255 |
+
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
|
256 |
init.ones_(self.weight)
|
257 |
else:
|
258 |
self.register_parameter("weight", None)
|
259 |
def forward(self, x):
|
260 |
+
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
261 |
|
262 |
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
|
263 |
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
|
264 |
eps: float = 1e-5) -> Tensor:
|
265 |
+
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
266 |
|
267 |
def get_device():
|
268 |
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
346 |
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
347 |
batch, ctx, dims = x.shape
|
348 |
else:
|
349 |
+
batch, head, ctx, head_dim = x.shape
|
350 |
|
351 |
if f0 is not None:
|
352 |
if f0.dim() == 2:
|
|
|
372 |
radius_mean = radius.mean() if 'radius' in locals() else 0.0
|
373 |
print(f" [{layer}] [Radius] {radius_shape} {radius_mean:.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
374 |
print(f" [{layer}] [Radius] {radius}")
|
|
|
375 |
self.counter += 1
|
376 |
return freqs.unsqueeze(0)
|
377 |
|
|
|
392 |
|
393 |
rbf = False
|
394 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
395 |
+
zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [],
|
396 |
+
optim_attn=False, use_pbias=False, use_smart_sensor=False, use_focus_bias=False):
|
397 |
super(MultiheadA, self).__init__()
|
398 |
|
399 |
self.dims = dims
|
|
|
426 |
else:
|
427 |
self.rope = None
|
428 |
|
429 |
+
self.use_smart_sensor = use_smart_sensor
|
430 |
+
if use_smart_sensor:
|
431 |
+
self.head_gate = nn.Parameter(torch.ones(head))
|
432 |
+
self.guidance_strength = nn.Parameter(torch.tensor(0.3))
|
433 |
+
self.lr_scale = nn.Parameter(torch.tensor(1.0))
|
434 |
+
|
435 |
+
self.use_focus_bias = use_focus_bias
|
436 |
+
if use_focus_bias:
|
437 |
+
self.focus_bias_strength = nn.Parameter(torch.tensor(0.3))
|
438 |
+
|
439 |
def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
440 |
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
441 |
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
|
|
457 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
458 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
459 |
|
460 |
+
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, enc = None, layer = None, feature_type="audio", need_weights=True, focus_bias=None, head_weights=None, cross_guidance=None, attention_lr=None) -> tuple:
|
461 |
|
462 |
x = x.to(device, dtype)
|
463 |
if xa is not None:
|
|
|
476 |
q2 = q.shape[2]
|
477 |
k2 = k.shape[2]
|
478 |
|
479 |
+
q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer)))
|
480 |
+
k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer)))
|
481 |
else:
|
482 |
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
483 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
|
|
485 |
|
486 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
487 |
|
488 |
+
if self.use_focus_bias and focus_bias is not None:
|
489 |
+
bias_strength = torch.sigmoid(self.focus_bias_strength)
|
490 |
+
qk = qk + bias_strength * focus_bias
|
491 |
+
|
492 |
+
if self.use_smart_sensor and head_weights is not None:
|
493 |
+
head_gate = torch.sigmoid(self.head_gate) * head_weights
|
494 |
+
qk = qk * head_gate.unsqueeze(-1).unsqueeze(-1)
|
495 |
+
|
496 |
+
if self.use_smart_sensor and cross_guidance is not None:
|
497 |
+
guidance_strength = torch.sigmoid(self.guidance_strength)
|
498 |
+
qk = qk + guidance_strength * cross_guidance
|
499 |
+
|
500 |
+
if self.use_smart_sensor and attention_lr is not None:
|
501 |
+
lr_scale = torch.sigmoid(self.lr_scale)
|
502 |
+
self.register_buffer("predicted_lr", attention_lr * lr_scale)
|
503 |
+
|
504 |
if self.rbf:
|
505 |
qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
506 |
if self.use_pbias:
|
507 |
+
pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None)
|
508 |
if pbias is not None:
|
509 |
qk = qk + pbias[:,:,:q2,:q2]
|
510 |
|
|
|
514 |
zscale[token_ids.float() == self.pad_token] = fzero
|
515 |
|
516 |
if mask is not None:
|
|
|
|
|
|
|
517 |
mask = mask.unsqueeze(0).unsqueeze(0)
|
518 |
qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
|
519 |
|
|
|
529 |
class FocusWindow(nn.Module):
|
530 |
|
531 |
def __init__(self, dims: int, head: int, max_span: int = 512, max_dist: int = 256,
|
532 |
+
feature_type: str = "waveform", debug: List[str] = [], learn_lr: bool = False, base_lr: float = 0.001):
|
533 |
super().__init__()
|
534 |
self.dims = dims
|
535 |
self.head = head
|
|
|
538 |
self.max_dist = max_dist
|
539 |
self.feature_type = feature_type
|
540 |
self.debug = debug
|
541 |
+
self.learn_lr = learn_lr
|
542 |
+
self.base_lr = base_lr
|
543 |
self.threshold = nn.Parameter(torch.tensor(0.01))
|
544 |
self.s_factor = nn.Parameter(torch.tensor(0.1))
|
545 |
self.temp_scale = nn.Parameter(torch.tensor(1.0))
|
546 |
self.sharpen = True
|
547 |
|
|
|
548 |
self.q_proj = Linear(dims, dims)
|
549 |
self.k_proj = Linear(dims, dims)
|
550 |
self.v_proj = Linear(dims, dims)
|
551 |
|
|
|
552 |
self.bias_strength = nn.Parameter(torch.tensor(0.5))
|
553 |
|
|
|
554 |
self.window_sizes = {
|
555 |
"spectrogram": 128,
|
556 |
"waveform": 256,
|
|
|
559 |
"phase": 64
|
560 |
}
|
561 |
|
|
|
562 |
self.span_lengths = {
|
563 |
"spectrogram": 256,
|
564 |
"waveform": 512,
|
|
|
567 |
"phase": 128
|
568 |
}
|
569 |
|
570 |
+
self.head_router = nn.Sequential(
|
571 |
+
Linear(dims, dims),
|
572 |
+
nn.SiLU(),
|
573 |
+
Linear(dims, head)
|
574 |
+
)
|
575 |
+
|
576 |
+
self.lr_predictor = nn.Sequential(
|
577 |
+
Linear(dims, dims // 4),
|
578 |
+
nn.SiLU(),
|
579 |
+
Linear(dims // 4, 1),
|
580 |
+
nn.Sigmoid()
|
581 |
+
)
|
582 |
+
|
583 |
+
def predict_attention_lr(self, x, feature_data=None):
|
584 |
+
lr_factor = self.lr_predictor(x.mean(dim=1))
|
585 |
+
return self.base_lr * lr_factor
|
586 |
|
587 |
+
def _focus(self, q, k, v, span_scale, mask=None):
|
588 |
+
|
589 |
q_energy = torch.norm(q, dim=-1).mean()
|
590 |
k_energy = torch.norm(k, dim=-1).mean()
|
591 |
content_richness = (q_energy + k_energy) / 2
|
592 |
|
|
|
593 |
base_iterations = 3
|
594 |
max_iterations = int(base_iterations + content_richness * 12)
|
595 |
+
max_iterations = min(max_iterations, 20)
|
596 |
|
597 |
iteration = 0
|
598 |
prev_attn = torch.zeros_like(q)
|
|
|
612 |
|
613 |
q_span = q[:, :eff_span, :]
|
614 |
k_span = k[:, :eff_span, :]
|
615 |
+
v_span = v[:, :eff_span, :]
|
616 |
|
617 |
batch, ctx, dims = q_span.size()
|
618 |
|
619 |
+
q_head = q_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
620 |
+
k_head = k_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
621 |
+
v_head = v_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
622 |
|
623 |
if self.sharpen:
|
624 |
temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
|
|
|
626 |
temperature = 0.5 + self.temp_scale * span_scale.mean().item()
|
627 |
|
628 |
scale = (dims // self.head) ** -0.5
|
629 |
+
attn = torch.matmul(q_head, k_head.transpose(-1, -2)) * scale
|
630 |
|
631 |
if mask is not None:
|
632 |
if mask.dim() == 4:
|
633 |
+
q_len, k_len = q_head.size(2), k_head.size(2)
|
634 |
mask_q_len = min(mask.size(2), q_len)
|
635 |
mask_k_len = min(mask.size(3), k_len)
|
636 |
|
|
|
645 |
attn = F.softmax(attn, dim=-1)
|
646 |
|
647 |
if mask is not None and mask.dtype == torch.bool:
|
648 |
+
q_len, k_len = q_head.size(2), k_head.size(2)
|
649 |
mask_q_len = min(mask.size(2), q_len)
|
650 |
mask_k_len = min(mask.size(3), k_len)
|
651 |
|
|
|
658 |
|
659 |
attn[:, :, :mask_q_len, :mask_k_len] = attn_to_mask
|
660 |
|
661 |
+
attn_output = torch.matmul(attn, v_head)
|
662 |
+
attn_out = attn_output.transpose(1, 2).contiguous().view(batch, ctx, dims)
|
663 |
+
|
664 |
+
q = q.clone()
|
665 |
+
q[:, :eff_span, :] = q_span + attn_out
|
666 |
|
667 |
diff = torch.abs(attn_out - prev_attn).mean()
|
668 |
dynamic_threshold = threshold + s_factor * diff
|
|
|
671 |
break
|
672 |
|
673 |
prev_attn = attn_out
|
|
|
674 |
iteration += 1
|
675 |
|
676 |
return attn_out, attn_weights
|
|
|
688 |
k_start = max(0, start_idx - span_len + win_size)
|
689 |
k_end = min(start_idx + span_len, ctx)
|
690 |
|
691 |
+
q = x[:, start_idx:end_idx, :]
|
692 |
+
k = x[:, k_start:k_end, :]
|
693 |
+
v = x[:, k_start:k_end, :]
|
694 |
|
695 |
window_mask = None
|
696 |
if mask is not None:
|
|
|
700 |
if window_mask.size(1) == 1:
|
701 |
window_mask = window_mask.expand(-1, self.head, -1, -1)
|
702 |
|
703 |
+
attn_out, _ = self._focus(q=q, k=k, v=v, span_scale=span_scale, mask=window_mask)
|
|
|
|
|
704 |
|
705 |
output[:, start_idx:end_idx, :] = attn_out
|
706 |
|
707 |
return output
|
708 |
|
709 |
+
def predict_head_importance(self, x, xa=None):
|
710 |
+
if xa is not None:
|
711 |
+
combined = x + 0.1 * xa
|
712 |
+
else:
|
713 |
+
combined = x
|
714 |
+
head_importance = self.head_router(combined.mean(dim=1))
|
715 |
+
return head_importance
|
716 |
+
|
717 |
+
def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=False, return_head_weights=False, learn_lr=False):
|
718 |
+
|
719 |
+
print(f"🎯 FocusWindow running! Input: {x.shape}, Feature: {xa.shape if xa is not None else None}")
|
720 |
+
|
721 |
q = self.q_proj(x)
|
722 |
k = self.k_proj(x if xa is None else xa)
|
723 |
v = self.v_proj(x if xa is None else xa)
|
724 |
|
|
|
725 |
if xa is not None:
|
|
|
726 |
feature_energy = torch.norm(xa, dim=-1).mean(dim=-1, keepdim=True)
|
727 |
span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
|
728 |
else:
|
729 |
span_scale = torch.ones(x.size(0), 1, device=x.device)
|
730 |
|
|
|
731 |
win_size = self.window_sizes.get(self.feature_type, 128)
|
732 |
span_len = self.span_lengths.get(self.feature_type, 256)
|
733 |
|
|
|
734 |
output = self.slide_win(
|
735 |
x=q,
|
736 |
win_size=win_size,
|
|
|
738 |
span_scale=span_scale,
|
739 |
mask=mask
|
740 |
)
|
741 |
+
|
742 |
+
if learn_lr:
|
743 |
+
lr_factor = self.lr_predictor(output.mean(dim=1))
|
744 |
+
return output, lr_factor
|
745 |
+
|
746 |
+
if return_head_weights:
|
747 |
+
head_weights = self.predict_head_importance(x, xa)
|
748 |
+
return output, head_weights
|
749 |
+
|
750 |
if return_bias:
|
|
|
751 |
bias_strength = torch.sigmoid(self.bias_strength)
|
752 |
return bias_strength * output
|
753 |
else:
|
754 |
return output
|
755 |
|
756 |
+
class CrossFeatureFocusAttention(nn.Module):
|
757 |
+
def __init__(self, dims: int, head: int, features: List[str] = ["spectrogram", "pitch"]):
|
758 |
+
super().__init__()
|
759 |
+
self.dims = dims
|
760 |
+
self.head = head
|
761 |
+
self.features = features
|
762 |
+
|
763 |
+
self.cross_attn_layers = nn.ModuleDict({
|
764 |
+
feature: nn.MultiheadAttention(dims, head, batch_first=True)
|
765 |
+
for feature in features
|
766 |
+
})
|
767 |
+
|
768 |
+
self.feature_fusion = nn.Sequential(
|
769 |
+
Linear(dims * len(features), dims),
|
770 |
+
nn.SiLU(),
|
771 |
+
Linear(dims, dims)
|
772 |
+
)
|
773 |
+
|
774 |
+
def forward(self, x, enc, mask=None):
|
775 |
+
if enc is None:
|
776 |
+
return None
|
777 |
+
|
778 |
+
cross_features = []
|
779 |
+
for feature in self.features:
|
780 |
+
if feature in enc:
|
781 |
+
feature_data = enc[feature]
|
782 |
+
if feature_data is not None:
|
783 |
+
attn_out, _ = self.cross_attn_layers[feature](
|
784 |
+
x, feature_data, feature_data,
|
785 |
+
attn_mask=mask
|
786 |
+
)
|
787 |
+
cross_features.append(attn_out)
|
788 |
+
|
789 |
+
if not cross_features:
|
790 |
+
return None
|
791 |
+
|
792 |
+
if len(cross_features) > 1:
|
793 |
+
fused = torch.cat(cross_features, dim=-1)
|
794 |
+
return self.feature_fusion(fused)
|
795 |
+
else:
|
796 |
+
return cross_features[0]
|
797 |
+
|
798 |
+
class AdaptiveAttentionLR(nn.Module):
|
799 |
+
def __init__(self, dims: int, head: int):
|
800 |
+
super().__init__()
|
801 |
+
self.dims = dims
|
802 |
+
self.head = head
|
803 |
+
|
804 |
+
self.lr_predictor = nn.Sequential(
|
805 |
+
Linear(dims, dims // 4),
|
806 |
+
nn.SiLU(),
|
807 |
+
Linear(dims // 4, 1),
|
808 |
+
nn.Sigmoid()
|
809 |
+
)
|
810 |
+
|
811 |
+
self.quality_estimator = nn.Sequential(
|
812 |
+
Linear(dims, dims // 2),
|
813 |
+
nn.SiLU(),
|
814 |
+
Linear(dims // 2, 1),
|
815 |
+
nn.Sigmoid()
|
816 |
+
)
|
817 |
+
|
818 |
+
def forward(self, x, feature_data=None, mask=None):
|
819 |
+
quality = self.quality_estimator(x.mean(dim=1))
|
820 |
+
|
821 |
+
lr_factor = self.lr_predictor(x.mean(dim=1))
|
822 |
+
|
823 |
+
adaptive_lr = quality * lr_factor
|
824 |
+
|
825 |
+
return adaptive_lr, adaptive_lr
|
826 |
+
|
827 |
+
class SmartSensorResidual(nn.Module):
|
828 |
+
def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
|
829 |
+
use_smart_sensor=True):
|
830 |
+
super().__init__()
|
831 |
+
self.ctx = ctx
|
832 |
+
self.dims = dims
|
833 |
+
self.head = head
|
834 |
+
self.act = act
|
835 |
+
self.debug = debug
|
836 |
+
|
837 |
+
if use_smart_sensor:
|
838 |
+
self.focus_attn = FocusWindow(dims, head, feature_type="waveform")
|
839 |
+
self.cross_feature_guide = CrossFeatureFocusAttention(dims, head,
|
840 |
+
features=["spectrogram", "pitch"])
|
841 |
+
self.adaptive_lr = AdaptiveAttentionLR(dims, head)
|
842 |
+
|
843 |
+
self.attna = MultiheadA(dims, head, debug=debug)
|
844 |
+
self.lna = RMSNorm(dims)
|
845 |
+
|
846 |
+
def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio"):
|
847 |
+
if hasattr(self, 'focus_attn') and enc is not None:
|
848 |
+
focus_output, head_weights = self.focus_attn(x, enc.get("waveform"), mask,
|
849 |
+
return_head_weights=True)
|
850 |
+
|
851 |
+
cross_guidance = self.cross_feature_guide(x, enc, mask)
|
852 |
+
|
853 |
+
_, attention_lr = self.adaptive_lr(x, enc.get("waveform"), mask)
|
854 |
+
|
855 |
+
x = x + self.attna(
|
856 |
+
self.lna(x),
|
857 |
+
xa=None,
|
858 |
+
mask=mask,
|
859 |
+
head_weights=head_weights,
|
860 |
+
cross_guidance=cross_guidance,
|
861 |
+
attention_lr=attention_lr,
|
862 |
+
enc=enc,
|
863 |
+
layer=layer
|
864 |
+
)[0]
|
865 |
+
|
866 |
+
return x
|
867 |
+
|
868 |
class t_gate(nn.Module):
|
869 |
def __init__(self, dims, num_types=4, enabled=True):
|
870 |
super().__init__()
|
|
|
989 |
bx = b * ax + (1 - b) * x
|
990 |
cx = self.lnb(bx)
|
991 |
dx = self.mlp(cx)
|
992 |
+
|
993 |
ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
|
994 |
fx = x + ex + dx
|
995 |
gx = self.lnc(fx)
|
|
|
1235 |
for f in self.features:
|
1236 |
if f in enc and f in self.blocks:
|
1237 |
xa = enc[f]
|
1238 |
+
for block in self.blocks[f]:
|
1239 |
xa = block(xa, enc=enc, layer=layer)
|
1240 |
out[f] = xa
|
1241 |
xa = xa + self.audio_embedding[:xa.shape[1]]
|
|
|
1496 |
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
1497 |
spec = (log_mel + 4.0) / 4.0
|
1498 |
spec = torch.tensor(spec)
|
|
|
1499 |
|
1500 |
wav_np = wav.numpy().astype(np.float64)
|
1501 |
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
|
|
|
1508 |
"spectrogram": spec,
|
1509 |
"f0": f0,
|
1510 |
"labels": labels,
|
|
|
|
|
1511 |
}
|
1512 |
|
1513 |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
|
|
|
1735 |
trainer = Seq2SeqTrainer(
|
1736 |
args=training_args,
|
1737 |
model=model,
|
1738 |
+
train_dataset=train_dataset,
|
1739 |
+
eval_dataset=test_dataset,
|
1740 |
+
data_collator=DataCollator(tokenizer=tokenizer),
|
1741 |
compute_metrics=metrics_fn,
|
1742 |
+
optimizers=(optimizer, scheduler)
|
1743 |
)
|
1744 |
model.init_weights()
|
1745 |
trainer.train()
|