Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| from dotenv import load_dotenv | |
| from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader, UnstructuredFileLoader | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain.chains import RetrievalQA | |
| from langchain_groq import ChatGroq | |
| import time | |
| import glob | |
| # --- Configuration --- | |
| DOCS_DIR = "docs" | |
| CHUNK_SIZE = 1500 | |
| CHUNK_OVERLAP = 200 | |
| EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" # Good default, consider others for specific needs | |
| CACHE_DIR_FAISS = "faiss_index_cache" # Directory to cache FAISS index | |
| # --- Helper Functions --- | |
| def get_api_key(): | |
| """Loads GROQ API key from .env file or environment variables.""" | |
| load_dotenv() | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| st.error("GROQ_API_KEY not found. Please set it in your environment variables or a .env file.") | |
| st.stop() | |
| return groq_api_key | |
| def load_and_process_documents(_docs_dir: str): | |
| """ | |
| Loads documents from the specified directory, processes them, | |
| creates embeddings, and stores them in a FAISS vector store. | |
| Caches the FAISS index to disk for faster subsequent loads. | |
| """ | |
| if not os.path.exists(_docs_dir) or not os.listdir(_docs_dir): | |
| st.warning(f"The '{_docs_dir}' directory is empty or does not exist. Please add your documents.") | |
| return None | |
| st.write(f"Scanning for documents in '{_docs_dir}'...") | |
| # Using UnstructuredFileLoader for broader file type support including tables | |
| # We'll use glob to find all files and pass them to UnstructuredFileLoader | |
| all_files = [] | |
| supported_extensions = ["*.pdf", "*.docx", "*.doc", "*.xlsx", "*.xls", "*.json", "*.txt", "*.md", "*.html", "*.csv", "*.pptx"] # Add more if needed | |
| for ext in supported_extensions: | |
| all_files.extend(glob.glob(os.path.join(_docs_dir, ext))) | |
| if not all_files: | |
| st.warning(f"No supported documents found in '{_docs_dir}'. Supported types: {', '.join(supported_extensions)}") | |
| return None | |
| st.write(f"Found {len(all_files)} files to process: {', '.join([os.path.basename(f) for f in all_files])}") | |
| docs = [] | |
| progress_bar = st.progress(0, text="Loading documents...") | |
| for i, file_path in enumerate(all_files): | |
| try: | |
| st.write(f"Processing: {os.path.basename(file_path)}") | |
| # UnstructuredFileLoader is good for various formats and attempts to handle tables. | |
| # For complex tables in PDFs, more specialized parsers might be needed if Unstructured is insufficient. | |
| # Common arguments for UnstructuredFileLoader for better table extraction: | |
| # strategy="hi_res" (for PDFs with complex layouts, may require `detectron2` installation) | |
| # mode="elements" or "paged" | |
| # pdf_infer_table_structure=True (if using unstructured[pdf]) | |
| loader = UnstructuredFileLoader(file_path, mode="elements", strategy="fast") # Start with "fast", try "hi_res" if table extraction is poor | |
| loaded_docs = loader.load() | |
| docs.extend(loaded_docs) | |
| except Exception as e: | |
| st.error(f"Error loading file {os.path.basename(file_path)}: {e}") | |
| progress_bar.progress((i + 1) / len(all_files), text=f"Loaded {os.path.basename(file_path)}") | |
| if not docs: | |
| st.warning("No documents were successfully loaded or processed.") | |
| return None | |
| progress_bar.progress(1.0, text="Documents loaded. Splitting into chunks...") | |
| time.sleep(0.5) # For UX | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
| texts = text_splitter.split_documents(docs) | |
| if not texts: | |
| st.warning("Document splitting resulted in no text chunks. Check document content and splitter settings.") | |
| return None | |
| st.write(f"Split documents into {len(texts)} chunks.") | |
| progress_bar.progress(0, text="Generating embeddings and creating vector store... (This may take a while)") | |
| # Initialize embeddings | |
| try: | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
| except Exception as e: | |
| st.error(f"Failed to load embedding model '{EMBEDDING_MODEL_NAME}': {e}") | |
| st.error("Please ensure you have an internet connection and the model name is correct.") | |
| st.stop() | |
| # Create FAISS vector store and cache it | |
| if os.path.exists(CACHE_DIR_FAISS) and os.listdir(CACHE_DIR_FAISS): | |
| try: | |
| st.write(f"Loading cached FAISS index from {CACHE_DIR_FAISS}...") | |
| vector_store = FAISS.load_local(CACHE_DIR_FAISS, embeddings, allow_dangerous_deserialization=True) # Required for FAISS with HuggingFaceEmbeddings | |
| st.write("FAISS index loaded from cache.") | |
| progress_bar.progress(1.0, text="Vector store ready.") | |
| return vector_store | |
| except Exception as e: | |
| st.warning(f"Could not load FAISS index from cache: {e}. Rebuilding index.") | |
| try: | |
| vector_store = FAISS.from_documents(texts, embeddings) | |
| if not os.path.exists(CACHE_DIR_FAISS): | |
| os.makedirs(CACHE_DIR_FAISS) | |
| vector_store.save_local(CACHE_DIR_FAISS) | |
| st.write(f"FAISS index created and saved to {CACHE_DIR_FAISS}.") | |
| progress_bar.progress(1.0, text="Vector store ready.") | |
| return vector_store | |
| except Exception as e: | |
| st.error(f"Error creating FAISS vector store: {e}") | |
| return None | |
| def get_llm(_api_key: str): | |
| """Initializes and returns the ChatGroq LLM.""" | |
| try: | |
| llm = ChatGroq( | |
| groq_api_key=_api_key, | |
| model_name="mixtral-8x7b-32768", # Or "llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it" | |
| temperature=0.2, # Adjust for creativity vs. factuality | |
| # max_tokens=1024, # Optional: set max tokens | |
| ) | |
| return llm | |
| except Exception as e: | |
| st.error(f"Error initializing GROQ LLM: {e}") | |
| st.stop() | |
| # --- Streamlit App UI --- | |
| st.set_page_config(page_title="SmartQuery RAG", layout="wide", initial_sidebar_state="expanded") | |
| # --- Styling (Optional - for a "catchy" look) --- | |
| st.markdown(""" | |
| <style> | |
| .stApp { | |
| background-color: #f0f2f6; /* Light grey background */ | |
| } | |
| .stTextInput > div > div > input { | |
| background-color: #ffffff; | |
| border-radius: 10px; | |
| } | |
| .stButton > button { | |
| border-radius: 10px; | |
| background-color: #1E88E5; /* Blue */ | |
| color: white; | |
| font-weight: bold; | |
| transition: background-color 0.3s ease; | |
| } | |
| .stButton > button:hover { | |
| background-color: #1565C0; /* Darker Blue */ | |
| } | |
| .stSpinner > div > svg { /* Spinner color */ | |
| fill: #1E88E5; | |
| } | |
| .loader-text { | |
| font-size: 1.2em; | |
| color: #333; | |
| } | |
| .ready-text { | |
| font-size: 1.2em; | |
| color: green; | |
| font-weight: bold; | |
| } | |
| .response-container { | |
| background-color: #ffffff; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| margin-top: 20px; | |
| } | |
| .response-header { | |
| font-size: 1.5em; | |
| color: #1E88E5; | |
| margin-bottom: 10px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- Main Application Logic --- | |
| st.title("π SmartQuery RAG Assistant") | |
| st.markdown("Ask questions about your documents (Customer Orders, Company Policies, Financial Data, Products, Return Policies, etc.)") | |
| # Load API Key | |
| groq_api_key = get_api_key() | |
| # Sidebar for status and controls | |
| st.sidebar.header("Knowledge Base Status") | |
| status_placeholder = st.sidebar.empty() | |
| status_placeholder.markdown("<p class='loader-text'>Knowledge Base is loading...</p>", unsafe_allow_html=True) | |
| # Create docs directory if it doesn't exist | |
| if not os.path.exists(DOCS_DIR): | |
| os.makedirs(DOCS_DIR) | |
| st.sidebar.info(f"'{DOCS_DIR}' directory created. Please add your documents there and refresh.") | |
| # Load and process documents, then initialize RAG | |
| vector_store = load_and_process_documents(DOCS_DIR) | |
| if vector_store: | |
| llm = get_llm(groq_api_key) | |
| retriever = vector_store.as_retriever( | |
| search_type="similarity", # "mmr" (Maximal Marginal Relevance) is another option | |
| search_kwargs={"k": 5} # Retrieve top 5 relevant chunks | |
| ) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", # Options: "stuff", "map_reduce", "refine", "map_rerank" | |
| retriever=retriever, | |
| return_source_documents=True # Set to True to see which documents were retrieved | |
| ) | |
| status_placeholder.markdown("<p class='ready-text'>β Application is Ready. Ask your questions!</p>", unsafe_allow_html=True) | |
| st.sidebar.success("Knowledge Base Loaded Successfully!") | |
| if st.sidebar.button("π Clear Cache & Reload Documents"): | |
| # Clear specific caches or the entire cache | |
| st.cache_resource.clear() # Clears all @st.cache_resource | |
| # Could also selectively clear st.cache_data if used. | |
| # Manually delete FAISS cache directory | |
| if os.path.exists(CACHE_DIR_FAISS): | |
| import shutil | |
| shutil.rmtree(CACHE_DIR_FAISS) | |
| st.sidebar.info(f"Cache '{CACHE_DIR_FAISS}' cleared.") | |
| st.rerun() | |
| else: | |
| status_placeholder.error("β οΈ Knowledge Base could not be loaded. Check messages above and ensure documents are in the 'docs' folder.") | |
| st.stop() | |
| # --- User Interaction --- | |
| st.markdown("---") | |
| query = st.text_input("Enter your question:", placeholder="e.g., What is the return policy for electronics?") | |
| if st.button("Submit Query", type="primary"): | |
| if not query: | |
| st.warning("Please enter a question.") | |
| else: | |
| with st.spinner("π§ Thinking... Fetching answer..."): | |
| try: | |
| start_time = time.time() | |
| response = qa_chain.invoke({"query": query}) | |
| end_time = time.time() | |
| st.markdown("<div class='response-container'>", unsafe_allow_html=True) | |
| st.markdown("<p class='response-header'>π‘ Answer:</p>", unsafe_allow_html=True) | |
| st.write(response["result"]) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| st.info(f"Response generated in {end_time - start_time:.2f} seconds.") | |
| with st.expander("π Show Retrieved Sources"): | |
| if "source_documents" in response and response["source_documents"]: | |
| for i, doc in enumerate(response["source_documents"]): | |
| st.markdown(f"**Source {i+1} (from: {doc.metadata.get('source', 'N/A').split('/')[-1]})**") | |
| st.caption(doc.page_content[:500] + "..." if doc.page_content else "N/A") # Display first 500 chars | |
| st.markdown("---") | |
| else: | |
| st.write("No specific source documents were identified for this query.") | |
| except Exception as e: | |
| st.error(f"An error occurred while processing your query: {e}") | |
| # --- Suggestions for Improvement (as per prompt request) --- | |
| st.sidebar.markdown("---") | |
| st.sidebar.subheader("π‘ Suggestions & Notes:") | |
| st.sidebar.markdown(""" | |
| - **Table Data:** `UnstructuredFileLoader` attempts to parse tables. For PDFs with very complex tables, if accuracy is insufficient, consider: | |
| - Pre-processing PDFs with tools like `Camelot` or `Tabula-py` to extract tables into CSV/Markdown, then load those. | |
| - Exploring `unstructured` with `strategy="hi_res"` (may require `detectron2` and `brew install poppler` or similar for your OS). This is more computationally intensive. | |
| - Fine-tuning embedding models or using models specialized for tabular data if table queries are critical. | |
| - **Accuracy:** "100% accuracy" is an ideal. RAG systems are powerful but can make mistakes. Improve by: | |
| - Better chunking strategies. | |
| - More advanced retrieval (e.g., HyDE, re-ranking). | |
| - Prompt engineering for the QA chain. | |
| - Using more powerful (and potentially slower/costlier) LLMs if available via GROQ. | |
| - Regularly evaluating and curating the document set. | |
| - **Performance:** The current FAISS caching helps significantly. For very large datasets, explore more scalable vector DBs. | |
| - **UI/UX:** Added some basic styling. For more "catchy" UI, explore Streamlit Components or more elaborate CSS. | |
| - **Error Handling:** Added basic error checks. Robust applications need more comprehensive error management. | |
| - **Scalability:** For many concurrent users on Hugging Face, resource limits (CPU, RAM) for the free tier might be a bottleneck, especially during embedding. | |
| - **Embedding Model:** `all-MiniLM-L6-v2` is efficient. For higher accuracy with more complex content, consider models like `sentence-transformers/all-mpnet-base-v2` or domain-specific embeddings. | |
| - **Deployment:** Ensure `GROQ_API_KEY` is set as a secret in Hugging Face Spaces. | |
| """) | |
| st.markdown("---") | |
| st.markdown("<p style='text-align: center; color: grey;'>Powered by Streamlit, Langchain & GROQ</p>", unsafe_allow_html=True) |