Spaces:
Runtime error
Runtime error
Update inference.py
Browse files- inference.py +20 -10
inference.py
CHANGED
@@ -27,20 +27,30 @@ def query_gpt35(prompt):
|
|
27 |
return f"[GPT-3.5 Error] {e}"
|
28 |
|
29 |
def generate_response(goal, option1, option2):
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
|
38 |
with torch.no_grad():
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
# GPT-3.5 prediction
|
|
|
44 |
gpt_result = query_gpt35(prompt)
|
45 |
|
46 |
return {
|
|
|
27 |
return f"[GPT-3.5 Error] {e}"
|
28 |
|
29 |
def generate_response(goal, option1, option2):
|
30 |
+
# Build inputs for option 1 and option 2
|
31 |
+
text1 = goal + " " + option1
|
32 |
+
text2 = goal + " " + option2
|
33 |
+
|
34 |
+
# Tokenize separately
|
35 |
+
input1 = tokenizer(text1, return_tensors="pt", padding=True, truncation=True)
|
36 |
+
input2 = tokenizer(text2, return_tensors="pt", padding=True, truncation=True)
|
37 |
+
|
38 |
+
# Remove token_type_ids to avoid forward() issues
|
39 |
+
input1.pop("token_type_ids", None)
|
40 |
+
input2.pop("token_type_ids", None)
|
41 |
|
42 |
with torch.no_grad():
|
43 |
+
logit1 = model(**input1)
|
44 |
+
logit2 = model(**input2)
|
45 |
+
|
46 |
+
# Get logits[0][0] since we only expect 1 class output vector per input
|
47 |
+
score1 = logit1[0][0].item()
|
48 |
+
score2 = logit2[0][0].item()
|
49 |
+
|
50 |
+
evo_result = option1 if score1 > score2 else option2
|
51 |
|
52 |
# GPT-3.5 prediction
|
53 |
+
prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better?"
|
54 |
gpt_result = query_gpt35(prompt)
|
55 |
|
56 |
return {
|