Sin2pi commited on
Commit
094da86
·
verified ·
1 Parent(s): 1db8188

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -7
model.py CHANGED
@@ -35,7 +35,6 @@ dtype = torch.float32
35
  warnings.filterwarnings("ignore")
36
  logging.basicConfig(level=logging.ERROR)
37
 
38
-
39
  extractor = None
40
  tokenizer = None
41
  optimizer = None
@@ -373,9 +372,9 @@ class rotary(nn.Module):
373
  print(f"Radius[end]: {radius[-1][:5].cpu().numpy()}")
374
 
375
  print(f"Final freqs shape: {freqs.shape}")
376
- print(f"Freqs[0]: {freqs[0][:5].cpu().numpy()}")
377
- print(f"Freqs[mid]: {freqs[ctx//2][:5].cpu().numpy()}")
378
- print(f"Freqs[end]: {freqs[-1][:5].cpu().numpy()}")
379
  print("================================\n")
380
 
381
  self._counter += 1
@@ -394,7 +393,6 @@ class rotary(nn.Module):
394
  x1 = x1.view(orig_shape)
395
  return torch.cat([x1.type_as(x), x2], dim=-1)
396
 
397
-
398
  # class FocusA(nn.Module):
399
  # def __init__(self, dims, head, max_dist=None, win_size=32, max_span=32, temp_scale=0.01, iterations=2):
400
  # super().__init__()
@@ -1822,5 +1820,3 @@ def main():
1822
  if __name__ == "__main__":
1823
  main()
1824
 
1825
-
1826
-
 
35
  warnings.filterwarnings("ignore")
36
  logging.basicConfig(level=logging.ERROR)
37
 
 
38
  extractor = None
39
  tokenizer = None
40
  optimizer = None
 
372
  print(f"Radius[end]: {radius[-1][:5].cpu().numpy()}")
373
 
374
  print(f"Final freqs shape: {freqs.shape}")
375
+ print(f"Freqs[0]: {freqs[0][:5].cpu().detach().numpy()}")
376
+ print(f"Freqs[mid]: {freqs[ctx//2][:5].cpu().detach().numpy()}")
377
+ print(f"Freqs[end]: {freqs[-1][:5].cpu().detach().numpy()}")
378
  print("================================\n")
379
 
380
  self._counter += 1
 
393
  x1 = x1.view(orig_shape)
394
  return torch.cat([x1.type_as(x), x2], dim=-1)
395
 
 
396
  # class FocusA(nn.Module):
397
  # def __init__(self, dims, head, max_dist=None, win_size=32, max_span=32, temp_scale=0.01, iterations=2):
398
  # super().__init__()
 
1820
  if __name__ == "__main__":
1821
  main()
1822