Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +40 -18
inference.py
CHANGED
@@ -1,35 +1,57 @@
|
|
1 |
import torch
|
2 |
from evo_model import EvoTransformer
|
|
|
3 |
|
4 |
-
# Load
|
5 |
-
|
6 |
-
|
7 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
|
|
|
9 |
model = EvoTransformer()
|
10 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
11 |
model.to(device)
|
12 |
model.eval()
|
13 |
-
return model
|
14 |
|
15 |
-
|
16 |
-
def predict(model, tokenizer, prompt, option1, option2, device):
|
17 |
-
inputs = [
|
18 |
-
f"{prompt} {option1}",
|
19 |
-
f"{prompt} {option2}",
|
20 |
-
]
|
21 |
|
|
|
|
|
22 |
encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
|
23 |
|
24 |
with torch.no_grad():
|
25 |
-
|
26 |
|
27 |
-
logits = outputs.squeeze(-1) # shape: [2]
|
28 |
probs = torch.softmax(logits, dim=0)
|
29 |
best = torch.argmax(probs).item()
|
30 |
|
31 |
-
|
32 |
-
"
|
33 |
-
"
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from evo_model import EvoTransformer
|
3 |
+
from transformers import AutoTokenizer
|
4 |
|
5 |
+
# Load tokenizer and model
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
7 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
8 |
|
9 |
+
def load_model(model_path="evo_hellaswag.pt"):
|
10 |
model = EvoTransformer()
|
11 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
12 |
model.to(device)
|
13 |
model.eval()
|
14 |
+
return model
|
15 |
|
16 |
+
evo_model = load_model()
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
def get_evo_response(prompt, option1, option2):
|
19 |
+
inputs = [f"{prompt} {option1}", f"{prompt} {option2}"]
|
20 |
encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
|
21 |
|
22 |
with torch.no_grad():
|
23 |
+
logits = evo_model(encoded["input_ids"]).squeeze(-1)
|
24 |
|
|
|
25 |
probs = torch.softmax(logits, dim=0)
|
26 |
best = torch.argmax(probs).item()
|
27 |
|
28 |
+
explanations = [
|
29 |
+
f"🅰️ Option 1: {option1}\nConfidence: {probs[0]:.2f}",
|
30 |
+
f"🅱️ Option 2: {option2}\nConfidence: {probs[1]:.2f}"
|
31 |
+
]
|
32 |
+
|
33 |
+
final = f"Evo suggests: Option {best + 1}\n\n{explanations[best]}"
|
34 |
+
return final
|
35 |
+
|
36 |
+
def get_gpt_response(prompt, option1, option2):
|
37 |
+
import openai
|
38 |
+
import os
|
39 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
40 |
+
|
41 |
+
full_prompt = (
|
42 |
+
f"Question: {prompt}\n"
|
43 |
+
f"Option 1: {option1}\n"
|
44 |
+
f"Option 2: {option2}\n"
|
45 |
+
"Which option makes more sense and why?"
|
46 |
+
)
|
47 |
+
|
48 |
+
try:
|
49 |
+
response = openai.ChatCompletion.create(
|
50 |
+
model="gpt-3.5-turbo",
|
51 |
+
messages=[
|
52 |
+
{"role": "user", "content": full_prompt}
|
53 |
+
]
|
54 |
+
)
|
55 |
+
return response.choices[0].message["content"].strip()
|
56 |
+
except Exception as e:
|
57 |
+
return f"GPT Error: {e}"
|