Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
class UstaLayerNorm(nn.Module): | |
def __init__(self, embedding_dim, eps=1e-5, device="cpu"): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(embedding_dim, device=device)) | |
self.device = device | |
def forward(self, x): | |
mean = x.mean(dim=-1, keepdim=True) | |
variance = x.var(dim=-1, keepdim=True, unbiased=False) | |
normalized_x = (x - mean) / torch.sqrt(variance + self.eps) | |
return self.weight * normalized_x | |