Update model.py
Browse files
model.py
CHANGED
@@ -35,7 +35,6 @@ dtype = torch.float32
|
|
35 |
warnings.filterwarnings("ignore")
|
36 |
logging.basicConfig(level=logging.ERROR)
|
37 |
|
38 |
-
|
39 |
extractor = None
|
40 |
tokenizer = None
|
41 |
optimizer = None
|
@@ -373,9 +372,9 @@ class rotary(nn.Module):
|
|
373 |
print(f"Radius[end]: {radius[-1][:5].cpu().numpy()}")
|
374 |
|
375 |
print(f"Final freqs shape: {freqs.shape}")
|
376 |
-
print(f"Freqs[0]: {freqs[0][:5].cpu().numpy()}")
|
377 |
-
print(f"Freqs[mid]: {freqs[ctx//2][:5].cpu().numpy()}")
|
378 |
-
print(f"Freqs[end]: {freqs[-1][:5].cpu().numpy()}")
|
379 |
print("================================\n")
|
380 |
|
381 |
self._counter += 1
|
@@ -394,7 +393,6 @@ class rotary(nn.Module):
|
|
394 |
x1 = x1.view(orig_shape)
|
395 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
396 |
|
397 |
-
|
398 |
# class FocusA(nn.Module):
|
399 |
# def __init__(self, dims, head, max_dist=None, win_size=32, max_span=32, temp_scale=0.01, iterations=2):
|
400 |
# super().__init__()
|
@@ -1822,5 +1820,3 @@ def main():
|
|
1822 |
if __name__ == "__main__":
|
1823 |
main()
|
1824 |
|
1825 |
-
|
1826 |
-
|
|
|
35 |
warnings.filterwarnings("ignore")
|
36 |
logging.basicConfig(level=logging.ERROR)
|
37 |
|
|
|
38 |
extractor = None
|
39 |
tokenizer = None
|
40 |
optimizer = None
|
|
|
372 |
print(f"Radius[end]: {radius[-1][:5].cpu().numpy()}")
|
373 |
|
374 |
print(f"Final freqs shape: {freqs.shape}")
|
375 |
+
print(f"Freqs[0]: {freqs[0][:5].cpu().detach().numpy()}")
|
376 |
+
print(f"Freqs[mid]: {freqs[ctx//2][:5].cpu().detach().numpy()}")
|
377 |
+
print(f"Freqs[end]: {freqs[-1][:5].cpu().detach().numpy()}")
|
378 |
print("================================\n")
|
379 |
|
380 |
self._counter += 1
|
|
|
393 |
x1 = x1.view(orig_shape)
|
394 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
395 |
|
|
|
396 |
# class FocusA(nn.Module):
|
397 |
# def __init__(self, dims, head, max_dist=None, win_size=32, max_span=32, temp_scale=0.01, iterations=2):
|
398 |
# super().__init__()
|
|
|
1820 |
if __name__ == "__main__":
|
1821 |
main()
|
1822 |
|
|
|
|