Custom_Rag_Bot / app.py
pradeepsengarr's picture
Update app.py
8b78b3b verified
raw
history blame
19.5 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import PyPDF2
import docx
import io
import os
from typing import List, Optional
class DocumentRAG:
def __init__(self):
print("πŸš€ Initializing RAG System...")
# Initialize embedding model (lightweight)
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
print("βœ… Embedding model loaded")
# Initialize quantized LLM
self.setup_llm()
# Document storage
self.documents = []
self.index = None
self.is_indexed = False
def setup_llm(self):
"""Setup quantized Mistral model"""
try:
# Check if CUDA is available
if not torch.cuda.is_available():
print("⚠️ CUDA not available, falling back to CPU or alternative model")
self.setup_fallback_model()
return
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
# Load tokenizer first
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# Fix padding token issue
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with quantization
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True
)
print("βœ… Quantized Mistral model loaded successfully")
except Exception as e:
print(f"❌ Error loading model: {e}")
print("πŸ”„ Falling back to alternative model...")
self.setup_fallback_model()
def setup_fallback_model(self):
"""Fallback to smaller model if Mistral fails"""
try:
# Use a model that's better for factual Q&A and less prone to hallucination
model_name = "microsoft/DialoGPT-small"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
# Fix padding token for fallback model too
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print("βœ… Fallback model loaded")
except Exception as e:
print(f"❌ Fallback model failed: {e}")
# Try an even simpler approach - return context-based answers without generation
self.model = None
self.tokenizer = None
print("⚠️ Using context-only mode (no text generation)")
def simple_context_answer(self, query: str, context: str) -> str:
"""Improved context-based answering when model is not available"""
if not context:
return "No relevant information found in the documents."
# Improved keyword matching approach
query_words = set(query.lower().split())
context_sentences = context.split('.')
# Find sentences that contain query keywords
relevant_sentences = []
for sentence in context_sentences:
sentence = sentence.strip()
if len(sentence) < 10: # Skip very short sentences
continue
sentence_words = set(sentence.lower().split())
# Check if sentence contains query keywords
common_words = query_words.intersection(sentence_words)
if len(common_words) >= 1: # Lowered threshold
relevant_sentences.append(sentence)
if relevant_sentences:
# Return the most relevant sentences
return '. '.join(relevant_sentences[:3]) + '.'
else:
# If no exact matches, return first few sentences of context
first_sentences = context_sentences[:2]
if first_sentences:
return '. '.join([s.strip() for s in first_sentences if s.strip()]) + '.'
return "Based on the document content, I found some information but cannot provide a specific answer to your question."
def extract_text_from_file(self, file_path: str) -> str:
"""Extract text from various file formats"""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.pdf':
return self.extract_from_pdf(file_path)
elif file_extension == '.docx':
return self.extract_from_docx(file_path)
elif file_extension == '.txt':
return self.extract_from_txt(file_path)
else:
return f"Unsupported file format: {file_extension}"
except Exception as e:
return f"Error reading file: {str(e)}"
def extract_from_pdf(self, file_path: str) -> str:
"""Extract text from PDF"""
text = ""
try:
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
except Exception as e:
text = f"Error reading PDF: {str(e)}"
return text
def extract_from_docx(self, file_path: str) -> str:
"""Extract text from DOCX"""
try:
doc = docx.Document(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
except Exception as e:
return f"Error reading DOCX: {str(e)}"
def extract_from_txt(self, file_path: str) -> str:
"""Extract text from TXT"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
except Exception as e:
try:
with open(file_path, 'r', encoding='latin-1') as file:
return file.read()
except Exception as e2:
return f"Error reading TXT: {str(e2)}"
def chunk_text(self, text: str, chunk_size: int = 200, overlap: int = 30) -> List[str]:
"""Split text into overlapping chunks with better sentence preservation"""
if not text.strip():
return []
# Split by sentences first, then group into chunks
sentences = text.replace('\n', ' ').split('. ')
chunks = []
current_chunk = ""
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
# Add sentence to current chunk
test_chunk = current_chunk + ". " + sentence if current_chunk else sentence
# If chunk gets too long, save it and start new one
if len(test_chunk.split()) > chunk_size:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
else:
current_chunk = test_chunk
# Add the last chunk
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def process_documents(self, files) -> str:
"""Process uploaded files and create embeddings"""
if not files:
return "❌ No files uploaded!"
try:
all_text = ""
processed_files = []
# Extract text from all files
for file in files:
if file is None:
continue
file_text = self.extract_text_from_file(file.name)
if not file_text.startswith("Error") and not file_text.startswith("Unsupported"):
all_text += f"\n\n--- {os.path.basename(file.name)} ---\n\n{file_text}"
processed_files.append(os.path.basename(file.name))
else:
return f"❌ {file_text}"
if not all_text.strip():
return "❌ No text extracted from files!"
# Chunk the text
self.documents = self.chunk_text(all_text)
if not self.documents:
return "❌ No valid text chunks created!"
# Create embeddings
print(f"πŸ“„ Creating embeddings for {len(self.documents)} chunks...")
embeddings = self.embedder.encode(self.documents, show_progress_bar=True)
# Build FAISS index
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension)
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings)
self.index.add(embeddings.astype('float32'))
self.is_indexed = True
return f"βœ… Successfully processed {len(processed_files)} files:\n" + \
f"πŸ“„ Files: {', '.join(processed_files)}\n" + \
f"πŸ“Š Created {len(self.documents)} text chunks\n" + \
f"πŸ” Ready for Q&A!"
except Exception as e:
return f"❌ Error processing documents: {str(e)}"
def retrieve_context(self, query: str, k: int = 5) -> str:
"""Retrieve relevant context for the query with improved retrieval"""
if not self.is_indexed:
return ""
try:
# Get query embedding
query_embedding = self.embedder.encode([query])
faiss.normalize_L2(query_embedding)
# Search for similar chunks
scores, indices = self.index.search(query_embedding.astype('float32'), k)
# Get relevant documents with MUCH LOWER threshold
relevant_docs = []
for i, idx in enumerate(indices[0]):
if idx < len(self.documents) and scores[0][i] > 0.05: # Much lower threshold
relevant_docs.append(self.documents[idx])
# If no high-similarity matches, take the top results anyway
if not relevant_docs:
for i, idx in enumerate(indices[0]):
if idx < len(self.documents):
relevant_docs.append(self.documents[idx])
if len(relevant_docs) >= 3: # Take at least 3 chunks
break
return "\n\n".join(relevant_docs)
except Exception as e:
print(f"Error in retrieval: {e}")
return ""
def generate_answer(self, query: str, context: str) -> str:
"""Generate answer using the LLM with improved prompting"""
if self.model is None or self.tokenizer is None:
return self.simple_context_answer(query, context)
try:
# Check if using Mistral (has specific prompt format) or fallback model
model_name = getattr(self.model.config, '_name_or_path', '').lower()
is_mistral = 'mistral' in model_name
if is_mistral:
# Improved prompt for Mistral - more flexible
prompt = f"""<s>[INST] You are a helpful document assistant. Answer the question based on the provided context. If the exact answer isn't in the context, provide the most relevant information available.
Context:
{context[:1500]}
Question: {query}
Please provide a helpful answer based on the available information. [/INST]"""
else:
# Improved prompt for fallback models
prompt = f"""Based on the following information, please answer the question:
Context:
{context[:1000]}
Question: {query}
Answer:"""
# Tokenize with proper handling
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=800,
truncation=True,
padding=True
)
# Move to same device as model
if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
inputs = {k: v.cuda() for k, v in inputs.items()}
# Generate with more flexible parameters
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
temperature=0.3, # Slightly higher for more natural responses
do_sample=True,
top_p=0.9,
num_beams=2,
early_stopping=True,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode response
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract answer based on model type
if is_mistral and "[/INST]" in full_response:
answer = full_response.split("[/INST]")[-1].strip()
else:
# For other models, remove the prompt
if "Answer:" in full_response:
answer = full_response.split("Answer:")[-1].strip()
else:
answer = full_response[len(prompt):].strip()
# Clean up the answer
answer = self.clean_answer(answer)
return answer if answer else self.simple_context_answer(query, context)
except Exception as e:
print(f"Error in generation: {e}")
return self.simple_context_answer(query, context)
def clean_answer(self, answer: str) -> str:
"""Clean up the generated answer"""
if not answer or len(answer) < 5:
return ""
# Remove obvious problematic patterns
lines = answer.split('\n')
cleaned_lines = []
for line in lines:
line = line.strip()
if line and not any(pattern in line.lower() for pattern in [
'what are you doing', 'what do you think', 'how are you',
'i am an ai', 'i cannot', 'i don\'t know'
]):
cleaned_lines.append(line)
cleaned_answer = ' '.join(cleaned_lines)
# Limit length to prevent rambling
if len(cleaned_answer) > 500:
sentences = cleaned_answer.split('.')
cleaned_answer = '. '.join(sentences[:3]) + '.'
return cleaned_answer.strip()
def answer_question(self, query: str) -> str:
"""Main function to answer questions with improved handling"""
if not query.strip():
return "❓ Please ask a question!"
if not self.is_indexed:
return "πŸ“ Please upload and process documents first!"
try:
# Retrieve relevant context
context = self.retrieve_context(query, k=7) # Get more chunks
if not context:
return "πŸ” No relevant information found in the uploaded documents for your question."
# Generate answer
answer = self.generate_answer(query, context)
if answer and len(answer) > 10:
return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Source Context:**\n{context[:300]}..."
else:
# Fallback to simple context display
return f"πŸ“„ **Based on the document content:**\n{context[:500]}..."
except Exception as e:
return f"❌ Error answering question: {str(e)}"
# Initialize the RAG system
print("Initializing Document RAG System...")
rag_system = DocumentRAG()
# Gradio Interface
def create_interface():
with gr.Blocks(title="πŸ“š Document Q&A with RAG", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ“š Document Q&A System
Upload your documents and ask questions about them!
**Supported formats:** PDF, DOCX, TXT
""")
with gr.Tab("πŸ“€ Upload Documents"):
with gr.Row():
with gr.Column():
file_upload = gr.File(
label="Upload Documents",
file_count="multiple",
file_types=[".pdf", ".docx", ".txt"]
)
process_btn = gr.Button("πŸ”„ Process Documents", variant="primary")
with gr.Column():
process_status = gr.Textbox(
label="Processing Status",
lines=8,
interactive=False
)
process_btn.click(
fn=rag_system.process_documents,
inputs=[file_upload],
outputs=[process_status]
)
with gr.Tab("❓ Ask Questions"):
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="Your Question",
placeholder="What would you like to know about your documents?",
lines=3
)
ask_btn = gr.Button("πŸ” Get Answer", variant="primary")
with gr.Column():
answer_output = gr.Textbox(
label="Answer",
lines=12,
interactive=False
)
ask_btn.click(
fn=rag_system.answer_question,
inputs=[question_input],
outputs=[answer_output]
)
# Example questions
gr.Markdown("""
### πŸ’‘ Example Questions:
- What is the main topic of the document?
- Can you summarize the key points?
- What are the conclusions mentioned?
- Are there any specific numbers or statistics?
- Who are the main people or organizations mentioned?
""")
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)