Spaces:
Sleeping
Sleeping
File size: 2,819 Bytes
fb120ec d590322 fb120ec 63713d5 2608adb d590322 fb120ec 63713d5 fb120ec a233ab6 fb120ec 63713d5 fb120ec 21c89fe 63713d5 21c89fe 63713d5 2608adb 21c89fe 2608adb fb120ec 21c89fe fb120ec d590322 fb120ec 21c89fe fb120ec 21c89fe 9e38e5c d590322 fb120ec d590322 fb120ec 63713d5 fb120ec d590322 fb120ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerV22
from retriever import retrieve
from websearch import web_search
from openai import OpenAI
import os
# --- Load Evo Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evo_model = EvoTransformerV22()
evo_model.load_state_dict(torch.load("trained_model_evo_hellaswag.pt", map_location=device))
evo_model.to(device)
evo_model.eval()
# --- Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# --- EvoRAG+ with Rich Reasoning ---
def evo_rag_response(query):
# Step 1: Get context from RAG (doc) + web
rag_context = retrieve(query)
web_context = web_search(query)
# Step 2: Combine for inference
combined = query + "\n\n" + rag_context + "\n\n" + web_context
inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
input_ids = inputs["input_ids"].to(device)
# Step 3: Evo decision
with torch.no_grad():
logits = evo_model(input_ids)
pred = int(torch.sigmoid(logits).item() > 0.5)
# Step 4: Extract Option Texts if available
option_text = ""
if "Option 1:" in query and "Option 2:" in query:
try:
opt1 = query.split("Option 1:")[1].split("Option 2:")[0].strip()
opt2 = query.split("Option 2:")[1].strip()
option_text = opt1 if pred == 0 else opt2
except:
pass
# Step 5: Format output
output = f"🧠 Evo suggests: Option {pred + 1}"
if option_text:
output += f"\n➡️ {option_text}"
output += "\n\n📌 Reasoning:\n"
if rag_context:
first_line = rag_context.strip().splitlines()[0][:250]
output += f"- {first_line}...\n"
else:
output += "- No document insight available.\n"
output += "\n📂 Context used:\n" + (rag_context[:400] if rag_context else "[None]")
output += "\n\n🌐 Web insight:\n" + (web_context[:400] if web_context else "[None]")
return output
# --- GPT-3.5 (OpenAI >= 1.0.0) ---
openai_api_key = os.environ.get("OPENAI_API_KEY", "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") # Replace or set via HF secrets
client = OpenAI(api_key=openai_api_key)
def get_gpt_response(query, context):
try:
prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error from GPT: {e}"
|