File size: 1,021 Bytes
785c4f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b652a8
cdcb82a
2b652a8
f87535f
785c4f7
 
 
 
cdcb82a
 
785c4f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from evo_model import EvoTransformer

# Load EvoTransformer model
def load_model(model_path="evo_hellaswag.pt", device=None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model = EvoTransformer()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model, device

# Predict the best option (0 or 1)
def predict(model, tokenizer, prompt, option1, option2, device):
    inputs = [
        f"{prompt} {option1}",
        f"{prompt} {option2}",
    ]

    encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(encoded["input_ids"])  # already includes classifier

    logits = outputs.squeeze(-1)  # shape: [2]
    probs = torch.softmax(logits, dim=0)
    best = torch.argmax(probs).item()

    return {
        "choice": best,
        "confidence": probs[best].item(),
        "scores": probs.tolist(),
    }