Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import time | |
| from datetime import datetime | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from pptx import Presentation | |
| from io import BytesIO | |
| import shutil | |
| import logging | |
| import chromadb | |
| import tempfile | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| import requests | |
| from transformers import BitsAndBytesConfig | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Environment setup for Hugging Face token | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "default-token") | |
| if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token": | |
| logger.warning("HUGGINGFACEHUB_API_TOKEN not set. Some models may not work.") | |
| # Model and embedding options | |
| LLM_MODELS = { | |
| "High Accuracy (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "Balanced (Gemma-2-2B)": "google/gemma-2-2b-it", | |
| "Lightweight (Mistral-7B)": "mistralai/Mistral-7B-Instruct-v0.2" | |
| } | |
| EMBEDDING_MODELS = { | |
| "Lightweight (MiniLM-L6)": "sentence-transformers/all-MiniLM-L6-v2", | |
| "Balanced (MPNet-Base)": "sentence-transformers/all-mpnet-base-v2", | |
| "High Accuracy (BGE-Large)": "BAAI/bge-large-en-v1.5" | |
| } | |
| # Global state | |
| vector_store = None | |
| qa_chain = None | |
| chat_history = [] | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| PERSIST_DIRECTORY = tempfile.mkdtemp() # Use temporary directory for ChromaDB | |
| # Custom PPTX loader | |
| class PPTXLoader: | |
| def __init__(self, file_path): | |
| self.file_path = file_path | |
| def load(self): | |
| docs = [] | |
| try: | |
| with open(self.file_path, "rb") as f: | |
| prs = Presentation(BytesIO(f.read())) | |
| for slide_num, slide in enumerate(prs.slides, 1): | |
| text = "" | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text") and shape.text: | |
| text += shape.text + "\n" | |
| if text.strip(): | |
| docs.append({"page_content": text, "metadata": {"source": self.file_path, "slide": slide_num}}) | |
| except Exception as e: | |
| logger.error(f"Error loading PPTX {self.file_path}: {str(e)}") | |
| return [] | |
| return docs | |
| # Function to load documents | |
| def load_documents(files): | |
| documents = [] | |
| for file in files: | |
| try: | |
| file_path = file.name | |
| logger.info(f"Loading file: {file_path}") | |
| if file_path.endswith(".pdf"): | |
| loader = PyPDFLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif file_path.endswith(".txt"): | |
| loader = TextLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif file_path.endswith(".docx"): | |
| loader = Docx2txtLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif file_path.endswith(".pptx"): | |
| loader = PPTXLoader(file_path) | |
| documents.extend([{"page_content": doc["page_content"], "metadata": doc["metadata"]} for doc in loader.load()]) | |
| except Exception as e: | |
| logger.error(f"Error loading file {file_path}: {str(e)}") | |
| continue | |
| return documents | |
| # Function to process documents and create vector store | |
| def process_documents(files, chunk_size, chunk_overlap, embedding_model): | |
| global vector_store | |
| if not files: | |
| return "Please upload at least one document.", None | |
| # Clear existing vector store | |
| if os.path.exists(PERSIST_DIRECTORY): | |
| try: | |
| shutil.rmtree(PERSIST_DIRECTORY) | |
| logger.info("Cleared existing ChromaDB directory.") | |
| except Exception as e: | |
| logger.error(f"Error clearing ChromaDB directory: {str(e)}") | |
| return f"Error clearing vector store: {str(e)}", None | |
| os.makedirs(PERSIST_DIRECTORY, exist_ok=True) | |
| # Load documents | |
| documents = load_documents(files) | |
| if not documents: | |
| return "No valid documents loaded. Check file formats or content.", None | |
| # Split documents | |
| try: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=int(chunk_size), | |
| chunk_overlap=int(chunk_overlap), | |
| length_function=len | |
| ) | |
| doc_splits = text_splitter.split_documents(documents) | |
| logger.info(f"Split {len(documents)} documents into {len(doc_splits)} chunks.") | |
| except Exception as e: | |
| logger.error(f"Error splitting documents: {str(e)}") | |
| return f"Error splitting documents: {str(e)}", None | |
| # Create embeddings | |
| try: | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODELS[embedding_model]) | |
| except Exception as e: | |
| logger.error(f"Error initializing embeddings for {embedding_model}: {str(e)}") | |
| return f"Error initializing embeddings: {str(e)}", None | |
| # Create vector store | |
| try: | |
| # Use in-memory Chroma client to avoid filesystem issues | |
| collection_name = f"doctalk_collection_{int(time.time())}" | |
| client = chromadb.Client() | |
| vector_store = Chroma.from_documents( | |
| documents=doc_splits, | |
| embedding=embeddings, | |
| collection_name=collection_name | |
| ) | |
| return f"Processed {len(documents)} documents into {len(doc_splits)} chunks.", None | |
| except Exception as e: | |
| logger.error(f"Error creating vector store: {str(e)}") | |
| return f"Error creating vector store: {str(e)}", None | |
| # Function to initialize QA chain with retry logic | |
| def initialize_qa_chain(llm_model, temperature): | |
| global qa_chain | |
| if not vector_store: | |
| return "Please process documents first.", None | |
| try: | |
| # Enable 4-bit quantization for all models to reduce memory usage | |
| quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
| llm = HuggingFaceEndpoint( | |
| repo_id=LLM_MODELS[llm_model], | |
| task="text-generation", | |
| temperature=float(temperature), | |
| max_new_tokens=512, | |
| huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"], | |
| timeout=30, | |
| model_kwargs={"quantization_config": quantization_config} | |
| ) | |
| # Dynamically set k based on vector store size | |
| collection = vector_store._collection | |
| doc_count = collection.count() | |
| k = min(3, doc_count) if doc_count > 0 else 1 | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=vector_store.as_retriever(search_kwargs={"k": k}), | |
| memory=memory | |
| ) | |
| logger.info(f"Initialized QA chain with {llm_model} and k={k}.") | |
| return "QA Doctor: QA chain initialized successfully.", None | |
| except requests.exceptions.HTTPError as e: | |
| logger.error(f"HTTP error initializing QA chain for {llm_model}: {str(e)}") | |
| if "503" in str(e): | |
| return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'High Accuracy (Mixtral-8x7B)' or wait and retry.", None | |
| elif "403" in str(e): | |
| return f"Error: Access denied for {llm_model}. Free-tier API limits models >10GB. Try 'High Accuracy (Mixtral-8x7B)' or upgrade to Pro at https://huggingface.co/settings/billing.", None | |
| return f"Error initializing QA chain: {str(e)}.", None | |
| except Exception as e: | |
| logger.error(f"Error initializing QA chain for {llm_model}: {str(e)}") | |
| return f"Error initializing QA chain: {str(e)}. Ensure your HF token is valid.", None | |
| # Function to handle user query with retry logic | |
| def answer_question(question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap): | |
| global chat_history | |
| if not vector_store: | |
| return "Please process documents first.", chat_history | |
| if not qa_chain: | |
| return "Please initialize the QA chain.", chat_history | |
| if not question.strip(): | |
| return "Please enter a valid question.", chat_history | |
| try: | |
| response = qa_chain.invoke({"question": question})["answer"] | |
| chat_history.append({"role": "user", "content": question}) | |
| chat_history.append({"role": "assistant", "content": response}) | |
| logger.info(f"Answered question: {question}") | |
| return response, chat_history | |
| except requests.exceptions.HTTPError as e: | |
| logger.error(f"HTTP error answering question: {str(e)}") | |
| if "503" in str(e): | |
| return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'High Accuracy (Mixtral-8x7B)' or wait and retry.", chat_history | |
| elif "403" in str(e): | |
| return f"Error: Access denied for {llm_model}. Free-tier API limits models >10GB. Try 'High Accuracy (Mixtral-8x7B)' or upgrade to Pro at https://huggingface.co/settings/billing.", chat_history | |
| return f"Error answering question: {str(e)}", chat_history | |
| except Exception as e: | |
| logger.error(f"Error answering question: {str(e)}") | |
| return f"Error answering question: {str(e)}", chat_history | |
| # Function to export chat history | |
| def export_chat(): | |
| if not chat_history: | |
| return "No chat history to export.", None | |
| try: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"chat_history_{timestamp}.txt" | |
| with open(filename, "w") as f: | |
| for message in chat_history: | |
| role = message["role"].capitalize() | |
| content = message["content"] | |
| f.write(f"{role}: {content}\n\n") | |
| logger.info(f"Exported chat history to {filename}.") | |
| return f"Chat history exported to {filename}.", filename | |
| except Exception as e: | |
| logger.error(f"Error exporting chat history: {str(e)}") | |
| return f"Error exporting chat history: {str(e)}", None | |
| # Function to reset the app | |
| def reset_app(): | |
| global vector_store, qa_chain, chat_history, memory | |
| try: | |
| vector_store = None | |
| qa_chain = None | |
| chat_history = [] | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| if os.path.exists(PERSIST_DIRECTORY): | |
| shutil.rmtree(PERSIST_DIRECTORY) | |
| os.makedirs(PERSIST_DIRECTORY, exist_ok=True) | |
| logger.info("Cleared ChromaDB directory on reset.") | |
| logger.info("App reset successfully.") | |
| return "App reset successfully.", None | |
| except Exception as e: | |
| logger.error(f"Error resetting app: {str(e)}") | |
| return f"Error resetting app: {str(e)}", None | |
| # Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo: | |
| gr.Markdown("# DocTalk: Document Q&A Chatbot") | |
| gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), select models, tune parameters, and ask questions!") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| file_upload = gr.Files(label="Upload Documents", file_types=[".pdf", ".txt", ".docx", ".pptx"]) | |
| with gr.Row(): | |
| process_button = gr.Button("Process Documents") | |
| reset_button = gr.Button("Reset App") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=1): | |
| llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="High Accuracy (Mixtral-8x7B)") | |
| embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)") | |
| temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature") | |
| chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size") | |
| chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap") | |
| init_button = gr.Button("Initialize QA Chain") | |
| gr.Markdown("## Chat Interface") | |
| question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...") | |
| answer = gr.Textbox(label="Answer", interactive=False) | |
| chat_display = gr.Chatbot(label="Chat History", type="messages") | |
| export_button = gr.Button("Export Chat History") | |
| export_file = gr.File(label="Exported Chat File") | |
| # Event handlers | |
| process_button.click( | |
| fn=process_documents, | |
| inputs=[file_upload, chunk_size, chunk_overlap, embedding_model], | |
| outputs=[status, chat_display] | |
| ) | |
| init_button.click( | |
| fn=initialize_qa_chain, | |
| inputs=[llm_model, temperature], | |
| outputs=[status, chat_display] | |
| ) | |
| question.submit( | |
| fn=answer_question, | |
| inputs=[question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap], | |
| outputs=[answer, chat_display] | |
| ) | |
| export_button.click( | |
| fn=export_chat, | |
| outputs=[status, export_file] | |
| ) | |
| reset_button.click( | |
| fn=reset_app, | |
| outputs=[status, chat_display] | |
| ) | |
| demo.launch() |