HemanM commited on
Commit
7d3dbed
·
verified ·
1 Parent(s): 2182155

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +77 -35
inference.py CHANGED
@@ -1,36 +1,78 @@
1
- import os
2
- import faiss
3
  import torch
4
- from transformers import AutoTokenizer, AutoModel
5
- from sentence_transformers import SentenceTransformer
6
- from PyPDF2 import PdfReader
7
-
8
- class RAGRetriever:
9
- def __init__(self):
10
- self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
11
- self.index = faiss.IndexFlatL2(384)
12
- self.contexts = []
13
- self.ids = []
14
-
15
- def add_document(self, text):
16
- sentences = text.split("\n")
17
- clean_sentences = [s.strip() for s in sentences if s.strip()]
18
- embeddings = self.encoder.encode(clean_sentences)
19
- self.index.add(embeddings)
20
- self.contexts.extend(clean_sentences)
21
-
22
- def retrieve(self, query, top_k=3):
23
- q_vec = self.encoder.encode([query])
24
- D, I = self.index.search(q_vec, top_k)
25
- return [self.contexts[i] for i in I[0]]
26
-
27
- def extract_text_from_file(file_path):
28
- ext = os.path.splitext(file_path)[-1].lower()
29
- if ext == ".txt":
30
- with open(file_path, "r", encoding="utf-8") as f:
31
- return f.read()
32
- elif ext == ".pdf":
33
- reader = PdfReader(file_path)
34
- return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
35
- else:
36
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from evo_model import EvoTransformer
3
+ from transformers import AutoTokenizer, pipeline
4
+ from rag_utils import RAGRetriever, extract_text_from_file
5
+ import os
6
+
7
+ # Load Evo model
8
+ def load_evo_model(model_path="evo_hellaswag.pt", device=None):
9
+ if device is None:
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ model = EvoTransformer()
13
+ model.load_state_dict(torch.load(model_path, map_location=device))
14
+ model.to(device)
15
+ model.eval()
16
+ return model, device
17
+
18
+ evo_model, device = load_evo_model()
19
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
20
+
21
+ # Load GPT-3.5 (via OpenAI API)
22
+ import openai
23
+ openai.api_key = os.getenv("OPENAI_API_KEY")
24
+
25
+ # RAG Retriever
26
+ retriever = RAGRetriever()
27
+
28
+ def get_context_from_file(file_obj):
29
+ file_path = file_obj.name
30
+ text = extract_text_from_file(file_path)
31
+ retriever.add_document(text)
32
+ return text
33
+
34
+ # Evo prediction
35
+ def get_evo_response(prompt, file=None):
36
+ # Step 1: augment context if document is uploaded
37
+ context = ""
38
+ if file is not None:
39
+ context_list = retriever.retrieve(prompt)
40
+ context = "\n".join(context_list)
41
+
42
+ full_prompt = f"{prompt}\n{context}"
43
+
44
+ # Step 2: use Evo to predict
45
+ options = ["Yes, proceed with the action.", "No, maintain current strategy."]
46
+ inputs = [f"{full_prompt} {opt}" for opt in options]
47
+
48
+ encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
49
+
50
+ with torch.no_grad():
51
+ logits = evo_model(encoded["input_ids"]).squeeze(-1)
52
+ probs = torch.softmax(logits, dim=0)
53
+ best = torch.argmax(probs).item()
54
+
55
+ return f"Evo suggests: {options[best]} (Confidence: {probs[best]:.2f})"
56
+
57
+ # GPT-3.5 response
58
+ def get_gpt_response(prompt, file=None):
59
+ context = ""
60
+ if file is not None:
61
+ context_list = retriever.retrieve(prompt)
62
+ context = "\n".join(context_list)
63
+
64
+ full_prompt = (
65
+ f"Question: {prompt}\n"
66
+ f"Relevant Context:\n{context}\n"
67
+ f"Answer like a financial advisor."
68
+ )
69
+
70
+ response = openai.ChatCompletion.create(
71
+ model="gpt-3.5-turbo",
72
+ messages=[
73
+ {"role": "user", "content": full_prompt}
74
+ ],
75
+ temperature=0.4,
76
+ )
77
+
78
+ return response.choices[0].message.content.strip()