File size: 728 Bytes
f02261f
4e96bf5
f02261f
4e96bf5
f02261f
 
 
4e96bf5
 
f02261f
 
 
4e96bf5
f02261f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from evo_model import EvoTransformerForClassification
from transformers import AutoTokenizer
import torch

# Load model and tokenizer once
model = EvoTransformerForClassification.from_pretrained("trained_model")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model.eval()

def generate_response(goal, sol1, sol2):
    prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
    prediction = torch.argmax(logits, dim=1).item()
    return "Solution 1" if prediction == 0 else "Solution 2"