HemanM commited on
Commit
21afb35
·
verified ·
1 Parent(s): a3b5a39

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +39 -56
inference.py CHANGED
@@ -1,80 +1,63 @@
1
  import torch
2
  from evo_model import EvoTransformer
3
- from transformers import AutoTokenizer, pipeline
4
- from rag_utils import RAGRetriever, extract_text_from_file
5
- import os
6
 
7
- # Load Evo model
8
- def load_evo_model(model_path="evo_hellaswag.pt", device=None):
 
9
  if device is None:
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
  model = EvoTransformer()
13
  model.load_state_dict(torch.load(model_path, map_location=device))
14
  model.to(device)
15
  model.eval()
16
  return model, device
17
 
18
- evo_model, device = load_evo_model()
19
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
20
 
21
- # Load GPT-3.5 (via OpenAI API)
22
- import openai
23
- openai.api_key = os.getenv("OPENAI_API_KEY")
24
 
25
- # RAG Retriever
26
- retriever = RAGRetriever()
 
 
 
27
 
28
- def get_context_from_file(file_obj):
29
- file_path = file_obj.name
30
- text = extract_text_from_file(file_path)
31
- retriever.add_document(text)
32
- return text
33
 
34
- # Evo prediction
35
- def get_evo_response(prompt, file=None):
36
- # Step 1: augment context if document is uploaded
37
- context = ""
38
- if file is not None:
39
- context_list = retriever.retrieve(prompt)
40
- context = "\n".join(context_list)
41
 
42
- full_prompt = f"{prompt}\n{context}"
43
-
44
- # Step 2: use Evo to predict
45
- options = ["Yes, proceed with the action.", "No, maintain current strategy."]
46
- inputs = [f"{full_prompt} {opt}" for opt in options]
47
 
48
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
49
-
50
  with torch.no_grad():
51
- logits = evo_model(encoded["input_ids"]).squeeze(-1)
52
- probs = torch.softmax(logits, dim=0)
53
  best = torch.argmax(probs).item()
54
 
55
- return f"Evo suggests: {options[best]} (Confidence: {probs[best]:.2f})"
56
-
57
- # GPT-3.5 response
58
- def get_gpt_response(prompt, file=None):
59
- context = ""
60
- if file is not None:
61
- context_list = retriever.retrieve(prompt)
62
- context = "\n".join(context_list)
63
-
64
- full_prompt = (
65
- f"Question: {prompt}\n"
66
- f"Relevant Context:\n{context}\n"
67
- f"Answer like a financial advisor."
68
- )
69
 
70
- response = openai.ChatCompletion.create(
71
- model="gpt-3.5-turbo",
72
- messages=[
73
- {"role": "user", "content": full_prompt}
74
- ],
75
- temperature=0.4,
76
- )
77
 
78
- return response.choices[0].message.content.strip()
79
 
80
- #
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from evo_model import EvoTransformer
3
+ from transformers import AutoTokenizer
4
+ from rag_utils import extract_text_from_file
5
+ from search_utils import web_search
6
 
7
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
8
+
9
+ def load_model(model_path="evo_hellaswag.pt", device=None):
10
  if device is None:
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
12
  model = EvoTransformer()
13
  model.load_state_dict(torch.load(model_path, map_location=device))
14
  model.to(device)
15
  model.eval()
16
  return model, device
17
 
18
+ evo_model, device = load_model()
 
19
 
20
+ def get_evo_response(query, file=None, enable_search=True):
21
+ context = ""
 
22
 
23
+ if file:
24
+ try:
25
+ context += extract_text_from_file(file)[:800]
26
+ except:
27
+ pass
28
 
29
+ if enable_search:
30
+ search_snippets = web_search(query)
31
+ context += "\n".join(search_snippets)
 
 
32
 
33
+ combined_prompt = f"{query}\nContext:\n{context}"
 
 
 
 
 
 
34
 
35
+ inputs = [
36
+ f"{combined_prompt} Option 1:",
37
+ f"{combined_prompt} Option 2:",
38
+ ]
 
39
 
40
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
 
41
  with torch.no_grad():
42
+ outputs = evo_model(encoded["input_ids"]).squeeze(-1)
43
+ probs = torch.softmax(outputs, dim=0)
44
  best = torch.argmax(probs).item()
45
 
46
+ return f"Option {best + 1} with {probs[best]:.2%} confidence.\n\nReasoning:\n{inputs[best]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ def get_gpt_response(query, context=""):
49
+ import openai
50
+ openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
51
 
52
+ prompt = f"{query}\nContext:\n{context}\nGive a thoughtful recommendation with reasons."
53
 
54
+ try:
55
+ response = openai.ChatCompletion.create(
56
+ model="gpt-3.5-turbo",
57
+ messages=[{"role": "user", "content": prompt}],
58
+ max_tokens=300,
59
+ temperature=0.7,
60
+ )
61
+ return response.choices[0].message.content.strip()
62
+ except Exception as e:
63
+ return f"Error: {str(e)}"