AkashDataScience commited on
Commit
6c3debe
·
1 Parent(s): 59bc37f

Loading model properly

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -3,7 +3,6 @@ import tiktoken
3
  import gradio as gr
4
  import torch.nn.functional as F
5
  from model import GPT, GPTConfig
6
- torch._dynamo.reset()
7
 
8
  device = 'cpu'
9
  if torch.cuda.is_available():
@@ -12,10 +11,14 @@ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
12
  device = "mps"
13
 
14
  model = GPT(GPTConfig())
15
- model.load_state_dict(torch.load("gpt2.pt", map_location=torch.device(device)))
 
 
 
 
 
16
 
17
  model.to(device)
18
- model = torch.compile(model, fullgraph=True, backend="cudagraphs")
19
 
20
  enc = tiktoken.get_encoding('gpt2')
21
 
 
3
  import gradio as gr
4
  import torch.nn.functional as F
5
  from model import GPT, GPTConfig
 
6
 
7
  device = 'cpu'
8
  if torch.cuda.is_available():
 
11
  device = "mps"
12
 
13
  model = GPT(GPTConfig())
14
+ ckpt = torch.load("gpt2.pt", map_location=torch.device(device))
15
+ unwanted_prefix = '_orig_mod.'
16
+ for k,v in list(ckpt.items()):
17
+ if k.startswith(unwanted_prefix):
18
+ ckpt[k[len(unwanted_prefix):]] = ckpt.pop(k)
19
+ model.load_state_dict(ckpt)
20
 
21
  model.to(device)
 
22
 
23
  enc = tiktoken.get_encoding('gpt2')
24