import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from sentence_transformers import SentenceTransformer import faiss import numpy as np import PyPDF2 import docx import io import os from typing import List, Optional class DocumentRAG: def __init__(self): print("šŸš€ Initializing RAG System...") # Initialize embedding model (lightweight) self.embedder = SentenceTransformer('all-MiniLM-L6-v2') print("āœ… Embedding model loaded") # Initialize quantized LLM self.setup_llm() # Document storage self.documents = [] self.index = None self.is_indexed = False def setup_llm(self): """Setup quantized Mistral model""" try: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) model_name = "mistralai/Mistral-7B-Instruct-v0.1" self.tokenizer = AutoTokenizer.from_pretrained(model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True ) print("āœ… Quantized Mistral model loaded") except Exception as e: print(f"āŒ Error loading model: {e}") # Fallback to a smaller model if Mistral fails self.setup_fallback_model() def setup_fallback_model(self): """Fallback to smaller model if Mistral fails""" try: model_name = "microsoft/DialoGPT-small" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name) print("āœ… Fallback model loaded") except Exception as e: print(f"āŒ Fallback model failed: {e}") self.model = None self.tokenizer = None def extract_text_from_file(self, file_path: str) -> str: """Extract text from various file formats""" try: file_extension = os.path.splitext(file_path)[1].lower() if file_extension == '.pdf': return self.extract_from_pdf(file_path) elif file_extension == '.docx': return self.extract_from_docx(file_path) elif file_extension == '.txt': return self.extract_from_txt(file_path) else: return f"Unsupported file format: {file_extension}" except Exception as e: return f"Error reading file: {str(e)}" def extract_from_pdf(self, file_path: str) -> str: """Extract text from PDF""" text = "" try: with open(file_path, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) for page in pdf_reader.pages: text += page.extract_text() + "\n" except Exception as e: text = f"Error reading PDF: {str(e)}" return text def extract_from_docx(self, file_path: str) -> str: """Extract text from DOCX""" try: doc = docx.Document(file_path) text = "" for paragraph in doc.paragraphs: text += paragraph.text + "\n" return text except Exception as e: return f"Error reading DOCX: {str(e)}" def extract_from_txt(self, file_path: str) -> str: """Extract text from TXT""" try: with open(file_path, 'r', encoding='utf-8') as file: return file.read() except Exception as e: try: with open(file_path, 'r', encoding='latin-1') as file: return file.read() except Exception as e2: return f"Error reading TXT: {str(e2)}" def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: """Split text into overlapping chunks""" if not text.strip(): return [] words = text.split() chunks = [] for i in range(0, len(words), chunk_size - overlap): chunk = ' '.join(words[i:i + chunk_size]) if chunk.strip(): chunks.append(chunk.strip()) if i + chunk_size >= len(words): break return chunks def process_documents(self, files) -> str: """Process uploaded files and create embeddings""" if not files: return "āŒ No files uploaded!" try: all_text = "" processed_files = [] # Extract text from all files for file in files: if file is None: continue file_text = self.extract_text_from_file(file.name) if not file_text.startswith("Error") and not file_text.startswith("Unsupported"): all_text += f"\n\n--- {os.path.basename(file.name)} ---\n\n{file_text}" processed_files.append(os.path.basename(file.name)) else: return f"āŒ {file_text}" if not all_text.strip(): return "āŒ No text extracted from files!" # Chunk the text self.documents = self.chunk_text(all_text) if not self.documents: return "āŒ No valid text chunks created!" # Create embeddings print(f"šŸ“„ Creating embeddings for {len(self.documents)} chunks...") embeddings = self.embedder.encode(self.documents, show_progress_bar=True) # Build FAISS index dimension = embeddings.shape[1] self.index = faiss.IndexFlatIP(dimension) # Normalize embeddings for cosine similarity faiss.normalize_L2(embeddings) self.index.add(embeddings.astype('float32')) self.is_indexed = True return f"āœ… Successfully processed {len(processed_files)} files:\n" + \ f"šŸ“„ Files: {', '.join(processed_files)}\n" + \ f"šŸ“Š Created {len(self.documents)} text chunks\n" + \ f"šŸ” Ready for Q&A!" except Exception as e: return f"āŒ Error processing documents: {str(e)}" def retrieve_context(self, query: str, k: int = 3) -> str: """Retrieve relevant context for the query""" if not self.is_indexed: return "" try: # Get query embedding query_embedding = self.embedder.encode([query]) faiss.normalize_L2(query_embedding) # Search for similar chunks scores, indices = self.index.search(query_embedding.astype('float32'), k) # Get relevant documents relevant_docs = [] for i, idx in enumerate(indices[0]): if idx < len(self.documents) and scores[0][i] > 0.1: # Similarity threshold relevant_docs.append(self.documents[idx]) return "\n\n".join(relevant_docs) except Exception as e: print(f"Error in retrieval: {e}") return "" def generate_answer(self, query: str, context: str) -> str: """Generate answer using the LLM""" if self.model is None or self.tokenizer is None: return "āŒ Model not available. Please try again." try: # Create prompt prompt = f"""[INST] Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question." Context: {context[:2000]} # Limit context length Question: {query} Answer: [/INST]""" # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", max_length=1024, truncation=True, padding=True ) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, temperature=0.7, do_sample=True, top_p=0.9, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode response full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract answer (remove the prompt part) if "[/INST]" in full_response: answer = full_response.split("[/INST]")[-1].strip() else: answer = full_response[len(prompt):].strip() return answer if answer else "I couldn't generate a proper response." except Exception as e: return f"āŒ Error generating answer: {str(e)}" def answer_question(self, query: str) -> str: """Main function to answer questions""" if not query.strip(): return "ā“ Please ask a question!" if not self.is_indexed: return "šŸ“ Please upload and process documents first!" try: # Retrieve relevant context context = self.retrieve_context(query) if not context: return "šŸ” No relevant information found in the uploaded documents." # Generate answer answer = self.generate_answer(query, context) return f"šŸ’” **Answer:** {answer}\n\nšŸ“„ **Source Context:** {context[:500]}..." except Exception as e: return f"āŒ Error answering question: {str(e)}" # Initialize the RAG system print("Initializing Document RAG System...") rag_system = DocumentRAG() # Gradio Interface def create_interface(): with gr.Blocks(title="šŸ“š Document Q&A with RAG", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # šŸ“š Document Q&A System Upload your documents and ask questions about them! **Supported formats:** PDF, DOCX, TXT """) with gr.Tab("šŸ“¤ Upload Documents"): with gr.Row(): with gr.Column(): file_upload = gr.File( label="Upload Documents", file_count="multiple", file_types=[".pdf", ".docx", ".txt"] ) process_btn = gr.Button("šŸ”„ Process Documents", variant="primary") with gr.Column(): process_status = gr.Textbox( label="Processing Status", lines=8, interactive=False ) process_btn.click( fn=rag_system.process_documents, inputs=[file_upload], outputs=[process_status] ) with gr.Tab("ā“ Ask Questions"): with gr.Row(): with gr.Column(): question_input = gr.Textbox( label="Your Question", placeholder="What would you like to know about your documents?", lines=3 ) ask_btn = gr.Button("šŸ” Get Answer", variant="primary") with gr.Column(): answer_output = gr.Textbox( label="Answer", lines=10, interactive=False ) ask_btn.click( fn=rag_system.answer_question, inputs=[question_input], outputs=[answer_output] ) # Example questions gr.Markdown(""" ### šŸ’” Example Questions: - What is the main topic of the document? - Can you summarize the key points? - What are the conclusions mentioned? - Are there any specific numbers or statistics? """) return demo # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )