import torch import torch.nn as nn class UstaLayerNorm(nn.Module): def __init__(self, embedding_dim, eps=1e-5, device="cpu"): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(embedding_dim, device=device)) self.device = device def forward(self, x): mean = x.mean(dim=-1, keepdim=True) variance = x.var(dim=-1, keepdim=True, unbiased=False) normalized_x = (x - mean) / torch.sqrt(variance + self.eps) return self.weight * normalized_x