EvoTransformer-v2.1 / inference.py
HemanM's picture
Update inference.py
b530936 verified
raw
history blame
953 Bytes
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
import torch
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
def generate_response(goal, sol1, sol2):
input_text = f"Goal: {goal}\nOption A: {sol1}\nOption B: {sol2}"
# Tokenize input
inputs = tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
)
# ✅ Filter out unwanted keys (e.g., token_type_ids)
filtered_inputs = {
k: v for k, v in inputs.items()
if k in ["input_ids", "attention_mask"]
}
# Predict
with torch.no_grad():
outputs = model(**filtered_inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
return "A" if predicted_class == 0 else "B"