Sin2pi commited on
Commit
fc4dff3
·
verified ·
1 Parent(s): d56f7a6

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +1 -15
modelA.py CHANGED
@@ -32,15 +32,7 @@ dtype = torch.float32
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
- PATH = 'E:/hf'
36
- os.environ['HF_HOME'] = PATH
37
- os.environ['HF_DATASETS_CACHE'] = PATH
38
- os.environ['TORCH_HOME'] = PATH
39
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
40
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
-
42
  def get_activation(act: str) -> nn.Module:
43
- """Get activation function by name."""
44
  act_map = {
45
  "gelu": nn.GELU(),
46
  "relu": nn.ReLU(),
@@ -205,7 +197,6 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
205
  return fig
206
 
207
  def valid(default_value, *items):
208
- """Get first non-None item"""
209
  for item in items:
210
  if item is not None:
211
  return item
@@ -715,8 +706,6 @@ class FocusWindow(nn.Module):
715
  return head_importance
716
 
717
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=False, return_head_weights=False, learn_lr=False):
718
-
719
- print(f"🎯 FocusWindow running! Input: {x.shape}, Feature: {xa.shape if xa is not None else None}")
720
 
721
  q = self.q_proj(x)
722
  k = self.k_proj(x if xa is None else xa)
@@ -817,11 +806,8 @@ class AdaptiveAttentionLR(nn.Module):
817
 
818
  def forward(self, x, feature_data=None, mask=None):
819
  quality = self.quality_estimator(x.mean(dim=1))
820
-
821
- lr_factor = self.lr_predictor(x.mean(dim=1))
822
-
823
  adaptive_lr = quality * lr_factor
824
-
825
  return adaptive_lr, adaptive_lr
826
 
827
  class SmartSensorResidual(nn.Module):
 
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
 
 
 
 
 
 
 
35
  def get_activation(act: str) -> nn.Module:
 
36
  act_map = {
37
  "gelu": nn.GELU(),
38
  "relu": nn.ReLU(),
 
197
  return fig
198
 
199
  def valid(default_value, *items):
 
200
  for item in items:
201
  if item is not None:
202
  return item
 
706
  return head_importance
707
 
708
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, return_bias=False, return_head_weights=False, learn_lr=False):
 
 
709
 
710
  q = self.q_proj(x)
711
  k = self.k_proj(x if xa is None else xa)
 
806
 
807
  def forward(self, x, feature_data=None, mask=None):
808
  quality = self.quality_estimator(x.mean(dim=1))
809
+ lr_factor = self.lr_predictor(x.mean(dim=1)
 
 
810
  adaptive_lr = quality * lr_factor
 
811
  return adaptive_lr, adaptive_lr
812
 
813
  class SmartSensorResidual(nn.Module):