HemanM commited on
Commit
d17c48e
·
verified ·
1 Parent(s): 1ce7a2d

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +13 -4
inference.py CHANGED
@@ -2,14 +2,23 @@ import torch
2
  from model import EvoTransformer
3
  from transformers import AutoTokenizer
4
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
6
- model = EvoTransformer(vocab_size=tokenizer.vocab_size, d_model=256, nhead=4, dim_feedforward=512, num_layers=4)
 
 
 
 
 
 
 
 
7
  model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
8
  model.eval()
9
 
10
  def predict(goal, sol1, sol2):
11
  text = goal + " " + sol1 + " " + sol2
12
- tokens = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
13
  with torch.no_grad():
14
- output = model(tokens["input_ids"])
15
- return "Solution 1" if output.argmax().item() == 0 else "Solution 2"
 
2
  from model import EvoTransformer
3
  from transformers import AutoTokenizer
4
 
5
+ # Load tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
7
+
8
+ # Load model
9
+ model = EvoTransformer(
10
+ vocab_size=tokenizer.vocab_size,
11
+ d_model=256,
12
+ nhead=4,
13
+ dim_feedforward=512,
14
+ num_layers=4
15
+ )
16
  model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
17
  model.eval()
18
 
19
  def predict(goal, sol1, sol2):
20
  text = goal + " " + sol1 + " " + sol2
21
+ inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
22
  with torch.no_grad():
23
+ logits = model(inputs["input_ids"])
24
+ return "Solution 1" if logits.argmax().item() == 0 else "Solution 2"