File size: 1,687 Bytes
fb120ec
d590322
fb120ec
63713d5
d590322
 
fb120ec
63713d5
fb120ec
 
a233ab6
fb120ec
 
 
63713d5
fb120ec
 
63713d5
 
 
 
 
 
 
fb120ec
 
 
63713d5
fb120ec
 
d590322
fb120ec
d590322
fb120ec
63713d5
 
d590322
fb120ec
 
 
 
d590322
fb120ec
63713d5
fb120ec
 
d590322
fb120ec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerV22
from retriever import retrieve
from openai import OpenAI
import os

# --- Load Evo Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evo_model = EvoTransformerV22()
evo_model.load_state_dict(torch.load("trained_model_evo_hellaswag.pt", map_location=device))
evo_model.to(device)
evo_model.eval()

# --- Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# --- EvoRAG Inference ---
def evo_rag_response(query):
    # Step 1: retrieve document chunks
    rag_context = retrieve(query)

    # Step 2: combine query with retrieved context
    combined = query + " " + rag_context
    inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
    input_ids = inputs["input_ids"].to(device)

    # Step 3: predict using Evo
    with torch.no_grad():
        logits = evo_model(input_ids)
        pred = int(torch.sigmoid(logits).item() > 0.5)

    return f"Evo suggests: Option {pred + 1}"

# --- GPT-3.5 Inference (OpenAI >= 1.0.0) ---
openai_api_key = os.environ.get("OPENAI_API_KEY", "sk-...")  # Replace or use HF secret
client = OpenAI(api_key=openai_api_key)

def get_gpt_response(query, context):
    try:
        prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"Error from GPT: {e}"