masadonline commited on
Commit
c832d1c
Β·
verified Β·
1 Parent(s): e0b47c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -103
app.py CHANGED
@@ -1,62 +1,123 @@
1
- import streamlit as st
2
- from PyPDF2 import PdfReader
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from langchain_community.embeddings import HuggingFaceEmbeddings
5
- from langchain.vectorstores import FAISS
6
- import pandas as pd
7
  import os
8
- import io
 
 
 
 
 
 
 
 
 
9
  import requests
 
 
 
 
 
 
 
10
 
11
- # --- 1. Data Loading and Preprocessing ---
12
 
13
- @st.cache_data()
14
- def load_and_process_pdfs_from_folder(docs_folder="docs"):
15
- """Loads and processes all PDF files from the specified folder."""
16
- all_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  all_tables = []
18
- for filename in os.listdir(docs_folder):
19
- if filename.endswith(".pdf"):
20
- filepath = os.path.join(docs_folder, filename)
21
- try:
22
- with open(filepath, 'rb') as file:
23
- pdf_reader = PdfReader(file)
24
- for page in pdf_reader.pages:
25
- all_text += page.extract_text() + "\n"
26
- try:
27
- for table in page.extract_tables():
28
- df = pd.DataFrame(table)
29
- all_tables.append(df)
30
- except Exception as e:
31
- print(f"Could not extract tables from page in {filename}. Error: {e}")
32
- except Exception as e:
33
- st.error(f"Error reading PDF {filename}: {e}")
34
- return all_text, all_tables
35
-
36
- @st.cache_data()
37
- def split_text_into_chunks(text):
38
- """Splits the text into smaller, manageable chunks."""
39
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
40
- chunks = text_splitter.split_text(text)
41
- return chunks
 
 
 
 
42
 
43
- @st.cache_data()
44
- def create_vectorstore(chunks):
45
- """Creates a vectorstore from the text chunks using HuggingFace embeddings."""
46
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
47
- vectorstore = FAISS.from_texts(chunks, embeddings)
48
- return vectorstore
49
 
50
- # --- 2. Question Answering with Groq ---
 
 
 
 
 
 
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def generate_answer_with_groq(question, context):
53
- """Generates an answer using the Groq API."""
54
  url = "https://api.groq.com/openai/v1/chat/completions"
55
  api_key = os.environ.get("GROQ_API_KEY")
56
- if not api_key:
57
- st.error("GROQ_API_KEY environment variable not found. Please set it.")
58
- return None # Indicate failure
59
-
60
  headers = {
61
  "Authorization": f"Bearer {api_key}",
62
  "Content-Type": "application/json",
@@ -82,59 +143,153 @@ def generate_answer_with_groq(question, context):
82
  "temperature": 0.5,
83
  "max_tokens": 300,
84
  }
 
 
 
 
 
 
85
  try:
86
- response = requests.post(url, headers=headers, json=payload)
87
- response.raise_for_status() # Raise an exception for bad status codes
88
- return response.json()['choices'][0]['message']['content'].strip()
89
- except requests.exceptions.RequestException as e:
90
- st.error(f"Error communicating with Groq API: {e}")
91
- return "An error occurred while trying to get the answer."
92
-
93
- def perform_rag_groq(vectorstore, query):
94
- """Performs retrieval and generates an answer using Groq."""
95
- retriever = vectorstore.as_retriever()
96
- relevant_docs = retriever.get_relevant_documents(query)
97
- context = "\n\n".join([doc.page_content for doc in relevant_docs])
98
- answer = generate_answer_with_groq(query, context)
99
- return {"answer": answer, "sources": [doc.metadata['source'] for doc in relevant_docs] if relevant_docs else []}
100
-
101
- # --- 3. Streamlit UI ---
102
-
103
- def main():
104
- st.title("PDF Q&A with Local Docs (Powered by Groq)")
105
- st.info("Make sure you have a 'docs' folder in the same directory as this script containing your PDF files.")
106
-
107
- with st.spinner("Loading and processing PDF(s)..."):
108
- all_text, all_tables = load_and_process_pdfs_from_folder()
109
-
110
- if all_text:
111
- with st.spinner("Creating knowledge base..."):
112
- chunks = split_text_into_chunks(all_text)
113
- # We need to add metadata (source) to the chunks for accurate source tracking
114
- metadatas = [{"source": f"doc_{i+1}"} for i in range(len(chunks))] # Basic source tracking
115
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
116
- vectorstore = FAISS.from_texts(chunks, embeddings, metadatas=metadatas)
117
-
118
- query = st.text_input("Ask a question about the documents:")
119
- if query:
120
- with st.spinner("Searching for answer..."):
121
- result = perform_rag_groq(vectorstore, query)
122
- if result and result.get("answer"):
123
- st.subheader("Answer:")
124
- st.write(result["answer"])
125
- if "sources" in result and result["sources"]:
126
- st.subheader("Source:")
127
- st.write(", ".join(result["sources"]))
128
- else:
129
- st.warning("Could not generate an answer.")
130
-
131
- if all_tables:
132
- st.subheader("Extracted Tables:")
133
- for i, table_df in enumerate(all_tables):
134
- st.write(f"Table {i+1}:")
135
- st.dataframe(table_df)
136
- elif not all_text:
137
- st.warning("No PDF files found in the 'docs' folder.")
138
-
139
- if __name__ == "__main__":
140
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import time
3
+ import threading
4
+ import streamlit as st
5
+ from twilio.rest import Client
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer
8
+ import faiss
9
+ import numpy as np
10
+ import docx
11
+ from groq import Groq
12
  import requests
13
+ from io import StringIO
14
+ from pdfminer.high_level import extract_text_to_fp
15
+ from pdfminer.layout import LAParams
16
+ from twilio.base.exceptions import TwilioRestException # Add this at the top
17
+ import pdfplumber
18
+ import datetime
19
+ import csv
20
 
21
+ APP_START_TIME = datetime.datetime.now(datetime.timezone.utc)
22
 
23
+ os.environ["PYTORCH_JIT"] = "0"
24
+
25
+ # --- PDF Extraction ---
26
+ def _extract_tables_from_page(page):
27
+ """Extracts tables from a single page of a PDF."""
28
+
29
+ tables = page.extract_tables()
30
+ if not tables:
31
+ return []
32
+
33
+ formatted_tables = []
34
+ for table in tables:
35
+ formatted_table = []
36
+ for row in table:
37
+ if row: # Filter out empty rows
38
+ formatted_row = [cell if cell is not None else "" for cell in row] # Replace None with ""
39
+ formatted_table.append(formatted_row)
40
+ else:
41
+ formatted_table.append([""]) # Append an empty row if the row is None
42
+ formatted_tables.append(formatted_table)
43
+ return formatted_tables
44
+
45
+ def extract_text_from_pdf(pdf_path):
46
+ text_output = StringIO()
47
  all_tables = []
48
+ try:
49
+ with pdfplumber.open(pdf_path) as pdf:
50
+ for page in pdf.pages:
51
+ # Extract tables
52
+ page_tables = _extract_tables_from_page(page)
53
+ if page_tables:
54
+ all_tables.extend(page_tables)
55
+ # Extract text
56
+ text = page.extract_text()
57
+ if text:
58
+ text_output.write(text + "\n\n")
59
+ except Exception as e:
60
+ print(f"Error extracting with pdfplumber: {e}")
61
+ # Fallback to pdfminer if pdfplumber fails
62
+ with open(pdf_path, 'rb') as file:
63
+ extract_text_to_fp(file, text_output, laparams=LAParams(), output_type='text', codec=None)
64
+ extracted_text = text_output.getvalue()
65
+ return extracted_text, all_tables # Return text and list of tables
66
+
67
+ def clean_extracted_text(text):
68
+ lines = text.splitlines()
69
+ cleaned = []
70
+ for line in lines:
71
+ line = line.strip()
72
+ if line:
73
+ line = ' '.join(line.split())
74
+ cleaned.append(line)
75
+ return '\n'.join(cleaned)
76
 
77
+ def _format_tables_internal(tables):
78
+ """Formats extracted tables into a string representation."""
 
 
 
 
79
 
80
+ formatted_tables_str = []
81
+ for table in tables:
82
+ # Use csv writer to handle commas and quotes correctly
83
+ with StringIO() as csvfile:
84
+ csvwriter = csv.writer(csvfile)
85
+ csvwriter.writerows(table)
86
+ formatted_tables_str.append(csvfile.getvalue())
87
+ return "\n\n".join(formatted_tables_str)
88
 
89
+ # --- DOCX Extraction ---
90
+ def extract_text_from_docx(docx_path):
91
+ try:
92
+ doc = docx.Document(docx_path)
93
+ return '\n'.join(para.text for para in doc.paragraphs)
94
+ except Exception:
95
+ return ""
96
+
97
+ # --- Chunking ---
98
+ def chunk_text(text, tokenizer, chunk_size=128, chunk_overlap=32, max_tokens=512):
99
+ tokens = tokenizer.tokenize(text)
100
+ chunks = []
101
+ start = 0
102
+ while start < len(tokens):
103
+ end = min(start + chunk_size, len(tokens))
104
+ chunk_tokens = tokens[start:end]
105
+ chunk_text = tokenizer.convert_tokens_to_string(chunk_tokens)
106
+ chunks.append(chunk_text)
107
+ if end == len(tokens):
108
+ break
109
+ start += chunk_size - chunk_overlap
110
+ return chunks
111
+
112
+ def retrieve_chunks(question, index, embed_model, text_chunks, k=3):
113
+ question_embedding = embed_model.encode(question)
114
+ D, I = index.search(np.array([question_embedding]), k)
115
+ return [text_chunks[i] for i in I[0]]
116
+
117
+ # --- Groq Answer Generator ---
118
  def generate_answer_with_groq(question, context):
 
119
  url = "https://api.groq.com/openai/v1/chat/completions"
120
  api_key = os.environ.get("GROQ_API_KEY")
 
 
 
 
121
  headers = {
122
  "Authorization": f"Bearer {api_key}",
123
  "Content-Type": "application/json",
 
143
  "temperature": 0.5,
144
  "max_tokens": 300,
145
  }
146
+ response = requests.post(url, headers=headers, json=payload)
147
+ response.raise_for_status()
148
+ return response.json()['choices'][0]['message']['content'].strip()
149
+
150
+ # --- Twilio Functions ---
151
+ def fetch_latest_incoming_message(client, conversation_sid):
152
  try:
153
+ messages = client.conversations.v1.conversations(conversation_sid).messages.list()
154
+ for msg in reversed(messages):
155
+ if msg.author.startswith("whatsapp:"):
156
+ return {
157
+ "sid": msg.sid,
158
+ "body": msg.body,
159
+ "author": msg.author,
160
+ "timestamp": msg.date_created,
161
+ }
162
+ except TwilioRestException as e:
163
+ if e.status == 404:
164
+ print(f"Conversation {conversation_sid} not found, skipping...")
165
+ else:
166
+ print(f"Twilio error fetching messages for {conversation_sid}:", e)
167
+ except Exception as e:
168
+ #print(f"Unexpected error in fetch_latest_incoming_message for {conversation_sid}:", e)
169
+ pass
170
+
171
+ return None
172
+
173
+ def send_twilio_message(client, conversation_sid, body):
174
+ return client.conversations.v1.conversations(conversation_sid).messages.create(
175
+ author="system", body=body
176
+ )
177
+
178
+ # --- Load Knowledge Base ---
179
+ def setup_knowledge_base():
180
+ folder_path = "docs"
181
+ all_text = ""
182
+
183
+ # Process PDFs
184
+ for filename in ["FAQ.pdf", "ProductReturnPolicy.pdf"]:
185
+ pdf_path = os.path.join(folder_path, filename)
186
+ text, tables = extract_text_from_pdf(pdf_path)
187
+ all_text += clean_extracted_text(text) + "\n"
188
+ all_text += _format_tables_internal(tables) + "\n"
189
+
190
+ # Process CSVs
191
+ for filename in ["CustomerOrders.csv"]:
192
+ csv_path = os.path.join(folder_path, filename)
193
+ try:
194
+ with open(csv_path, newline='', encoding='utf-8') as csvfile:
195
+ reader = csv.DictReader(csvfile)
196
+ for row in reader:
197
+ line = f"Order ID: {row.get('OrderID')} | Customer Name: {row.get('CustomerName')} | Order Date: {row.get('OrderDate')} | ProductID: {row.get('ProductID')} | Date: {row.get('OrderDate')} | Quantity: {row.get('Quantity')} | UnitPrice(USD): {row.get('UnitPrice(USD)')} | TotalPrice(USD): {row.get('TotalPrice(USD)')} | ShippingAddress: {row.get('ShippingAddress')} | OrderStatus: {row.get('OrderStatus')}"
198
+ all_text += line + "\n"
199
+ except Exception as e:
200
+ print(f"❌ Error reading {filename}: {e}")
201
+
202
+ for filename in ["Products.csv"]:
203
+ csv_path = os.path.join(folder_path, filename)
204
+ try:
205
+ with open(csv_path, newline='', encoding='utf-8') as csvfile:
206
+ reader = csv.DictReader(csvfile)
207
+ for row in reader:
208
+ line = f"Product ID: {row.get('ProductID')} | Toy Name: {row.get('ToyName')} | Category: {row.get('Category')} | Price(USD): {row.get('Price(USD)')} | Stock Quantity: {row.get('StockQuantity')} | Description: {row.get('Description')}"
209
+ all_text += line + "\n"
210
+ except Exception as e:
211
+ print(f"❌ Error reading {filename}: {e}")
212
+
213
+ # Tokenization & chunking
214
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
215
+ chunks = chunk_text(all_text, tokenizer)
216
+ model = SentenceTransformer('all-mpnet-base-v2')
217
+ embeddings = model.encode(chunks, show_progress_bar=False, truncation=True, max_length=512)
218
+ dim = embeddings[0].shape[0]
219
+ index = faiss.IndexFlatL2(dim)
220
+ index.add(np.array(embeddings).astype('float32'))
221
+ return index, model, chunks
222
+
223
+
224
+
225
+ # --- Monitor Conversations ---
226
+ def start_conversation_monitor(client, index, embed_model, text_chunks):
227
+ processed_convos = set()
228
+ last_processed_timestamp = {}
229
+
230
+ def poll_conversation(convo_sid):
231
+ while True:
232
+ try:
233
+ latest_msg = fetch_latest_incoming_message(client, convo_sid)
234
+ if latest_msg:
235
+ msg_time = latest_msg["timestamp"]
236
+ if convo_sid not in last_processed_timestamp or msg_time > last_processed_timestamp[convo_sid]:
237
+ last_processed_timestamp[convo_sid] = msg_time
238
+ question = latest_msg["body"]
239
+ sender = latest_msg["author"]
240
+ print(f"\nπŸ“₯ New message from {sender} in {convo_sid}: {question}")
241
+ context = "\n\n".join(retrieve_chunks(question, index, embed_model, text_chunks))
242
+ answer = generate_answer_with_groq(question, context)
243
+ send_twilio_message(client, convo_sid, answer)
244
+ print(f"πŸ“€ Replied to {sender}: {answer}")
245
+ time.sleep(3)
246
+ except Exception as e:
247
+ print(f"❌ Error in convo {convo_sid} polling:", e)
248
+ time.sleep(5)
249
+
250
+ def poll_new_conversations():
251
+ print("➑️ Monitoring for new WhatsApp conversations...")
252
+ while True:
253
+ try:
254
+ conversations = client.conversations.v1.conversations.list(limit=20)
255
+ for convo in conversations:
256
+ convo_full = client.conversations.v1.conversations(convo.sid).fetch()
257
+ if convo.sid not in processed_convos and convo_full.date_created > APP_START_TIME:
258
+ participants = client.conversations.v1.conversations(convo.sid).participants.list()
259
+ for p in participants:
260
+ address = p.messaging_binding.get("address", "") if p.messaging_binding else ""
261
+ if address.startswith("whatsapp:"):
262
+ print(f"πŸ†• New WhatsApp convo found: {convo.sid}")
263
+ processed_convos.add(convo.sid)
264
+ threading.Thread(target=poll_conversation, args=(convo.sid,), daemon=True).start()
265
+ except Exception as e:
266
+ print("❌ Error polling conversations:", e)
267
+ time.sleep(5)
268
+
269
+ # βœ… Launch conversation polling monitor
270
+ threading.Thread(target=poll_new_conversations, daemon=True).start()
271
+
272
+
273
+
274
+ # --- Streamlit UI ---
275
+ st.set_page_config(page_title="Quasa – A Smart WhatsApp Chatbot", layout="wide")
276
+ st.title("πŸ“± Quasa – A Smart WhatsApp Chatbot")
277
+
278
+ account_sid = st.secrets.get("TWILIO_SID")
279
+ auth_token = st.secrets.get("TWILIO_TOKEN")
280
+ GROQ_API_KEY = st.secrets.get("GROQ_API_KEY")
281
+
282
+ if not all([account_sid, auth_token, GROQ_API_KEY]):
283
+ st.warning("⚠️ Provide all credentials below:")
284
+ account_sid = st.text_input("Twilio SID", value=account_sid or "")
285
+ auth_token = st.text_input("Twilio Token", type="password", value=auth_token or "")
286
+ GROQ_API_KEY = st.text_input("GROQ API Key", type="password", value=GROQ_API_KEY or "")
287
+
288
+ if all([account_sid, auth_token, GROQ_API_KEY]):
289
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
290
+ client = Client(account_sid, auth_token)
291
+
292
+ st.success("🟒 Monitoring new WhatsApp conversations...")
293
+ index, model, chunks = setup_knowledge_base()
294
+ threading.Thread(target=start_conversation_monitor, args=(client, index, model, chunks), daemon=True).start()
295
+ st.info("⏳ Waiting for new messages...")