Update rmsnorm.py
Browse filesModifies the forward pass of RMSNorm to avoid mixed precision issues as described in https://github.com/chandar-lab/AMPLIFY/issues/19
- rmsnorm.py +5 -1
rmsnorm.py
CHANGED
@@ -20,6 +20,9 @@ class RMSNorm(nn.Module):
|
|
20 |
self.eps = eps
|
21 |
self.weight = nn.Parameter(torch.ones(dim))
|
22 |
|
|
|
|
|
|
|
23 |
def forward(self, x):
|
24 |
"""
|
25 |
Forward pass through the RMSNorm layer.
|
@@ -31,4 +34,5 @@ class RMSNorm(nn.Module):
|
|
31 |
torch.Tensor: The output tensor after applying RMSNorm.
|
32 |
|
33 |
"""
|
34 |
-
|
|
|
|
20 |
self.eps = eps
|
21 |
self.weight = nn.Parameter(torch.ones(dim))
|
22 |
|
23 |
+
def _norm(self, x):
|
24 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
25 |
+
|
26 |
def forward(self, x):
|
27 |
"""
|
28 |
Forward pass through the RMSNorm layer.
|
|
|
34 |
torch.Tensor: The output tensor after applying RMSNorm.
|
35 |
|
36 |
"""
|
37 |
+
output = self._norm(x.float()).type_as(x) # Avoids mixed precision issues as in https://github.com/chandar-lab/AMPLIFY/issues/19
|
38 |
+
return output * self.weight
|