Update model_hf.py
Browse files- model_hf.py +16 -15
model_hf.py
CHANGED
@@ -567,6 +567,7 @@ class c_gate(nn.Module):
|
|
567 |
comb = torch.cat([s, w, p, e, ph], dim=-1)
|
568 |
return self.integ(comb)
|
569 |
|
|
|
570 |
class Residual(nn.Module):
|
571 |
_seen = set()
|
572 |
def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
|
@@ -586,10 +587,10 @@ class Residual(nn.Module):
|
|
586 |
self.t_gate = tgate
|
587 |
self.m_gate = mgate
|
588 |
self.c_gate = cgate
|
589 |
-
self.
|
590 |
-
|
591 |
-
self.
|
592 |
-
|
593 |
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
|
594 |
"tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
|
595 |
"softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
|
@@ -618,21 +619,21 @@ class Residual(nn.Module):
|
|
618 |
if xa is not None:
|
619 |
xa = xa.to(device, dtype)
|
620 |
|
621 |
-
|
622 |
x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
|
623 |
-
|
624 |
if self.attnb and xa is not None:
|
625 |
-
|
626 |
-
|
627 |
-
|
|
|
|
|
628 |
|
629 |
-
normx = self.lnc(x)
|
630 |
-
mlp_out = self.mlp(normx)
|
631 |
-
|
632 |
if self.skip_gates:
|
633 |
-
x = x +
|
634 |
-
|
635 |
else:
|
|
|
|
|
636 |
|
637 |
if self.t_gate:
|
638 |
gate = self.t_gate(normx)
|
@@ -664,9 +665,9 @@ class Residual(nn.Module):
|
|
664 |
else:
|
665 |
print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
|
666 |
self.counter += 1
|
667 |
-
|
668 |
return x
|
669 |
|
|
|
670 |
class FEncoder(nn.Module):
|
671 |
def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
|
672 |
super().__init__()
|
|
|
567 |
comb = torch.cat([s, w, p, e, ph], dim=-1)
|
568 |
return self.integ(comb)
|
569 |
|
570 |
+
|
571 |
class Residual(nn.Module):
|
572 |
_seen = set()
|
573 |
def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
|
|
|
587 |
self.t_gate = tgate
|
588 |
self.m_gate = mgate
|
589 |
self.c_gate = cgate
|
590 |
+
self.do_blend = "no_blend" not in self.debug
|
591 |
+
self.blend = nn.Parameter(torch.tensor(0.5))
|
592 |
+
self.skip_gates = True if "skip_gates" in self.debug else False
|
593 |
+
|
594 |
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
|
595 |
"tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
|
596 |
"softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
|
|
|
619 |
if xa is not None:
|
620 |
xa = xa.to(device, dtype)
|
621 |
|
622 |
+
blend = self.blend
|
623 |
x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
|
624 |
+
xb = x
|
625 |
if self.attnb and xa is not None:
|
626 |
+
x = x + self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
|
627 |
+
|
628 |
+
if self.do_blend:
|
629 |
+
b = torch.sigmoid(blend)
|
630 |
+
x = b * xb + (1 - b) * x
|
631 |
|
|
|
|
|
|
|
632 |
if self.skip_gates:
|
633 |
+
x = x + self.mlp(self.lnc(x))
|
|
|
634 |
else:
|
635 |
+
normx = self.lnc(x)
|
636 |
+
mlp_out = self.mlp(normx)
|
637 |
|
638 |
if self.t_gate:
|
639 |
gate = self.t_gate(normx)
|
|
|
665 |
else:
|
666 |
print(f"Step {self.counter}: Using MLP gate: {self.mlp_gate if hasattr(self, 'mlp_gate') else None}")
|
667 |
self.counter += 1
|
|
|
668 |
return x
|
669 |
|
670 |
+
|
671 |
class FEncoder(nn.Module):
|
672 |
def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
|
673 |
super().__init__()
|