Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -736,9 +736,9 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
736 |
self.max_seq_len = config.max_seq_len
|
737 |
self.dim = config.dim
|
738 |
self.theta = config.theta
|
739 |
-
self.sincos_cache =
|
740 |
|
741 |
-
def _create_sinusoidal_positions(self) -> torch.Tensor:
|
742 |
"""
|
743 |
Create the sines and cosines for the RoPE.
|
744 |
|
@@ -747,19 +747,19 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
747 |
"""
|
748 |
# Create the inverse frequency based on theta and dim
|
749 |
inv_freq = 1.0 / (
|
750 |
-
self.theta ** (torch.arange(0, self.dim, 2).float() / self.dim)
|
751 |
)
|
752 |
|
753 |
# Compute sinusoidal input using the broadcasting
|
754 |
sinusoid_inp = torch.einsum(
|
755 |
-
"i,j->ij", torch.arange(self.max_seq_len).float(), inv_freq
|
756 |
)
|
757 |
|
758 |
# Apply sin and cos to the sinusoidal input
|
759 |
sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
|
760 |
|
761 |
# Allocate a tensor for the final sin-cos values
|
762 |
-
sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32)
|
763 |
|
764 |
# Fill the sincos tensor with sin and cos values
|
765 |
sentinel = self.dim // 2 + self.dim % 2
|
@@ -824,6 +824,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
824 |
Returns:
|
825 |
RoPE embeddings for the keys and values.
|
826 |
"""
|
|
|
|
|
|
|
|
|
827 |
batch_size, seq_len, num_heads, head_dim = k.shape
|
828 |
|
829 |
# Generate position ids
|
|
|
736 |
self.max_seq_len = config.max_seq_len
|
737 |
self.dim = config.dim
|
738 |
self.theta = config.theta
|
739 |
+
self.sincos_cache = None
|
740 |
|
741 |
+
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
|
742 |
"""
|
743 |
Create the sines and cosines for the RoPE.
|
744 |
|
|
|
747 |
"""
|
748 |
# Create the inverse frequency based on theta and dim
|
749 |
inv_freq = 1.0 / (
|
750 |
+
self.theta ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
|
751 |
)
|
752 |
|
753 |
# Compute sinusoidal input using the broadcasting
|
754 |
sinusoid_inp = torch.einsum(
|
755 |
+
"i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq
|
756 |
)
|
757 |
|
758 |
# Apply sin and cos to the sinusoidal input
|
759 |
sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
|
760 |
|
761 |
# Allocate a tensor for the final sin-cos values
|
762 |
+
sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32, device=device)
|
763 |
|
764 |
# Fill the sincos tensor with sin and cos values
|
765 |
sentinel = self.dim // 2 + self.dim % 2
|
|
|
824 |
Returns:
|
825 |
RoPE embeddings for the keys and values.
|
826 |
"""
|
827 |
+
if self.sincos_cache is None:
|
828 |
+
device = k.device
|
829 |
+
self.sincos_cache = self._create_sinusoidal_positions(device=device)
|
830 |
+
|
831 |
batch_size, seq_len, num_heads, head_dim = k.shape
|
832 |
|
833 |
# Generate position ids
|