masadonline commited on
Commit
1086067
Β·
verified Β·
1 Parent(s): f31fbc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -59
app.py CHANGED
@@ -1,18 +1,68 @@
1
  import os
2
  import streamlit as st
3
- from glob import glob
4
- from langchain_community.document_loaders import PyPDFLoader
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.vectorstores import FAISS
7
- from langchain.embeddings import HuggingFaceEmbeddings
8
- from langchain.chains import RetrievalQA
9
- from langchain_groq import ChatGroq # βœ… Correct import
10
-
11
- # Set page config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  st.set_page_config(page_title="SMEHelpBot πŸ€–", layout="wide")
13
  st.title("πŸ€– SMEHelpBot – Your AI Assistant for Small Businesses")
14
 
15
- # Load API key
16
  GROQ_API_KEY = st.secrets.get("GROQ_API_KEY") or os.getenv("GROQ_API_KEY")
17
  if not GROQ_API_KEY:
18
  st.error("❌ Please set your GROQ_API_KEY in environment or .streamlit/secrets.toml")
@@ -20,54 +70,58 @@ if not GROQ_API_KEY:
20
 
21
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
22
 
23
- # Load all PDFs from the 'docs' folder
24
- pdf_paths = glob("docs/*.pdf")
25
 
26
- if not pdf_paths:
27
- st.warning("πŸ“ Please place some PDF files in the `docs/` folder.")
28
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Load and split all PDFs
31
- documents = []
32
- for path in pdf_paths:
33
- loader = PyPDFLoader(path)
34
- documents.extend(loader.load())
35
-
36
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
37
- chunks = splitter.split_documents(documents)
38
-
39
- # Create vector store from chunks
40
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
41
- vectorstore = FAISS.from_documents(chunks, embeddings)
42
- retriever = vectorstore.as_retriever()
43
-
44
- # Set up LLM with Groq
45
- llm = ChatGroq(temperature=0.3, model_name="llama3-8b-8192")
46
-
47
- # Build QA chain
48
- qa_chain = RetrievalQA.from_chain_type(
49
- llm=llm,
50
- chain_type="stuff",
51
- retriever=retriever,
52
- return_source_documents=True
53
- )
54
-
55
- # User input
56
- user_question = st.text_input("πŸ’¬ Ask your question about SME documents:", key="user_question")
57
-
58
- # Button to trigger response
59
- if st.button("Ask") or user_question and st.session_state.get("user_question_submitted", False) is False:
60
- st.session_state["user_question_submitted"] = True
61
- with st.spinner("πŸ€” Thinking..."):
62
- result = qa_chain({"query": user_question})
63
- st.success("βœ… Answer:")
64
- st.write(result["result"])
65
-
66
- with st.expander("πŸ“„ Source Snippets"):
67
- for i, doc in enumerate(result["source_documents"]):
68
- st.markdown(f"**Source {i+1}:**\n{doc.page_content[:300]}...")
69
-
70
- # Reset the submit flag if input changes
71
- if user_question != st.session_state.get("last_input", ""):
72
- st.session_state["user_question_submitted"] = False
73
- st.session_state["last_input"] = user_question
 
1
  import os
2
  import streamlit as st
3
+ import PyPDF2
4
+ from pdfminer.high_level import extract_text
5
+ from transformers import AutoTokenizer
6
+ from sentence_transformers import SentenceTransformer
7
+ import faiss
8
+ import numpy as np
9
+ from groq import Groq
10
+
11
+ # --- Helper Functions ---
12
+
13
+ def extract_text_from_pdf(pdf_path):
14
+ try:
15
+ text = ""
16
+ with open(pdf_path, 'rb') as file:
17
+ pdf_reader = PyPDF2.PdfReader(file)
18
+ for page_num in range(len(pdf_reader.pages)):
19
+ page = pdf_reader.pages[page_num]
20
+ page_text = page.extract_text()
21
+ if page_text:
22
+ text += page_text
23
+ return text
24
+ except Exception as e:
25
+ st.warning(f"PyPDF2 failed with error: {e}. Trying pdfminer.six...")
26
+ return extract_text(pdf_path)
27
+
28
+ def chunk_text_with_tokenizer(text, tokenizer, chunk_size=150, chunk_overlap=30):
29
+ tokens = tokenizer.tokenize(text)
30
+ chunks = []
31
+ start = 0
32
+ while start < len(tokens):
33
+ end = min(start + chunk_size, len(tokens))
34
+ chunk_tokens = tokens[start:end]
35
+ chunk_text = tokenizer.convert_tokens_to_string(chunk_tokens)
36
+ chunks.append(chunk_text)
37
+ start += chunk_size - chunk_overlap
38
+ return chunks
39
+
40
+ def retrieve_relevant_chunks(question, index, embeddings_model, text_chunks, k=3):
41
+ question_embedding = embeddings_model.encode([question])[0]
42
+ D, I = index.search(np.array([question_embedding]), k)
43
+ relevant_chunks = [text_chunks[i] for i in I[0]]
44
+ return relevant_chunks
45
+
46
+ def generate_answer_with_groq(question, context):
47
+ prompt = f"Based on the following context, answer the question: '{question}'\n\nContext:\n{context}"
48
+ model_name = "llama-3.3-70b-versatile" # Adjust model if needed
49
+ try:
50
+ groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
51
+ response = groq_client.chat.completions.create(
52
+ model=model_name,
53
+ messages=[{"role": "user", "content": prompt}]
54
+ )
55
+ return response.choices[0].message.content
56
+ except Exception as e:
57
+ st.error(f"Error generating answer with Groq API: {e}")
58
+ return "I'm sorry, I couldn't generate an answer at this time."
59
+
60
+ # --- Streamlit UI & Logic ---
61
+
62
  st.set_page_config(page_title="SMEHelpBot πŸ€–", layout="wide")
63
  st.title("πŸ€– SMEHelpBot – Your AI Assistant for Small Businesses")
64
 
65
+ # GROQ API key check
66
  GROQ_API_KEY = st.secrets.get("GROQ_API_KEY") or os.getenv("GROQ_API_KEY")
67
  if not GROQ_API_KEY:
68
  st.error("❌ Please set your GROQ_API_KEY in environment or .streamlit/secrets.toml")
 
70
 
71
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
72
 
73
+ # File uploader
74
+ uploaded_pdf = st.file_uploader("πŸ“ Upload PDF document(s) for SME knowledge base", type=["pdf"], accept_multiple_files=False)
75
 
76
+ # Text input for question
77
+ user_question = st.text_input("πŸ’¬ Ask your question about SME documents:")
78
+
79
+ # Button to trigger processing
80
+ if st.button("Get Answer") or (user_question and uploaded_pdf):
81
+ if not uploaded_pdf:
82
+ st.warning("Please upload a PDF file first.")
83
+ elif not user_question:
84
+ st.warning("Please enter a question.")
85
+ else:
86
+ with st.spinner("Processing PDF and generating answer..."):
87
+ # Save uploaded file temporarily for PyPDF2/pdfminer
88
+ temp_path = f"/tmp/{uploaded_pdf.name}"
89
+ with open(temp_path, "wb") as f:
90
+ f.write(uploaded_pdf.getbuffer())
91
+
92
+ # Extract text
93
+ pdf_text = extract_text_from_pdf(temp_path)
94
+
95
+ # Tokenizer + Chunk
96
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
97
+ text_chunks = chunk_text_with_tokenizer(pdf_text, tokenizer)
98
+
99
+ # Embeddings
100
+ embedding_model = SentenceTransformer('all-mpnet-base-v2')
101
+ all_embeddings = embedding_model.encode(text_chunks) if text_chunks else []
102
+
103
+ if not all_embeddings:
104
+ st.error("No text chunks found to create embeddings.")
105
+ else:
106
+ # Create FAISS index
107
+ embedding_dim = all_embeddings[0].shape[0]
108
+ index = faiss.IndexFlatL2(embedding_dim)
109
+ index.add(np.array(all_embeddings))
110
+
111
+ # Retrieve relevant chunks
112
+ relevant_chunks = retrieve_relevant_chunks(user_question, index, embedding_model, text_chunks)
113
+ context = "\n\n".join(relevant_chunks)
114
+
115
+ # Generate answer with Groq
116
+ answer = generate_answer_with_groq(user_question, context)
117
+
118
+ # Display outputs
119
+ st.markdown("### Extracted Text Snippet:")
120
+ st.write(pdf_text[:500] + "...")
121
+
122
+ st.markdown("### Sample Text Chunks:")
123
+ for i, chunk in enumerate(text_chunks[:3]):
124
+ st.write(f"Chunk {i+1}: {chunk[:200]}...")
125
 
126
+ st.markdown("### Answer:")
127
+ st.success(answer)