File size: 1,803 Bytes
a457e2e
3bfaa31
a457e2e
b530936
d838202
981d63b
4e96bf5
a457e2e
b530936
4e96bf5
 
a457e2e
 
 
3bfaa31
854864a
a457e2e
 
3bfaa31
a457e2e
 
 
 
 
3bfaa31
a457e2e
 
854864a
a457e2e
e33deda
 
 
 
 
 
 
 
 
 
 
981d63b
4e96bf5
e33deda
 
 
 
 
 
 
 
a457e2e
 
e33deda
a457e2e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import openai
from openai import OpenAI
from transformers import BertTokenizer
from evo_model import EvoTransformerForClassification
import torch
import os

# Load Evo model
model = EvoTransformerForClassification.from_pretrained("trained_model")
model.eval()

# Tokenizer (BERT-compatible)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def query_gpt35(prompt):
    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=50,
            temperature=0.3,
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"[GPT-3.5 Error] {e}"

def generate_response(goal, option1, option2):
    # Build inputs for option 1 and option 2
    text1 = goal + " " + option1
    text2 = goal + " " + option2

    # Tokenize separately
    input1 = tokenizer(text1, return_tensors="pt", padding=True, truncation=True)
    input2 = tokenizer(text2, return_tensors="pt", padding=True, truncation=True)

    # Remove token_type_ids to avoid forward() issues
    input1.pop("token_type_ids", None)
    input2.pop("token_type_ids", None)

    with torch.no_grad():
        logit1 = model(**input1)
        logit2 = model(**input2)

    # Get logits[0][0] since we only expect 1 class output vector per input
    score1 = logit1[0][0].item()
    score2 = logit2[0][0].item()

    evo_result = option1 if score1 > score2 else option2

    # GPT-3.5 prediction
    prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better?"
    gpt_result = query_gpt35(prompt)

    return {
        "evo_suggestion": evo_result,
        "gpt_suggestion": gpt_result
    }