Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -1188,11 +1188,6 @@ class RotaryEmbeddingBis(torch.nn.Module):
|
|
1188 |
heads[..., heads.shape[-1] // 2 :],
|
1189 |
)
|
1190 |
|
1191 |
-
print("x_first device : ", x_first.device)
|
1192 |
-
print("cos device : ", cos.device)
|
1193 |
-
print("x_second device : ", x_second.device)
|
1194 |
-
print("sin device : ", sin.device)
|
1195 |
-
|
1196 |
first_part = x_first * cos - x_second * sin
|
1197 |
second_part = x_second * cos + x_first * sin
|
1198 |
|
@@ -1201,13 +1196,11 @@ class RotaryEmbeddingBis(torch.nn.Module):
|
|
1201 |
def _compute_cos_sin_tables(
|
1202 |
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
|
1203 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
1204 |
-
print("x device : ", x.device)
|
1205 |
seq_len = x.shape[seq_dimension]
|
1206 |
# Reset the tables if the sequence length has changed,
|
1207 |
# or if we're on a new device (possibly due to tracing for instance)
|
1208 |
self._seq_len_cached = seq_len
|
1209 |
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
|
1210 |
-
print("t device : ", t.device)
|
1211 |
# freqs = torch.outer(t, inv_freq)
|
1212 |
freqs = torch.einsum("i, j -> ij", t, inv_freq)
|
1213 |
|
@@ -1235,8 +1228,6 @@ class RotaryEmbeddingBis(torch.nn.Module):
|
|
1235 |
updated_base ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
|
1236 |
)
|
1237 |
|
1238 |
-
print("q device : ", q.device)
|
1239 |
-
print("inv_freq device : ", inv_freq.device)
|
1240 |
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
|
1241 |
q,
|
1242 |
inv_freq,
|
|
|
1188 |
heads[..., heads.shape[-1] // 2 :],
|
1189 |
)
|
1190 |
|
|
|
|
|
|
|
|
|
|
|
1191 |
first_part = x_first * cos - x_second * sin
|
1192 |
second_part = x_second * cos + x_first * sin
|
1193 |
|
|
|
1196 |
def _compute_cos_sin_tables(
|
1197 |
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
|
1198 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
1199 |
seq_len = x.shape[seq_dimension]
|
1200 |
# Reset the tables if the sequence length has changed,
|
1201 |
# or if we're on a new device (possibly due to tracing for instance)
|
1202 |
self._seq_len_cached = seq_len
|
1203 |
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
|
|
|
1204 |
# freqs = torch.outer(t, inv_freq)
|
1205 |
freqs = torch.einsum("i, j -> ij", t, inv_freq)
|
1206 |
|
|
|
1228 |
updated_base ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
|
1229 |
)
|
1230 |
|
|
|
|
|
1231 |
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
|
1232 |
q,
|
1233 |
inv_freq,
|