HemanM commited on
Commit
2608adb
·
verified ·
1 Parent(s): 2688ff5

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +10 -6
inference.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from transformers import AutoTokenizer
3
  from evo_model import EvoTransformerV22
4
  from retriever import retrieve
 
5
  from openai import OpenAI
6
  import os
7
 
@@ -15,17 +16,20 @@ evo_model.eval()
15
  # --- Load Tokenizer ---
16
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
17
 
18
- # --- EvoRAG Inference ---
19
  def evo_rag_response(query):
20
- # Step 1: retrieve document chunks
21
  rag_context = retrieve(query)
22
 
23
- # Step 2: combine query with retrieved context
24
- combined = query + " " + rag_context
 
 
 
25
  inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
26
  input_ids = inputs["input_ids"].to(device)
27
 
28
- # Step 3: predict using Evo
29
  with torch.no_grad():
30
  logits = evo_model(input_ids)
31
  pred = int(torch.sigmoid(logits).item() > 0.5)
@@ -33,7 +37,7 @@ def evo_rag_response(query):
33
  return f"Evo suggests: Option {pred + 1}"
34
 
35
  # --- GPT-3.5 Inference (OpenAI >= 1.0.0) ---
36
- openai_api_key = os.environ.get("OPENAI_API_KEY", "sk-...") # Replace or use HF secret
37
  client = OpenAI(api_key=openai_api_key)
38
 
39
  def get_gpt_response(query, context):
 
2
  from transformers import AutoTokenizer
3
  from evo_model import EvoTransformerV22
4
  from retriever import retrieve
5
+ from websearch import web_search
6
  from openai import OpenAI
7
  import os
8
 
 
16
  # --- Load Tokenizer ---
17
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
18
 
19
+ # --- EvoRAG+ Inference ---
20
  def evo_rag_response(query):
21
+ # Step 1: get document context (from uploaded file)
22
  rag_context = retrieve(query)
23
 
24
+ # Step 2: get online info (search/web)
25
+ web_context = web_search(query)
26
+
27
+ # Step 3: combine all into one input
28
+ combined = query + "\n\n" + rag_context + "\n\n" + web_context
29
  inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
30
  input_ids = inputs["input_ids"].to(device)
31
 
32
+ # Step 4: Evo prediction
33
  with torch.no_grad():
34
  logits = evo_model(input_ids)
35
  pred = int(torch.sigmoid(logits).item() > 0.5)
 
37
  return f"Evo suggests: Option {pred + 1}"
38
 
39
  # --- GPT-3.5 Inference (OpenAI >= 1.0.0) ---
40
+ openai_api_key = os.environ.get("OPENAI_API_KEY", "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") # Replace or use HF secret
41
  client = OpenAI(api_key=openai_api_key)
42
 
43
  def get_gpt_response(query, context):