File size: 743 Bytes
4e96bf5
 
 
 
d17c48e
4e96bf5
d17c48e
 
 
 
 
 
 
 
 
4e96bf5
 
 
 
 
d17c48e
4e96bf5
d17c48e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from model import EvoTransformer
from transformers import AutoTokenizer

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Load model
model = EvoTransformer(
    vocab_size=tokenizer.vocab_size,
    d_model=256,
    nhead=4,
    dim_feedforward=512,
    num_layers=4
)
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
model.eval()

def predict(goal, sol1, sol2):
    text = goal + " " + sol1 + " " + sol2
    inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
    with torch.no_grad():
        logits = model(inputs["input_ids"])
    return "Solution 1" if logits.argmax().item() == 0 else "Solution 2"