Update modelA.py
Browse files
modelA.py
CHANGED
@@ -33,6 +33,7 @@ warnings.filterwarnings("ignore")
|
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
35 |
def get_activation(act: str) -> nn.Module:
|
|
|
36 |
act_map = {
|
37 |
"gelu": nn.GELU(),
|
38 |
"relu": nn.ReLU(),
|
@@ -50,11 +51,11 @@ def get_activation(act: str) -> nn.Module:
|
|
50 |
@dataclass
|
51 |
class Dimensions:
|
52 |
vocab: int
|
|
|
53 |
ctx: int
|
54 |
dims: int
|
55 |
head: int
|
56 |
layer: int
|
57 |
-
mels: int
|
58 |
act: str
|
59 |
debug: List[str]
|
60 |
features: List[str]
|
@@ -197,6 +198,7 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
|
|
197 |
return fig
|
198 |
|
199 |
def valid(default_value, *items):
|
|
|
200 |
for item in items:
|
201 |
if item is not None:
|
202 |
return item
|
@@ -264,6 +266,21 @@ def get_dtype():
|
|
264 |
def tox():
|
265 |
return {"device": get_device(), "dtype": get_dtype()}
|
266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
def sinusoids(length, channels, max_tscale=10000):
|
268 |
assert channels % 2 == 0
|
269 |
log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
|
@@ -323,46 +340,39 @@ class rotary(nn.Module):
|
|
323 |
idx = torch.arange(ctx, device=f0.device)
|
324 |
idx = (idx * F).long().clamp(0, L - 1)
|
325 |
radius = radius[idx]
|
326 |
-
|
|
|
|
|
327 |
else:
|
328 |
-
return torch.polar(torch.ones_like(freqs), freqs)
|
329 |
|
330 |
-
def
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
elif isinstance(x, torch.Tensor) and x.ndim == 2:
|
336 |
-
batch, ctx = x.shape
|
337 |
-
elif isinstance(x, torch.Tensor) and x.ndim == 3:
|
338 |
-
batch, ctx, dims = x.shape
|
339 |
else:
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
341 |
|
|
|
342 |
if f0 is not None:
|
343 |
if f0.dim() == 2:
|
344 |
f0 = f0.squeeze(0)
|
345 |
theta = f0 + self.theta
|
346 |
else:
|
347 |
theta = self.theta
|
348 |
-
|
349 |
freqs = self.theta_freqs(theta)
|
350 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
351 |
freqs = t[:, None] * freqs
|
352 |
-
|
353 |
-
|
354 |
-
radius = f0.to(device, dtype)
|
355 |
-
freqs = torch.polar(radius.unsqueeze(-1), freqs)
|
356 |
-
else:
|
357 |
-
radius = torch.ones_like(freqs)
|
358 |
-
freqs = torch.polar(radius, freqs)
|
359 |
-
|
360 |
if "radius" in self.debug and self.counter == 10:
|
361 |
-
|
362 |
-
radius_shape = radius.shape if 'radius' in locals() else "N/A"
|
363 |
-
radius_mean = radius.mean() if 'radius' in locals() else 0.0
|
364 |
-
print(f" [{layer}] [Radius] {radius_shape} {radius_mean:.2f} [Theta] {theta_value:.2f} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
365 |
-
print(f" [{layer}] [Radius] {radius}")
|
366 |
self.counter += 1
|
367 |
return freqs.unsqueeze(0)
|
368 |
|
@@ -383,8 +393,7 @@ class MultiheadA(nn.Module):
|
|
383 |
|
384 |
rbf = False
|
385 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
386 |
-
zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [],
|
387 |
-
optim_attn=False, use_pbias=False, use_smart_sensor=False, use_focus_bias=False):
|
388 |
super(MultiheadA, self).__init__()
|
389 |
|
390 |
self.dims = dims
|
@@ -417,16 +426,6 @@ class MultiheadA(nn.Module):
|
|
417 |
else:
|
418 |
self.rope = None
|
419 |
|
420 |
-
self.use_smart_sensor = use_smart_sensor
|
421 |
-
if use_smart_sensor:
|
422 |
-
self.head_gate = nn.Parameter(torch.ones(head))
|
423 |
-
self.guidance_strength = nn.Parameter(torch.tensor(0.3))
|
424 |
-
self.lr_scale = nn.Parameter(torch.tensor(1.0))
|
425 |
-
|
426 |
-
self.use_focus_bias = use_focus_bias
|
427 |
-
if use_focus_bias:
|
428 |
-
self.focus_bias_strength = nn.Parameter(torch.tensor(0.3))
|
429 |
-
|
430 |
def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
431 |
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
432 |
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
@@ -448,7 +447,7 @@ class MultiheadA(nn.Module):
|
|
448 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
449 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
450 |
|
451 |
-
def forward(self, x: Tensor, xa
|
452 |
|
453 |
x = x.to(device, dtype)
|
454 |
if xa is not None:
|
@@ -476,26 +475,10 @@ class MultiheadA(nn.Module):
|
|
476 |
|
477 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
478 |
|
479 |
-
if self.use_focus_bias and focus_bias is not None:
|
480 |
-
bias_strength = torch.sigmoid(self.focus_bias_strength)
|
481 |
-
qk = qk + bias_strength * focus_bias
|
482 |
-
|
483 |
-
if self.use_smart_sensor and head_weights is not None:
|
484 |
-
head_gate = torch.sigmoid(self.head_gate) * head_weights
|
485 |
-
qk = qk * head_gate.unsqueeze(-1).unsqueeze(-1)
|
486 |
-
|
487 |
-
if self.use_smart_sensor and cross_guidance is not None:
|
488 |
-
guidance_strength = torch.sigmoid(self.guidance_strength)
|
489 |
-
qk = qk + guidance_strength * cross_guidance
|
490 |
-
|
491 |
-
if self.use_smart_sensor and attention_lr is not None:
|
492 |
-
lr_scale = torch.sigmoid(self.lr_scale)
|
493 |
-
self.register_buffer("predicted_lr", attention_lr * lr_scale)
|
494 |
-
|
495 |
if self.rbf:
|
496 |
qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
497 |
if self.use_pbias:
|
498 |
-
pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None)
|
499 |
if pbias is not None:
|
500 |
qk = qk + pbias[:,:,:q2,:q2]
|
501 |
|
@@ -504,10 +487,8 @@ class MultiheadA(nn.Module):
|
|
504 |
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
|
505 |
zscale[token_ids.float() == self.pad_token] = fzero
|
506 |
|
507 |
-
if
|
508 |
-
mask = mask.unsqueeze(0).unsqueeze(0)
|
509 |
qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
|
510 |
-
|
511 |
qk = qk * zscale.unsqueeze(-2)
|
512 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
513 |
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
@@ -517,340 +498,6 @@ class MultiheadA(nn.Module):
|
|
517 |
self.counter += 1
|
518 |
return self.o(wv), qk
|
519 |
|
520 |
-
class FocusWindow(nn.Module):
|
521 |
-
|
522 |
-
def __init__(self, dims: int, head: int, max_span: int = 512, max_dist: int = 256,
|
523 |
-
feature_type: str = "waveform", debug: List[str] = [], learn_lr: bool = False, base_lr: float = 0.001):
|
524 |
-
super().__init__()
|
525 |
-
self.dims = dims
|
526 |
-
self.head = head
|
527 |
-
self.head_dim = dims // head
|
528 |
-
self.max_span = max_span
|
529 |
-
self.max_dist = max_dist
|
530 |
-
self.feature_type = feature_type
|
531 |
-
self.debug = debug
|
532 |
-
self.learn_lr = learn_lr
|
533 |
-
self.base_lr = base_lr
|
534 |
-
self.threshold = nn.Parameter(torch.tensor(0.01))
|
535 |
-
self.s_factor = nn.Parameter(torch.tensor(0.1))
|
536 |
-
self.temp_scale = nn.Parameter(torch.tensor(1.0))
|
537 |
-
self.sharpen = True
|
538 |
-
|
539 |
-
self.q_proj = Linear(dims, dims)
|
540 |
-
self.k_proj = Linear(dims, dims)
|
541 |
-
self.v_proj = Linear(dims, dims)
|
542 |
-
|
543 |
-
self.bias_strength = nn.Parameter(torch.tensor(0.5))
|
544 |
-
|
545 |
-
self.window_sizes = {
|
546 |
-
"spectrogram": 128,
|
547 |
-
"waveform": 256,
|
548 |
-
"pitch": 64,
|
549 |
-
"envelope": 64,
|
550 |
-
"phase": 64
|
551 |
-
}
|
552 |
-
|
553 |
-
self.span_lengths = {
|
554 |
-
"spectrogram": 256,
|
555 |
-
"waveform": 512,
|
556 |
-
"pitch": 128,
|
557 |
-
"envelope": 128,
|
558 |
-
"phase": 128
|
559 |
-
}
|
560 |
-
|
561 |
-
self.head_router = nn.Sequential(
|
562 |
-
Linear(dims, dims),
|
563 |
-
nn.SiLU(),
|
564 |
-
Linear(dims, head)
|
565 |
-
)
|
566 |
-
|
567 |
-
self.lr_predictor = nn.Sequential(
|
568 |
-
Linear(dims, dims // 4),
|
569 |
-
nn.SiLU(),
|
570 |
-
Linear(dims // 4, 1),
|
571 |
-
nn.Sigmoid()
|
572 |
-
)
|
573 |
-
|
574 |
-
def predict_attention_lr(self, x, feature_data=None):
|
575 |
-
lr_factor = self.lr_predictor(x.mean(dim=1))
|
576 |
-
return self.base_lr * lr_factor
|
577 |
-
|
578 |
-
def _focus(self, q, k, v, span_scale, mask=None):
|
579 |
-
|
580 |
-
q_energy = torch.norm(q, dim=-1).mean()
|
581 |
-
k_energy = torch.norm(k, dim=-1).mean()
|
582 |
-
content_richness = (q_energy + k_energy) / 2
|
583 |
-
|
584 |
-
base_iterations = 3
|
585 |
-
max_iterations = int(base_iterations + content_richness * 12)
|
586 |
-
max_iterations = min(max_iterations, 20)
|
587 |
-
|
588 |
-
iteration = 0
|
589 |
-
prev_attn = torch.zeros_like(q)
|
590 |
-
attn_out = torch.zeros_like(q)
|
591 |
-
attn_weights = None
|
592 |
-
|
593 |
-
threshold = self.threshold.item()
|
594 |
-
s_factor = self.s_factor.item()
|
595 |
-
|
596 |
-
while iteration < max_iterations:
|
597 |
-
span_len = int(self.max_span * span_scale.mean().item())
|
598 |
-
span_len = min(span_len, q.size(1), k.size(1), k.size(1))
|
599 |
-
eff_span = min(span_len, self.max_dist)
|
600 |
-
|
601 |
-
if eff_span == 0:
|
602 |
-
break
|
603 |
-
|
604 |
-
q_span = q[:, :eff_span, :]
|
605 |
-
k_span = k[:, :eff_span, :]
|
606 |
-
v_span = v[:, :eff_span, :]
|
607 |
-
|
608 |
-
batch, ctx, dims = q_span.size()
|
609 |
-
|
610 |
-
q_head = q_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
611 |
-
k_head = k_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
612 |
-
v_head = v_span.view(batch, ctx, self.head, -1).transpose(1, 2)
|
613 |
-
|
614 |
-
if self.sharpen:
|
615 |
-
temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
|
616 |
-
else:
|
617 |
-
temperature = 0.5 + self.temp_scale * span_scale.mean().item()
|
618 |
-
|
619 |
-
scale = (dims // self.head) ** -0.5
|
620 |
-
attn = torch.matmul(q_head, k_head.transpose(-1, -2)) * scale
|
621 |
-
|
622 |
-
if mask is not None:
|
623 |
-
if mask.dim() == 4:
|
624 |
-
q_len, k_len = q_head.size(2), k_head.size(2)
|
625 |
-
mask_q_len = min(mask.size(2), q_len)
|
626 |
-
mask_k_len = min(mask.size(3), k_len)
|
627 |
-
|
628 |
-
mask_part = mask[:, :, :mask_q_len, :mask_k_len]
|
629 |
-
if mask_part.dtype == torch.bool:
|
630 |
-
attn[:, :, :mask_q_len, :mask_k_len] = attn[:, :, :mask_q_len, :mask_k_len].masked_fill(
|
631 |
-
mask_part, float("-inf")
|
632 |
-
)
|
633 |
-
else:
|
634 |
-
attn[:, :, :mask_q_len, :mask_k_len] = attn[:, :, :mask_q_len, :mask_k_len] + mask_part
|
635 |
-
|
636 |
-
attn = F.softmax(attn, dim=-1)
|
637 |
-
|
638 |
-
if mask is not None and mask.dtype == torch.bool:
|
639 |
-
q_len, k_len = q_head.size(2), k_head.size(2)
|
640 |
-
mask_q_len = min(mask.size(2), q_len)
|
641 |
-
mask_k_len = min(mask.size(3), k_len)
|
642 |
-
|
643 |
-
binary_mask = (~mask[:, :, :mask_q_len, :mask_k_len]).float()
|
644 |
-
attn_to_mask = attn[:, :, :mask_q_len, :mask_k_len]
|
645 |
-
attn_to_mask = attn_to_mask * binary_mask
|
646 |
-
|
647 |
-
attn_sum = attn_to_mask.sum(dim=-1, keepdim=True)
|
648 |
-
attn_to_mask = attn_to_mask / (attn_sum + 1e-6)
|
649 |
-
|
650 |
-
attn[:, :, :mask_q_len, :mask_k_len] = attn_to_mask
|
651 |
-
|
652 |
-
attn_output = torch.matmul(attn, v_head)
|
653 |
-
attn_out = attn_output.transpose(1, 2).contiguous().view(batch, ctx, dims)
|
654 |
-
|
655 |
-
q = q.clone()
|
656 |
-
q[:, :eff_span, :] = q_span + attn_out
|
657 |
-
|
658 |
-
diff = torch.abs(attn_out - prev_attn).mean()
|
659 |
-
dynamic_threshold = threshold + s_factor * diff
|
660 |
-
|
661 |
-
if diff < dynamic_threshold:
|
662 |
-
break
|
663 |
-
|
664 |
-
prev_attn = attn_out
|
665 |
-
iteration += 1
|
666 |
-
|
667 |
-
return attn_out, attn_weights
|
668 |
-
|
669 |
-
def slide_win(self, x, win_size, span_len, span_scale, mask=None):
|
670 |
-
batch, ctx, dims = x.size()
|
671 |
-
num_windows = (ctx + win_size - 1) // win_size
|
672 |
-
output = torch.zeros_like(x)
|
673 |
-
|
674 |
-
for i in range(num_windows):
|
675 |
-
start_idx = i * win_size
|
676 |
-
end_idx = min((i + 1) * win_size, ctx)
|
677 |
-
window_size = end_idx - start_idx
|
678 |
-
|
679 |
-
k_start = max(0, start_idx - span_len + win_size)
|
680 |
-
k_end = min(start_idx + span_len, ctx)
|
681 |
-
|
682 |
-
q = x[:, start_idx:end_idx, :]
|
683 |
-
k = x[:, k_start:k_end, :]
|
684 |
-
v = x[:, k_start:k_end, :]
|
685 |
-
|
686 |
-
window_mask = None
|
687 |
-
if mask is not None:
|
688 |
-
if mask.dim() == 4:
|
689 |
-
window_mask = mask[:, :, start_idx:end_idx, k_start:k_end]
|
690 |
-
|
691 |
-
if window_mask.size(1) == 1:
|
692 |
-
window_mask = window_mask.expand(-1, self.head, -1, -1)
|
693 |
-
|
694 |
-
attn_out, _ = self._focus(q=q, k=k, v=v, span_scale=span_scale, mask=window_mask)
|
695 |
-
|
696 |
-
output[:, start_idx:end_idx, :] = attn_out
|
697 |
-
|
698 |
-
return output
|
699 |
-
|
700 |
-
def predict_head_importance(self, x, xa=None):
|
701 |
-
if xa is not None:
|
702 |
-
combined = x + 0.1 * xa
|
703 |
-
else:
|
704 |
-
combined = x
|
705 |
-
head_importance = self.head_router(combined.mean(dim=1))
|
706 |
-
return head_importance
|
707 |
-
|
708 |
-
def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=False, return_head_weights=False, learn_lr=False):
|
709 |
-
|
710 |
-
q = self.q_proj(x)
|
711 |
-
k = self.k_proj(x if xa is None else xa)
|
712 |
-
v = self.v_proj(x if xa is None else xa)
|
713 |
-
|
714 |
-
if xa is not None:
|
715 |
-
feature_energy = torch.norm(xa, dim=-1).mean(dim=-1, keepdim=True)
|
716 |
-
span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
|
717 |
-
else:
|
718 |
-
span_scale = torch.ones(x.size(0), 1, device=x.device)
|
719 |
-
|
720 |
-
win_size = self.window_sizes.get(self.feature_type, 128)
|
721 |
-
span_len = self.span_lengths.get(self.feature_type, 256)
|
722 |
-
|
723 |
-
output = self.slide_win(
|
724 |
-
x=q,
|
725 |
-
win_size=win_size,
|
726 |
-
span_len=span_len,
|
727 |
-
span_scale=span_scale,
|
728 |
-
mask=mask
|
729 |
-
)
|
730 |
-
|
731 |
-
if learn_lr:
|
732 |
-
lr_factor = self.lr_predictor(output.mean(dim=1))
|
733 |
-
return output, lr_factor
|
734 |
-
|
735 |
-
if return_head_weights:
|
736 |
-
head_weights = self.predict_head_importance(x, xa)
|
737 |
-
return output, head_weights
|
738 |
-
|
739 |
-
if return_bias:
|
740 |
-
bias_strength = torch.sigmoid(self.bias_strength)
|
741 |
-
return bias_strength * output
|
742 |
-
else:
|
743 |
-
return output
|
744 |
-
|
745 |
-
class CrossFeatureFocusAttention(nn.Module):
|
746 |
-
def __init__(self, dims: int, head: int, features: List[str] = ["spectrogram", "pitch"]):
|
747 |
-
super().__init__()
|
748 |
-
self.dims = dims
|
749 |
-
self.head = head
|
750 |
-
self.features = features
|
751 |
-
|
752 |
-
self.cross_attn_layers = nn.ModuleDict({
|
753 |
-
feature: nn.MultiheadAttention(dims, head, batch_first=True)
|
754 |
-
for feature in features
|
755 |
-
})
|
756 |
-
|
757 |
-
self.feature_fusion = nn.Sequential(
|
758 |
-
Linear(dims * len(features), dims),
|
759 |
-
nn.SiLU(),
|
760 |
-
Linear(dims, dims)
|
761 |
-
)
|
762 |
-
|
763 |
-
def forward(self, x, enc, mask=None):
|
764 |
-
if enc is None:
|
765 |
-
return None
|
766 |
-
|
767 |
-
cross_features = []
|
768 |
-
for feature in self.features:
|
769 |
-
if feature in enc:
|
770 |
-
feature_data = enc[feature]
|
771 |
-
if feature_data is not None:
|
772 |
-
attn_out, _ = self.cross_attn_layers[feature](
|
773 |
-
x, feature_data, feature_data,
|
774 |
-
attn_mask=mask
|
775 |
-
)
|
776 |
-
cross_features.append(attn_out)
|
777 |
-
|
778 |
-
if not cross_features:
|
779 |
-
return None
|
780 |
-
|
781 |
-
if len(cross_features) > 1:
|
782 |
-
fused = torch.cat(cross_features, dim=-1)
|
783 |
-
return self.feature_fusion(fused)
|
784 |
-
else:
|
785 |
-
return cross_features[0]
|
786 |
-
|
787 |
-
class AdaptiveAttentionLR(nn.Module):
|
788 |
-
def __init__(self, dims: int, head: int):
|
789 |
-
super().__init__()
|
790 |
-
self.dims = dims
|
791 |
-
self.head = head
|
792 |
-
|
793 |
-
self.lr_predictor = nn.Sequential(
|
794 |
-
Linear(dims, dims // 4),
|
795 |
-
nn.SiLU(),
|
796 |
-
Linear(dims // 4, 1),
|
797 |
-
nn.Sigmoid()
|
798 |
-
)
|
799 |
-
|
800 |
-
self.quality_estimator = nn.Sequential(
|
801 |
-
Linear(dims, dims // 2),
|
802 |
-
nn.SiLU(),
|
803 |
-
Linear(dims // 2, 1),
|
804 |
-
nn.Sigmoid()
|
805 |
-
)
|
806 |
-
|
807 |
-
def forward(self, x, feature_data=None, mask=None):
|
808 |
-
quality = self.quality_estimator(x.mean(dim=1))
|
809 |
-
lr_factor = self.lr_predictor(x.mean(dim=1)
|
810 |
-
adaptive_lr = quality * lr_factor
|
811 |
-
return adaptive_lr, adaptive_lr
|
812 |
-
|
813 |
-
class SmartSensorResidual(nn.Module):
|
814 |
-
def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
|
815 |
-
use_smart_sensor=True):
|
816 |
-
super().__init__()
|
817 |
-
self.ctx = ctx
|
818 |
-
self.dims = dims
|
819 |
-
self.head = head
|
820 |
-
self.act = act
|
821 |
-
self.debug = debug
|
822 |
-
|
823 |
-
if use_smart_sensor:
|
824 |
-
self.focus_attn = FocusWindow(dims, head, feature_type="waveform")
|
825 |
-
self.cross_feature_guide = CrossFeatureFocusAttention(dims, head,
|
826 |
-
features=["spectrogram", "pitch"])
|
827 |
-
self.adaptive_lr = AdaptiveAttentionLR(dims, head)
|
828 |
-
|
829 |
-
self.attna = MultiheadA(dims, head, debug=debug)
|
830 |
-
self.lna = RMSNorm(dims)
|
831 |
-
|
832 |
-
def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio"):
|
833 |
-
if hasattr(self, 'focus_attn') and enc is not None:
|
834 |
-
focus_output, head_weights = self.focus_attn(x, enc.get("waveform"), mask,
|
835 |
-
return_head_weights=True)
|
836 |
-
|
837 |
-
cross_guidance = self.cross_feature_guide(x, enc, mask)
|
838 |
-
|
839 |
-
_, attention_lr = self.adaptive_lr(x, enc.get("waveform"), mask)
|
840 |
-
|
841 |
-
x = x + self.attna(
|
842 |
-
self.lna(x),
|
843 |
-
xa=None,
|
844 |
-
mask=mask,
|
845 |
-
head_weights=head_weights,
|
846 |
-
cross_guidance=cross_guidance,
|
847 |
-
attention_lr=attention_lr,
|
848 |
-
enc=enc,
|
849 |
-
layer=layer
|
850 |
-
)[0]
|
851 |
-
|
852 |
-
return x
|
853 |
-
|
854 |
class t_gate(nn.Module):
|
855 |
def __init__(self, dims, num_types=4, enabled=True):
|
856 |
super().__init__()
|
@@ -931,7 +578,7 @@ class mlp_gate(nn.Module):
|
|
931 |
class Residual(nn.Module):
|
932 |
_seen = set()
|
933 |
def __init__(self, ctx, dims, head, act, debug: List[str] = [],
|
934 |
-
tgate=True, mgate=False, cgate=False, mem_size=512, features=None
|
935 |
super().__init__()
|
936 |
|
937 |
self.dims = dims
|
@@ -945,9 +592,7 @@ class Residual(nn.Module):
|
|
945 |
|
946 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
947 |
act_fn = get_activation(act)
|
948 |
-
|
949 |
self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
|
950 |
-
self.focus = FocusWindow(dims, head, debug=debug) if focus else None
|
951 |
|
952 |
if not any([tgate, mgate, cgate]):
|
953 |
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
@@ -966,16 +611,13 @@ class Residual(nn.Module):
|
|
966 |
self.lnb = RMSNorm(dims)
|
967 |
self.lnc = RMSNorm(dims)
|
968 |
|
969 |
-
def forward(self, x, xa=None, mask=None, enc=None, layer=None,
|
970 |
-
|
971 |
-
focus = self.focus(x, xa=xa, mask=mask, enc=enc, layer=layer) if self.focus is not None else 0
|
972 |
|
973 |
b = torch.sigmoid(self.blend)
|
974 |
-
ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer)[0]
|
975 |
bx = b * ax + (1 - b) * x
|
976 |
cx = self.lnb(bx)
|
977 |
dx = self.mlp(cx)
|
978 |
-
|
979 |
ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
|
980 |
fx = x + ex + dx
|
981 |
gx = self.lnc(fx)
|
@@ -1017,9 +659,9 @@ class FEncoder(nn.Module):
|
|
1017 |
self.norm = RMSNorm(dims)
|
1018 |
self._norm = RMSNorm(dims)
|
1019 |
|
1020 |
-
def apply_rope_to_features(self, x, layer=None,
|
1021 |
-
if
|
1022 |
-
|
1023 |
batch, ctx, dims = x.shape
|
1024 |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
1025 |
if feature_type == "spectrogram" and self.rope is not None:
|
@@ -1030,10 +672,10 @@ class FEncoder(nn.Module):
|
|
1030 |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
1031 |
return x
|
1032 |
|
1033 |
-
def forward(self, x, enc=None, layer=None,
|
1034 |
x = self.encoder(x).permute(0, 2, 1)
|
1035 |
if self.use_rope:
|
1036 |
-
x = self.apply_rope_to_features(x, layer=layer,
|
1037 |
else:
|
1038 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
1039 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
@@ -1070,17 +712,17 @@ class WEncoder(nn.Module):
|
|
1070 |
self.positional = lambda length: sinusoids(length, dims)
|
1071 |
self.norm = RMSNorm(dims)
|
1072 |
|
1073 |
-
def apply_rope_to_features(self, x, layer=None):
|
1074 |
if not self.use_rope or self.rope is None:
|
1075 |
return x
|
1076 |
batch, ctx, dims = x.shape
|
1077 |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
1078 |
-
rope_freqs = self.rope(ctx, layer=layer,
|
1079 |
x = self.rope.apply_rotary(x, rope_freqs)
|
1080 |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
1081 |
return x
|
1082 |
|
1083 |
-
def forward(self, x, enc=None, layer=None,
|
1084 |
x = self.downsample(x)
|
1085 |
x = self.encoder(x)
|
1086 |
x = x.permute(0, 2, 1)
|
@@ -1118,17 +760,17 @@ class PEncoder(nn.Module):
|
|
1118 |
self.positional = lambda length: sinusoids(length, dims)
|
1119 |
self.norm = RMSNorm(dims)
|
1120 |
|
1121 |
-
def apply_rope_to_features(self, x, layer=None):
|
1122 |
if not self.use_rope or self.rope is None:
|
1123 |
return x
|
1124 |
batch, ctx, dims = x.shape
|
1125 |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
1126 |
-
rope_freqs = self.rope(ctx, layer=layer,
|
1127 |
x = self.rope.apply_rotary(x, rope_freqs)
|
1128 |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
1129 |
return x
|
1130 |
|
1131 |
-
def forward(self, x, enc=None, layer=None,
|
1132 |
x = self.encoder(x).permute(0, 2, 1)
|
1133 |
if self.use_rope:
|
1134 |
x = self.apply_rope_to_features(x, layer=layer)
|
@@ -1138,110 +780,148 @@ class PEncoder(nn.Module):
|
|
1138 |
x = self.norm(x)
|
1139 |
return x
|
1140 |
|
1141 |
-
class
|
1142 |
-
|
1143 |
-
|
1144 |
-
super(
|
1145 |
-
|
|
|
1146 |
self.dims = dims
|
1147 |
self.head = head
|
1148 |
-
self.ctx = ctx
|
1149 |
self.head_dim = dims // head
|
1150 |
self.debug = debug
|
1151 |
self.counter = 0
|
|
|
1152 |
self.features = features
|
1153 |
-
self.
|
1154 |
-
self.sequential = "sequential" in debug
|
1155 |
-
act_fn = get_activation(act)
|
1156 |
|
1157 |
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
1158 |
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
1159 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1161 |
if features == ["spectrogram", "waveform", "pitch"]:
|
1162 |
cgate=True
|
1163 |
else:
|
1164 |
cgate = False
|
1165 |
|
1166 |
self.blocks = nn.ModuleDict({
|
1167 |
-
|
1168 |
"spectrogram": nn.ModuleList(
|
1169 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1170 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate
|
1171 |
-
if "spectrogram" in features else None),
|
1172 |
-
|
1173 |
"waveform": nn.ModuleList(
|
1174 |
[WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
|
1175 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate
|
1176 |
-
if "waveform" in features else None),
|
1177 |
-
|
1178 |
"pitch": nn.ModuleList(
|
1179 |
[FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
|
1180 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate
|
1181 |
-
if "pitch" in features else None),
|
1182 |
-
|
1183 |
"envelope": nn.ModuleList(
|
1184 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1185 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate
|
1186 |
-
if "envelope" in features else None),
|
1187 |
-
|
1188 |
"phase": nn.ModuleList(
|
1189 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1190 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
-
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
def forward(self, enc, layer="encoder"):
|
1212 |
enc = dict_to(enc, device, dtype)
|
|
|
1213 |
|
1214 |
-
x = enc.get("input_ids").long()
|
1215 |
x = self.token(x) + self.positional[:x.shape[1]]
|
1216 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
1217 |
|
1218 |
-
|
1219 |
-
|
1220 |
-
|
1221 |
-
for f in self.features:
|
1222 |
-
if f in enc and f in self.blocks:
|
1223 |
xa = enc[f]
|
1224 |
for block in self.blocks[f]:
|
1225 |
-
xa = block(xa, enc=enc, layer=layer)
|
1226 |
-
out[f] = xa
|
1227 |
-
xa = xa + self.audio_embedding[:xa.shape[1]]
|
1228 |
|
1229 |
for block in self.block:
|
1230 |
mask = self.mask[:x.shape[1], :x.shape[1]]
|
1231 |
-
x = block(x, xa=None, mask=mask, enc=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1232 |
|
1233 |
-
for
|
1234 |
-
if
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
|
|
|
|
|
|
1238 |
if self.sequential:
|
1239 |
x = out
|
1240 |
else:
|
1241 |
a = torch.sigmoid(self.blend)
|
1242 |
x = a * out + (1 - a) * x
|
1243 |
|
1244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1245 |
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
1246 |
|
1247 |
class Echo(nn.Module):
|
@@ -1249,17 +929,17 @@ class Echo(nn.Module):
|
|
1249 |
super().__init__()
|
1250 |
self.param = param
|
1251 |
|
1252 |
-
self.
|
1253 |
vocab=param.vocab,
|
1254 |
mels=param.mels,
|
1255 |
ctx=param.ctx,
|
1256 |
dims=param.dims,
|
1257 |
head=param.head,
|
1258 |
layer=param.layer,
|
1259 |
-
debug=param.debug,
|
1260 |
features=param.features,
|
1261 |
act=param.act,
|
1262 |
-
|
|
|
1263 |
|
1264 |
def forward(self,
|
1265 |
labels=None,
|
@@ -1268,33 +948,42 @@ class Echo(nn.Module):
|
|
1268 |
spectrogram: Optional[torch.Tensor]=None,
|
1269 |
pitch: Optional[torch.Tensor]=None,
|
1270 |
f0: Optional[torch.Tensor]=None,
|
1271 |
-
|
1272 |
-
|
|
|
1273 |
) -> Dict[str, Optional[torch.Tensor]]:
|
1274 |
|
1275 |
-
|
1276 |
if spectrogram is not None:
|
1277 |
-
|
|
|
1278 |
if waveform is not None:
|
1279 |
-
|
|
|
1280 |
if pitch is not None:
|
1281 |
-
|
1282 |
-
|
1283 |
-
encoder_inputs["envelope"] = envelope
|
1284 |
-
if phase is not None:
|
1285 |
-
encoder_inputs["phase"] = phase
|
1286 |
if f0 is not None:
|
1287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1288 |
if input_ids is not None:
|
1289 |
-
|
|
|
|
|
|
|
1290 |
|
1291 |
-
logits = self.
|
1292 |
|
1293 |
loss = None
|
1294 |
if labels is not None:
|
1295 |
loss = F.cross_entropy(
|
1296 |
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
|
1297 |
-
|
1298 |
return {"logits": logits, "loss": loss}
|
1299 |
|
1300 |
@property
|
@@ -1335,8 +1024,6 @@ class Echo(nn.Module):
|
|
1335 |
self.init_counts["Conv2d"] += 1
|
1336 |
elif isinstance(module, MultiheadA):
|
1337 |
self.init_counts["MultiheadA"] += 1
|
1338 |
-
elif isinstance(module, SpeechTransformer):
|
1339 |
-
self.init_counts["SpeechTransformer"] += 1
|
1340 |
elif isinstance(module, Residual):
|
1341 |
self.init_counts["Residual"] += 1
|
1342 |
|
@@ -1361,24 +1048,24 @@ class Echo(nn.Module):
|
|
1361 |
batch_size = x.shape[0]
|
1362 |
break
|
1363 |
ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
|
1364 |
-
|
1365 |
if spectrogram is not None:
|
1366 |
-
|
1367 |
if waveform is not None:
|
1368 |
-
|
1369 |
if pitch is not None:
|
1370 |
-
|
1371 |
if envelope is not None:
|
1372 |
-
|
1373 |
if phase is not None:
|
1374 |
-
|
1375 |
if f0 is not None:
|
1376 |
-
|
1377 |
|
1378 |
for i in range(max_length - 1):
|
1379 |
with torch.no_grad():
|
1380 |
-
|
1381 |
-
logits = self.SpeechTransformer(
|
1382 |
next_token_logits = logits[:, -1, :]
|
1383 |
if i < min_length:
|
1384 |
next_token_logits[:, eos_token_id] = 0
|
@@ -1441,6 +1128,15 @@ def setup_tokenizer(token: str):
|
|
1441 |
tokenizer.eos_token_id = 2
|
1442 |
return tokenizer
|
1443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1444 |
def load_wave(wave_data, sample_rate):
|
1445 |
if isinstance(wave_data, str):
|
1446 |
waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
|
@@ -1452,11 +1148,17 @@ def load_wave(wave_data, sample_rate):
|
|
1452 |
|
1453 |
return waveform
|
1454 |
|
1455 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1456 |
|
1457 |
-
|
1458 |
-
sr = audio["sampling_rate"]
|
1459 |
-
wav = load_wave(wave_data=audio, sample_rate=sr)
|
1460 |
|
1461 |
dataset_config = {
|
1462 |
"hop_length": 256,
|
@@ -1471,29 +1173,99 @@ def extract_features(batch, tokenizer, sample_rate=16000, hop_length=256, **data
|
|
1471 |
"window_fn": torch.hann_window,
|
1472 |
"mel_scale": "htk",
|
1473 |
"norm": None,
|
1474 |
-
"normalized": False
|
1475 |
-
|
1476 |
-
transform = torchaudio.transforms.MelSpectrogram(
|
1477 |
-
**dataset_config
|
1478 |
-
)
|
1479 |
-
|
1480 |
-
mel_spectrogram = transform(wav)
|
1481 |
-
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
1482 |
-
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
1483 |
-
spec = (log_mel + 4.0) / 4.0
|
1484 |
-
spec = torch.tensor(spec)
|
1485 |
-
|
1486 |
-
wav_np = wav.numpy().astype(np.float64)
|
1487 |
-
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length/sample_rate*1000)
|
1488 |
-
f0 = pw.stonemask(wav_np, f0, t, sample_rate)
|
1489 |
-
f0 = torch.from_numpy(f0)
|
1490 |
|
|
|
|
|
|
|
1491 |
labels = tokenizer.encode(batch["transcription"])
|
1492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1493 |
return {
|
1494 |
-
"spectrogram": spec,
|
1495 |
"f0": f0,
|
|
|
|
|
|
|
|
|
1496 |
"labels": labels,
|
|
|
|
|
|
|
1497 |
}
|
1498 |
|
1499 |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
|
@@ -1580,9 +1352,11 @@ class DataCollator:
|
|
1580 |
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
|
1581 |
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
|
1582 |
|
1583 |
-
elif key in ["spectrogram", "waveform", "pitch", "
|
1584 |
-
|
1585 |
items = [f[key] for f in features if key in f]
|
|
|
|
|
|
|
1586 |
items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
|
1587 |
max_len = max(item.shape[-1] for item in items)
|
1588 |
padded = []
|
@@ -1675,11 +1449,17 @@ def main():
|
|
1675 |
os.makedirs(log_dir, exist_ok=True)
|
1676 |
tokenizer = setup_tokenizer(token)
|
1677 |
train_dataset, test_dataset = prepare_datasets(tokenizer, token)
|
|
|
1678 |
param = Dimensions(
|
1679 |
-
vocab=40000,
|
1680 |
-
mels=128,
|
1681 |
-
|
1682 |
-
|
|
|
|
|
|
|
|
|
|
|
1683 |
)
|
1684 |
|
1685 |
model = Echo(param).to('cuda')
|
|
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
35 |
def get_activation(act: str) -> nn.Module:
|
36 |
+
"""Get activation function by name."""
|
37 |
act_map = {
|
38 |
"gelu": nn.GELU(),
|
39 |
"relu": nn.ReLU(),
|
|
|
51 |
@dataclass
|
52 |
class Dimensions:
|
53 |
vocab: int
|
54 |
+
mels: int
|
55 |
ctx: int
|
56 |
dims: int
|
57 |
head: int
|
58 |
layer: int
|
|
|
59 |
act: str
|
60 |
debug: List[str]
|
61 |
features: List[str]
|
|
|
198 |
return fig
|
199 |
|
200 |
def valid(default_value, *items):
|
201 |
+
"""Get first non-None item"""
|
202 |
for item in items:
|
203 |
if item is not None:
|
204 |
return item
|
|
|
266 |
def tox():
|
267 |
return {"device": get_device(), "dtype": get_dtype()}
|
268 |
|
269 |
+
class sinus(nn.Module):
|
270 |
+
def __init__(self, ctx: int, dims: int):
|
271 |
+
super().__init__()
|
272 |
+
|
273 |
+
position = torch.arange(start=0, end=ctx, dtype=dtype).unsqueeze(dim=1)
|
274 |
+
div_term = torch.exp(input=torch.arange(start=0, end=dims, step=2, dtype=dtype) * -(math.log(10000.0) / dims))
|
275 |
+
features = torch.zeros(ctx, dims)
|
276 |
+
features[:, 0::2] = torch.sin(position * div_term)
|
277 |
+
features[:, 1::2] = torch.cos(position* div_term)
|
278 |
+
self.register_buffer('sinusoid', tensor=features)
|
279 |
+
self.positional_embeddings = nn.Parameter(self.sinusoid.clone())
|
280 |
+
def forward(self, positions):
|
281 |
+
position_embeddings = self.positional_embeddings[positions]
|
282 |
+
return position_embeddings
|
283 |
+
|
284 |
def sinusoids(length, channels, max_tscale=10000):
|
285 |
assert channels % 2 == 0
|
286 |
log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
|
|
|
340 |
idx = torch.arange(ctx, device=f0.device)
|
341 |
idx = (idx * F).long().clamp(0, L - 1)
|
342 |
radius = radius[idx]
|
343 |
+
return torch.polar(radius.unsqueeze(-1), freqs), radius
|
344 |
+
else:
|
345 |
+
return torch.polar(radius.unsqueeze(-1), freqs), radius
|
346 |
else:
|
347 |
+
return torch.polar(torch.ones_like(freqs), freqs), None
|
348 |
|
349 |
+
def check_f0(self, f0, f0t, ctx):
|
350 |
+
if f0 is not None and f0.shape[1] == ctx:
|
351 |
+
return f0
|
352 |
+
elif f0t is not None and f0t.shape[1] == ctx:
|
353 |
+
return f0t
|
|
|
|
|
|
|
|
|
354 |
else:
|
355 |
+
return None
|
356 |
+
|
357 |
+
def forward(self, x=None, enc=None, layer=None, feature=None) -> Tensor:
|
358 |
+
ctx=x
|
359 |
+
f0 = enc.get("f0") if enc is not None else None
|
360 |
+
f0t = enc.get("f0t") if enc is not None else None
|
361 |
|
362 |
+
f0 = self.check_f0(f0, f0t, ctx)
|
363 |
if f0 is not None:
|
364 |
if f0.dim() == 2:
|
365 |
f0 = f0.squeeze(0)
|
366 |
theta = f0 + self.theta
|
367 |
else:
|
368 |
theta = self.theta
|
|
|
369 |
freqs = self.theta_freqs(theta)
|
370 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
371 |
freqs = t[:, None] * freqs
|
372 |
+
freqs, radius = self._apply_radii(freqs, f0, ctx)
|
373 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
if "radius" in self.debug and self.counter == 10:
|
375 |
+
print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
|
|
|
|
|
|
|
|
376 |
self.counter += 1
|
377 |
return freqs.unsqueeze(0)
|
378 |
|
|
|
393 |
|
394 |
rbf = False
|
395 |
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
396 |
+
zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
|
|
|
397 |
super(MultiheadA, self).__init__()
|
398 |
|
399 |
self.dims = dims
|
|
|
426 |
else:
|
427 |
self.rope = None
|
428 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
430 |
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
431 |
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
|
|
447 |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
448 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
449 |
|
450 |
+
def forward(self, x: Tensor, xa = None, mask = None, enc = None, layer = None, feature=None) -> tuple:
|
451 |
|
452 |
x = x.to(device, dtype)
|
453 |
if xa is not None:
|
|
|
475 |
|
476 |
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
477 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
if self.rbf:
|
479 |
qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
480 |
if self.use_pbias:
|
481 |
+
pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None)
|
482 |
if pbias is not None:
|
483 |
qk = qk + pbias[:,:,:q2,:q2]
|
484 |
|
|
|
487 |
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
|
488 |
zscale[token_ids.float() == self.pad_token] = fzero
|
489 |
|
490 |
+
if xa is not None:
|
|
|
491 |
qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
|
|
|
492 |
qk = qk * zscale.unsqueeze(-2)
|
493 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
494 |
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
|
|
498 |
self.counter += 1
|
499 |
return self.o(wv), qk
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
class t_gate(nn.Module):
|
502 |
def __init__(self, dims, num_types=4, enabled=True):
|
503 |
super().__init__()
|
|
|
578 |
class Residual(nn.Module):
|
579 |
_seen = set()
|
580 |
def __init__(self, ctx, dims, head, act, debug: List[str] = [],
|
581 |
+
tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
|
582 |
super().__init__()
|
583 |
|
584 |
self.dims = dims
|
|
|
592 |
|
593 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
594 |
act_fn = get_activation(act)
|
|
|
595 |
self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
|
|
|
596 |
|
597 |
if not any([tgate, mgate, cgate]):
|
598 |
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
|
|
611 |
self.lnb = RMSNorm(dims)
|
612 |
self.lnc = RMSNorm(dims)
|
613 |
|
614 |
+
def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature=None) -> Tensor:
|
|
|
|
|
615 |
|
616 |
b = torch.sigmoid(self.blend)
|
617 |
+
ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer, feature=feature)[0]
|
618 |
bx = b * ax + (1 - b) * x
|
619 |
cx = self.lnb(bx)
|
620 |
dx = self.mlp(cx)
|
|
|
621 |
ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
|
622 |
fx = x + ex + dx
|
623 |
gx = self.lnc(fx)
|
|
|
659 |
self.norm = RMSNorm(dims)
|
660 |
self._norm = RMSNorm(dims)
|
661 |
|
662 |
+
def apply_rope_to_features(self, x, layer=None, feature=None):
|
663 |
+
if feature in ["envelope", "phase"]:
|
664 |
+
feature = "spectrogram"
|
665 |
batch, ctx, dims = x.shape
|
666 |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
667 |
if feature_type == "spectrogram" and self.rope is not None:
|
|
|
672 |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
673 |
return x
|
674 |
|
675 |
+
def forward(self, x, enc=None, layer=None, feature=None):
|
676 |
x = self.encoder(x).permute(0, 2, 1)
|
677 |
if self.use_rope:
|
678 |
+
x = self.apply_rope_to_features(x, layer=layer, feature=feature)
|
679 |
else:
|
680 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
681 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
|
|
712 |
self.positional = lambda length: sinusoids(length, dims)
|
713 |
self.norm = RMSNorm(dims)
|
714 |
|
715 |
+
def apply_rope_to_features(self, x, layer=None, feature=None):
|
716 |
if not self.use_rope or self.rope is None:
|
717 |
return x
|
718 |
batch, ctx, dims = x.shape
|
719 |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
720 |
+
rope_freqs = self.rope(ctx, layer=layer, feature=feature)
|
721 |
x = self.rope.apply_rotary(x, rope_freqs)
|
722 |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
723 |
return x
|
724 |
|
725 |
+
def forward(self, x, enc=None, layer=None, feature=None):
|
726 |
x = self.downsample(x)
|
727 |
x = self.encoder(x)
|
728 |
x = x.permute(0, 2, 1)
|
|
|
760 |
self.positional = lambda length: sinusoids(length, dims)
|
761 |
self.norm = RMSNorm(dims)
|
762 |
|
763 |
+
def apply_rope_to_features(self, x, layer=None, feature=None):
|
764 |
if not self.use_rope or self.rope is None:
|
765 |
return x
|
766 |
batch, ctx, dims = x.shape
|
767 |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
768 |
+
rope_freqs = self.rope(ctx, layer=layer, feature=feature)
|
769 |
x = self.rope.apply_rotary(x, rope_freqs)
|
770 |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
771 |
return x
|
772 |
|
773 |
+
def forward(self, x, enc=None, layer=None, feature=None):
|
774 |
x = self.encoder(x).permute(0, 2, 1)
|
775 |
if self.use_rope:
|
776 |
x = self.apply_rope_to_features(x, layer=layer)
|
|
|
780 |
x = self.norm(x)
|
781 |
return x
|
782 |
|
783 |
+
class theBridge(nn.Module):
|
784 |
+
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
|
785 |
+
debug: List[str], features: List[str], act: str = "gelu"):
|
786 |
+
super(theBridge, self).__init__()
|
787 |
+
|
788 |
+
self.ctx = ctx
|
789 |
self.dims = dims
|
790 |
self.head = head
|
|
|
791 |
self.head_dim = dims // head
|
792 |
self.debug = debug
|
793 |
self.counter = 0
|
794 |
+
self.dropout = 0.01
|
795 |
self.features = features
|
796 |
+
self.do_blend = "no_blend" not in self.debug
|
797 |
+
self.sequential = "sequential" in self.debug
|
|
|
798 |
|
799 |
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
|
800 |
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
|
801 |
+
self.sinusoid = lambda length: sinusoids(length, dims)
|
802 |
+
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
|
803 |
+
self.ln_dec = RMSNorm(dims)
|
804 |
+
|
805 |
+
with torch.no_grad():
|
806 |
+
self.token.weight[0].zero_()
|
807 |
+
|
808 |
+
self.block = nn.ModuleList([
|
809 |
+
Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
|
810 |
+
for _ in range(layer)])
|
811 |
+
|
812 |
+
self.cross_attn = nn.ModuleList([
|
813 |
+
Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
|
814 |
+
for _ in range(layer)])
|
815 |
+
|
816 |
+
self.cross_modal = nn.ModuleList([
|
817 |
+
Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
|
818 |
+
for _ in range(layer)])
|
819 |
|
820 |
+
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0).unsqueeze(0).unsqueeze(0)
|
821 |
+
self.register_buffer("mask", mask, persistent=False)
|
822 |
+
self.register_buffer("mask_win", self.window_mask(ctx, ctx), persistent=False)
|
823 |
+
self.register_buffer("mask_cat", self.modal_mask(ctx, ctx), persistent=False)
|
824 |
+
self.register_buffer("mask_cross", self.cross_mask(ctx, ctx), persistent=False)
|
825 |
+
|
826 |
+
act_fn = get_activation(act)
|
827 |
if features == ["spectrogram", "waveform", "pitch"]:
|
828 |
cgate=True
|
829 |
else:
|
830 |
cgate = False
|
831 |
|
832 |
self.blocks = nn.ModuleDict({
|
|
|
833 |
"spectrogram": nn.ModuleList(
|
834 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
835 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "spectrogram" in features else None),
|
|
|
|
|
836 |
"waveform": nn.ModuleList(
|
837 |
[WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
|
838 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "waveform" in features else None),
|
|
|
|
|
839 |
"pitch": nn.ModuleList(
|
840 |
[FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
|
841 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None),
|
|
|
|
|
842 |
"envelope": nn.ModuleList(
|
843 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
844 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "envelope" in features else None),
|
|
|
|
|
845 |
"phase": nn.ModuleList(
|
846 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
847 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None)})
|
848 |
+
|
849 |
+
def window_mask(self, text_ctx, aud_ctx):
|
850 |
+
mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device), diagonal=0)
|
851 |
+
audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device))
|
852 |
+
full_mask = torch.cat([mask, audio_mask], dim=-1)
|
853 |
+
return full_mask.unsqueeze(0).unsqueeze(0)
|
854 |
+
|
855 |
+
def modal_mask(self, text_len, audio_len):
|
856 |
+
combined_mask = torch.ones(text_len + audio_len, text_len + audio_len, device=device)
|
857 |
+
combined_mask[:text_len, :text_len] = torch.tril(torch.ones(text_len, text_len, device=device))
|
858 |
+
combined_mask[:text_len, text_len:] = torch.tril(torch.ones(text_len, audio_len, device=device))
|
859 |
+
return combined_mask.unsqueeze(0).unsqueeze(0)
|
860 |
+
|
861 |
+
def cross_mask(self, text_len, audio_len):
|
862 |
+
mask = torch.tril(torch.ones(text_len, text_len, device=device))
|
863 |
+
audio_mask = torch.tril(torch.ones(text_len, audio_len, device=device))
|
864 |
+
full_mask = torch.cat([mask, audio_mask], dim=-1)
|
865 |
+
return full_mask.unsqueeze(0).unsqueeze(0)
|
866 |
+
|
867 |
+
def forward(self, x, enc, layer='decoder', feature=None) -> Tensor:
|
|
|
868 |
enc = dict_to(enc, device, dtype)
|
869 |
+
_text_len = x.shape[1]
|
870 |
|
|
|
871 |
x = self.token(x) + self.positional[:x.shape[1]]
|
872 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
873 |
|
874 |
+
for f in enc:
|
875 |
+
if f in self.features:
|
|
|
|
|
|
|
876 |
xa = enc[f]
|
877 |
for block in self.blocks[f]:
|
878 |
+
xa = block(xa, enc=enc, layer=layer, feature=feature)
|
|
|
|
|
879 |
|
880 |
for block in self.block:
|
881 |
mask = self.mask[:x.shape[1], :x.shape[1]]
|
882 |
+
x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
|
883 |
+
if feature in self.features:
|
884 |
+
xa = xa + self.sinusoid(xa.shape[1])
|
885 |
+
mask = self.mask_win[:x.shape[1], :xa.shape[1]]
|
886 |
+
out = block(x, xa=xa, mask=mask, enc=enc, layer=layer)
|
887 |
+
if self.sequential:
|
888 |
+
x = out
|
889 |
+
else:
|
890 |
+
a = torch.sigmoid(self.blend)
|
891 |
+
x = a * out + (1 - a) * x
|
892 |
|
893 |
+
for block in self.cross_attn:
|
894 |
+
if feature in self.features:
|
895 |
+
xa = xa + self.sinusoid(xa.shape[1])
|
896 |
+
mask_x = self.cross_mask(x.shape[1], xa.shape[1])
|
897 |
+
mask_xa = self.cross_mask(xa.shape[1], x.shape[1])
|
898 |
+
x = block(x, xa=xa, mask=mask_x, enc=enc, layer=layer)
|
899 |
+
xa = block(xa, xa=x, mask=mask_xa, enc=enc, layer=layer)
|
900 |
+
out = block(x, xa=xa, mask=mask_x, enc=enc, layer=layer)
|
901 |
if self.sequential:
|
902 |
x = out
|
903 |
else:
|
904 |
a = torch.sigmoid(self.blend)
|
905 |
x = a * out + (1 - a) * x
|
906 |
|
907 |
+
for block in self.cross_modal:
|
908 |
+
if feature in enc:
|
909 |
+
xa = xa + self.sinusoid(xa.shape[1])
|
910 |
+
xcat = torch.cat([x, xa], dim=1)
|
911 |
+
mask = self.mask_cat(x.shape[1], xa.shape[1])
|
912 |
+
x = block(xcat, xa=None, mask=mask, enc=enc, layer=layer)
|
913 |
+
x = x[:, :_text_len]
|
914 |
+
|
915 |
+
if self.counter < 1 and "encoder" in self.debug:
|
916 |
+
s = enc.get("spectrogram")
|
917 |
+
w = enc.get("waveform")
|
918 |
+
p = default(enc.get("pitch"), enc.get("f0"))
|
919 |
+
plot_waveform(x=s, w=w, p=p, hop_length=128)
|
920 |
+
shapes = {k: v.shape for k, v in enc.items()}
|
921 |
+
print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
|
922 |
+
self.counter += 1
|
923 |
+
|
924 |
+
x = self.ln_dec(x)
|
925 |
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
926 |
|
927 |
class Echo(nn.Module):
|
|
|
929 |
super().__init__()
|
930 |
self.param = param
|
931 |
|
932 |
+
self.processor = theBridge(
|
933 |
vocab=param.vocab,
|
934 |
mels=param.mels,
|
935 |
ctx=param.ctx,
|
936 |
dims=param.dims,
|
937 |
head=param.head,
|
938 |
layer=param.layer,
|
|
|
939 |
features=param.features,
|
940 |
act=param.act,
|
941 |
+
debug=param.debug,
|
942 |
+
)
|
943 |
|
944 |
def forward(self,
|
945 |
labels=None,
|
|
|
948 |
spectrogram: Optional[torch.Tensor]=None,
|
949 |
pitch: Optional[torch.Tensor]=None,
|
950 |
f0: Optional[torch.Tensor]=None,
|
951 |
+
f0t: Optional[torch.Tensor]=None,
|
952 |
+
harmonic: Optional[torch.Tensor]=None,
|
953 |
+
aperiodic: Optional[torch.Tensor]=None,
|
954 |
) -> Dict[str, Optional[torch.Tensor]]:
|
955 |
|
956 |
+
enc = {}
|
957 |
if spectrogram is not None:
|
958 |
+
enc["spectrogram"] = spectrogram
|
959 |
+
feature = "spectrogram"
|
960 |
if waveform is not None:
|
961 |
+
enc["waveform"] = waveform
|
962 |
+
feature = "waveform"
|
963 |
if pitch is not None:
|
964 |
+
enc["pitch"] = pitch
|
965 |
+
feature = "pitch"
|
|
|
|
|
|
|
966 |
if f0 is not None:
|
967 |
+
enc["f0"] = f0
|
968 |
+
if f0t is not None:
|
969 |
+
enc["f0t"] = f0t
|
970 |
+
if harmonic is not None:
|
971 |
+
enc["harmonic"] = harmonic
|
972 |
+
if aperiodic is not None:
|
973 |
+
enc["aperiodic"] = aperiodic
|
974 |
if input_ids is not None:
|
975 |
+
enc["input_ids"] = input_ids
|
976 |
+
feature = "input_ids"
|
977 |
+
else:
|
978 |
+
feature = "spectrogram"
|
979 |
|
980 |
+
logits = self.processor(input_ids, enc, feature)
|
981 |
|
982 |
loss = None
|
983 |
if labels is not None:
|
984 |
loss = F.cross_entropy(
|
985 |
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
|
986 |
+
|
987 |
return {"logits": logits, "loss": loss}
|
988 |
|
989 |
@property
|
|
|
1024 |
self.init_counts["Conv2d"] += 1
|
1025 |
elif isinstance(module, MultiheadA):
|
1026 |
self.init_counts["MultiheadA"] += 1
|
|
|
|
|
1027 |
elif isinstance(module, Residual):
|
1028 |
self.init_counts["Residual"] += 1
|
1029 |
|
|
|
1048 |
batch_size = x.shape[0]
|
1049 |
break
|
1050 |
ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
|
1051 |
+
feature = {}
|
1052 |
if spectrogram is not None:
|
1053 |
+
feature["spectrogram"] = spectrogram
|
1054 |
if waveform is not None:
|
1055 |
+
feature["waveform"] = waveform
|
1056 |
if pitch is not None:
|
1057 |
+
feature["pitch"] = pitch
|
1058 |
if envelope is not None:
|
1059 |
+
feature["envelope"] = envelope
|
1060 |
if phase is not None:
|
1061 |
+
feature["phase"] = phase
|
1062 |
if f0 is not None:
|
1063 |
+
feature["f0"] = f0
|
1064 |
|
1065 |
for i in range(max_length - 1):
|
1066 |
with torch.no_grad():
|
1067 |
+
feature["input_ids"] = ids
|
1068 |
+
logits = self.SpeechTransformer(feature)
|
1069 |
next_token_logits = logits[:, -1, :]
|
1070 |
if i < min_length:
|
1071 |
next_token_logits[:, eos_token_id] = 0
|
|
|
1128 |
tokenizer.eos_token_id = 2
|
1129 |
return tokenizer
|
1130 |
|
1131 |
+
def tokenize_pitch(pitch_features, target_length):
|
1132 |
+
pitch_len = pitch_features.shape[-1]
|
1133 |
+
token_len = target_length
|
1134 |
+
if pitch_len > token_len:
|
1135 |
+
pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len)
|
1136 |
+
else:
|
1137 |
+
pitch_tokens = F.interpolate(pitch_features, token_len)
|
1138 |
+
return pitch_tokens
|
1139 |
+
|
1140 |
def load_wave(wave_data, sample_rate):
|
1141 |
if isinstance(wave_data, str):
|
1142 |
waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
|
|
|
1148 |
|
1149 |
return waveform
|
1150 |
|
1151 |
+
def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
|
1152 |
+
import librosa
|
1153 |
+
mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
|
1154 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
1155 |
+
|
1156 |
+
sp_mel = torch.matmul(sp, mel_basis.T)
|
1157 |
+
ap_mel = torch.matmul(ap, mel_basis.T)
|
1158 |
+
|
1159 |
+
return sp_mel, ap_mel
|
1160 |
|
1161 |
+
def extract_features(batch, tokenizer, waveform=False, spec=True, f0=False, f0t=False, sample_rate=16000, hop_length=256, mode="mean", debug=True, **dataset_config):
|
|
|
|
|
1162 |
|
1163 |
dataset_config = {
|
1164 |
"hop_length": 256,
|
|
|
1173 |
"window_fn": torch.hann_window,
|
1174 |
"mel_scale": "htk",
|
1175 |
"norm": None,
|
1176 |
+
"normalized": False,
|
1177 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1178 |
|
1179 |
+
audio = batch["audio"]
|
1180 |
+
sr = audio["sampling_rate"]
|
1181 |
+
wave = load_wave(wave_data=audio, sample_rate=sr)
|
1182 |
labels = tokenizer.encode(batch["transcription"])
|
1183 |
|
1184 |
+
if waveform:
|
1185 |
+
wav = load_wave(wave_data=audio, sample_rate=sr)
|
1186 |
+
else:
|
1187 |
+
wav = None
|
1188 |
+
|
1189 |
+
if spec:
|
1190 |
+
transform = torchaudio.transforms.MelSpectrogram( **dataset_config)
|
1191 |
+
mel_spectrogram = transform(wave)
|
1192 |
+
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
1193 |
+
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
1194 |
+
spec = (log_mel + 4.0) / 4.0
|
1195 |
+
spec = torch.tensor(spec)
|
1196 |
+
else:
|
1197 |
+
spec = None
|
1198 |
+
|
1199 |
+
if f0:
|
1200 |
+
wavnp = wave.numpy().astype(np.float64)
|
1201 |
+
f0_np, t = pw.dio(wavnp, sample_rate,
|
1202 |
+
frame_period = hop_length / sample_rate * 1000)
|
1203 |
+
f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
|
1204 |
+
f0 = torch.from_numpy(f0_np)
|
1205 |
+
|
1206 |
+
if f0t:
|
1207 |
+
audio_duration = len(wavnp) / sample_rate
|
1208 |
+
T = len(labels)
|
1209 |
+
tok_dur_sec = audio_duration / T
|
1210 |
+
token_starts = np.arange(T) * tok_dur_sec
|
1211 |
+
token_ends = token_starts + tok_dur_sec
|
1212 |
+
start_idx = np.searchsorted(t, token_starts, side="left")
|
1213 |
+
end_idx = np.searchsorted(t, token_ends, side="right")
|
1214 |
+
pitch_tok = np.zeros(T, dtype=np.float32)
|
1215 |
+
for i in range(T):
|
1216 |
+
lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
|
1217 |
+
segment = f0_np[lo:hi]
|
1218 |
+
pitch_tok[i] = segment.mean() if mode=="mean" else (np.median(segment) if mode=="median" else segment[-1])
|
1219 |
+
pitch_tok[pitch_tok < 100.0] = 0.0
|
1220 |
+
|
1221 |
+
bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
|
1222 |
+
f0t = torch.from_numpy(np.concatenate([[bos_pitch], pitch_tok]))
|
1223 |
+
f0t = torch.from_numpy(pitch_tok)
|
1224 |
+
f0 = torch.from_numpy(f0_np)
|
1225 |
+
|
1226 |
+
spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
|
1227 |
+
apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
|
1228 |
+
sp = torch.from_numpy(spnp)
|
1229 |
+
ap = torch.from_numpy(apnp)
|
1230 |
+
sp = sp[:, :128].contiguous().T
|
1231 |
+
ap = ap[:, :128].contiguous().T
|
1232 |
+
f0t = torch.where(f0t == 0.0, torch.zeros_like(f0t), (f0t - 71.0) / (400.0 - 71.0))
|
1233 |
+
sp = torch.where(sp == 0.0, torch.zeros_like(sp), sp / 1.0)
|
1234 |
+
ap= torch.where(ap == 0.0, torch.zeros_like(ap), ap / 1.0)
|
1235 |
+
|
1236 |
+
else:
|
1237 |
+
f0t = None
|
1238 |
+
sp = None
|
1239 |
+
ap = None
|
1240 |
+
t = None
|
1241 |
+
token_starts = None
|
1242 |
+
else:
|
1243 |
+
f0t = None
|
1244 |
+
f0 = None
|
1245 |
+
sp = None
|
1246 |
+
ap = None
|
1247 |
+
t = None
|
1248 |
+
token_starts = None
|
1249 |
+
|
1250 |
+
if debug:
|
1251 |
+
print(f"['f0']: {f0t if f0t is not None else None}")
|
1252 |
+
print(f"['f0']: {f0.shape if f0 is not None else None}")
|
1253 |
+
print(f"['f0t']: {f0t.shape if f0t is not None else None}")
|
1254 |
+
print(f"['harmonic']: {sp.shape if sp is not None else None}")
|
1255 |
+
print(f"['aperiodic']: {ap.shape if ap is not None else None}")
|
1256 |
+
print(f"['spec']: {spec.shape if spec is not None else None}")
|
1257 |
+
print(f"['wav']: {wav.shape if wav is not None else None}")
|
1258 |
+
|
1259 |
return {
|
|
|
1260 |
"f0": f0,
|
1261 |
+
"f0t": f0t,
|
1262 |
+
"pitch": f0,
|
1263 |
+
"harmonic": sp,
|
1264 |
+
"aperiodic": ap,
|
1265 |
"labels": labels,
|
1266 |
+
"waveform": wav,
|
1267 |
+
"spectrogram": spec,
|
1268 |
+
|
1269 |
}
|
1270 |
|
1271 |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, **dataset_config):
|
|
|
1352 |
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
|
1353 |
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
|
1354 |
|
1355 |
+
elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0"]:
|
|
|
1356 |
items = [f[key] for f in features if key in f]
|
1357 |
+
items = [item for item in items if item is not None]
|
1358 |
+
if not items:
|
1359 |
+
continue
|
1360 |
items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
|
1361 |
max_len = max(item.shape[-1] for item in items)
|
1362 |
padded = []
|
|
|
1449 |
os.makedirs(log_dir, exist_ok=True)
|
1450 |
tokenizer = setup_tokenizer(token)
|
1451 |
train_dataset, test_dataset = prepare_datasets(tokenizer, token)
|
1452 |
+
|
1453 |
param = Dimensions(
|
1454 |
+
vocab=40000,
|
1455 |
+
mels=128,
|
1456 |
+
ctx=1500,
|
1457 |
+
dims=512,
|
1458 |
+
head=4,
|
1459 |
+
layer=4,
|
1460 |
+
act="swish",
|
1461 |
+
debug={"decoder", "radius"},
|
1462 |
+
features = ["spectrogram"],
|
1463 |
)
|
1464 |
|
1465 |
model = Echo(param).to('cuda')
|