EvoTransformer-v2.1 / inference.py
HemanM's picture
Update inference.py
a457e2e verified
raw
history blame
1.44 kB
import openai
from transformers import BertTokenizer
from evo_model import EvoTransformerForClassification
import torch
# Load Evo model
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
# Tokenizer (BERT-compatible)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Set OpenAI key (assumes you have it set as ENV VAR or replace directly)
import os
openai.api_key = os.getenv("OPENAI_API_KEY")
def query_gpt35(prompt):
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
max_tokens=50,
temperature=0.3,
)
return response['choices'][0]['message']['content'].strip()
except Exception as e:
return f"[GPT-3.5 Error] {e}"
def generate_response(goal, option1, option2):
# Evo prediction
prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better?"
inputs = tokenizer([goal + " " + option1, goal + " " + option2],
return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
logits = model(**inputs)
pred = torch.argmax(logits, dim=1).item()
evo_result = option1 if pred == 0 else option2
# GPT-3.5 prediction
gpt_result = query_gpt35(prompt)
return {
"evo_suggestion": evo_result,
"gpt_suggestion": gpt_result
}