Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
import pandas as pd | |
import os | |
import io | |
import requests | |
# --- 1. Data Loading and Preprocessing --- | |
def load_and_process_pdfs_from_folder(docs_folder="docs"): | |
"""Loads and processes all PDF files from the specified folder.""" | |
all_text = "" | |
all_tables = [] | |
for filename in os.listdir(docs_folder): | |
if filename.endswith(".pdf"): | |
filepath = os.path.join(docs_folder, filename) | |
try: | |
with open(filepath, 'rb') as file: | |
pdf_reader = PdfReader(file) | |
for page in pdf_reader.pages: | |
all_text += page.extract_text() + "\n" | |
try: | |
for table in page.extract_tables(): | |
df = pd.DataFrame(table) | |
all_tables.append(df) | |
except Exception as e: | |
print(f"Could not extract tables from page in {filename}. Error: {e}") | |
except Exception as e: | |
st.error(f"Error reading PDF {filename}: {e}") | |
return all_text, all_tables | |
def split_text_into_chunks(text): | |
"""Splits the text into smaller, manageable chunks.""" | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
chunks = text_splitter.split_text(text) | |
return chunks | |
def create_vectorstore(chunks): | |
"""Creates a vectorstore from the text chunks using HuggingFace embeddings.""" | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
vectorstore = FAISS.from_texts(chunks, embeddings) | |
return vectorstore | |
# --- 2. Question Answering with Groq --- | |
def generate_answer_with_groq(question, context): | |
"""Generates an answer using the Groq API.""" | |
url = "https://api.groq.com/openai/v1/chat/completions" | |
api_key = os.environ.get("GROQ_API_KEY") | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json", | |
} | |
prompt = ( | |
f"Customer asked: '{question}'\n\n" | |
f"Here is the relevant product or policy info to help:\n{context}\n\n" | |
f"Respond in a friendly and helpful tone as a toy shop support agent." | |
) | |
payload = { | |
"model": "llama3-8b-8192", | |
"messages": [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are ToyBot, a friendly and helpful WhatsApp assistant for an online toy shop. " | |
"Your goal is to politely answer customer questions, help them choose the right toys, " | |
"provide order or delivery information, explain return policies, and guide them through purchases." | |
) | |
}, | |
{"role": "user", "content": prompt}, | |
], | |
"temperature": 0.5, | |
"max_tokens": 300, | |
} | |
try: | |
response = requests.post(url, headers=headers, json=payload) | |
response.raise_for_status() # Raise an exception for bad status codes | |
return response.json()['choices'][0]['message']['content'].strip() | |
except requests.exceptions.RequestException as e: | |
st.error(f"Error communicating with Groq API: {e}") | |
return "An error occurred while trying to get the answer." | |
def perform_rag_groq(vectorstore, query): | |
"""Performs retrieval and generates an answer using Groq.""" | |
retriever = vectorstore.as_retriever() | |
relevant_docs = retriever.get_relevant_documents(query) | |
context = "\n\n".join([doc.page_content for doc in relevant_docs]) | |
answer = generate_answer_with_groq(query, context) | |
return {"answer": answer, "sources": [doc.metadata['source'] for doc in relevant_docs]} # You might need to adjust how sources are stored | |
# --- 3. Streamlit UI --- | |
def main(): | |
st.title("PDF Q&A with Local Docs (Powered by Groq)") | |
st.info("Make sure you have a 'docs' folder in the same directory as this script containing your PDF files.") | |
groq_api_key = st.text_input("Enter your Groq API Key:", type="password") | |
if not groq_api_key: | |
st.warning("Please enter your Groq API key to ask questions.") | |
return | |
os.environ["GROQ_API_KEY"] = groq_api_key | |
with st.spinner("Loading and processing PDF(s)..."): | |
all_text, all_tables = load_and_process_pdfs_from_folder() | |
if all_text: | |
with st.spinner("Creating knowledge base..."): | |
chunks = split_text_into_chunks(all_text) | |
# We need to add metadata (source) to the chunks for accurate source tracking | |
metadatas = [{"source": f"doc_{i+1}"} for i in range(len(chunks))] # Basic source tracking | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
vectorstore = FAISS.from_texts(chunks, embeddings, metadatas=metadatas) | |
query = st.text_input("Ask a question about the documents:") | |
if query: | |
with st.spinner("Searching for answer..."): | |
result = perform_rag_groq(vectorstore, query) | |
st.subheader("Answer:") | |
st.write(result["answer"]) | |
if "sources" in result: | |
st.subheader("Source:") | |
st.write(", ".join(result["sources"])) # Display sources | |
if all_tables: | |
st.subheader("Extracted Tables:") | |
for i, table_df in enumerate(all_tables): | |
st.write(f"Table {i+1}:") | |
st.dataframe(table_df) | |
elif not all_text: | |
st.warning("No PDF files found in the 'docs' folder.") | |
if __name__ == "__main__": | |
main() |