HemanM commited on
Commit
e33deda
Β·
verified Β·
1 Parent(s): 981d63b

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +20 -10
inference.py CHANGED
@@ -27,20 +27,30 @@ def query_gpt35(prompt):
27
  return f"[GPT-3.5 Error] {e}"
28
 
29
  def generate_response(goal, option1, option2):
30
- # Evo prediction
31
- prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better?"
32
- inputs = tokenizer([goal + " " + option1, goal + " " + option2],
33
- return_tensors="pt", padding=True, truncation=True)
34
-
35
- # βœ… Remove token_type_ids if it exists
36
- inputs.pop("token_type_ids", None)
 
 
 
 
37
 
38
  with torch.no_grad():
39
- logits = model(**inputs)
40
- pred = torch.argmax(logits, dim=1).item()
41
- evo_result = option1 if pred == 0 else option2
 
 
 
 
 
42
 
43
  # GPT-3.5 prediction
 
44
  gpt_result = query_gpt35(prompt)
45
 
46
  return {
 
27
  return f"[GPT-3.5 Error] {e}"
28
 
29
  def generate_response(goal, option1, option2):
30
+ # Build inputs for option 1 and option 2
31
+ text1 = goal + " " + option1
32
+ text2 = goal + " " + option2
33
+
34
+ # Tokenize separately
35
+ input1 = tokenizer(text1, return_tensors="pt", padding=True, truncation=True)
36
+ input2 = tokenizer(text2, return_tensors="pt", padding=True, truncation=True)
37
+
38
+ # Remove token_type_ids to avoid forward() issues
39
+ input1.pop("token_type_ids", None)
40
+ input2.pop("token_type_ids", None)
41
 
42
  with torch.no_grad():
43
+ logit1 = model(**input1)
44
+ logit2 = model(**input2)
45
+
46
+ # Get logits[0][0] since we only expect 1 class output vector per input
47
+ score1 = logit1[0][0].item()
48
+ score2 = logit2[0][0].item()
49
+
50
+ evo_result = option1 if score1 > score2 else option2
51
 
52
  # GPT-3.5 prediction
53
+ prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better?"
54
  gpt_result = query_gpt35(prompt)
55
 
56
  return {