File size: 953 Bytes
4e96bf5
b530936
f02261f
4e96bf5
b530936
f02261f
b530936
4e96bf5
 
f02261f
b530936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e96bf5
b530936
 
 
f02261f
b530936
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
import torch

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()

def generate_response(goal, sol1, sol2):
    input_text = f"Goal: {goal}\nOption A: {sol1}\nOption B: {sol2}"
    
    # Tokenize input
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )
    
    # ✅ Filter out unwanted keys (e.g., token_type_ids)
    filtered_inputs = {
        k: v for k, v in inputs.items()
        if k in ["input_ids", "attention_mask"]
    }

    # Predict
    with torch.no_grad():
        outputs = model(**filtered_inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=1).item()

    return "A" if predicted_class == 0 else "B"