HemanM commited on
Commit
4e96bf5
·
verified ·
1 Parent(s): f82b09c

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -0
inference.py CHANGED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import EvoTransformer
3
+ from transformers import AutoTokenizer
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
6
+ model = EvoTransformer(vocab_size=tokenizer.vocab_size, d_model=256, nhead=4, dim_feedforward=512, num_layers=4)
7
+ model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
8
+ model.eval()
9
+
10
+ def predict(goal, sol1, sol2):
11
+ text = goal + " " + sol1 + " " + sol2
12
+ tokens = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
13
+ with torch.no_grad():
14
+ output = model(tokens["input_ids"])
15
+ return "Solution 1" if output.argmax().item() == 0 else "Solution 2"