File size: 690 Bytes
4e96bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from model import EvoTransformer
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = EvoTransformer(vocab_size=tokenizer.vocab_size, d_model=256, nhead=4, dim_feedforward=512, num_layers=4)
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
model.eval()

def predict(goal, sol1, sol2):
    text = goal + " " + sol1 + " " + sol2
    tokens = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
    with torch.no_grad():
        output = model(tokens["input_ids"])
    return "Solution 1" if output.argmax().item() == 0 else "Solution 2"