James Zhou
[init]
9867d34
import torch.nn as nn
import torch.nn.functional as F
def get_activation_layer(act_type):
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
# Approximate `tanh` requires torch >= 1.13
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
class SwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
out_dim: int,
):
"""
Initialize the SwiGLU FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
Attributes:
w1: Linear transformation for the first layer.
w2: Linear transformation for the second layer.
w3: Linear transformation for the third layer.
"""
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))