File size: 507 Bytes
e611d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch.nn as nn

def get_activation(activation: str) -> nn.Module:
    """Get activation function by name."""
    if activation == "relu":
        return nn.ReLU()
    elif activation == "leaky_relu":
        return nn.LeakyReLU(negative_slope=0.2)
    elif activation == "gelu":
        return nn.GELU()
    elif activation == "silu":
        return nn.SiLU()
    elif activation == "tanh":
        return nn.Tanh()
    else:
        raise ValueError(f"Unsupported activation function: {activation}")