Yanisadel commited on
Commit
93927ba
·
1 Parent(s): 57943d9

Upload model

Browse files
chatNT.py CHANGED
@@ -747,7 +747,8 @@ class TorchRotaryEmbedding(torch.nn.Module):
747
  """
748
  # Create the inverse frequency based on theta and dim
749
  inv_freq = 1.0 / (
750
- self.theta ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
 
751
  )
752
 
753
  # Compute sinusoidal input using the broadcasting
@@ -759,7 +760,9 @@ class TorchRotaryEmbedding(torch.nn.Module):
759
  sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
760
 
761
  # Allocate a tensor for the final sin-cos values
762
- sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32, device=device)
 
 
763
 
764
  # Fill the sincos tensor with sin and cos values
765
  sentinel = self.dim // 2 + self.dim % 2
@@ -827,7 +830,7 @@ class TorchRotaryEmbedding(torch.nn.Module):
827
  if self.sincos_cache is None:
828
  device = k.device
829
  self.sincos_cache = self._create_sinusoidal_positions(device=device)
830
-
831
  batch_size, seq_len, num_heads, head_dim = k.shape
832
 
833
  # Generate position ids
@@ -839,7 +842,7 @@ class TorchRotaryEmbedding(torch.nn.Module):
839
  position_ids += positions
840
 
841
  # Retrieve sincos values using the position_ids
842
- sincos = self.sincos_cache[position_ids]
843
 
844
  # Split sincos into sin_pos and cos_pos
845
  sincos = torch.chunk(sincos, 2, dim=-1)
@@ -975,7 +978,9 @@ class TorchGptDecoder(nn.Module):
975
  self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
976
  ) -> torch.Tensor:
977
  if attention_mask is None:
978
- attention_mask = build_causal_attention_mask(1, embeddings.shape[1], device=embeddings.device)
 
 
979
  for layer in self.layers:
980
  embeddings = layer(embeddings, attention_mask)
981
 
@@ -985,7 +990,9 @@ class TorchGptDecoder(nn.Module):
985
  self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
986
  ) -> dict[str, torch.Tensor]:
987
  if attention_mask is None:
988
- attention_mask = build_causal_attention_mask(1, token_ids.shape[1], device=token_ids.device)
 
 
989
 
990
  tokens_embeddings = self.token_embed(token_ids)
991
 
@@ -1127,7 +1134,9 @@ def get_activation_fn(activation_name: str): # type: ignore
1127
  return activations.get(activation_name, nn.functional.relu)
1128
 
1129
 
1130
- def build_causal_attention_mask(batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
 
 
1131
  """
1132
  Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1133
  to an attention layer.
@@ -1218,14 +1227,16 @@ class RotaryEmbeddingBis(torch.nn.Module):
1218
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1219
  if self.rescaling_factor is None:
1220
  inv_freq = 1.0 / (
1221
- self.upper_freq ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
 
1222
  )
1223
  else:
1224
  updated_base = self.upper_freq * (
1225
  self.rescaling_factor ** (self.dim / (self.dim - 2))
1226
  )
1227
  inv_freq = 1.0 / (
1228
- updated_base ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
 
1229
  )
1230
 
1231
  self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
 
747
  """
748
  # Create the inverse frequency based on theta and dim
749
  inv_freq = 1.0 / (
750
+ self.theta
751
+ ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
752
  )
753
 
754
  # Compute sinusoidal input using the broadcasting
 
760
  sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
761
 
762
  # Allocate a tensor for the final sin-cos values
763
+ sincos = torch.zeros(
764
+ (self.max_seq_len, self.dim), dtype=torch.float32, device=device
765
+ )
766
 
767
  # Fill the sincos tensor with sin and cos values
768
  sentinel = self.dim // 2 + self.dim % 2
 
830
  if self.sincos_cache is None:
831
  device = k.device
832
  self.sincos_cache = self._create_sinusoidal_positions(device=device)
833
+
834
  batch_size, seq_len, num_heads, head_dim = k.shape
835
 
836
  # Generate position ids
 
842
  position_ids += positions
843
 
844
  # Retrieve sincos values using the position_ids
845
+ sincos = self.sincos_cache[position_ids] # type: ignore
846
 
847
  # Split sincos into sin_pos and cos_pos
848
  sincos = torch.chunk(sincos, 2, dim=-1)
 
978
  self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
979
  ) -> torch.Tensor:
980
  if attention_mask is None:
981
+ attention_mask = build_causal_attention_mask(
982
+ 1, embeddings.shape[1], device=embeddings.device
983
+ )
984
  for layer in self.layers:
985
  embeddings = layer(embeddings, attention_mask)
986
 
 
990
  self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
991
  ) -> dict[str, torch.Tensor]:
992
  if attention_mask is None:
993
+ attention_mask = build_causal_attention_mask(
994
+ 1, token_ids.shape[1], device=token_ids.device
995
+ )
996
 
997
  tokens_embeddings = self.token_embed(token_ids)
998
 
 
1134
  return activations.get(activation_name, nn.functional.relu)
1135
 
1136
 
1137
+ def build_causal_attention_mask(
1138
+ batch_size: int, seq_len: int, device: torch.device
1139
+ ) -> torch.Tensor:
1140
  """
1141
  Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1142
  to an attention layer.
 
1227
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1228
  if self.rescaling_factor is None:
1229
  inv_freq = 1.0 / (
1230
+ self.upper_freq
1231
+ ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1232
  )
1233
  else:
1234
  updated_base = self.upper_freq * (
1235
  self.rescaling_factor ** (self.dim / (self.dim - 2))
1236
  )
1237
  inv_freq = 1.0 / (
1238
+ updated_base
1239
+ ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1240
  )
1241
 
1242
  self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
config.json CHANGED
@@ -80,6 +80,6 @@
80
  "use_gradient_checkpointing": false
81
  },
82
  "seq_token_id": 32000,
83
- "torch_dtype": "float32",
84
  "transformers_version": "4.41.1"
85
  }
 
80
  "use_gradient_checkpointing": false
81
  },
82
  "seq_token_id": 32000,
83
+ "torch_dtype": "bfloat16",
84
  "transformers_version": "4.41.1"
85
  }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcde43f92cb4fba555a67426956b0c0c0b1e9ac7b04a7d6744580e4729cfd9e3
3
+ size 4998275550
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:187615f3a8661430364e2e824d5b0a0363c9cf5b3d8512f33c44015b0be27343
3
+ size 4890784808
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:916b86538557669e3a74c00d4d58ae44e494c4439aba8c2d6ee51baf05f62ebe
3
+ size 4985672264
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8524670292b2f477cd558fd76b3372840949dadd0b0a6c386519b05a82faebe6
3
+ size 1212565848
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff