|
import torch |
|
|
|
|
|
class LayerNorm(torch.nn.Module): |
|
""" |
|
Layer normalization module. |
|
|
|
Args: |
|
channels (int): Number of channels. |
|
eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-5. |
|
""" |
|
|
|
def __init__(self, channels: int, eps: float = 1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.gamma = torch.nn.Parameter(torch.ones(channels)) |
|
self.beta = torch.nn.Parameter(torch.zeros(channels)) |
|
|
|
def forward(self, x): |
|
|
|
x = x.transpose(1, -1) |
|
x = torch.nn.functional.layer_norm( |
|
x, (x.size(-1),), self.gamma, self.beta, self.eps |
|
) |
|
|
|
return x.transpose(1, -1) |
|
|