HemanM commited on
Commit
fb120ec
·
verified ·
1 Parent(s): 6a94f97

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +41 -0
inference.py CHANGED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, OpenAIGPTLMHeadModel
3
+ from evo_model import EvoTransformerV22
4
+
5
+ # Load Evo model
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ evo_model = EvoTransformerV22()
8
+ evo_model.load_state_dict(torch.load("trained_model/evo_hellaswag.pt", map_location=device))
9
+ evo_model.to(device)
10
+ evo_model.eval()
11
+
12
+ # Load tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
14
+
15
+ # 🧠 Evo logic
16
+ def get_evo_response(query, context):
17
+ combined = query + " " + context
18
+ inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
19
+ input_ids = inputs["input_ids"].to(device)
20
+
21
+ with torch.no_grad():
22
+ logits = evo_model(input_ids)
23
+ pred = torch.argmax(logits, dim=1).item()
24
+
25
+ return f"Evo suggests: Option {pred + 1}" # Assumes binary classification (0 or 1)
26
+
27
+ # 🤖 GPT-3.5 comparison (optional)
28
+ import openai
29
+ openai.api_key = "sk-..." # Replace with your OpenAI API key
30
+
31
+ def get_gpt_response(query, context):
32
+ try:
33
+ prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
34
+ response = openai.ChatCompletion.create(
35
+ model="gpt-3.5-turbo",
36
+ messages=[{"role": "user", "content": prompt}],
37
+ temperature=0.3
38
+ )
39
+ return response['choices'][0]['message']['content'].strip()
40
+ except Exception as e:
41
+ return f"Error from GPT: {e}"