HemanM commited on
Commit
f02261f
·
verified ·
1 Parent(s): c615588

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +13 -19
inference.py CHANGED
@@ -1,24 +1,18 @@
1
- import torch
2
- from model import EvoTransformer
3
  from transformers import AutoTokenizer
 
4
 
5
- # Load tokenizer
6
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
7
-
8
- # Load model
9
- model = EvoTransformer(
10
- vocab_size=tokenizer.vocab_size,
11
- d_model=256,
12
- nhead=4,
13
- dim_feedforward=512,
14
- num_layers=4
15
- )
16
- model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device("cpu")))
17
  model.eval()
18
 
19
- def predict(goal, sol1, sol2):
20
- text = goal + " " + sol1 + " " + sol2
21
- inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
22
  with torch.no_grad():
23
- logits = model(inputs["input_ids"])
24
- return "Solution 1" if logits.argmax().item() == 0 else "Solution 2"
 
 
 
 
1
+ from evo_model import EvoTransformerForClassification
 
2
  from transformers import AutoTokenizer
3
+ import torch
4
 
5
+ # Load model and tokenizer once
6
+ model = EvoTransformerForClassification.from_pretrained("trained_model")
7
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
 
 
 
 
 
 
 
 
8
  model.eval()
9
 
10
+ def generate_response(goal, sol1, sol2):
11
+ prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
12
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
13
  with torch.no_grad():
14
+ outputs = model(**inputs)
15
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
16
+ prediction = torch.argmax(logits, dim=1).item()
17
+ return "Solution 1" if prediction == 0 else "Solution 2"
18
+