File size: 973 Bytes
4e96bf5
b530936
d838202
 
4e96bf5
d838202
 
 
 
f02261f
b530936
4e96bf5
 
d838202
 
 
854864a
d838202
 
854864a
4e96bf5
d838202
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
from init_save import initialize_and_save_model  # Ensure this line is added
import torch

# ✅ Ensure model is initialized and saved BEFORE loading
initialize_and_save_model("trained_model")

# 🔁 Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()

def generate_response(goal, sol1, sol2):
    prompt = f"Goal: {goal}\nOption 1: {sol1}\nOption 2: {sol2}\nWhich is better?"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)

    if 'token_type_ids' in inputs:
        del inputs['token_type_ids']  # Evo doesn't use this

    with torch.no_grad():
        logits = model(**inputs)
    
    predicted = torch.argmax(logits, dim=1).item()
    return f"Option {predicted + 1} seems more reasonable based on EvoTransformer."