Update modelA.py
Browse files
modelA.py
CHANGED
@@ -32,7 +32,15 @@ dtype = torch.float32
|
|
32 |
warnings.filterwarnings("ignore")
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def get_activation(act: str) -> nn.Module:
|
|
|
36 |
act_map = {
|
37 |
"gelu": nn.GELU(),
|
38 |
"relu": nn.ReLU(),
|
@@ -496,6 +504,207 @@ class MultiheadA(nn.Module):
|
|
496 |
self.counter += 1
|
497 |
return self.o(wv), qk
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
class t_gate(nn.Module):
|
500 |
def __init__(self, dims, num_types=4, enabled=True):
|
501 |
super().__init__()
|
@@ -949,7 +1158,7 @@ class Echo(nn.Module):
|
|
949 |
self.init_counts = {
|
950 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
951 |
"Conv2d": 0, "SEBlock": 0, "SpeechTransformer": 0,
|
952 |
-
"Residual": 0, "MultiheadA": 0,
|
953 |
"MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
|
954 |
"WEncoder": 0, "PEncoder": 0}
|
955 |
|
@@ -1166,9 +1375,9 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **
|
|
1166 |
return train_dataset, test_dataset
|
1167 |
|
1168 |
def filter_func(x):
|
1169 |
-
return (0 < len(x["transcription"]) <
|
1170 |
len(x["audio"]["array"]) > 0 and
|
1171 |
-
len(x["audio"]["array"]) <
|
1172 |
|
1173 |
raw_train = load_dataset(
|
1174 |
"google/fleurs", "en_us", token=token, split="train[:1000]", trust_remote_code=True)
|
@@ -1322,7 +1531,6 @@ def main():
|
|
1322 |
vocab=40000, ctx=2048, dims=512, head=4, layer=4,
|
1323 |
mels=128, act="swish",
|
1324 |
debug={},
|
1325 |
-
cross_attn=True,
|
1326 |
features=["spectrogram"]
|
1327 |
)
|
1328 |
|
|
|
32 |
warnings.filterwarnings("ignore")
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
35 |
+
PATH = 'E:/hf'
|
36 |
+
os.environ['HF_HOME'] = PATH
|
37 |
+
os.environ['HF_DATASETS_CACHE'] = PATH
|
38 |
+
os.environ['TORCH_HOME'] = PATH
|
39 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
40 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
41 |
+
|
42 |
def get_activation(act: str) -> nn.Module:
|
43 |
+
"""Get activation function by name."""
|
44 |
act_map = {
|
45 |
"gelu": nn.GELU(),
|
46 |
"relu": nn.ReLU(),
|
|
|
504 |
self.counter += 1
|
505 |
return self.o(wv), qk
|
506 |
|
507 |
+
class FocusWindow(nn.Module):
|
508 |
+
|
509 |
+
def __init__(self, dims: int, head: int, max_span: int = 512, max_dist: int = 256,
|
510 |
+
feature_type: str = "waveform", debug: List[str] = []):
|
511 |
+
super().__init__()
|
512 |
+
self.dims = dims
|
513 |
+
self.head = head
|
514 |
+
self.head_dim = dims // head
|
515 |
+
self.max_span = max_span
|
516 |
+
self.max_dist = max_dist
|
517 |
+
self.feature_type = feature_type
|
518 |
+
self.debug = debug
|
519 |
+
|
520 |
+
# Adaptive parameters for focus control
|
521 |
+
self.threshold = nn.Parameter(torch.tensor(0.01))
|
522 |
+
self.s_factor = nn.Parameter(torch.tensor(0.1))
|
523 |
+
self.temp_scale = nn.Parameter(torch.tensor(1.0))
|
524 |
+
self.sharpen = True
|
525 |
+
|
526 |
+
# Feature-specific projections
|
527 |
+
self.q_proj = Linear(dims, dims)
|
528 |
+
self.k_proj = Linear(dims, dims)
|
529 |
+
self.v_proj = Linear(dims, dims)
|
530 |
+
|
531 |
+
# Bias strength controller
|
532 |
+
self.bias_strength = nn.Parameter(torch.tensor(0.5))
|
533 |
+
|
534 |
+
# Feature-specific window sizes
|
535 |
+
self.window_sizes = {
|
536 |
+
"spectrogram": 128,
|
537 |
+
"waveform": 256,
|
538 |
+
"pitch": 64,
|
539 |
+
"envelope": 64,
|
540 |
+
"phase": 64
|
541 |
+
}
|
542 |
+
|
543 |
+
# Feature-specific span lengths
|
544 |
+
self.span_lengths = {
|
545 |
+
"spectrogram": 256,
|
546 |
+
"waveform": 512,
|
547 |
+
"pitch": 128,
|
548 |
+
"envelope": 128,
|
549 |
+
"phase": 128
|
550 |
+
}
|
551 |
+
|
552 |
+
def _focus(self, q, k, v, span_scale, mask=None):
|
553 |
+
|
554 |
+
q_energy = torch.norm(q, dim=-1).mean()
|
555 |
+
k_energy = torch.norm(k, dim=-1).mean()
|
556 |
+
content_richness = (q_energy + k_energy) / 2
|
557 |
+
|
558 |
+
# Dynamic max iterations: more interesting content = more iterations
|
559 |
+
base_iterations = 3
|
560 |
+
max_iterations = int(base_iterations + content_richness * 12)
|
561 |
+
max_iterations = min(max_iterations, 20) # Cap at 20
|
562 |
+
|
563 |
+
iteration = 0
|
564 |
+
prev_attn = torch.zeros_like(q)
|
565 |
+
attn_out = torch.zeros_like(q)
|
566 |
+
attn_weights = None
|
567 |
+
|
568 |
+
threshold = self.threshold.item()
|
569 |
+
s_factor = self.s_factor.item()
|
570 |
+
|
571 |
+
while iteration < max_iterations:
|
572 |
+
span_len = int(self.max_span * span_scale.mean().item())
|
573 |
+
span_len = min(span_len, q.size(1), k.size(1), k.size(1))
|
574 |
+
eff_span = min(span_len, self.max_dist)
|
575 |
+
|
576 |
+
if eff_span == 0:
|
577 |
+
break
|
578 |
+
|
579 |
+
q_span = q[:, :eff_span, :]
|
580 |
+
k_span = k[:, :eff_span, :]
|
581 |
+
v_span = k[:, :eff_span, :]
|
582 |
+
|
583 |
+
batch, ctx, dims = q_span.size()
|
584 |
+
|
585 |
+
q = q_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
586 |
+
k = k_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
587 |
+
v = v_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
588 |
+
|
589 |
+
if self.sharpen:
|
590 |
+
temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
|
591 |
+
else:
|
592 |
+
temperature = 0.5 + self.temp_scale * span_scale.mean().item()
|
593 |
+
|
594 |
+
scale = (dims // self.head) ** -0.5
|
595 |
+
attn = torch.matmul(q, k.transpose(-1, -2)) * scale
|
596 |
+
|
597 |
+
if mask is not None:
|
598 |
+
if mask.dim() == 4:
|
599 |
+
q_len, k_len = q.size(2), k.size(2)
|
600 |
+
mask_q_len = min(mask.size(2), q_len)
|
601 |
+
mask_k_len = min(mask.size(3), k_len)
|
602 |
+
|
603 |
+
mask_part = mask[:, :, :mask_q_len, :mask_k_len]
|
604 |
+
if mask_part.dtype == torch.bool:
|
605 |
+
attn[:, :, :mask_q_len, :mask_k_len] = attn[:, :, :mask_q_len, :mask_k_len].masked_fill(
|
606 |
+
mask_part, float("-inf")
|
607 |
+
)
|
608 |
+
else:
|
609 |
+
attn[:, :, :mask_q_len, :mask_k_len] = attn[:, :, :mask_q_len, :mask_k_len] + mask_part
|
610 |
+
|
611 |
+
attn = F.softmax(attn, dim=-1)
|
612 |
+
|
613 |
+
if mask is not None and mask.dtype == torch.bool:
|
614 |
+
q_len, k_len = q.size(2), k.size(2)
|
615 |
+
mask_q_len = min(mask.size(2), q_len)
|
616 |
+
mask_k_len = min(mask.size(3), k_len)
|
617 |
+
|
618 |
+
binary_mask = (~mask[:, :, :mask_q_len, :mask_k_len]).float()
|
619 |
+
attn_to_mask = attn[:, :, :mask_q_len, :mask_k_len]
|
620 |
+
attn_to_mask = attn_to_mask * binary_mask
|
621 |
+
|
622 |
+
attn_sum = attn_to_mask.sum(dim=-1, keepdim=True)
|
623 |
+
attn_to_mask = attn_to_mask / (attn_sum + 1e-6)
|
624 |
+
|
625 |
+
attn[:, :, :mask_q_len, :mask_k_len] = attn_to_mask
|
626 |
+
|
627 |
+
attn_output = torch.matmul(attn, v)
|
628 |
+
attn_out = attn_output.transpose(1, 2).contiguous().view(batch, ctx, -1)
|
629 |
+
|
630 |
+
diff = torch.abs(attn_out - prev_attn).mean()
|
631 |
+
dynamic_threshold = threshold + s_factor * diff
|
632 |
+
|
633 |
+
if diff < dynamic_threshold:
|
634 |
+
break
|
635 |
+
|
636 |
+
prev_attn = attn_out
|
637 |
+
q = q + attn_out
|
638 |
+
iteration += 1
|
639 |
+
|
640 |
+
return attn_out, attn_weights
|
641 |
+
|
642 |
+
def slide_win(self, x, win_size, span_len, span_scale, mask=None):
|
643 |
+
batch, ctx, dims = x.size()
|
644 |
+
num_windows = (ctx + win_size - 1) // win_size
|
645 |
+
output = torch.zeros_like(x)
|
646 |
+
|
647 |
+
for i in range(num_windows):
|
648 |
+
start_idx = i * win_size
|
649 |
+
end_idx = min((i + 1) * win_size, ctx)
|
650 |
+
window_size = end_idx - start_idx
|
651 |
+
|
652 |
+
k_start = max(0, start_idx - span_len + win_size)
|
653 |
+
k_end = min(start_idx + span_len, ctx)
|
654 |
+
|
655 |
+
q = x[:, start_idx:end_idx, :]
|
656 |
+
k = x[:, k_start:k_end, :]
|
657 |
+
k = k
|
658 |
+
|
659 |
+
window_mask = None
|
660 |
+
if mask is not None:
|
661 |
+
if mask.dim() == 4:
|
662 |
+
window_mask = mask[:, :, start_idx:end_idx, k_start:k_end]
|
663 |
+
|
664 |
+
if window_mask.size(1) == 1:
|
665 |
+
window_mask = window_mask.expand(-1, self.head, -1, -1)
|
666 |
+
|
667 |
+
attn_out, _ = self._focus(
|
668 |
+
q=q, k=k, v=v, span_scale=span_scale, mask=window_mask
|
669 |
+
)
|
670 |
+
|
671 |
+
output[:, start_idx:end_idx, :] = attn_out
|
672 |
+
|
673 |
+
return output
|
674 |
+
|
675 |
+
def forward(self, x, feature_data=None, mask=None, return_bias=True):
|
676 |
+
q = self.q_proj(x)
|
677 |
+
k = self.k_proj(x if feature_data is None else feature_data)
|
678 |
+
v = self.v_proj(x if feature_data is None else feature_data)
|
679 |
+
|
680 |
+
# Create span scale based on feature characteristics
|
681 |
+
if feature_data is not None:
|
682 |
+
# Feature-specific span scaling
|
683 |
+
feature_energy = torch.norm(feature_data, dim=-1).mean(dim=-1, keepdim=True)
|
684 |
+
span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
|
685 |
+
else:
|
686 |
+
span_scale = torch.ones(x.size(0), 1, device=x.device)
|
687 |
+
|
688 |
+
# Get feature-specific parameters
|
689 |
+
win_size = self.window_sizes.get(self.feature_type, 128)
|
690 |
+
span_len = self.span_lengths.get(self.feature_type, 256)
|
691 |
+
|
692 |
+
# Apply sliding window with focus attention
|
693 |
+
output = self.slide_win(
|
694 |
+
x=q,
|
695 |
+
win_size=win_size,
|
696 |
+
span_len=span_len,
|
697 |
+
span_scale=span_scale,
|
698 |
+
mask=mask
|
699 |
+
)
|
700 |
+
|
701 |
+
if return_bias:
|
702 |
+
# Return as bias for main attention
|
703 |
+
bias_strength = torch.sigmoid(self.bias_strength)
|
704 |
+
return bias_strength * output
|
705 |
+
else:
|
706 |
+
return output
|
707 |
+
|
708 |
class t_gate(nn.Module):
|
709 |
def __init__(self, dims, num_types=4, enabled=True):
|
710 |
super().__init__()
|
|
|
1158 |
self.init_counts = {
|
1159 |
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
|
1160 |
"Conv2d": 0, "SEBlock": 0, "SpeechTransformer": 0,
|
1161 |
+
"Residual": 0, "MultiheadA": 0,
|
1162 |
"MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
|
1163 |
"WEncoder": 0, "PEncoder": 0}
|
1164 |
|
|
|
1375 |
return train_dataset, test_dataset
|
1376 |
|
1377 |
def filter_func(x):
|
1378 |
+
return (0 < len(x["transcription"]) < 2048 and
|
1379 |
len(x["audio"]["array"]) > 0 and
|
1380 |
+
len(x["audio"]["array"]) < 2048 * 160)
|
1381 |
|
1382 |
raw_train = load_dataset(
|
1383 |
"google/fleurs", "en_us", token=token, split="train[:1000]", trust_remote_code=True)
|
|
|
1531 |
vocab=40000, ctx=2048, dims=512, head=4, layer=4,
|
1532 |
mels=128, act="swish",
|
1533 |
debug={},
|
|
|
1534 |
features=["spectrogram"]
|
1535 |
)
|
1536 |
|