Spaces:
Running
Running
File size: 2,032 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import torch.nn as nn
from diffusers.utils import is_torch_npu_available, is_torch_version
def _patched_rms_norm_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if is_torch_npu_available():
import torch_npu
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
if self.bias is not None:
hidden_states = hidden_states + self.bias
elif is_torch_version(">=", "2.4"):
### ===== <Modified> =======
input_dtype = hidden_states.dtype
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = nn.functional.rms_norm(
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
)
if self.bias is not None:
hidden_states = hidden_states + self.bias
hidden_states = hidden_states.to(input_dtype)
### ===== </Modified> =====
else:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
if self.bias is not None:
hidden_states = hidden_states + self.bias
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
|