Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,285 Bytes
9867d34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch.nn as nn
import torch.nn.functional as F
def get_activation_layer(act_type):
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
# Approximate `tanh` requires torch >= 1.13
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
class SwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
out_dim: int,
):
"""
Initialize the SwiGLU FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
Attributes:
w1: Linear transformation for the first layer.
w2: Linear transformation for the second layer.
w3: Linear transformation for the third layer.
"""
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|