Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +10 -6
inference.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
from transformers import AutoTokenizer
|
3 |
from evo_model import EvoTransformerV22
|
4 |
from retriever import retrieve
|
|
|
5 |
from openai import OpenAI
|
6 |
import os
|
7 |
|
@@ -15,17 +16,20 @@ evo_model.eval()
|
|
15 |
# --- Load Tokenizer ---
|
16 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
17 |
|
18 |
-
# --- EvoRAG Inference ---
|
19 |
def evo_rag_response(query):
|
20 |
-
# Step 1:
|
21 |
rag_context = retrieve(query)
|
22 |
|
23 |
-
# Step 2:
|
24 |
-
|
|
|
|
|
|
|
25 |
inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
26 |
input_ids = inputs["input_ids"].to(device)
|
27 |
|
28 |
-
# Step
|
29 |
with torch.no_grad():
|
30 |
logits = evo_model(input_ids)
|
31 |
pred = int(torch.sigmoid(logits).item() > 0.5)
|
@@ -33,7 +37,7 @@ def evo_rag_response(query):
|
|
33 |
return f"Evo suggests: Option {pred + 1}"
|
34 |
|
35 |
# --- GPT-3.5 Inference (OpenAI >= 1.0.0) ---
|
36 |
-
openai_api_key = os.environ.get("OPENAI_API_KEY", "sk
|
37 |
client = OpenAI(api_key=openai_api_key)
|
38 |
|
39 |
def get_gpt_response(query, context):
|
|
|
2 |
from transformers import AutoTokenizer
|
3 |
from evo_model import EvoTransformerV22
|
4 |
from retriever import retrieve
|
5 |
+
from websearch import web_search
|
6 |
from openai import OpenAI
|
7 |
import os
|
8 |
|
|
|
16 |
# --- Load Tokenizer ---
|
17 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
18 |
|
19 |
+
# --- EvoRAG+ Inference ---
|
20 |
def evo_rag_response(query):
|
21 |
+
# Step 1: get document context (from uploaded file)
|
22 |
rag_context = retrieve(query)
|
23 |
|
24 |
+
# Step 2: get online info (search/web)
|
25 |
+
web_context = web_search(query)
|
26 |
+
|
27 |
+
# Step 3: combine all into one input
|
28 |
+
combined = query + "\n\n" + rag_context + "\n\n" + web_context
|
29 |
inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
|
30 |
input_ids = inputs["input_ids"].to(device)
|
31 |
|
32 |
+
# Step 4: Evo prediction
|
33 |
with torch.no_grad():
|
34 |
logits = evo_model(input_ids)
|
35 |
pred = int(torch.sigmoid(logits).item() > 0.5)
|
|
|
37 |
return f"Evo suggests: Option {pred + 1}"
|
38 |
|
39 |
# --- GPT-3.5 Inference (OpenAI >= 1.0.0) ---
|
40 |
+
openai_api_key = os.environ.get("OPENAI_API_KEY", "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") # Replace or use HF secret
|
41 |
client = OpenAI(api_key=openai_api_key)
|
42 |
|
43 |
def get_gpt_response(query, context):
|