File size: 1,713 Bytes
785c4f7
 
e7d2e38
785c4f7
e7d2e38
 
 
785c4f7
e7d2e38
785c4f7
 
 
 
e7d2e38
785c4f7
e7d2e38
785c4f7
e7d2e38
 
785c4f7
 
 
e7d2e38
cdcb82a
f87535f
785c4f7
 
e7d2e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from evo_model import EvoTransformer
from transformers import AutoTokenizer

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(model_path="evo_hellaswag.pt"):
    model = EvoTransformer()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

evo_model = load_model()

def get_evo_response(prompt, option1, option2):
    inputs = [f"{prompt} {option1}", f"{prompt} {option2}"]
    encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)

    with torch.no_grad():
        logits = evo_model(encoded["input_ids"]).squeeze(-1)

    probs = torch.softmax(logits, dim=0)
    best = torch.argmax(probs).item()

    explanations = [
        f"🅰️ Option 1: {option1}\nConfidence: {probs[0]:.2f}",
        f"🅱️ Option 2: {option2}\nConfidence: {probs[1]:.2f}"
    ]

    final = f"Evo suggests: Option {best + 1}\n\n{explanations[best]}"
    return final

def get_gpt_response(prompt, option1, option2):
    import openai
    import os
    openai.api_key = os.getenv("OPENAI_API_KEY")

    full_prompt = (
        f"Question: {prompt}\n"
        f"Option 1: {option1}\n"
        f"Option 2: {option2}\n"
        "Which option makes more sense and why?"
    )

    try:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": full_prompt}
            ]
        )
        return response.choices[0].message["content"].strip()
    except Exception as e:
        return f"GPT Error: {e}"