File size: 2,819 Bytes
fb120ec
d590322
fb120ec
63713d5
2608adb
d590322
 
fb120ec
63713d5
fb120ec
 
a233ab6
fb120ec
 
 
63713d5
fb120ec
 
21c89fe
63713d5
21c89fe
63713d5
2608adb
 
21c89fe
2608adb
fb120ec
 
 
21c89fe
fb120ec
 
d590322
fb120ec
21c89fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb120ec
21c89fe
9e38e5c
d590322
fb120ec
 
 
 
d590322
fb120ec
63713d5
fb120ec
 
d590322
fb120ec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerV22
from retriever import retrieve
from websearch import web_search
from openai import OpenAI
import os

# --- Load Evo Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evo_model = EvoTransformerV22()
evo_model.load_state_dict(torch.load("trained_model_evo_hellaswag.pt", map_location=device))
evo_model.to(device)
evo_model.eval()

# --- Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# --- EvoRAG+ with Rich Reasoning ---
def evo_rag_response(query):
    # Step 1: Get context from RAG (doc) + web
    rag_context = retrieve(query)
    web_context = web_search(query)

    # Step 2: Combine for inference
    combined = query + "\n\n" + rag_context + "\n\n" + web_context
    inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
    input_ids = inputs["input_ids"].to(device)

    # Step 3: Evo decision
    with torch.no_grad():
        logits = evo_model(input_ids)
        pred = int(torch.sigmoid(logits).item() > 0.5)

    # Step 4: Extract Option Texts if available
    option_text = ""
    if "Option 1:" in query and "Option 2:" in query:
        try:
            opt1 = query.split("Option 1:")[1].split("Option 2:")[0].strip()
            opt2 = query.split("Option 2:")[1].strip()
            option_text = opt1 if pred == 0 else opt2
        except:
            pass

    # Step 5: Format output
    output = f"🧠 Evo suggests: Option {pred + 1}"
    if option_text:
        output += f"\n➡️ {option_text}"

    output += "\n\n📌 Reasoning:\n"
    if rag_context:
        first_line = rag_context.strip().splitlines()[0][:250]
        output += f"- {first_line}...\n"
    else:
        output += "- No document insight available.\n"

    output += "\n📂 Context used:\n" + (rag_context[:400] if rag_context else "[None]")

    output += "\n\n🌐 Web insight:\n" + (web_context[:400] if web_context else "[None]")

    return output

# --- GPT-3.5 (OpenAI >= 1.0.0) ---
openai_api_key = os.environ.get("OPENAI_API_KEY", "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")  # Replace or set via HF secrets
client = OpenAI(api_key=openai_api_key)

def get_gpt_response(query, context):
    try:
        prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"Error from GPT: {e}"