Yanisadel commited on
Commit
ba069f0
·
1 Parent(s): 02765e9

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +0 -9
chatNT.py CHANGED
@@ -1188,11 +1188,6 @@ class RotaryEmbeddingBis(torch.nn.Module):
1188
  heads[..., heads.shape[-1] // 2 :],
1189
  )
1190
 
1191
- print("x_first device : ", x_first.device)
1192
- print("cos device : ", cos.device)
1193
- print("x_second device : ", x_second.device)
1194
- print("sin device : ", sin.device)
1195
-
1196
  first_part = x_first * cos - x_second * sin
1197
  second_part = x_second * cos + x_first * sin
1198
 
@@ -1201,13 +1196,11 @@ class RotaryEmbeddingBis(torch.nn.Module):
1201
  def _compute_cos_sin_tables(
1202
  self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1203
  ) -> tuple[torch.Tensor, torch.Tensor]:
1204
- print("x device : ", x.device)
1205
  seq_len = x.shape[seq_dimension]
1206
  # Reset the tables if the sequence length has changed,
1207
  # or if we're on a new device (possibly due to tracing for instance)
1208
  self._seq_len_cached = seq_len
1209
  t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
1210
- print("t device : ", t.device)
1211
  # freqs = torch.outer(t, inv_freq)
1212
  freqs = torch.einsum("i, j -> ij", t, inv_freq)
1213
 
@@ -1235,8 +1228,6 @@ class RotaryEmbeddingBis(torch.nn.Module):
1235
  updated_base ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1236
  )
1237
 
1238
- print("q device : ", q.device)
1239
- print("inv_freq device : ", inv_freq.device)
1240
  self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1241
  q,
1242
  inv_freq,
 
1188
  heads[..., heads.shape[-1] // 2 :],
1189
  )
1190
 
 
 
 
 
 
1191
  first_part = x_first * cos - x_second * sin
1192
  second_part = x_second * cos + x_first * sin
1193
 
 
1196
  def _compute_cos_sin_tables(
1197
  self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1198
  ) -> tuple[torch.Tensor, torch.Tensor]:
 
1199
  seq_len = x.shape[seq_dimension]
1200
  # Reset the tables if the sequence length has changed,
1201
  # or if we're on a new device (possibly due to tracing for instance)
1202
  self._seq_len_cached = seq_len
1203
  t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
 
1204
  # freqs = torch.outer(t, inv_freq)
1205
  freqs = torch.einsum("i, j -> ij", t, inv_freq)
1206
 
 
1228
  updated_base ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1229
  )
1230
 
 
 
1231
  self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1232
  q,
1233
  inv_freq,