Niveytha27 commited on
Commit
aa94d18
·
verified ·
1 Parent(s): 0ead15e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import io
3
+ import re
4
+ import numpy as np
5
+ import faiss
6
+ import torch
7
+ import streamlit as st
8
+ from pypdf import PdfReader
9
+ from rank_bm25 import BM25Okapi
10
+ from sentence_transformers import SentenceTransformer
11
+ from accelerate import Accelerator
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
13
+ from bert_score import score
14
+
15
+ st.title("Financial Document Q&A Chatbot")
16
+
17
+ @st.cache_resource
18
+ def load_models():
19
+ embedding_model = SentenceTransformer("BAAI/bge-large-en")
20
+ accelerator = Accelerator()
21
+ MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
24
+ model = accelerator.prepare(model)
25
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
26
+ return embedding_model, tokenizer, generator, accelerator
27
+
28
+ embedding_model, tokenizer, generator, accelerator = load_models()
29
+
30
+ def download_pdf(url):
31
+ try:
32
+ response = requests.get(url, stream=True)
33
+ response.raise_for_status()
34
+ return response.content
35
+ except requests.exceptions.RequestException as e:
36
+ st.error(f"Error downloading PDF from {url}: {e}")
37
+ return None
38
+
39
+ def extract_text_from_pdf(pdf_bytes):
40
+ try:
41
+ pdf_file = io.BytesIO(pdf_bytes)
42
+ reader = PdfReader(pdf_file)
43
+ text = ""
44
+ for page in reader.pages:
45
+ text += page.extract_text() or ""
46
+ return text
47
+ except Exception as e:
48
+ st.error(f"Error extracting text from PDF: {e}")
49
+ return None
50
+
51
+ def preprocess_text(text):
52
+ """Cleans text while retaining financial symbols and ensuring proper formatting."""
53
+ if not text:
54
+ return ""
55
+
56
+ # Define allowed financial symbols
57
+ financial_symbols = r"\$\€\₹\£\¥\₩\₽\₮\₦\₲"
58
+
59
+ # Allow numbers, letters, spaces, financial symbols, common punctuation (.,%/-)
60
+ text = re.sub(fr"[^\w\s{financial_symbols}.,%/₹$€¥£-]", "", text)
61
+
62
+ # Normalize spaces
63
+ text = re.sub(r'\s+', ' ', text).strip()
64
+
65
+ return text
66
+
67
+ @st.cache_resource
68
+ def load_and_index_data(pdf_urls):
69
+ all_data = []
70
+ for url in pdf_urls:
71
+ pdf_bytes = download_pdf(url)
72
+ if pdf_bytes:
73
+ text = extract_text_from_pdf(pdf_bytes)
74
+ if text:
75
+ preprocessed_text = preprocess_text(text)
76
+ all_data.append(preprocessed_text)
77
+
78
+ def chunk_text(text, chunk_size=700, overlap_size=150):
79
+ chunks = []
80
+ start = 0
81
+ text_length = len(text)
82
+ while start < text_length:
83
+ end = min(start + chunk_size, text_length)
84
+ if end < text_length and text[end].isalnum():
85
+ last_space = text.rfind(" ", start, end)
86
+ if last_space != -1:
87
+ end = last_space
88
+ chunk = text[start:end].strip()
89
+ if chunk:
90
+ chunks.append(chunk)
91
+ if end == text_length:
92
+ break
93
+ overlap_start = max(0, end - overlap_size)
94
+ if overlap_start < end:
95
+ last_overlap_space = text.rfind(" ", 0, overlap_start)
96
+ if last_overlap_space != -1 and last_overlap_space > start:
97
+ start = last_overlap_space + 1
98
+ else:
99
+ start = end
100
+ else:
101
+ start = end
102
+ return chunks
103
+
104
+ chunks = []
105
+ for data in all_data:
106
+ chunks.extend(chunk_text(data))
107
+
108
+ embeddings = embedding_model.encode(chunks)
109
+ index = faiss.IndexFlatL2(embeddings.shape[1])
110
+ index.add(embeddings)
111
+ return index, chunks
112
+
113
+ def bm25_retrieval(query, documents, top_k=3):
114
+ tokenized_docs = [doc.split() for doc in documents]
115
+ bm25 = BM25Okapi(tokenized_docs)
116
+ return [documents[i] for i in np.argsort(bm25.get_scores(query.split()))[::-1][:top_k]]
117
+
118
+ def adaptive_retrieval(query, index, chunks, top_k=3, bm25_weight=0.5):
119
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True, dtype=np.float16)
120
+ _, indices = index.search(query_embedding, top_k)
121
+ vector_results = [chunks[i] for i in indices[0]]
122
+ bm25_results = bm25_retrieval(query, chunks, top_k)
123
+ return list(set(vector_results + bm25_results))
124
+
125
+ def rerank(query, results):
126
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True)
127
+ result_embeddings = embedding_model.encode(results, convert_to_numpy=True)
128
+ similarities = np.dot(result_embeddings, query_embedding.T).flatten()
129
+ return [results[i] for i in np.argsort(similarities)[::-1]], similarities
130
+
131
+ def merge_chunks(retrieved_chunks, overlap_size=100):
132
+ merged_chunks = []
133
+ buffer = retrieved_chunks[0] if retrieved_chunks else ""
134
+ for i in range(1, len(retrieved_chunks)):
135
+ chunk = retrieved_chunks[i]
136
+ overlap_start = buffer[-overlap_size:]
137
+ overlap_index = chunk.find(overlap_start)
138
+ if overlap_index != -1:
139
+ buffer += chunk[overlap_index + overlap_size:]
140
+ else:
141
+ merged_chunks.append(buffer)
142
+ buffer = chunk
143
+ if buffer:
144
+ merged_chunks.append(buffer)
145
+ return merged_chunks
146
+
147
+ def calculate_confidence(query, answer):
148
+ P, R, F1 = score([answer], [query], lang="en", verbose=False)
149
+ return F1.item()
150
+
151
+ def generate_response(query, context):
152
+ prompt = f"""Your task is to analyze the given Context and answer the Question concisely in plain English.
153
+ **Guidelines:**
154
+ - Do NOT include </think> tag, just provide the final answer only.
155
+ - Provide a direct, factual answer based strictly on the Context.
156
+ - Avoid generating Python code, solutions, or any irrelevant information.
157
+ Context: {context}
158
+ Question: {query}
159
+ Answer:
160
+ """
161
+ response = generator(prompt, max_new_tokens=150, num_return_sequences=1)[0]['generated_text']
162
+ answer = response.split("Answer:")[1].strip()
163
+ return answer
164
+
165
+ if "messages" not in st.session_state:
166
+ st.session_state.messages = []
167
+
168
+ for message in st.session_state.messages:
169
+ with st.chat_message(message["role"]):
170
+ st.markdown(message["content"])
171
+
172
+ pdf_urls = st.text_area("Enter PDF URLs (one per line):", "")
173
+ pdf_urls = [url.strip() for url in pdf_urls.split("\n") if url.strip()]
174
+
175
+ if st.button("Load and Index PDFs"):
176
+ with st.spinner("Loading and indexing PDFs..."):
177
+ index, chunks = load_and_index_data(pdf_urls)
178
+ st.session_state.index = index
179
+ st.session_state.chunks = chunks
180
+ st.success("PDFs loaded and indexed successfully.")
181
+
182
+ if "index" in st.session_state and "chunks" in st.session_state:
183
+ if prompt := st.chat_input("Enter your financial question:"):
184
+ st.session_state.messages.append({"role": "user", "content": prompt})
185
+ with st.chat_message("user"):
186
+ st.markdown(prompt)
187
+
188
+ with st.chat_message("assistant"):
189
+ message_placeholder = st.empty()
190
+ retrieved_chunks = adaptive_retrieval(prompt, st.session_state.index, st.session_state.chunks)
191
+ merged_chunks = merge_chunks(retrieved_chunks, 150)
192
+ reranked_chunks, similarities = rerank(prompt, merged_chunks)
193
+ context = " ".join(reranked_chunks[:3])
194
+ answer = generate_response(prompt, context)
195
+ confidence = calculate_confidence(prompt, answer)
196
+ full_response = f"{answer}\n\nConfidence: {confidence:.2f}"
197
+ message_placeholder.markdown(full_response)
198
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
199
+
200
+ accelerator.free_memory()