Update modelA.py
Browse files
modelA.py
CHANGED
@@ -32,13 +32,6 @@ dtype = torch.float32
|
|
32 |
warnings.filterwarnings("ignore")
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
35 |
-
PATH = 'E:/hf'
|
36 |
-
os.environ['HF_HOME'] = PATH
|
37 |
-
os.environ['HF_DATASETS_CACHE'] = PATH
|
38 |
-
os.environ['TORCH_HOME'] = PATH
|
39 |
-
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
40 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
41 |
-
|
42 |
def get_activation(act: str) -> nn.Module:
|
43 |
"""Get activation function by name."""
|
44 |
act_map = {
|
@@ -65,7 +58,6 @@ class Dimensions:
|
|
65 |
mels: int
|
66 |
act: str
|
67 |
debug: List[str]
|
68 |
-
cross_attn: bool
|
69 |
features: List[str]
|
70 |
|
71 |
def get_generation_config(param):
|
@@ -672,15 +664,15 @@ class FocusWindow(nn.Module):
|
|
672 |
|
673 |
return output
|
674 |
|
675 |
-
def forward(self, x,
|
676 |
q = self.q_proj(x)
|
677 |
-
k = self.k_proj(x if
|
678 |
-
v = self.v_proj(x if
|
679 |
|
680 |
# Create span scale based on feature characteristics
|
681 |
-
if
|
682 |
# Feature-specific span scaling
|
683 |
-
feature_energy = torch.norm(
|
684 |
span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
|
685 |
else:
|
686 |
span_scale = torch.ones(x.size(0), 1, device=x.device)
|
@@ -785,7 +777,7 @@ class mlp_gate(nn.Module):
|
|
785 |
class Residual(nn.Module):
|
786 |
_seen = set()
|
787 |
def __init__(self, ctx, dims, head, act, debug: List[str] = [],
|
788 |
-
tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
|
789 |
super().__init__()
|
790 |
|
791 |
self.dims = dims
|
@@ -799,7 +791,9 @@ class Residual(nn.Module):
|
|
799 |
|
800 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
801 |
act_fn = get_activation(act)
|
|
|
802 |
self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
|
|
|
803 |
|
804 |
if not any([tgate, mgate, cgate]):
|
805 |
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
@@ -819,9 +813,11 @@ class Residual(nn.Module):
|
|
819 |
self.lnc = RMSNorm(dims)
|
820 |
|
821 |
def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
|
822 |
-
|
|
|
|
|
823 |
b = torch.sigmoid(self.blend)
|
824 |
-
ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer)[0]
|
825 |
bx = b * ax + (1 - b) * x
|
826 |
cx = self.lnb(bx)
|
827 |
dx = self.mlp(cx)
|
@@ -1016,32 +1012,32 @@ class SpeechTransformer(nn.Module):
|
|
1016 |
|
1017 |
"spectrogram": nn.ModuleList(
|
1018 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1019 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
|
1020 |
if "spectrogram" in features else None),
|
1021 |
|
1022 |
"waveform": nn.ModuleList(
|
1023 |
[WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
|
1024 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
|
1025 |
if "waveform" in features else None),
|
1026 |
|
1027 |
"pitch": nn.ModuleList(
|
1028 |
[FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
|
1029 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
|
1030 |
if "pitch" in features else None),
|
1031 |
|
1032 |
"envelope": nn.ModuleList(
|
1033 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1034 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
|
1035 |
if "envelope" in features else None),
|
1036 |
|
1037 |
"phase": nn.ModuleList(
|
1038 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1039 |
-
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
|
1040 |
if "phase" in features else None),
|
1041 |
})
|
1042 |
|
1043 |
self.block = nn.ModuleList([
|
1044 |
-
Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
|
1045 |
for _ in range(layer)])
|
1046 |
|
1047 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
|
|
32 |
warnings.filterwarnings("ignore")
|
33 |
logging.basicConfig(level=logging.ERROR)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def get_activation(act: str) -> nn.Module:
|
36 |
"""Get activation function by name."""
|
37 |
act_map = {
|
|
|
58 |
mels: int
|
59 |
act: str
|
60 |
debug: List[str]
|
|
|
61 |
features: List[str]
|
62 |
|
63 |
def get_generation_config(param):
|
|
|
664 |
|
665 |
return output
|
666 |
|
667 |
+
def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=True):
|
668 |
q = self.q_proj(x)
|
669 |
+
k = self.k_proj(x if xa is None else xa)
|
670 |
+
v = self.v_proj(x if xa is None else xa)
|
671 |
|
672 |
# Create span scale based on feature characteristics
|
673 |
+
if xa is not None:
|
674 |
# Feature-specific span scaling
|
675 |
+
feature_energy = torch.norm(xa, dim=-1).mean(dim=-1, keepdim=True)
|
676 |
span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
|
677 |
else:
|
678 |
span_scale = torch.ones(x.size(0), 1, device=x.device)
|
|
|
777 |
class Residual(nn.Module):
|
778 |
_seen = set()
|
779 |
def __init__(self, ctx, dims, head, act, debug: List[str] = [],
|
780 |
+
tgate=True, mgate=False, cgate=False, mem_size=512, features=None, focus=True):
|
781 |
super().__init__()
|
782 |
|
783 |
self.dims = dims
|
|
|
791 |
|
792 |
self.blend = nn.Parameter(torch.tensor(0.5))
|
793 |
act_fn = get_activation(act)
|
794 |
+
|
795 |
self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
|
796 |
+
self.focus = FocusWindow(dims, head, debug=debug) if focus else None
|
797 |
|
798 |
if not any([tgate, mgate, cgate]):
|
799 |
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
|
|
813 |
self.lnc = RMSNorm(dims)
|
814 |
|
815 |
def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
|
816 |
+
|
817 |
+
focus = self.focus(x, xa=xa, mask=mask, enc=enc, layer=layer) if self.focus is not None else 0
|
818 |
+
|
819 |
b = torch.sigmoid(self.blend)
|
820 |
+
ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer)[0] + focus
|
821 |
bx = b * ax + (1 - b) * x
|
822 |
cx = self.lnb(bx)
|
823 |
dx = self.mlp(cx)
|
|
|
1012 |
|
1013 |
"spectrogram": nn.ModuleList(
|
1014 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1015 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
|
1016 |
if "spectrogram" in features else None),
|
1017 |
|
1018 |
"waveform": nn.ModuleList(
|
1019 |
[WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
|
1020 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
|
1021 |
if "waveform" in features else None),
|
1022 |
|
1023 |
"pitch": nn.ModuleList(
|
1024 |
[FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
|
1025 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
|
1026 |
if "pitch" in features else None),
|
1027 |
|
1028 |
"envelope": nn.ModuleList(
|
1029 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1030 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
|
1031 |
if "envelope" in features else None),
|
1032 |
|
1033 |
"phase": nn.ModuleList(
|
1034 |
[FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
|
1035 |
+
[Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
|
1036 |
if "phase" in features else None),
|
1037 |
})
|
1038 |
|
1039 |
self.block = nn.ModuleList([
|
1040 |
+
Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features, focus=False)
|
1041 |
for _ in range(layer)])
|
1042 |
|
1043 |
self.blend = nn.Parameter(torch.tensor(0.5))
|