File size: 507 Bytes
e611d1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch.nn as nn
def get_activation(activation: str) -> nn.Module:
"""Get activation function by name."""
if activation == "relu":
return nn.ReLU()
elif activation == "leaky_relu":
return nn.LeakyReLU(negative_slope=0.2)
elif activation == "gelu":
return nn.GELU()
elif activation == "silu":
return nn.SiLU()
elif activation == "tanh":
return nn.Tanh()
else:
raise ValueError(f"Unsupported activation function: {activation}") |