import torch.nn as nn import torch.nn.functional as F def get_activation_layer(act_type): if act_type == "gelu": return lambda: nn.GELU() elif act_type == "gelu_tanh": # Approximate `tanh` requires torch >= 1.13 return lambda: nn.GELU(approximate="tanh") elif act_type == "relu": return nn.ReLU elif act_type == "silu": return nn.SiLU else: raise ValueError(f"Unknown activation type: {act_type}") class SwiGLU(nn.Module): def __init__( self, dim: int, hidden_dim: int, out_dim: int, ): """ Initialize the SwiGLU FeedForward module. Args: dim (int): Input dimension. hidden_dim (int): Hidden dimension of the feedforward layer. Attributes: w1: Linear transformation for the first layer. w2: Linear transformation for the second layer. w3: Linear transformation for the third layer. """ super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x))