davidhd commited on
Commit
dbc72bc
·
verified ·
1 Parent(s): 5411a4b

Update rmsnorm.py

Browse files

Modifies the forward pass of RMSNorm to avoid mixed precision issues as described in https://github.com/chandar-lab/AMPLIFY/issues/19

Files changed (1) hide show
  1. rmsnorm.py +5 -1
rmsnorm.py CHANGED
@@ -20,6 +20,9 @@ class RMSNorm(nn.Module):
20
  self.eps = eps
21
  self.weight = nn.Parameter(torch.ones(dim))
22
 
 
 
 
23
  def forward(self, x):
24
  """
25
  Forward pass through the RMSNorm layer.
@@ -31,4 +34,5 @@ class RMSNorm(nn.Module):
31
  torch.Tensor: The output tensor after applying RMSNorm.
32
 
33
  """
34
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
 
 
20
  self.eps = eps
21
  self.weight = nn.Parameter(torch.ones(dim))
22
 
23
+ def _norm(self, x):
24
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
25
+
26
  def forward(self, x):
27
  """
28
  Forward pass through the RMSNorm layer.
 
34
  torch.Tensor: The output tensor after applying RMSNorm.
35
 
36
  """
37
+ output = self._norm(x.float()).type_as(x) # Avoids mixed precision issues as in https://github.com/chandar-lab/AMPLIFY/issues/19
38
+ return output * self.weight