HemanM commited on
Commit
11f2d5b
·
verified ·
1 Parent(s): e26e95c

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +36 -45
inference.py CHANGED
@@ -1,62 +1,53 @@
1
  import torch
2
- from transformers import AutoTokenizer
3
  from evo_model import EvoTransformer
 
4
  from rag_utils import extract_text_from_file
5
- from search_utils import web_search_and_format
6
 
7
- # Load Evo model and tokenizer
8
- model_path = "evo_hellaswag.pt"
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model = EvoTransformer()
11
- model.load_state_dict(torch.load(model_path, map_location=device))
12
- model.to(device)
13
  model.eval()
14
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
15
-
16
- def get_evo_response(query, context="", file=None, enable_search=True):
17
- rag_context = ""
18
-
19
- if file is not None:
20
- rag_context += extract_text_from_file(file)
21
 
 
 
22
  if enable_search:
23
- search_context = web_search_and_format(query)
24
- rag_context += "\n" + search_context
 
25
 
26
- full_context = f"{context}\n{rag_context}".strip()
 
 
27
 
28
- # Define hypothetical options (can be more sophisticated later)
29
- option1 = "Yes, take action."
30
- option2 = "No, do not take action."
31
-
32
- inputs = [
33
- f"Q: {query} Context: {full_context} A: {option1}",
34
- f"Q: {query} Context: {full_context} A: {option2}",
35
- ]
36
-
37
- encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
38
  with torch.no_grad():
39
- logits = model(encoded["input_ids"]).squeeze(-1)
40
  probs = torch.softmax(logits, dim=0)
41
- best = torch.argmax(probs).item()
42
 
43
- answer = option1 if best == 0 else option2
44
  reasoning = (
45
- f"Evo suggests: **{answer}**\n\n"
46
- f"🧠 Confidence: {probs[best]:.2f}\n"
47
- f"📖 Context used:\n{full_context[:1000]}..." # limit to 1000 chars
48
  )
 
49
 
50
- return answer, reasoning
51
-
52
- def get_gpt_response(query, context=""):
53
  import openai
54
- openai.api_key = "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" # Make sure to secure this
55
-
56
- prompt = f"Q: {query}\nContext: {context}\nA:"
57
- response = openai.ChatCompletion.create(
58
- model="gpt-3.5-turbo",
59
- messages=[{"role": "user", "content": prompt}],
60
- temperature=0.3
61
- )
62
- return response["choices"][0]["message"]["content"].strip()
 
 
 
 
 
 
 
1
  import torch
 
2
  from evo_model import EvoTransformer
3
+ from transformers import AutoTokenizer
4
  from rag_utils import extract_text_from_file
5
+ from search_utils import web_search
6
 
7
+ # Load Evo model
8
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
9
  model = EvoTransformer()
10
+ model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
 
11
  model.eval()
 
 
 
 
 
 
 
12
 
13
+ def get_evo_response(query, context=None, enable_search=True):
14
+ search_snippets = ""
15
  if enable_search:
16
+ snippets = web_search(query)
17
+ if snippets:
18
+ search_snippets = "\n".join(snippets)
19
 
20
+ full_context = f"{context or ''}\n\n{search_snippets}".strip()
21
+ input_1 = f"{query} Option 1"
22
+ input_2 = f"{query} Option 2"
23
 
24
+ inputs = tokenizer([input_1, input_2], padding=True, truncation=True, return_tensors="pt")
 
 
 
 
 
 
 
 
 
25
  with torch.no_grad():
26
+ logits = model(inputs["input_ids"]).squeeze(-1)
27
  probs = torch.softmax(logits, dim=0)
28
+ best_idx = torch.argmax(probs).item()
29
 
30
+ suggestion = f"Option {best_idx + 1}"
31
  reasoning = (
32
+ f"Evo suggests: **{suggestion}** (Confidence: {probs[best_idx]:.2f})\n\n"
33
+ f"Context used:\n{full_context}"
 
34
  )
35
+ return suggestion, reasoning
36
 
37
+ def get_gpt_response(query, context=None):
 
 
38
  import openai
39
+ openai.api_key = os.getenv("OPENAI_API_KEY", "")
40
+ context = context or "None"
41
+
42
+ try:
43
+ response = openai.ChatCompletion.create(
44
+ model="gpt-3.5-turbo",
45
+ messages=[
46
+ {"role": "system", "content": "You are a helpful expert advisor."},
47
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {query}"}
48
+ ],
49
+ max_tokens=250
50
+ )
51
+ return response["choices"][0]["message"]["content"].strip()
52
+ except Exception as e:
53
+ return f"⚠️ GPT error: {str(e)}"