EvoTransformer-v2.1 / inference.py
HemanM's picture
Update inference.py
d838202 verified
raw
history blame
973 Bytes
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
from init_save import initialize_and_save_model # Ensure this line is added
import torch
# βœ… Ensure model is initialized and saved BEFORE loading
initialize_and_save_model("trained_model")
# πŸ” Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()
def generate_response(goal, sol1, sol2):
prompt = f"Goal: {goal}\nOption 1: {sol1}\nOption 2: {sol2}\nWhich is better?"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
if 'token_type_ids' in inputs:
del inputs['token_type_ids'] # Evo doesn't use this
with torch.no_grad():
logits = model(**inputs)
predicted = torch.argmax(logits, dim=1).item()
return f"Option {predicted + 1} seems more reasonable based on EvoTransformer."