Custom_Rag_Bot / app.py
pradeepsengarr's picture
Update app.py
fd77b07 verified
raw
history blame
10.5 kB
import re
import os
import faiss
import numpy as np
import gradio as gr
from typing import List
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from PyPDF2 import PdfReader
import docx2txt
# === Helper functions ===
def clean_text(text: str) -> str:
"""Clean and normalize text."""
text = re.sub(r'\s+', ' ', text) # normalize whitespace
text = text.strip()
return text
def chunk_text(text: str, max_chunk_size: int = 300, overlap: int = 50) -> List[str]:
"""Split text into smaller overlapping chunks for better semantic search."""
sentences = re.split(r'(?<=[.?!])\s+', text)
chunks = []
chunk = ""
for sentence in sentences:
if len(chunk) + len(sentence) <= max_chunk_size:
chunk += sentence + " "
else:
chunks.append(chunk.strip())
chunk = sentence + " "
if chunk:
chunks.append(chunk.strip())
# Add overlapping between chunks to retain context
overlapped_chunks = []
for i in range(len(chunks)):
combined = chunks[i]
if i > 0:
combined = chunks[i-1][-overlap:] + " " + combined
overlapped_chunks.append(clean_text(combined))
return overlapped_chunks
def extract_text_from_pdf(file_path: str) -> str:
"""Extract text from PDF file."""
text = ""
try:
reader = PdfReader(file_path)
for page in reader.pages:
text += page.extract_text() + " "
except Exception as e:
print(f"Error reading PDF {file_path}: {e}")
return clean_text(text)
def extract_text_from_docx(file_path: str) -> str:
"""Extract text from DOCX file."""
try:
text = docx2txt.process(file_path)
return clean_text(text)
except Exception as e:
print(f"Error reading DOCX {file_path}: {e}")
return ""
def extract_text_from_txt(file_path: str) -> str:
"""Extract text from TXT file."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
return clean_text(text)
except Exception as e:
print(f"Error reading TXT {file_path}: {e}")
return ""
# === Main RAG System ===
class SmartDocumentRAG:
def __init__(self):
# Model & embedding initialization
self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
self.documents = []
self.chunks = []
self.index = None
self.is_indexed = False
self.document_summary = ""
def process_documents(self, uploaded_files) -> str:
"""Load, extract, chunk, embed, and index documents."""
if not uploaded_files:
return "⚠️ No files uploaded."
self.documents.clear()
self.chunks.clear()
all_text = ""
# Extract text from each uploaded file
for file_obj in uploaded_files:
# Save file temporarily to disk to process
file_path = file_obj.name
ext = os.path.splitext(file_path)[1].lower()
text = ""
if ext == ".pdf":
text = extract_text_from_pdf(file_path)
elif ext == ".docx":
text = extract_text_from_docx(file_path)
elif ext == ".txt":
text = extract_text_from_txt(file_path)
else:
continue # skip unsupported
if text:
self.documents.append(text)
all_text += text + " "
if not all_text.strip():
return "⚠️ No extractable text found in uploaded files."
# Create chunks for semantic search
self.chunks = chunk_text(all_text)
# Create embeddings for chunks
embeddings = self.embedder.encode(self.chunks, convert_to_numpy=True)
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # normalize
# Create FAISS index
dim = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dim)
self.index.add(embeddings.astype('float32'))
self.is_indexed = True
# Create simple summary
self.document_summary = self.generate_summary(all_text)
return f"βœ… Processed {len(self.documents)} document(s), {len(self.chunks)} chunks indexed."
def generate_summary(self, text: str) -> str:
"""Generate a simple summary using top sentences."""
sentences = re.split(r'(?<=[.?!])\s+', text)
summary = ' '.join(sentences[:5]) # first 5 sentences as naive summary
return summary
def find_relevant_content(self, query: str, top_k: int = 3) -> str:
"""Perform semantic search to find relevant content chunks."""
if not self.is_indexed or not self.chunks:
return ""
query_emb = self.embedder.encode([query], convert_to_numpy=True)
query_emb = query_emb / np.linalg.norm(query_emb, axis=1, keepdims=True)
scores, indices = self.index.search(query_emb.astype('float32'), min(top_k, len(self.chunks)))
relevant_chunks = []
for i, idx in enumerate(indices[0]):
if scores[0][i] > 0.1:
relevant_chunks.append(self.chunks[idx])
return " ".join(relevant_chunks)
def extract_direct_answer(self, query: str, context: str) -> str:
"""Simple regex-based fallback extraction."""
q = query.lower()
if any(word in q for word in ['name', 'who is', 'who']):
names = re.findall(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', context)
if names:
return f"**Name:** {names[0]}"
if any(word in q for word in ['experience', 'years']):
years = re.findall(r'(\d+)[\+\-\s]*(?:years?|yrs?)', context.lower())
if years:
return f"**Experience:** {years[0]} years"
if any(word in q for word in ['skill', 'technology', 'tech']):
skills = re.findall(r'\b(?:Python|Java|JavaScript|React|Node|SQL|AWS|Docker|Kubernetes|Git|HTML|CSS|Angular|Vue|Spring|Django|Flask|MongoDB|PostgreSQL)\b', context, re.I)
if skills:
unique_skills = sorted(set(skills), key=skills.index)
return f"**Skills:** {', '.join(unique_skills)}"
if any(word in q for word in ['education', 'degree', 'university']):
edu = re.findall(r'(?:Bachelor|Master|PhD|B\.?S\.?|M\.?S\.?|B\.?A\.?|M\.?A\.?).*?(?:in|of)\s+([^.]+)', context, re.I)
if edu:
return f"**Education:** {edu[0]}"
# Fallback: first sentence from context
sentences = [s.strip() for s in context.split('.') if s.strip()]
if sentences:
return f"**Answer:** {sentences[0]}"
return "I found relevant content but could not extract a specific answer."
def answer_question(self, query: str) -> str:
if not query.strip():
return "❓ Please ask a question."
if not self.is_indexed:
return "πŸ“ Please upload and process documents first."
q_lower = query.lower()
if any(word in q_lower for word in ['summary', 'summarize', 'overview', 'about']):
return f"πŸ“„ **Document Summary:**\n\n{self.document_summary}"
context = self.find_relevant_content(query, top_k=3)
if not context:
return "πŸ” No relevant information found. Try rephrasing your question."
try:
# Use model for QA
result = self.qa_pipeline(question=query, context=context)
answer = result.get('answer', '').strip()
score = result.get('score', 0)
# Confidence threshold to fallback to regex extraction
if score < 0.1 or not answer:
return self.extract_direct_answer(query, context)
return f"**Answer:** {answer}\n\n**Context:** {context[:200]}..."
except Exception as e:
print(f"QA model error: {e}")
return self.extract_direct_answer(query, context)
# === Gradio UI ===
def main():
rag = SmartDocumentRAG()
def process_files(files):
return rag.process_documents(files)
def ask_question(question):
return rag.answer_question(question)
def get_summary():
return rag.answer_question("summary")
with gr.Blocks(title="🧠 Enhanced Document Q&A", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🧠 Enhanced Document Q&A System
**Optimized with Better Models & Semantic Search**
- Upload PDF, DOCX, TXT files
- Semantic search + QA pipeline
- Direct answer extraction fallback
""")
with gr.Tab("πŸ“€ Upload & Process"):
with gr.Row():
with gr.Column():
file_upload = gr.File(label="πŸ“ Upload Documents", file_types=['.pdf','.docx','.txt'], file_count="multiple", height=150)
process_btn = gr.Button("πŸ”„ Process Documents", variant="primary", size="lg")
with gr.Column():
process_status = gr.Textbox(label="πŸ“‹ Processing Status", lines=10, interactive=False)
process_btn.click(fn=process_files, inputs=file_upload, outputs=process_status)
with gr.Tab("❓ Q&A"):
with gr.Row():
with gr.Column():
question_input = gr.Textbox(label="πŸ€” Ask Your Question", lines=3,
placeholder="Name? Experience? Skills? Education?")
with gr.Row():
ask_btn = gr.Button("🧠 Get Answer", variant="primary")
summary_btn = gr.Button("πŸ“Š Get Summary", variant="secondary")
with gr.Column():
answer_output = gr.Textbox(label="πŸ’‘ Answer", lines=8, interactive=False)
ask_btn.click(fn=ask_question, inputs=question_input, outputs=answer_output)
summary_btn.click(fn=get_summary, inputs=None, outputs=answer_output)
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
if __name__ == "__main__":
main()