Spaces:
Running
Running
Update inference.py
Browse files- inference.py +13 -4
inference.py
CHANGED
@@ -2,14 +2,23 @@ import torch
|
|
2 |
from model import EvoTransformer
|
3 |
from transformers import AutoTokenizer
|
4 |
|
|
|
5 |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
|
8 |
model.eval()
|
9 |
|
10 |
def predict(goal, sol1, sol2):
|
11 |
text = goal + " " + sol1 + " " + sol2
|
12 |
-
|
13 |
with torch.no_grad():
|
14 |
-
|
15 |
-
return "Solution 1" if
|
|
|
2 |
from model import EvoTransformer
|
3 |
from transformers import AutoTokenizer
|
4 |
|
5 |
+
# Load tokenizer
|
6 |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
7 |
+
|
8 |
+
# Load model
|
9 |
+
model = EvoTransformer(
|
10 |
+
vocab_size=tokenizer.vocab_size,
|
11 |
+
d_model=256,
|
12 |
+
nhead=4,
|
13 |
+
dim_feedforward=512,
|
14 |
+
num_layers=4
|
15 |
+
)
|
16 |
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
|
17 |
model.eval()
|
18 |
|
19 |
def predict(goal, sol1, sol2):
|
20 |
text = goal + " " + sol1 + " " + sol2
|
21 |
+
inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
|
22 |
with torch.no_grad():
|
23 |
+
logits = model(inputs["input_ids"])
|
24 |
+
return "Solution 1" if logits.argmax().item() == 0 else "Solution 2"
|