EvoTransformer-v2.1 / inference.py
HemanM's picture
Update inference.py
f02261f verified
raw
history blame
728 Bytes
from evo_model import EvoTransformerForClassification
from transformers import AutoTokenizer
import torch
# Load model and tokenizer once
model = EvoTransformerForClassification.from_pretrained("trained_model")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model.eval()
def generate_response(goal, sol1, sol2):
prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
prediction = torch.argmax(logits, dim=1).item()
return "Solution 1" if prediction == 0 else "Solution 2"