HemanM commited on
Commit
70b5bb7
·
verified ·
1 Parent(s): 3eaea0f

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -3
inference.py CHANGED
@@ -7,11 +7,23 @@ import os
7
 
8
  # Load Evo model and tokenizer
9
  model = EvoTransformerV22()
10
- model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
11
- model.eval()
12
-
13
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # GPT Setup
16
  openai.api_key = os.getenv("OPENAI_API_KEY") # 🔒 Load securely from environment
17
 
 
7
 
8
  # Load Evo model and tokenizer
9
  model = EvoTransformerV22()
 
 
 
10
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
 
12
+ # Smart load order
13
+ if os.path.exists("trained_model/evo_retrained.pt"):
14
+ model.load_state_dict(torch.load("trained_model/evo_retrained.pt", map_location="cpu"))
15
+ print("🔁 Loaded retrained Evo model.")
16
+ elif os.path.exists("trained_model/evo_pretrained.pt"):
17
+ model.load_state_dict(torch.load("trained_model/evo_pretrained.pt", map_location="cpu"))
18
+ print("📦 Loaded pretrained Evo model.")
19
+ elif os.path.exists("evo_hellaswag.pt"):
20
+ model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
21
+ print("📥 Loaded default Evo model.")
22
+ else:
23
+ raise FileNotFoundError("❌ No Evo model file found.")
24
+
25
+ model.eval()
26
+
27
  # GPT Setup
28
  openai.api_key = os.getenv("OPENAI_API_KEY") # 🔒 Load securely from environment
29