Yanisadel commited on
Commit
b54a5dd
·
1 Parent(s): 64c0358

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +9 -5
chatNT.py CHANGED
@@ -736,9 +736,9 @@ class TorchRotaryEmbedding(torch.nn.Module):
736
  self.max_seq_len = config.max_seq_len
737
  self.dim = config.dim
738
  self.theta = config.theta
739
- self.sincos_cache = self._create_sinusoidal_positions()
740
 
741
- def _create_sinusoidal_positions(self) -> torch.Tensor:
742
  """
743
  Create the sines and cosines for the RoPE.
744
 
@@ -747,19 +747,19 @@ class TorchRotaryEmbedding(torch.nn.Module):
747
  """
748
  # Create the inverse frequency based on theta and dim
749
  inv_freq = 1.0 / (
750
- self.theta ** (torch.arange(0, self.dim, 2).float() / self.dim)
751
  )
752
 
753
  # Compute sinusoidal input using the broadcasting
754
  sinusoid_inp = torch.einsum(
755
- "i,j->ij", torch.arange(self.max_seq_len).float(), inv_freq
756
  )
757
 
758
  # Apply sin and cos to the sinusoidal input
759
  sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
760
 
761
  # Allocate a tensor for the final sin-cos values
762
- sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32)
763
 
764
  # Fill the sincos tensor with sin and cos values
765
  sentinel = self.dim // 2 + self.dim % 2
@@ -824,6 +824,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
824
  Returns:
825
  RoPE embeddings for the keys and values.
826
  """
 
 
 
 
827
  batch_size, seq_len, num_heads, head_dim = k.shape
828
 
829
  # Generate position ids
 
736
  self.max_seq_len = config.max_seq_len
737
  self.dim = config.dim
738
  self.theta = config.theta
739
+ self.sincos_cache = None
740
 
741
+ def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
742
  """
743
  Create the sines and cosines for the RoPE.
744
 
 
747
  """
748
  # Create the inverse frequency based on theta and dim
749
  inv_freq = 1.0 / (
750
+ self.theta ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
751
  )
752
 
753
  # Compute sinusoidal input using the broadcasting
754
  sinusoid_inp = torch.einsum(
755
+ "i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq
756
  )
757
 
758
  # Apply sin and cos to the sinusoidal input
759
  sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
760
 
761
  # Allocate a tensor for the final sin-cos values
762
+ sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32, device=device)
763
 
764
  # Fill the sincos tensor with sin and cos values
765
  sentinel = self.dim // 2 + self.dim % 2
 
824
  Returns:
825
  RoPE embeddings for the keys and values.
826
  """
827
+ if self.sincos_cache is None:
828
+ device = k.device
829
+ self.sincos_cache = self._create_sinusoidal_positions(device=device)
830
+
831
  batch_size, seq_len, num_heads, head_dim = k.shape
832
 
833
  # Generate position ids