import torch from model import EvoTransformer from transformers import AutoTokenizer # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # Load model model = EvoTransformer( vocab_size=tokenizer.vocab_size, d_model=256, nhead=4, dim_feedforward=512, num_layers=4 ) model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu"))) model.eval() def predict(goal, sol1, sol2): text = goal + " " + sol1 + " " + sol2 inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=64) with torch.no_grad(): logits = model(inputs["input_ids"]) return "Solution 1" if logits.argmax().item() == 0 else "Solution 2"