Sin2pi commited on
Commit
8d65545
·
verified ·
1 Parent(s): 87411f1

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +18 -22
modelA.py CHANGED
@@ -32,13 +32,6 @@ dtype = torch.float32
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
- PATH = 'E:/hf'
36
- os.environ['HF_HOME'] = PATH
37
- os.environ['HF_DATASETS_CACHE'] = PATH
38
- os.environ['TORCH_HOME'] = PATH
39
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
40
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
-
42
  def get_activation(act: str) -> nn.Module:
43
  """Get activation function by name."""
44
  act_map = {
@@ -65,7 +58,6 @@ class Dimensions:
65
  mels: int
66
  act: str
67
  debug: List[str]
68
- cross_attn: bool
69
  features: List[str]
70
 
71
  def get_generation_config(param):
@@ -672,15 +664,15 @@ class FocusWindow(nn.Module):
672
 
673
  return output
674
 
675
- def forward(self, x, feature_data=None, mask=None, return_bias=True):
676
  q = self.q_proj(x)
677
- k = self.k_proj(x if feature_data is None else feature_data)
678
- v = self.v_proj(x if feature_data is None else feature_data)
679
 
680
  # Create span scale based on feature characteristics
681
- if feature_data is not None:
682
  # Feature-specific span scaling
683
- feature_energy = torch.norm(feature_data, dim=-1).mean(dim=-1, keepdim=True)
684
  span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
685
  else:
686
  span_scale = torch.ones(x.size(0), 1, device=x.device)
@@ -785,7 +777,7 @@ class mlp_gate(nn.Module):
785
  class Residual(nn.Module):
786
  _seen = set()
787
  def __init__(self, ctx, dims, head, act, debug: List[str] = [],
788
- tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
789
  super().__init__()
790
 
791
  self.dims = dims
@@ -799,7 +791,9 @@ class Residual(nn.Module):
799
 
800
  self.blend = nn.Parameter(torch.tensor(0.5))
801
  act_fn = get_activation(act)
 
802
  self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
 
803
 
804
  if not any([tgate, mgate, cgate]):
805
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
@@ -819,9 +813,11 @@ class Residual(nn.Module):
819
  self.lnc = RMSNorm(dims)
820
 
821
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
822
-
 
 
823
  b = torch.sigmoid(self.blend)
824
- ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer)[0]
825
  bx = b * ax + (1 - b) * x
826
  cx = self.lnb(bx)
827
  dx = self.mlp(cx)
@@ -1016,32 +1012,32 @@ class SpeechTransformer(nn.Module):
1016
 
1017
  "spectrogram": nn.ModuleList(
1018
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1019
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
1020
  if "spectrogram" in features else None),
1021
 
1022
  "waveform": nn.ModuleList(
1023
  [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
1024
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
1025
  if "waveform" in features else None),
1026
 
1027
  "pitch": nn.ModuleList(
1028
  [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
1029
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
1030
  if "pitch" in features else None),
1031
 
1032
  "envelope": nn.ModuleList(
1033
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1034
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
1035
  if "envelope" in features else None),
1036
 
1037
  "phase": nn.ModuleList(
1038
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1039
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
1040
  if "phase" in features else None),
1041
  })
1042
 
1043
  self.block = nn.ModuleList([
1044
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
1045
  for _ in range(layer)])
1046
 
1047
  self.blend = nn.Parameter(torch.tensor(0.5))
 
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
 
 
 
 
 
 
 
35
  def get_activation(act: str) -> nn.Module:
36
  """Get activation function by name."""
37
  act_map = {
 
58
  mels: int
59
  act: str
60
  debug: List[str]
 
61
  features: List[str]
62
 
63
  def get_generation_config(param):
 
664
 
665
  return output
666
 
667
+ def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=True):
668
  q = self.q_proj(x)
669
+ k = self.k_proj(x if xa is None else xa)
670
+ v = self.v_proj(x if xa is None else xa)
671
 
672
  # Create span scale based on feature characteristics
673
+ if xa is not None:
674
  # Feature-specific span scaling
675
+ feature_energy = torch.norm(xa, dim=-1).mean(dim=-1, keepdim=True)
676
  span_scale = torch.sigmoid(feature_energy / (feature_energy.std() + 1e-6))
677
  else:
678
  span_scale = torch.ones(x.size(0), 1, device=x.device)
 
777
  class Residual(nn.Module):
778
  _seen = set()
779
  def __init__(self, ctx, dims, head, act, debug: List[str] = [],
780
+ tgate=True, mgate=False, cgate=False, mem_size=512, features=None, focus=True):
781
  super().__init__()
782
 
783
  self.dims = dims
 
791
 
792
  self.blend = nn.Parameter(torch.tensor(0.5))
793
  act_fn = get_activation(act)
794
+
795
  self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
796
+ self.focus = FocusWindow(dims, head, debug=debug) if focus else None
797
 
798
  if not any([tgate, mgate, cgate]):
799
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
 
813
  self.lnc = RMSNorm(dims)
814
 
815
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
816
+
817
+ focus = self.focus(x, xa=xa, mask=mask, enc=enc, layer=layer) if self.focus is not None else 0
818
+
819
  b = torch.sigmoid(self.blend)
820
+ ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer)[0] + focus
821
  bx = b * ax + (1 - b) * x
822
  cx = self.lnb(bx)
823
  dx = self.mlp(cx)
 
1012
 
1013
  "spectrogram": nn.ModuleList(
1014
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1015
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1016
  if "spectrogram" in features else None),
1017
 
1018
  "waveform": nn.ModuleList(
1019
  [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
1020
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1021
  if "waveform" in features else None),
1022
 
1023
  "pitch": nn.ModuleList(
1024
  [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
1025
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1026
  if "pitch" in features else None),
1027
 
1028
  "envelope": nn.ModuleList(
1029
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1030
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1031
  if "envelope" in features else None),
1032
 
1033
  "phase": nn.ModuleList(
1034
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1035
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate, focus=False) for _ in range(layer)]
1036
  if "phase" in features else None),
1037
  })
1038
 
1039
  self.block = nn.ModuleList([
1040
+ Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features, focus=False)
1041
  for _ in range(layer)])
1042
 
1043
  self.blend = nn.Parameter(torch.tensor(0.5))