Custom_Rag_Bot / app.py
pradeepsengarr's picture
Update app.py
1dedfac verified
raw
history blame
13.3 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import PyPDF2
import docx
import io
import os
from typing import List, Optional
class DocumentRAG:
def __init__(self):
print("πŸš€ Initializing RAG System...")
# Initialize embedding model (lightweight)
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
print("βœ… Embedding model loaded")
# Initialize quantized LLM
self.setup_llm()
# Document storage
self.documents = []
self.index = None
self.is_indexed = False
def setup_llm(self):
"""Setup quantized Mistral model"""
try:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
print("βœ… Quantized Mistral model loaded")
except Exception as e:
print(f"❌ Error loading model: {e}")
# Fallback to a smaller model if Mistral fails
self.setup_fallback_model()
def setup_fallback_model(self):
"""Fallback to smaller model if Mistral fails"""
try:
model_name = "microsoft/DialoGPT-small"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
print("βœ… Fallback model loaded")
except Exception as e:
print(f"❌ Fallback model failed: {e}")
self.model = None
self.tokenizer = None
def extract_text_from_file(self, file_path: str) -> str:
"""Extract text from various file formats"""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.pdf':
return self.extract_from_pdf(file_path)
elif file_extension == '.docx':
return self.extract_from_docx(file_path)
elif file_extension == '.txt':
return self.extract_from_txt(file_path)
else:
return f"Unsupported file format: {file_extension}"
except Exception as e:
return f"Error reading file: {str(e)}"
def extract_from_pdf(self, file_path: str) -> str:
"""Extract text from PDF"""
text = ""
try:
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
except Exception as e:
text = f"Error reading PDF: {str(e)}"
return text
def extract_from_docx(self, file_path: str) -> str:
"""Extract text from DOCX"""
try:
doc = docx.Document(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
except Exception as e:
return f"Error reading DOCX: {str(e)}"
def extract_from_txt(self, file_path: str) -> str:
"""Extract text from TXT"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
except Exception as e:
try:
with open(file_path, 'r', encoding='latin-1') as file:
return file.read()
except Exception as e2:
return f"Error reading TXT: {str(e2)}"
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
"""Split text into overlapping chunks"""
if not text.strip():
return []
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if chunk.strip():
chunks.append(chunk.strip())
if i + chunk_size >= len(words):
break
return chunks
def process_documents(self, files) -> str:
"""Process uploaded files and create embeddings"""
if not files:
return "❌ No files uploaded!"
try:
all_text = ""
processed_files = []
# Extract text from all files
for file in files:
if file is None:
continue
file_text = self.extract_text_from_file(file.name)
if not file_text.startswith("Error") and not file_text.startswith("Unsupported"):
all_text += f"\n\n--- {os.path.basename(file.name)} ---\n\n{file_text}"
processed_files.append(os.path.basename(file.name))
else:
return f"❌ {file_text}"
if not all_text.strip():
return "❌ No text extracted from files!"
# Chunk the text
self.documents = self.chunk_text(all_text)
if not self.documents:
return "❌ No valid text chunks created!"
# Create embeddings
print(f"πŸ“„ Creating embeddings for {len(self.documents)} chunks...")
embeddings = self.embedder.encode(self.documents, show_progress_bar=True)
# Build FAISS index
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension)
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings)
self.index.add(embeddings.astype('float32'))
self.is_indexed = True
return f"βœ… Successfully processed {len(processed_files)} files:\n" + \
f"πŸ“„ Files: {', '.join(processed_files)}\n" + \
f"πŸ“Š Created {len(self.documents)} text chunks\n" + \
f"πŸ” Ready for Q&A!"
except Exception as e:
return f"❌ Error processing documents: {str(e)}"
def retrieve_context(self, query: str, k: int = 3) -> str:
"""Retrieve relevant context for the query"""
if not self.is_indexed:
return ""
try:
# Get query embedding
query_embedding = self.embedder.encode([query])
faiss.normalize_L2(query_embedding)
# Search for similar chunks
scores, indices = self.index.search(query_embedding.astype('float32'), k)
# Get relevant documents
relevant_docs = []
for i, idx in enumerate(indices[0]):
if idx < len(self.documents) and scores[0][i] > 0.1: # Similarity threshold
relevant_docs.append(self.documents[idx])
return "\n\n".join(relevant_docs)
except Exception as e:
print(f"Error in retrieval: {e}")
return ""
def generate_answer(self, query: str, context: str) -> str:
"""Generate answer using the LLM"""
if self.model is None or self.tokenizer is None:
return "❌ Model not available. Please try again."
try:
# Create prompt
prompt = f"""<s>[INST] Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question."
Context:
{context[:2000]} # Limit context length
Question: {query}
Answer: [/INST]"""
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True
)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode response
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract answer (remove the prompt part)
if "[/INST]" in full_response:
answer = full_response.split("[/INST]")[-1].strip()
else:
answer = full_response[len(prompt):].strip()
return answer if answer else "I couldn't generate a proper response."
except Exception as e:
return f"❌ Error generating answer: {str(e)}"
def answer_question(self, query: str) -> str:
"""Main function to answer questions"""
if not query.strip():
return "❓ Please ask a question!"
if not self.is_indexed:
return "πŸ“ Please upload and process documents first!"
try:
# Retrieve relevant context
context = self.retrieve_context(query)
if not context:
return "πŸ” No relevant information found in the uploaded documents."
# Generate answer
answer = self.generate_answer(query, context)
return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Source Context:** {context[:500]}..."
except Exception as e:
return f"❌ Error answering question: {str(e)}"
# Initialize the RAG system
print("Initializing Document RAG System...")
rag_system = DocumentRAG()
# Gradio Interface
def create_interface():
with gr.Blocks(title="πŸ“š Document Q&A with RAG", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ“š Document Q&A System
Upload your documents and ask questions about them!
**Supported formats:** PDF, DOCX, TXT
""")
with gr.Tab("πŸ“€ Upload Documents"):
with gr.Row():
with gr.Column():
file_upload = gr.File(
label="Upload Documents",
file_count="multiple",
file_types=[".pdf", ".docx", ".txt"]
)
process_btn = gr.Button("πŸ”„ Process Documents", variant="primary")
with gr.Column():
process_status = gr.Textbox(
label="Processing Status",
lines=8,
interactive=False
)
process_btn.click(
fn=rag_system.process_documents,
inputs=[file_upload],
outputs=[process_status]
)
with gr.Tab("❓ Ask Questions"):
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="Your Question",
placeholder="What would you like to know about your documents?",
lines=3
)
ask_btn = gr.Button("πŸ” Get Answer", variant="primary")
with gr.Column():
answer_output = gr.Textbox(
label="Answer",
lines=10,
interactive=False
)
ask_btn.click(
fn=rag_system.answer_question,
inputs=[question_input],
outputs=[answer_output]
)
# Example questions
gr.Markdown("""
### πŸ’‘ Example Questions:
- What is the main topic of the document?
- Can you summarize the key points?
- What are the conclusions mentioned?
- Are there any specific numbers or statistics?
""")
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)