Spaces:
Sleeping
Sleeping
import re | |
import os | |
import faiss | |
import numpy as np | |
import gradio as gr | |
from typing import List | |
from sentence_transformers import SentenceTransformer | |
from transformers import pipeline | |
from PyPDF2 import PdfReader | |
import docx2txt | |
# === Helper functions === | |
def clean_text(text: str) -> str: | |
"""Clean and normalize text.""" | |
text = re.sub(r'\s+', ' ', text) # normalize whitespace | |
text = text.strip() | |
return text | |
def chunk_text(text: str, max_chunk_size: int = 300, overlap: int = 50) -> List[str]: | |
"""Split text into smaller overlapping chunks for better semantic search.""" | |
sentences = re.split(r'(?<=[.?!])\s+', text) | |
chunks = [] | |
chunk = "" | |
for sentence in sentences: | |
if len(chunk) + len(sentence) <= max_chunk_size: | |
chunk += sentence + " " | |
else: | |
chunks.append(chunk.strip()) | |
chunk = sentence + " " | |
if chunk: | |
chunks.append(chunk.strip()) | |
# Add overlapping between chunks to retain context | |
overlapped_chunks = [] | |
for i in range(len(chunks)): | |
combined = chunks[i] | |
if i > 0: | |
combined = chunks[i-1][-overlap:] + " " + combined | |
overlapped_chunks.append(clean_text(combined)) | |
return overlapped_chunks | |
def extract_text_from_pdf(file_path: str) -> str: | |
"""Extract text from PDF file.""" | |
text = "" | |
try: | |
reader = PdfReader(file_path) | |
for page in reader.pages: | |
text += page.extract_text() + " " | |
except Exception as e: | |
print(f"Error reading PDF {file_path}: {e}") | |
return clean_text(text) | |
def extract_text_from_docx(file_path: str) -> str: | |
"""Extract text from DOCX file.""" | |
try: | |
text = docx2txt.process(file_path) | |
return clean_text(text) | |
except Exception as e: | |
print(f"Error reading DOCX {file_path}: {e}") | |
return "" | |
def extract_text_from_txt(file_path: str) -> str: | |
"""Extract text from TXT file.""" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
text = f.read() | |
return clean_text(text) | |
except Exception as e: | |
print(f"Error reading TXT {file_path}: {e}") | |
return "" | |
# === Main RAG System === | |
class SmartDocumentRAG: | |
def __init__(self): | |
# Model & embedding initialization | |
self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad") | |
self.documents = [] | |
self.chunks = [] | |
self.index = None | |
self.is_indexed = False | |
self.document_summary = "" | |
def process_documents(self, uploaded_files) -> str: | |
"""Load, extract, chunk, embed, and index documents.""" | |
if not uploaded_files: | |
return "β οΈ No files uploaded." | |
self.documents.clear() | |
self.chunks.clear() | |
all_text = "" | |
# Extract text from each uploaded file | |
for file_obj in uploaded_files: | |
# Save file temporarily to disk to process | |
file_path = file_obj.name | |
ext = os.path.splitext(file_path)[1].lower() | |
text = "" | |
if ext == ".pdf": | |
text = extract_text_from_pdf(file_path) | |
elif ext == ".docx": | |
text = extract_text_from_docx(file_path) | |
elif ext == ".txt": | |
text = extract_text_from_txt(file_path) | |
else: | |
continue # skip unsupported | |
if text: | |
self.documents.append(text) | |
all_text += text + " " | |
if not all_text.strip(): | |
return "β οΈ No extractable text found in uploaded files." | |
# Create chunks for semantic search | |
self.chunks = chunk_text(all_text) | |
# Create embeddings for chunks | |
embeddings = self.embedder.encode(self.chunks, convert_to_numpy=True) | |
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # normalize | |
# Create FAISS index | |
dim = embeddings.shape[1] | |
self.index = faiss.IndexFlatIP(dim) | |
self.index.add(embeddings.astype('float32')) | |
self.is_indexed = True | |
# Create simple summary | |
self.document_summary = self.generate_summary(all_text) | |
return f"β Processed {len(self.documents)} document(s), {len(self.chunks)} chunks indexed." | |
def generate_summary(self, text: str) -> str: | |
"""Generate a simple summary using top sentences.""" | |
sentences = re.split(r'(?<=[.?!])\s+', text) | |
summary = ' '.join(sentences[:5]) # first 5 sentences as naive summary | |
return summary | |
def find_relevant_content(self, query: str, top_k: int = 3) -> str: | |
"""Perform semantic search to find relevant content chunks.""" | |
if not self.is_indexed or not self.chunks: | |
return "" | |
query_emb = self.embedder.encode([query], convert_to_numpy=True) | |
query_emb = query_emb / np.linalg.norm(query_emb, axis=1, keepdims=True) | |
scores, indices = self.index.search(query_emb.astype('float32'), min(top_k, len(self.chunks))) | |
relevant_chunks = [] | |
for i, idx in enumerate(indices[0]): | |
if scores[0][i] > 0.1: | |
relevant_chunks.append(self.chunks[idx]) | |
return " ".join(relevant_chunks) | |
def extract_direct_answer(self, query: str, context: str) -> str: | |
"""Simple regex-based fallback extraction.""" | |
q = query.lower() | |
if any(word in q for word in ['name', 'who is', 'who']): | |
names = re.findall(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', context) | |
if names: | |
return f"**Name:** {names[0]}" | |
if any(word in q for word in ['experience', 'years']): | |
years = re.findall(r'(\d+)[\+\-\s]*(?:years?|yrs?)', context.lower()) | |
if years: | |
return f"**Experience:** {years[0]} years" | |
if any(word in q for word in ['skill', 'technology', 'tech']): | |
skills = re.findall(r'\b(?:Python|Java|JavaScript|React|Node|SQL|AWS|Docker|Kubernetes|Git|HTML|CSS|Angular|Vue|Spring|Django|Flask|MongoDB|PostgreSQL)\b', context, re.I) | |
if skills: | |
unique_skills = sorted(set(skills), key=skills.index) | |
return f"**Skills:** {', '.join(unique_skills)}" | |
if any(word in q for word in ['education', 'degree', 'university']): | |
edu = re.findall(r'(?:Bachelor|Master|PhD|B\.?S\.?|M\.?S\.?|B\.?A\.?|M\.?A\.?).*?(?:in|of)\s+([^.]+)', context, re.I) | |
if edu: | |
return f"**Education:** {edu[0]}" | |
# Fallback: first sentence from context | |
sentences = [s.strip() for s in context.split('.') if s.strip()] | |
if sentences: | |
return f"**Answer:** {sentences[0]}" | |
return "I found relevant content but could not extract a specific answer." | |
def answer_question(self, query: str) -> str: | |
if not query.strip(): | |
return "β Please ask a question." | |
if not self.is_indexed: | |
return "π Please upload and process documents first." | |
q_lower = query.lower() | |
if any(word in q_lower for word in ['summary', 'summarize', 'overview', 'about']): | |
return f"π **Document Summary:**\n\n{self.document_summary}" | |
context = self.find_relevant_content(query, top_k=3) | |
if not context: | |
return "π No relevant information found. Try rephrasing your question." | |
try: | |
# Use model for QA | |
result = self.qa_pipeline(question=query, context=context) | |
answer = result.get('answer', '').strip() | |
score = result.get('score', 0) | |
# Confidence threshold to fallback to regex extraction | |
if score < 0.1 or not answer: | |
return self.extract_direct_answer(query, context) | |
return f"**Answer:** {answer}\n\n**Context:** {context[:200]}..." | |
except Exception as e: | |
print(f"QA model error: {e}") | |
return self.extract_direct_answer(query, context) | |
# === Gradio UI === | |
def main(): | |
rag = SmartDocumentRAG() | |
def process_files(files): | |
return rag.process_documents(files) | |
def ask_question(question): | |
return rag.answer_question(question) | |
def get_summary(): | |
return rag.answer_question("summary") | |
with gr.Blocks(title="π§ Enhanced Document Q&A", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π§ Enhanced Document Q&A System | |
**Optimized with Better Models & Semantic Search** | |
- Upload PDF, DOCX, TXT files | |
- Semantic search + QA pipeline | |
- Direct answer extraction fallback | |
""") | |
with gr.Tab("π€ Upload & Process"): | |
with gr.Row(): | |
with gr.Column(): | |
file_upload = gr.File(label="π Upload Documents", file_types=['.pdf','.docx','.txt'], file_count="multiple", height=150) | |
process_btn = gr.Button("π Process Documents", variant="primary", size="lg") | |
with gr.Column(): | |
process_status = gr.Textbox(label="π Processing Status", lines=10, interactive=False) | |
process_btn.click(fn=process_files, inputs=file_upload, outputs=process_status) | |
with gr.Tab("β Q&A"): | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox(label="π€ Ask Your Question", lines=3, | |
placeholder="Name? Experience? Skills? Education?") | |
with gr.Row(): | |
ask_btn = gr.Button("π§ Get Answer", variant="primary") | |
summary_btn = gr.Button("π Get Summary", variant="secondary") | |
with gr.Column(): | |
answer_output = gr.Textbox(label="π‘ Answer", lines=8, interactive=False) | |
ask_btn.click(fn=ask_question, inputs=question_input, outputs=answer_output) | |
summary_btn.click(fn=get_summary, inputs=None, outputs=answer_output) | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
if __name__ == "__main__": | |
main() | |