ReCEP / src /bce /model /activation.py
NielTT's picture
Upload 108 files
e611d1f verified
import torch.nn as nn
def get_activation(activation: str) -> nn.Module:
"""Get activation function by name."""
if activation == "relu":
return nn.ReLU()
elif activation == "leaky_relu":
return nn.LeakyReLU(negative_slope=0.2)
elif activation == "gelu":
return nn.GELU()
elif activation == "silu":
return nn.SiLU()
elif activation == "tanh":
return nn.Tanh()
else:
raise ValueError(f"Unsupported activation function: {activation}")