File size: 1,989 Bytes
fb120ec
d590322
fb120ec
63713d5
2608adb
d590322
 
fb120ec
63713d5
fb120ec
 
a233ab6
fb120ec
 
 
63713d5
fb120ec
 
2608adb
63713d5
2608adb
63713d5
 
2608adb
 
 
 
 
fb120ec
 
 
2608adb
fb120ec
 
d590322
fb120ec
d590322
fb120ec
63713d5
2608adb
d590322
fb120ec
 
 
 
d590322
fb120ec
63713d5
fb120ec
 
d590322
fb120ec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerV22
from retriever import retrieve
from websearch import web_search
from openai import OpenAI
import os

# --- Load Evo Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evo_model = EvoTransformerV22()
evo_model.load_state_dict(torch.load("trained_model_evo_hellaswag.pt", map_location=device))
evo_model.to(device)
evo_model.eval()

# --- Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# --- EvoRAG+ Inference ---
def evo_rag_response(query):
    # Step 1: get document context (from uploaded file)
    rag_context = retrieve(query)

    # Step 2: get online info (search/web)
    web_context = web_search(query)

    # Step 3: combine all into one input
    combined = query + "\n\n" + rag_context + "\n\n" + web_context
    inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
    input_ids = inputs["input_ids"].to(device)

    # Step 4: Evo prediction
    with torch.no_grad():
        logits = evo_model(input_ids)
        pred = int(torch.sigmoid(logits).item() > 0.5)

    return f"Evo suggests: Option {pred + 1}"

# --- GPT-3.5 Inference (OpenAI >= 1.0.0) ---
openai_api_key = os.environ.get("OPENAI_API_KEY", "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")  # Replace or use HF secret
client = OpenAI(api_key=openai_api_key)

def get_gpt_response(query, context):
    try:
        prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"Error from GPT: {e}"