import torch.nn as nn | |
def get_activation(activation: str) -> nn.Module: | |
"""Get activation function by name.""" | |
if activation == "relu": | |
return nn.ReLU() | |
elif activation == "leaky_relu": | |
return nn.LeakyReLU(negative_slope=0.2) | |
elif activation == "gelu": | |
return nn.GELU() | |
elif activation == "silu": | |
return nn.SiLU() | |
elif activation == "tanh": | |
return nn.Tanh() | |
else: | |
raise ValueError(f"Unsupported activation function: {activation}") |