Shriti09 commited on
Commit
707f5ae
·
verified ·
1 Parent(s): f271aef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -2,15 +2,27 @@ import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer
4
  from model_smol2 import LlamaForCausalLM, config_model
5
-
6
  # Instantiate the model
7
  model = LlamaForCausalLM(config_model)
8
 
9
- # Load the checkpoint
10
- checkpoint_path = "/Users/shriti/Downloads/Assign13_ERAV3/deply/final_checkpoint.pt"
11
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
12
- model.load_state_dict(checkpoint['model_state_dict'])
13
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Load tokenizer (replace with the appropriate tokenizer if you're using a custom one)
16
  # Load the tokenizer
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer
4
  from model_smol2 import LlamaForCausalLM, config_model
5
+ import requests
6
  # Instantiate the model
7
  model = LlamaForCausalLM(config_model)
8
 
9
+ # Correct URL for Google Drive direct download
10
+ url = "https://drive.google.com/uc?id=1tyZhudOcZRMaSLkzAWksVyQmbPQ0fi50"
11
+
12
+ response = requests.get(url)
13
+ if response.status_code == 200:
14
+ with open("final_checkpoint.pt", "wb") as f:
15
+ f.write(response.content)
16
+ else:
17
+ print(f"Failed to download the file. Status code: {response.status_code}")
18
+
19
+ # Now load the checkpoint
20
+ try:
21
+ checkpoint = torch.load("final_checkpoint.pt", map_location="cpu", weights_only=True)
22
+ model.load_state_dict(checkpoint['model_state_dict'])
23
+ model.eval()
24
+ except Exception as e:
25
+ print(f"Error loading checkpoint: {e}")
26
 
27
  # Load tokenizer (replace with the appropriate tokenizer if you're using a custom one)
28
  # Load the tokenizer