HemanM commited on
Commit
453eba8
·
verified ·
1 Parent(s): ffe1458

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +42 -43
inference.py CHANGED
@@ -1,63 +1,62 @@
1
  import torch
2
- from evo_model import EvoTransformer
3
  from transformers import AutoTokenizer
 
4
  from rag_utils import extract_text_from_file
5
- from search_utils import web_search
6
-
 
 
 
 
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
8
 
9
- def load_model(model_path="evo_hellaswag.pt", device=None):
10
- if device is None:
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model = EvoTransformer()
13
- model.load_state_dict(torch.load(model_path, map_location=device))
14
- model.to(device)
15
- model.eval()
16
- return model, device
17
-
18
- evo_model, device = load_model()
19
 
20
- def get_evo_response(query, file=None, enable_search=True):
21
- context = ""
22
-
23
- if file:
24
- try:
25
- context += extract_text_from_file(file)[:800]
26
- except:
27
- pass
28
 
29
  if enable_search:
30
- search_snippets = web_search(query)
31
- context += "\n".join(search_snippets)
 
 
32
 
33
- combined_prompt = f"{query}\nContext:\n{context}"
 
 
34
 
35
  inputs = [
36
- f"{combined_prompt} Option 1:",
37
- f"{combined_prompt} Option 2:",
38
  ]
39
 
40
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
41
  with torch.no_grad():
42
- outputs = evo_model(encoded["input_ids"]).squeeze(-1)
43
- probs = torch.softmax(outputs, dim=0)
44
  best = torch.argmax(probs).item()
45
 
46
- return f"Option {best + 1} with {probs[best]:.2%} confidence.\n\nReasoning:\n{inputs[best]}"
 
 
 
 
 
 
 
47
 
48
  def get_gpt_response(query, context=""):
49
  import openai
50
- openai.api_key = os.getenv("OPENAI_API_KEY")
51
-
52
- prompt = f"{query}\nContext:\n{context}\nGive a thoughtful recommendation with reasons."
53
-
54
- try:
55
- response = openai.ChatCompletion.create(
56
- model="gpt-3.5-turbo",
57
- messages=[{"role": "user", "content": prompt}],
58
- max_tokens=300,
59
- temperature=0.7,
60
- )
61
- return response.choices[0].message.content.strip()
62
- except Exception as e:
63
- return f"Error: {str(e)}"
 
1
  import torch
 
2
  from transformers import AutoTokenizer
3
+ from evo_model import EvoTransformer
4
  from rag_utils import extract_text_from_file
5
+ from search_utils import web_search_and_format
6
+
7
+ # Load Evo model and tokenizer
8
+ model_path = "evo_hellaswag.pt"
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model = EvoTransformer()
11
+ model.load_state_dict(torch.load(model_path, map_location=device))
12
+ model.to(device)
13
+ model.eval()
14
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
15
 
16
+ def get_evo_response(query, context="", file=None, enable_search=True):
17
+ rag_context = ""
 
 
 
 
 
 
 
 
18
 
19
+ if file is not None:
20
+ rag_context += extract_text_from_file(file)
 
 
 
 
 
 
21
 
22
  if enable_search:
23
+ search_context = web_search_and_format(query)
24
+ rag_context += "\n" + search_context
25
+
26
+ full_context = f"{context}\n{rag_context}".strip()
27
 
28
+ # Define hypothetical options (can be more sophisticated later)
29
+ option1 = "Yes, take action."
30
+ option2 = "No, do not take action."
31
 
32
  inputs = [
33
+ f"Q: {query} Context: {full_context} A: {option1}",
34
+ f"Q: {query} Context: {full_context} A: {option2}",
35
  ]
36
 
37
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
38
  with torch.no_grad():
39
+ logits = model(encoded["input_ids"]).squeeze(-1)
40
+ probs = torch.softmax(logits, dim=0)
41
  best = torch.argmax(probs).item()
42
 
43
+ answer = option1 if best == 0 else option2
44
+ reasoning = (
45
+ f"✅ Evo suggests: **{answer}**\n\n"
46
+ f"🧠 Confidence: {probs[best]:.2f}\n"
47
+ f"📖 Context used:\n{full_context[:1000]}..." # limit to 1000 chars
48
+ )
49
+
50
+ return answer, reasoning
51
 
52
  def get_gpt_response(query, context=""):
53
  import openai
54
+ openai.api_key = "sk-..." # Make sure to secure this
55
+
56
+ prompt = f"Q: {query}\nContext: {context}\nA:"
57
+ response = openai.ChatCompletion.create(
58
+ model="gpt-3.5-turbo",
59
+ messages=[{"role": "user", "content": prompt}],
60
+ temperature=0.3
61
+ )
62
+ return response["choices"][0]["message"]["content"].strip()