HemanM commited on
Commit
981d63b
·
verified ·
1 Parent(s): 3e1b974

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -1
inference.py CHANGED
@@ -2,6 +2,7 @@ import openai
2
  from transformers import BertTokenizer
3
  from evo_model import EvoTransformerForClassification
4
  import torch
 
5
 
6
  # Load Evo model
7
  model = EvoTransformerForClassification.from_pretrained("trained_model")
@@ -11,7 +12,6 @@ model.eval()
11
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
12
 
13
  # Set OpenAI key (assumes you have it set as ENV VAR or replace directly)
14
- import os
15
  openai.api_key = os.getenv("OPENAI_API_KEY")
16
 
17
  def query_gpt35(prompt):
@@ -31,6 +31,10 @@ def generate_response(goal, option1, option2):
31
  prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better?"
32
  inputs = tokenizer([goal + " " + option1, goal + " " + option2],
33
  return_tensors="pt", padding=True, truncation=True)
 
 
 
 
34
  with torch.no_grad():
35
  logits = model(**inputs)
36
  pred = torch.argmax(logits, dim=1).item()
 
2
  from transformers import BertTokenizer
3
  from evo_model import EvoTransformerForClassification
4
  import torch
5
+ import os
6
 
7
  # Load Evo model
8
  model = EvoTransformerForClassification.from_pretrained("trained_model")
 
12
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
13
 
14
  # Set OpenAI key (assumes you have it set as ENV VAR or replace directly)
 
15
  openai.api_key = os.getenv("OPENAI_API_KEY")
16
 
17
  def query_gpt35(prompt):
 
31
  prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better?"
32
  inputs = tokenizer([goal + " " + option1, goal + " " + option2],
33
  return_tensors="pt", padding=True, truncation=True)
34
+
35
+ # ✅ Remove token_type_ids if it exists
36
+ inputs.pop("token_type_ids", None)
37
+
38
  with torch.no_grad():
39
  logits = model(**inputs)
40
  pred = torch.argmax(logits, dim=1).item()