Niveytha27 commited on
Commit
693c47c
·
verified ·
1 Parent(s): aa94d18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -200
app.py CHANGED
@@ -1,200 +1,224 @@
1
- import requests
2
- import io
3
- import re
4
- import numpy as np
5
- import faiss
6
- import torch
7
- import streamlit as st
8
- from pypdf import PdfReader
9
- from rank_bm25 import BM25Okapi
10
- from sentence_transformers import SentenceTransformer
11
- from accelerate import Accelerator
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
13
- from bert_score import score
14
-
15
- st.title("Financial Document Q&A Chatbot")
16
-
17
- @st.cache_resource
18
- def load_models():
19
- embedding_model = SentenceTransformer("BAAI/bge-large-en")
20
- accelerator = Accelerator()
21
- MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
24
- model = accelerator.prepare(model)
25
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
26
- return embedding_model, tokenizer, generator, accelerator
27
-
28
- embedding_model, tokenizer, generator, accelerator = load_models()
29
-
30
- def download_pdf(url):
31
- try:
32
- response = requests.get(url, stream=True)
33
- response.raise_for_status()
34
- return response.content
35
- except requests.exceptions.RequestException as e:
36
- st.error(f"Error downloading PDF from {url}: {e}")
37
- return None
38
-
39
- def extract_text_from_pdf(pdf_bytes):
40
- try:
41
- pdf_file = io.BytesIO(pdf_bytes)
42
- reader = PdfReader(pdf_file)
43
- text = ""
44
- for page in reader.pages:
45
- text += page.extract_text() or ""
46
- return text
47
- except Exception as e:
48
- st.error(f"Error extracting text from PDF: {e}")
49
- return None
50
-
51
- def preprocess_text(text):
52
- """Cleans text while retaining financial symbols and ensuring proper formatting."""
53
- if not text:
54
- return ""
55
-
56
- # Define allowed financial symbols
57
- financial_symbols = r"\$\€\₹\£\¥\₩\₽\₮\₦\₲"
58
-
59
- # Allow numbers, letters, spaces, financial symbols, common punctuation (.,%/-)
60
- text = re.sub(fr"[^\w\s{financial_symbols}.,%/₹$€¥£-]", "", text)
61
-
62
- # Normalize spaces
63
- text = re.sub(r'\s+', ' ', text).strip()
64
-
65
- return text
66
-
67
- @st.cache_resource
68
- def load_and_index_data(pdf_urls):
69
- all_data = []
70
- for url in pdf_urls:
71
- pdf_bytes = download_pdf(url)
72
- if pdf_bytes:
73
- text = extract_text_from_pdf(pdf_bytes)
74
- if text:
75
- preprocessed_text = preprocess_text(text)
76
- all_data.append(preprocessed_text)
77
-
78
- def chunk_text(text, chunk_size=700, overlap_size=150):
79
- chunks = []
80
- start = 0
81
- text_length = len(text)
82
- while start < text_length:
83
- end = min(start + chunk_size, text_length)
84
- if end < text_length and text[end].isalnum():
85
- last_space = text.rfind(" ", start, end)
86
- if last_space != -1:
87
- end = last_space
88
- chunk = text[start:end].strip()
89
- if chunk:
90
- chunks.append(chunk)
91
- if end == text_length:
92
- break
93
- overlap_start = max(0, end - overlap_size)
94
- if overlap_start < end:
95
- last_overlap_space = text.rfind(" ", 0, overlap_start)
96
- if last_overlap_space != -1 and last_overlap_space > start:
97
- start = last_overlap_space + 1
98
- else:
99
- start = end
100
- else:
101
- start = end
102
- return chunks
103
-
104
- chunks = []
105
- for data in all_data:
106
- chunks.extend(chunk_text(data))
107
-
108
- embeddings = embedding_model.encode(chunks)
109
- index = faiss.IndexFlatL2(embeddings.shape[1])
110
- index.add(embeddings)
111
- return index, chunks
112
-
113
- def bm25_retrieval(query, documents, top_k=3):
114
- tokenized_docs = [doc.split() for doc in documents]
115
- bm25 = BM25Okapi(tokenized_docs)
116
- return [documents[i] for i in np.argsort(bm25.get_scores(query.split()))[::-1][:top_k]]
117
-
118
- def adaptive_retrieval(query, index, chunks, top_k=3, bm25_weight=0.5):
119
- query_embedding = embedding_model.encode([query], convert_to_numpy=True, dtype=np.float16)
120
- _, indices = index.search(query_embedding, top_k)
121
- vector_results = [chunks[i] for i in indices[0]]
122
- bm25_results = bm25_retrieval(query, chunks, top_k)
123
- return list(set(vector_results + bm25_results))
124
-
125
- def rerank(query, results):
126
- query_embedding = embedding_model.encode([query], convert_to_numpy=True)
127
- result_embeddings = embedding_model.encode(results, convert_to_numpy=True)
128
- similarities = np.dot(result_embeddings, query_embedding.T).flatten()
129
- return [results[i] for i in np.argsort(similarities)[::-1]], similarities
130
-
131
- def merge_chunks(retrieved_chunks, overlap_size=100):
132
- merged_chunks = []
133
- buffer = retrieved_chunks[0] if retrieved_chunks else ""
134
- for i in range(1, len(retrieved_chunks)):
135
- chunk = retrieved_chunks[i]
136
- overlap_start = buffer[-overlap_size:]
137
- overlap_index = chunk.find(overlap_start)
138
- if overlap_index != -1:
139
- buffer += chunk[overlap_index + overlap_size:]
140
- else:
141
- merged_chunks.append(buffer)
142
- buffer = chunk
143
- if buffer:
144
- merged_chunks.append(buffer)
145
- return merged_chunks
146
-
147
- def calculate_confidence(query, answer):
148
- P, R, F1 = score([answer], [query], lang="en", verbose=False)
149
- return F1.item()
150
-
151
- def generate_response(query, context):
152
- prompt = f"""Your task is to analyze the given Context and answer the Question concisely in plain English.
153
- **Guidelines:**
154
- - Do NOT include </think> tag, just provide the final answer only.
155
- - Provide a direct, factual answer based strictly on the Context.
156
- - Avoid generating Python code, solutions, or any irrelevant information.
157
- Context: {context}
158
- Question: {query}
159
- Answer:
160
- """
161
- response = generator(prompt, max_new_tokens=150, num_return_sequences=1)[0]['generated_text']
162
- answer = response.split("Answer:")[1].strip()
163
- return answer
164
-
165
- if "messages" not in st.session_state:
166
- st.session_state.messages = []
167
-
168
- for message in st.session_state.messages:
169
- with st.chat_message(message["role"]):
170
- st.markdown(message["content"])
171
-
172
- pdf_urls = st.text_area("Enter PDF URLs (one per line):", "")
173
- pdf_urls = [url.strip() for url in pdf_urls.split("\n") if url.strip()]
174
-
175
- if st.button("Load and Index PDFs"):
176
- with st.spinner("Loading and indexing PDFs..."):
177
- index, chunks = load_and_index_data(pdf_urls)
178
- st.session_state.index = index
179
- st.session_state.chunks = chunks
180
- st.success("PDFs loaded and indexed successfully.")
181
-
182
- if "index" in st.session_state and "chunks" in st.session_state:
183
- if prompt := st.chat_input("Enter your financial question:"):
184
- st.session_state.messages.append({"role": "user", "content": prompt})
185
- with st.chat_message("user"):
186
- st.markdown(prompt)
187
-
188
- with st.chat_message("assistant"):
189
- message_placeholder = st.empty()
190
- retrieved_chunks = adaptive_retrieval(prompt, st.session_state.index, st.session_state.chunks)
191
- merged_chunks = merge_chunks(retrieved_chunks, 150)
192
- reranked_chunks, similarities = rerank(prompt, merged_chunks)
193
- context = " ".join(reranked_chunks[:3])
194
- answer = generate_response(prompt, context)
195
- confidence = calculate_confidence(prompt, answer)
196
- full_response = f"{answer}\n\nConfidence: {confidence:.2f}"
197
- message_placeholder.markdown(full_response)
198
- st.session_state.messages.append({"role": "assistant", "content": full_response})
199
-
200
- accelerator.free_memory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import io
3
+ import re
4
+ import numpy as np
5
+ import faiss
6
+ import torch
7
+ import time
8
+ import streamlit as st
9
+ from pypdf import PdfReader
10
+ from rank_bm25 import BM25Okapi
11
+ from sentence_transformers import SentenceTransformer
12
+ from accelerate import Accelerator
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
+ from bert_score import score
15
+
16
+ def download_pdf(url):
17
+ """Downloads a PDF from a URL and returns its content as bytes."""
18
+ try:
19
+ response = requests.get(url, stream=True)
20
+ response.raise_for_status()
21
+ return response.content
22
+ except requests.exceptions.RequestException as e:
23
+ st.error(f"Error downloading PDF from {url}: {e}")
24
+ return None
25
+
26
+ def extract_text_from_pdf(pdf_bytes):
27
+ """Extracts text from a PDF byte stream."""
28
+ try:
29
+ pdf_file = io.BytesIO(pdf_bytes)
30
+ reader = PdfReader(pdf_file)
31
+ text = ""
32
+ for page in reader.pages:
33
+ text += page.extract_text() or "" #Handle None return.
34
+ return text
35
+ except Exception as e:
36
+ st.error(f"Error extracting text from PDF: {e}")
37
+ return None
38
+
39
+ def preprocess_text(text):
40
+ """Cleans text while retaining financial symbols and ensuring proper formatting."""
41
+ if not text:
42
+ return ""
43
+
44
+ # Define allowed financial symbols
45
+ financial_symbols = r"\$\€\₹\£\¥\₩\₽\₮\₦\₲"
46
+
47
+ # Allow numbers, letters, spaces, financial symbols, common punctuation (.,%/-)
48
+ text = re.sub(fr"[^\w\s{financial_symbols}.,%/₹$€¥£-]", "", text)
49
+
50
+ # Normalize spaces
51
+ text = re.sub(r'\s+', ' ', text).strip()
52
+
53
+ return text
54
+
55
+ def load_financial_pdfs(pdf_urls):
56
+ """Downloads and extracts text from a list of PDF URLs."""
57
+ all_data = []
58
+ for url in pdf_urls:
59
+ pdf_bytes = download_pdf(url)
60
+ if pdf_bytes:
61
+ text = extract_text_from_pdf(pdf_bytes)
62
+ if text:
63
+ preprocessed_text = preprocess_text(text)
64
+ all_data.append(preprocessed_text)
65
+ return all_data
66
+
67
+ # Example Usage (Replace with actual PDF URLs)
68
+ pdf_urls = [
69
+ "https://www.latentview.com/wp-content/uploads/2023/07/LatentView-Annual-Report-2022-23.pdf",
70
+ "https://www.latentview.com/wp-content/uploads/2024/08/LatentView-Annual-Report-2023-24.pdf",
71
+ ]
72
+
73
+ all_data = load_financial_pdfs(pdf_urls)
74
+
75
+ def chunk_text(text, chunk_size=700, overlap_size=150):
76
+ """Chunks text without breaking words in the middle (corrected overlap)."""
77
+ chunks = []
78
+ start = 0
79
+ text_length = len(text)
80
+
81
+ while start < text_length:
82
+ end = min(start + chunk_size, text_length)
83
+
84
+ # Ensure we do not split words
85
+ if end < text_length and text[end].isalnum():
86
+ last_space = text.rfind(" ", start, end) # Find last space within the chunk
87
+ if last_space != -1: # If a space is found, adjust the end
88
+ end = last_space
89
+
90
+ chunk = text[start:end].strip()
91
+ if chunk: # Avoid empty chunks
92
+ chunks.append(chunk)
93
+
94
+ if end == text_length:
95
+ break
96
+
97
+ # Corrected overlap calculation
98
+ overlap_start = max(0, end - overlap_size)
99
+ if overlap_start < end: # Prevent infinite loop if overlap_start is equal to end.
100
+ last_overlap_space = text.rfind(" ", 0, overlap_start)
101
+ if last_overlap_space != -1 and last_overlap_space > start:
102
+ start = last_overlap_space + 1
103
+ else:
104
+ start = end # If no space found, start at the last end.
105
+ else:
106
+ start = end
107
+
108
+ return chunks
109
+
110
+ chunks = []
111
+ for data in all_data:
112
+ chunks.extend(chunk_text(data))
113
+
114
+ embedding_model = SentenceTransformer("BAAI/bge-large-en")
115
+ # embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
116
+ embeddings = embedding_model.encode(chunks)
117
+
118
+ index = faiss.IndexFlatL2(embeddings.shape[1])
119
+ index.add(embeddings)
120
+
121
+
122
+ def bm25_retrieval(query, documents, top_k=3):
123
+ tokenized_docs = [doc.split() for doc in documents]
124
+ bm25 = BM25Okapi(tokenized_docs)
125
+ return [documents[i] for i in np.argsort(bm25.get_scores(query.split()))[::-1][:top_k]]
126
+
127
+ def adaptive_retrieval(query, index, chunks, top_k=3, bm25_weight=0.5):
128
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True, dtype=np.float16)
129
+ _, indices = index.search(query_embedding, top_k)
130
+ vector_results = [chunks[i] for i in indices[0]]
131
+ bm25_results = bm25_retrieval(query, chunks, top_k)
132
+ return list(set(vector_results + bm25_results))
133
+
134
+ def rerank(query, results):
135
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True)
136
+ result_embeddings = embedding_model.encode(results, convert_to_numpy=True)
137
+ similarities = np.dot(result_embeddings, query_embedding.T).flatten()
138
+ return [results[i] for i in np.argsort(similarities)[::-1]], similarities
139
+
140
+ #Chunk merging.
141
+ def merge_chunks(retrieved_chunks, overlap_size=100):
142
+ """Merges overlapping chunks properly by detecting the actual overlap."""
143
+ merged_chunks = []
144
+ buffer = retrieved_chunks[0] if retrieved_chunks else ""
145
+
146
+ for i in range(1, len(retrieved_chunks)):
147
+ chunk = retrieved_chunks[i]
148
+
149
+ # Find actual overlap
150
+ overlap_start = buffer[-overlap_size:] # Get the last `overlap_size` chars of the previous chunk
151
+ overlap_index = chunk.find(overlap_start) # Find where this part appears in the new chunk
152
+
153
+ if overlap_index != -1:
154
+ # Merge only the non-overlapping part
155
+ buffer += chunk[overlap_index + overlap_size:]
156
+ else:
157
+ # Store completed merged chunk and start a new one
158
+ merged_chunks.append(buffer)
159
+ buffer = chunk
160
+
161
+ if buffer:
162
+ merged_chunks.append(buffer)
163
+
164
+ return merged_chunks
165
+
166
+ # def calculate_confidence(query, context, similarities):
167
+ # return np.mean(similarities) # Averaged similarity scores
168
+ def calculate_confidence(query, answer):
169
+ P, R, F1 = score([answer], [query], lang="en", verbose=False)
170
+ return F1.item()
171
+
172
+ # Load SLM
173
+ accelerator = Accelerator()
174
+ accelerator.free_memory()
175
+ MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
176
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
177
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", cache_dir="./my_models")
178
+ model = accelerator.prepare(model)
179
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
180
+
181
+ def generate_response(query, context):
182
+ prompt = f"""Your task is to analyze the given Context and answer the Question concisely in plain English.
183
+ **Guidelines:**
184
+ - Do NOT include </think> tag, just provide the final answer only.
185
+ - Provide a direct, factual answer based strictly on the Context.
186
+ - Avoid generating Python code, solutions, or any irrelevant information.
187
+
188
+ Context: {context}
189
+ Question: {query}
190
+ Answer:
191
+ """
192
+ response = generator(prompt, max_new_tokens=150, num_return_sequences=1)[0]['generated_text'] #example 100 max new tokens
193
+ print(response)
194
+ answer = response.split("Answer:")[1].strip()
195
+ return answer
196
+
197
+ import gradio as gr
198
+
199
+ # Your existing functions should be defined before using them
200
+ # adaptive_retrieval, merge_chunks, rerank, generate_response, calculate_confidence
201
+
202
+ def inference_pipeline(query):
203
+ retrieved_chunks = adaptive_retrieval(query, index, chunks)
204
+ merged_chunks = merge_chunks(retrieved_chunks, 150)
205
+ reranked_chunks, similarities = rerank(query, merged_chunks)
206
+ context = " ".join(reranked_chunks[:3]) # Take top 3 most relevant
207
+ response = generate_response(query, context)
208
+ confidence = calculate_confidence(query, context, similarities)
209
+ return response, confidence
210
+
211
+ # Define the Gradio UI
212
+ ui = gr.Interface(
213
+ fn=inference_pipeline,
214
+ inputs=gr.Textbox(label="Enter your financial question"),
215
+ outputs=[
216
+ gr.Textbox(label="Generated Response"),
217
+ gr.Number(label="Confidence Score"),
218
+ ],
219
+ title="Financial Q&A Assistant",
220
+ description="Ask financial questions and get AI-powered responses with confidence scores.",
221
+ )
222
+
223
+ # Launch the UI
224
+ ui.launch(share=True) # share=True allows public access