HemanM commited on
Commit
e7d2e38
·
verified ·
1 Parent(s): f39d1fb

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +40 -18
inference.py CHANGED
@@ -1,35 +1,57 @@
1
  import torch
2
  from evo_model import EvoTransformer
 
3
 
4
- # Load EvoTransformer model
5
- def load_model(model_path="evo_hellaswag.pt", device=None):
6
- if device is None:
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
 
9
  model = EvoTransformer()
10
  model.load_state_dict(torch.load(model_path, map_location=device))
11
  model.to(device)
12
  model.eval()
13
- return model, device
14
 
15
- # Predict the best option (0 or 1)
16
- def predict(model, tokenizer, prompt, option1, option2, device):
17
- inputs = [
18
- f"{prompt} {option1}",
19
- f"{prompt} {option2}",
20
- ]
21
 
 
 
22
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
23
 
24
  with torch.no_grad():
25
- outputs = model(encoded["input_ids"]) # already includes classifier
26
 
27
- logits = outputs.squeeze(-1) # shape: [2]
28
  probs = torch.softmax(logits, dim=0)
29
  best = torch.argmax(probs).item()
30
 
31
- return {
32
- "choice": best,
33
- "confidence": probs[best].item(),
34
- "scores": probs.tolist(),
35
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from evo_model import EvoTransformer
3
+ from transformers import AutoTokenizer
4
 
5
+ # Load tokenizer and model
6
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
8
 
9
+ def load_model(model_path="evo_hellaswag.pt"):
10
  model = EvoTransformer()
11
  model.load_state_dict(torch.load(model_path, map_location=device))
12
  model.to(device)
13
  model.eval()
14
+ return model
15
 
16
+ evo_model = load_model()
 
 
 
 
 
17
 
18
+ def get_evo_response(prompt, option1, option2):
19
+ inputs = [f"{prompt} {option1}", f"{prompt} {option2}"]
20
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
21
 
22
  with torch.no_grad():
23
+ logits = evo_model(encoded["input_ids"]).squeeze(-1)
24
 
 
25
  probs = torch.softmax(logits, dim=0)
26
  best = torch.argmax(probs).item()
27
 
28
+ explanations = [
29
+ f"🅰️ Option 1: {option1}\nConfidence: {probs[0]:.2f}",
30
+ f"🅱️ Option 2: {option2}\nConfidence: {probs[1]:.2f}"
31
+ ]
32
+
33
+ final = f"Evo suggests: Option {best + 1}\n\n{explanations[best]}"
34
+ return final
35
+
36
+ def get_gpt_response(prompt, option1, option2):
37
+ import openai
38
+ import os
39
+ openai.api_key = os.getenv("OPENAI_API_KEY")
40
+
41
+ full_prompt = (
42
+ f"Question: {prompt}\n"
43
+ f"Option 1: {option1}\n"
44
+ f"Option 2: {option2}\n"
45
+ "Which option makes more sense and why?"
46
+ )
47
+
48
+ try:
49
+ response = openai.ChatCompletion.create(
50
+ model="gpt-3.5-turbo",
51
+ messages=[
52
+ {"role": "user", "content": full_prompt}
53
+ ]
54
+ )
55
+ return response.choices[0].message["content"].strip()
56
+ except Exception as e:
57
+ return f"GPT Error: {e}"