HemanM commited on
Commit
785c4f7
·
verified ·
1 Parent(s): 09f0cd3

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +36 -0
inference.py CHANGED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from evo_model import EvoTransformer
3
+
4
+ # Load EvoTransformer model
5
+ def load_model(model_path="evo_hellaswag.pt", device=None):
6
+ if device is None:
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ model = EvoTransformer()
10
+ model.load_state_dict(torch.load(model_path, map_location=device))
11
+ model.to(device)
12
+ model.eval()
13
+ return model, device
14
+
15
+ # Predict the best option (0 or 1)
16
+ def predict(model, tokenizer, prompt, option1, option2, device):
17
+ inputs = [
18
+ f"{prompt} {option1}",
19
+ f"{prompt} {option2}",
20
+ ]
21
+
22
+ encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
23
+
24
+ with torch.no_grad():
25
+ outputs = model(encoded["input_ids"])
26
+
27
+ # Simple linear classifier logic
28
+ logits = torch.nn.functional.linear(outputs, model.classifier.weight, model.classifier.bias)
29
+ probs = torch.softmax(logits, dim=1)
30
+ best = torch.argmax(probs).item()
31
+
32
+ return {
33
+ "choice": best,
34
+ "confidence": probs[0][best].item(),
35
+ "scores": probs[0].tolist(),
36
+ }