EvoTransformer-v2.1 / inference.py
HemanM's picture
Update inference.py
d17c48e verified
raw
history blame
743 Bytes
import torch
from model import EvoTransformer
from transformers import AutoTokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# Load model
model = EvoTransformer(
vocab_size=tokenizer.vocab_size,
d_model=256,
nhead=4,
dim_feedforward=512,
num_layers=4
)
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
model.eval()
def predict(goal, sol1, sol2):
text = goal + " " + sol1 + " " + sol2
inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
with torch.no_grad():
logits = model(inputs["input_ids"])
return "Solution 1" if logits.argmax().item() == 0 else "Solution 2"