AIToyBot / app.py
masadonline's picture
Update app.py
2e93654 verified
raw
history blame
13.5 kB
import streamlit as st
import os
from dotenv import load_dotenv
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader, UnstructuredFileLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_groq import ChatGroq
import time
import glob
# --- Configuration ---
DOCS_DIR = "docs"
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 200
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" # Good default, consider others for specific needs
CACHE_DIR_FAISS = "faiss_index_cache" # Directory to cache FAISS index
# --- Helper Functions ---
def get_api_key():
"""Loads GROQ API key from .env file or environment variables."""
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
st.error("GROQ_API_KEY not found. Please set it in your environment variables or a .env file.")
st.stop()
return groq_api_key
@st.cache_resource(show_spinner="Loading and processing documents...")
def load_and_process_documents(_docs_dir: str):
"""
Loads documents from the specified directory, processes them,
creates embeddings, and stores them in a FAISS vector store.
Caches the FAISS index to disk for faster subsequent loads.
"""
if not os.path.exists(_docs_dir) or not os.listdir(_docs_dir):
st.warning(f"The '{_docs_dir}' directory is empty or does not exist. Please add your documents.")
return None
st.write(f"Scanning for documents in '{_docs_dir}'...")
# Using UnstructuredFileLoader for broader file type support including tables
# We'll use glob to find all files and pass them to UnstructuredFileLoader
all_files = []
supported_extensions = ["*.pdf", "*.docx", "*.doc", "*.xlsx", "*.xls", "*.json", "*.txt", "*.md", "*.html", "*.csv", "*.pptx"] # Add more if needed
for ext in supported_extensions:
all_files.extend(glob.glob(os.path.join(_docs_dir, ext)))
if not all_files:
st.warning(f"No supported documents found in '{_docs_dir}'. Supported types: {', '.join(supported_extensions)}")
return None
st.write(f"Found {len(all_files)} files to process: {', '.join([os.path.basename(f) for f in all_files])}")
docs = []
progress_bar = st.progress(0, text="Loading documents...")
for i, file_path in enumerate(all_files):
try:
st.write(f"Processing: {os.path.basename(file_path)}")
# UnstructuredFileLoader is good for various formats and attempts to handle tables.
# For complex tables in PDFs, more specialized parsers might be needed if Unstructured is insufficient.
# Common arguments for UnstructuredFileLoader for better table extraction:
# strategy="hi_res" (for PDFs with complex layouts, may require `detectron2` installation)
# mode="elements" or "paged"
# pdf_infer_table_structure=True (if using unstructured[pdf])
loader = UnstructuredFileLoader(file_path, mode="elements", strategy="fast") # Start with "fast", try "hi_res" if table extraction is poor
loaded_docs = loader.load()
docs.extend(loaded_docs)
except Exception as e:
st.error(f"Error loading file {os.path.basename(file_path)}: {e}")
progress_bar.progress((i + 1) / len(all_files), text=f"Loaded {os.path.basename(file_path)}")
if not docs:
st.warning("No documents were successfully loaded or processed.")
return None
progress_bar.progress(1.0, text="Documents loaded. Splitting into chunks...")
time.sleep(0.5) # For UX
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
texts = text_splitter.split_documents(docs)
if not texts:
st.warning("Document splitting resulted in no text chunks. Check document content and splitter settings.")
return None
st.write(f"Split documents into {len(texts)} chunks.")
progress_bar.progress(0, text="Generating embeddings and creating vector store... (This may take a while)")
# Initialize embeddings
try:
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
except Exception as e:
st.error(f"Failed to load embedding model '{EMBEDDING_MODEL_NAME}': {e}")
st.error("Please ensure you have an internet connection and the model name is correct.")
st.stop()
# Create FAISS vector store and cache it
if os.path.exists(CACHE_DIR_FAISS) and os.listdir(CACHE_DIR_FAISS):
try:
st.write(f"Loading cached FAISS index from {CACHE_DIR_FAISS}...")
vector_store = FAISS.load_local(CACHE_DIR_FAISS, embeddings, allow_dangerous_deserialization=True) # Required for FAISS with HuggingFaceEmbeddings
st.write("FAISS index loaded from cache.")
progress_bar.progress(1.0, text="Vector store ready.")
return vector_store
except Exception as e:
st.warning(f"Could not load FAISS index from cache: {e}. Rebuilding index.")
try:
vector_store = FAISS.from_documents(texts, embeddings)
if not os.path.exists(CACHE_DIR_FAISS):
os.makedirs(CACHE_DIR_FAISS)
vector_store.save_local(CACHE_DIR_FAISS)
st.write(f"FAISS index created and saved to {CACHE_DIR_FAISS}.")
progress_bar.progress(1.0, text="Vector store ready.")
return vector_store
except Exception as e:
st.error(f"Error creating FAISS vector store: {e}")
return None
@st.cache_resource(show_spinner="Initializing LLM...")
def get_llm(_api_key: str):
"""Initializes and returns the ChatGroq LLM."""
try:
llm = ChatGroq(
groq_api_key=_api_key,
model_name="mixtral-8x7b-32768", # Or "llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it"
temperature=0.2, # Adjust for creativity vs. factuality
# max_tokens=1024, # Optional: set max tokens
)
return llm
except Exception as e:
st.error(f"Error initializing GROQ LLM: {e}")
st.stop()
# --- Streamlit App UI ---
st.set_page_config(page_title="SmartQuery RAG", layout="wide", initial_sidebar_state="expanded")
# --- Styling (Optional - for a "catchy" look) ---
st.markdown("""
<style>
.stApp {
background-color: #f0f2f6; /* Light grey background */
}
.stTextInput > div > div > input {
background-color: #ffffff;
border-radius: 10px;
}
.stButton > button {
border-radius: 10px;
background-color: #1E88E5; /* Blue */
color: white;
font-weight: bold;
transition: background-color 0.3s ease;
}
.stButton > button:hover {
background-color: #1565C0; /* Darker Blue */
}
.stSpinner > div > svg { /* Spinner color */
fill: #1E88E5;
}
.loader-text {
font-size: 1.2em;
color: #333;
}
.ready-text {
font-size: 1.2em;
color: green;
font-weight: bold;
}
.response-container {
background-color: #ffffff;
padding: 20px;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
margin-top: 20px;
}
.response-header {
font-size: 1.5em;
color: #1E88E5;
margin-bottom: 10px;
}
</style>
""", unsafe_allow_html=True)
# --- Main Application Logic ---
st.title("πŸ“„ SmartQuery RAG Assistant")
st.markdown("Ask questions about your documents (Customer Orders, Company Policies, Financial Data, Products, Return Policies, etc.)")
# Load API Key
groq_api_key = get_api_key()
# Sidebar for status and controls
st.sidebar.header("Knowledge Base Status")
status_placeholder = st.sidebar.empty()
status_placeholder.markdown("<p class='loader-text'>Knowledge Base is loading...</p>", unsafe_allow_html=True)
# Create docs directory if it doesn't exist
if not os.path.exists(DOCS_DIR):
os.makedirs(DOCS_DIR)
st.sidebar.info(f"'{DOCS_DIR}' directory created. Please add your documents there and refresh.")
# Load and process documents, then initialize RAG
vector_store = load_and_process_documents(DOCS_DIR)
if vector_store:
llm = get_llm(groq_api_key)
retriever = vector_store.as_retriever(
search_type="similarity", # "mmr" (Maximal Marginal Relevance) is another option
search_kwargs={"k": 5} # Retrieve top 5 relevant chunks
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # Options: "stuff", "map_reduce", "refine", "map_rerank"
retriever=retriever,
return_source_documents=True # Set to True to see which documents were retrieved
)
status_placeholder.markdown("<p class='ready-text'>βœ… Application is Ready. Ask your questions!</p>", unsafe_allow_html=True)
st.sidebar.success("Knowledge Base Loaded Successfully!")
if st.sidebar.button("πŸ”„ Clear Cache & Reload Documents"):
# Clear specific caches or the entire cache
st.cache_resource.clear() # Clears all @st.cache_resource
# Could also selectively clear st.cache_data if used.
# Manually delete FAISS cache directory
if os.path.exists(CACHE_DIR_FAISS):
import shutil
shutil.rmtree(CACHE_DIR_FAISS)
st.sidebar.info(f"Cache '{CACHE_DIR_FAISS}' cleared.")
st.rerun()
else:
status_placeholder.error("⚠️ Knowledge Base could not be loaded. Check messages above and ensure documents are in the 'docs' folder.")
st.stop()
# --- User Interaction ---
st.markdown("---")
query = st.text_input("Enter your question:", placeholder="e.g., What is the return policy for electronics?")
if st.button("Submit Query", type="primary"):
if not query:
st.warning("Please enter a question.")
else:
with st.spinner("🧠 Thinking... Fetching answer..."):
try:
start_time = time.time()
response = qa_chain.invoke({"query": query})
end_time = time.time()
st.markdown("<div class='response-container'>", unsafe_allow_html=True)
st.markdown("<p class='response-header'>πŸ’‘ Answer:</p>", unsafe_allow_html=True)
st.write(response["result"])
st.markdown("</div>", unsafe_allow_html=True)
st.info(f"Response generated in {end_time - start_time:.2f} seconds.")
with st.expander("πŸ“š Show Retrieved Sources"):
if "source_documents" in response and response["source_documents"]:
for i, doc in enumerate(response["source_documents"]):
st.markdown(f"**Source {i+1} (from: {doc.metadata.get('source', 'N/A').split('/')[-1]})**")
st.caption(doc.page_content[:500] + "..." if doc.page_content else "N/A") # Display first 500 chars
st.markdown("---")
else:
st.write("No specific source documents were identified for this query.")
except Exception as e:
st.error(f"An error occurred while processing your query: {e}")
# --- Suggestions for Improvement (as per prompt request) ---
st.sidebar.markdown("---")
st.sidebar.subheader("πŸ’‘ Suggestions & Notes:")
st.sidebar.markdown("""
- **Table Data:** `UnstructuredFileLoader` attempts to parse tables. For PDFs with very complex tables, if accuracy is insufficient, consider:
- Pre-processing PDFs with tools like `Camelot` or `Tabula-py` to extract tables into CSV/Markdown, then load those.
- Exploring `unstructured` with `strategy="hi_res"` (may require `detectron2` and `brew install poppler` or similar for your OS). This is more computationally intensive.
- Fine-tuning embedding models or using models specialized for tabular data if table queries are critical.
- **Accuracy:** "100% accuracy" is an ideal. RAG systems are powerful but can make mistakes. Improve by:
- Better chunking strategies.
- More advanced retrieval (e.g., HyDE, re-ranking).
- Prompt engineering for the QA chain.
- Using more powerful (and potentially slower/costlier) LLMs if available via GROQ.
- Regularly evaluating and curating the document set.
- **Performance:** The current FAISS caching helps significantly. For very large datasets, explore more scalable vector DBs.
- **UI/UX:** Added some basic styling. For more "catchy" UI, explore Streamlit Components or more elaborate CSS.
- **Error Handling:** Added basic error checks. Robust applications need more comprehensive error management.
- **Scalability:** For many concurrent users on Hugging Face, resource limits (CPU, RAM) for the free tier might be a bottleneck, especially during embedding.
- **Embedding Model:** `all-MiniLM-L6-v2` is efficient. For higher accuracy with more complex content, consider models like `sentence-transformers/all-mpnet-base-v2` or domain-specific embeddings.
- **Deployment:** Ensure `GROQ_API_KEY` is set as a secret in Hugging Face Spaces.
""")
st.markdown("---")
st.markdown("<p style='text-align: center; color: grey;'>Powered by Streamlit, Langchain & GROQ</p>", unsafe_allow_html=True)