HemanM commited on
Commit
6a1a7af
·
verified ·
1 Parent(s): 3363688

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +32 -32
inference.py CHANGED
@@ -1,54 +1,54 @@
1
- import os
2
  import torch
3
- from evo_model import EvoTransformer
 
4
  from transformers import AutoTokenizer
 
5
  from rag_utils import extract_text_from_file
6
  from search_utils import web_search
7
 
8
- # Load Evo model
9
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
- model = EvoTransformer()
11
  model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
12
  model.eval()
13
 
14
- def get_evo_response(query, context=None, enable_search=True):
15
- search_snippets = ""
16
- if enable_search:
17
- snippets = web_search(query)
18
- if snippets:
19
- search_snippets = "\n".join(snippets)
20
-
21
- full_context = f"{context or ''}\n\n{search_snippets}".strip()
22
- input_1 = f"{query} Option 1"
23
- input_2 = f"{query} Option 2"
24
 
25
- inputs = tokenizer([input_1, input_2], padding=True, truncation=True, return_tensors="pt")
 
 
 
26
  with torch.no_grad():
27
- logits = model(inputs["input_ids"]).squeeze(-1)
28
- probs = torch.softmax(logits, dim=0)
29
- best_idx = torch.argmax(probs).item()
 
30
 
31
- suggestion = f"Option {best_idx + 1}"
32
- reasoning = (
33
- f"Evo suggests: **{suggestion}** (Confidence: {probs[best_idx]:.2f})\n\n"
34
- f"Context used:\n{full_context}"
35
- )
36
- return suggestion, reasoning
37
 
38
- def get_gpt_response(query, context=None):
39
- import openai
40
  openai.api_key = os.getenv("OPENAI_API_KEY", "")
41
- context = context or "None"
 
42
 
43
  try:
44
  response = openai.ChatCompletion.create(
45
  model="gpt-3.5-turbo",
46
  messages=[
47
- {"role": "system", "content": "You are a helpful expert advisor."},
48
- {"role": "user", "content": f"Context: {context}\n\nQuestion: {query}"}
49
- ],
50
- max_tokens=250
51
  )
52
- return response["choices"][0]["message"]["content"].strip()
53
  except Exception as e:
54
  return f"⚠️ GPT error: {str(e)}"
 
 
1
  import torch
2
+ import openai
3
+ import os
4
  from transformers import AutoTokenizer
5
+ from evo_model import EvoTransformerV22
6
  from rag_utils import extract_text_from_file
7
  from search_utils import web_search
8
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
+ model = EvoTransformerV22()
11
  model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
12
  model.eval()
13
 
14
+ def format_input(question, options, context, web_results):
15
+ prompt = f"{question}\n"
16
+ if context:
17
+ prompt += f"\nContext:\n{context}\n"
18
+ if web_results:
19
+ prompt += f"\nWeb Search Results:\n" + "\n".join(web_results)
20
+ prompt += "\nOptions:\n"
21
+ for idx, opt in enumerate(options):
22
+ prompt += f"{idx+1}. {opt}\n"
23
+ return prompt.strip()
24
 
25
+ def get_evo_response(question, context, options, enable_search=True):
26
+ web_results = web_search(question) if enable_search else []
27
+ input_text = format_input(question, options, context, web_results)
28
+ encoded = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=256)
29
  with torch.no_grad():
30
+ logits = model(encoded["input_ids"])
31
+ probs = torch.softmax(logits, dim=1).squeeze()
32
+ pred_index = torch.argmax(probs).item()
33
+ confidence = probs[pred_index].item()
34
 
35
+ suggestion = options[pred_index] if pred_index < len(options) else "N/A"
36
+ evo_reasoning = f"Evo suggests: **{suggestion}** (Confidence: {confidence:.2f})\n\nContext used:\n" + "\n".join(web_results)
37
+ return suggestion, evo_reasoning
 
 
 
38
 
39
+ def get_gpt_response(question, context, options):
 
40
  openai.api_key = os.getenv("OPENAI_API_KEY", "")
41
+ formatted_options = "\n".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
42
+ prompt = f"Question: {question}\n\nContext:\n{context}\n\nOptions:\n{formatted_options}\n\nWhich option makes the most sense and why?"
43
 
44
  try:
45
  response = openai.ChatCompletion.create(
46
  model="gpt-3.5-turbo",
47
  messages=[
48
+ {"role": "system", "content": "You are a helpful reasoning assistant."},
49
+ {"role": "user", "content": prompt}
50
+ ]
 
51
  )
52
+ return response['choices'][0]['message']['content']
53
  except Exception as e:
54
  return f"⚠️ GPT error: {str(e)}"