HemanM commited on
Commit
63713d5
·
verified ·
1 Parent(s): afd132a

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +14 -10
inference.py CHANGED
@@ -1,33 +1,39 @@
1
  import torch
2
  from transformers import AutoTokenizer
3
  from evo_model import EvoTransformerV22
 
4
  from openai import OpenAI
5
  import os
6
 
7
- # Load Evo model
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  evo_model = EvoTransformerV22()
10
  evo_model.load_state_dict(torch.load("trained_model_evo_hellaswag.pt", map_location=device))
11
  evo_model.to(device)
12
  evo_model.eval()
13
 
14
- # Load tokenizer
15
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
16
 
17
- # 🧠 Evo logic (binary classification with sigmoid)
18
- def get_evo_response(query, context):
19
- combined = query + " " + context
 
 
 
 
20
  inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
21
  input_ids = inputs["input_ids"].to(device)
22
 
 
23
  with torch.no_grad():
24
  logits = evo_model(input_ids)
25
  pred = int(torch.sigmoid(logits).item() > 0.5)
26
 
27
  return f"Evo suggests: Option {pred + 1}"
28
 
29
- # 🤖 GPT-3.5 comparison using openai>=1.0.0
30
- openai_api_key = os.environ.get("OPENAI_API_KEY", "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") # Replace with real key or set via HF secrets
31
  client = OpenAI(api_key=openai_api_key)
32
 
33
  def get_gpt_response(query, context):
@@ -35,9 +41,7 @@ def get_gpt_response(query, context):
35
  prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
36
  response = client.chat.completions.create(
37
  model="gpt-3.5-turbo",
38
- messages=[
39
- {"role": "user", "content": prompt}
40
- ],
41
  temperature=0.3
42
  )
43
  return response.choices[0].message.content.strip()
 
1
  import torch
2
  from transformers import AutoTokenizer
3
  from evo_model import EvoTransformerV22
4
+ from retriever import retrieve
5
  from openai import OpenAI
6
  import os
7
 
8
+ # --- Load Evo Model ---
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  evo_model = EvoTransformerV22()
11
  evo_model.load_state_dict(torch.load("trained_model_evo_hellaswag.pt", map_location=device))
12
  evo_model.to(device)
13
  evo_model.eval()
14
 
15
+ # --- Load Tokenizer ---
16
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
17
 
18
+ # --- EvoRAG Inference ---
19
+ def evo_rag_response(query):
20
+ # Step 1: retrieve document chunks
21
+ rag_context = retrieve(query)
22
+
23
+ # Step 2: combine query with retrieved context
24
+ combined = query + " " + rag_context
25
  inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
26
  input_ids = inputs["input_ids"].to(device)
27
 
28
+ # Step 3: predict using Evo
29
  with torch.no_grad():
30
  logits = evo_model(input_ids)
31
  pred = int(torch.sigmoid(logits).item() > 0.5)
32
 
33
  return f"Evo suggests: Option {pred + 1}"
34
 
35
+ # --- GPT-3.5 Inference (OpenAI >= 1.0.0) ---
36
+ openai_api_key = os.environ.get("OPENAI_API_KEY", "sk-...") # Replace or use HF secret
37
  client = OpenAI(api_key=openai_api_key)
38
 
39
  def get_gpt_response(query, context):
 
41
  prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
42
  response = client.chat.completions.create(
43
  model="gpt-3.5-turbo",
44
+ messages=[{"role": "user", "content": prompt}],
 
 
45
  temperature=0.3
46
  )
47
  return response.choices[0].message.content.strip()