Sin2pi commited on
Commit
078fc2d
·
verified ·
1 Parent(s): d064036

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +16 -15
model_hf.py CHANGED
@@ -567,6 +567,7 @@ class c_gate(nn.Module):
567
  comb = torch.cat([s, w, p, e, ph], dim=-1)
568
  return self.integ(comb)
569
 
 
570
  class Residual(nn.Module):
571
  _seen = set()
572
  def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
@@ -586,10 +587,10 @@ class Residual(nn.Module):
586
  self.t_gate = tgate
587
  self.m_gate = mgate
588
  self.c_gate = cgate
589
- self.skip_gates=True
590
-
591
- self.blend = nn.Parameter(torch.tensor(0.5))
592
-
593
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
594
  "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
595
  "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
@@ -618,21 +619,21 @@ class Residual(nn.Module):
618
  if xa is not None:
619
  xa = xa.to(device, dtype)
620
 
621
- bln = self.blend
622
  x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
623
-
624
  if self.attnb and xa is not None:
625
- c = self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
626
- b = torch.sigmoid(bln)
627
- x = b * x + (1 - b) * c
 
 
628
 
629
- normx = self.lnc(x)
630
- mlp_out = self.mlp(normx)
631
-
632
  if self.skip_gates:
633
- x = x + mlp_out
634
-
635
  else:
 
 
636
 
637
  if self.t_gate:
638
  gate = self.t_gate(normx)
@@ -664,9 +665,9 @@ class Residual(nn.Module):
664
  else:
665
  print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
666
  self.counter += 1
667
-
668
  return x
669
 
 
670
  class FEncoder(nn.Module):
671
  def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
672
  super().__init__()
 
567
  comb = torch.cat([s, w, p, e, ph], dim=-1)
568
  return self.integ(comb)
569
 
570
+
571
  class Residual(nn.Module):
572
  _seen = set()
573
  def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
 
587
  self.t_gate = tgate
588
  self.m_gate = mgate
589
  self.c_gate = cgate
590
+ self.do_blend = "no_blend" not in self.debug
591
+ self.blend = nn.Parameter(torch.tensor(0.5))
592
+ self.skip_gates = True if "skip_gates" in self.debug else False
593
+
594
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
595
  "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
596
  "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
 
619
  if xa is not None:
620
  xa = xa.to(device, dtype)
621
 
622
+ blend = self.blend
623
  x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
624
+ xb = x
625
  if self.attnb and xa is not None:
626
+ x = x + self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
627
+
628
+ if self.do_blend:
629
+ b = torch.sigmoid(blend)
630
+ x = b * xb + (1 - b) * x
631
 
 
 
 
632
  if self.skip_gates:
633
+ x = x + self.mlp(self.lnc(x))
 
634
  else:
635
+ normx = self.lnc(x)
636
+ mlp_out = self.mlp(normx)
637
 
638
  if self.t_gate:
639
  gate = self.t_gate(normx)
 
665
  else:
666
  print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
667
  self.counter += 1
 
668
  return x
669
 
670
+
671
  class FEncoder(nn.Module):
672
  def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
673
  super().__init__()