Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import PyPDF2 | |
import docx | |
import io | |
import os | |
import re | |
from typing import List, Optional, Dict, Tuple | |
import json | |
from collections import Counter | |
import warnings | |
warnings.filterwarnings("ignore") | |
class SmartDocumentRAG: | |
def __init__(self): | |
print("π Initializing Enhanced Smart RAG System...") | |
# Initialize better embedding model | |
self.embedder = SentenceTransformer('all-MiniLM-L6-v2') # Faster and good quality | |
print("β Embedding model loaded") | |
# Initialize optimized LLM with better quantization | |
self.setup_llm() | |
# Document storage | |
self.documents = [] | |
self.document_metadata = [] | |
self.index = None | |
self.is_indexed = False | |
self.raw_text = "" | |
self.document_type = "general" | |
self.document_summary = "" | |
self.sentence_embeddings = [] | |
self.sentences = [] | |
def setup_llm(self): | |
"""Setup optimized model with better quantization""" | |
try: | |
# Check CUDA availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"π§ Using device: {device}") | |
if device == "cuda": | |
self.setup_gpu_model() | |
else: | |
self.setup_cpu_model() | |
except Exception as e: | |
print(f"β Error loading models: {e}") | |
self.setup_fallback_model() | |
def setup_gpu_model(self): | |
"""Setup GPU model with proper quantization""" | |
try: | |
# Use Phi-2 - excellent for Q&A and reasoning | |
model_name = "microsoft/DialoGPT-medium" | |
# Better quantization config | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_quant_storage=torch.uint8 | |
) | |
try: | |
# Try Flan-T5 first - excellent for Q&A | |
model_name = "google/flan-t5-base" | |
print(f"π€ Loading {model_name}...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True | |
) | |
# Create pipeline for easier use | |
self.qa_pipeline = pipeline( | |
"text2text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
max_length=512, | |
do_sample=True, | |
temperature=0.3, | |
top_p=0.9 | |
) | |
print("β Flan-T5 model loaded successfully") | |
self.model_type = "flan-t5" | |
except Exception as e: | |
print(f"Flan-T5 failed, trying Phi-2: {e}") | |
# Try Phi-2 as backup | |
model_name = "microsoft/phi-2" | |
print(f"π€ Loading {model_name}...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True | |
) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
print("β Phi-2 model loaded successfully") | |
self.model_type = "phi-2" | |
except Exception as e: | |
print(f"β GPU models failed: {e}") | |
self.setup_cpu_model() | |
def setup_cpu_model(self): | |
"""Setup CPU-optimized model""" | |
try: | |
# Use DistilBERT for Q&A - much better than DialoGPT for this task | |
model_name = "distilbert-base-cased-distilled-squad" | |
print(f"π€ Loading CPU model: {model_name}") | |
self.qa_pipeline = pipeline( | |
"question-answering", | |
model=model_name, | |
tokenizer=model_name | |
) | |
self.model_type = "distilbert-qa" | |
print("β DistilBERT Q&A model loaded successfully") | |
except Exception as e: | |
print(f"β CPU model failed: {e}") | |
self.setup_fallback_model() | |
def setup_fallback_model(self): | |
"""Fallback to basic model""" | |
try: | |
print("π€ Loading fallback model...") | |
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad") | |
self.model_type = "fallback" | |
print("β Fallback model loaded") | |
except Exception as e: | |
print(f"β All models failed: {e}") | |
self.qa_pipeline = None | |
self.model_type = "none" | |
def detect_document_type(self, text: str) -> str: | |
"""Enhanced document type detection""" | |
text_lower = text.lower() | |
resume_patterns = [ | |
'experience', 'skills', 'education', 'linkedin', 'email', 'phone', | |
'work experience', 'employment', 'resume', 'cv', 'curriculum vitae', | |
'internship', 'projects', 'achievements', 'career', 'profile', 'objective' | |
] | |
research_patterns = [ | |
'abstract', 'introduction', 'methodology', 'conclusion', 'references', | |
'literature review', 'hypothesis', 'study', 'research', 'findings', | |
'data analysis', 'results', 'discussion', 'bibliography', 'journal' | |
] | |
business_patterns = [ | |
'company', 'revenue', 'market', 'strategy', 'business', 'financial', | |
'quarter', 'profit', 'sales', 'growth', 'investment', 'stakeholder', | |
'operations', 'management', 'corporate', 'enterprise', 'budget' | |
] | |
technical_patterns = [ | |
'implementation', 'algorithm', 'system', 'technical', 'specification', | |
'architecture', 'development', 'software', 'programming', 'api', | |
'database', 'framework', 'deployment', 'infrastructure', 'code' | |
] | |
def count_matches(patterns, text): | |
score = 0 | |
for pattern in patterns: | |
count = text.count(pattern) | |
score += count * (2 if len(pattern.split()) > 1 else 1) # Weight phrases higher | |
return score | |
scores = { | |
'resume': count_matches(resume_patterns, text_lower), | |
'research': count_matches(research_patterns, text_lower), | |
'business': count_matches(business_patterns, text_lower), | |
'technical': count_matches(technical_patterns, text_lower) | |
} | |
max_score = max(scores.values()) | |
if max_score > 5: # Higher threshold | |
return max(scores, key=scores.get) | |
return 'general' | |
def create_document_summary(self, text: str) -> str: | |
"""Enhanced document summary creation""" | |
try: | |
clean_text = re.sub(r'\s+', ' ', text).strip() | |
sentences = re.split(r'[.!?]+', clean_text) | |
sentences = [s.strip() for s in sentences if len(s.strip()) > 30] | |
if not sentences: | |
return "Document contains basic information." | |
# Use first few sentences and key information | |
if self.document_type == 'resume': | |
return self.extract_resume_summary(sentences, clean_text) | |
elif self.document_type == 'research': | |
return self.extract_research_summary(sentences) | |
elif self.document_type == 'business': | |
return self.extract_business_summary(sentences) | |
else: | |
return self.extract_general_summary(sentences) | |
except Exception as e: | |
print(f"Summary creation error: {e}") | |
return "Document summary not available." | |
def extract_resume_summary(self, sentences: List[str], full_text: str) -> str: | |
"""Extract resume-specific summary with better name detection""" | |
summary_parts = [] | |
# Extract name using multiple patterns | |
name = self.extract_name(full_text) | |
if name: | |
summary_parts.append(f"Resume of {name}") | |
# Extract role/title | |
role_patterns = [ | |
r'(?:software|senior|junior|lead|principal)?\s*(?:engineer|developer|analyst|manager|designer|architect|consultant)', | |
r'(?:full stack|frontend|backend|data|ml|ai)\s*(?:engineer|developer)', | |
r'(?:product|project|technical)\s*manager' | |
] | |
for sentence in sentences[:5]: | |
for pattern in role_patterns: | |
matches = re.findall(pattern, sentence.lower()) | |
if matches: | |
summary_parts.append(f"working as {matches[0].title()}") | |
break | |
# Extract experience | |
exp_match = re.search(r'(\d+)[\+\-\s]*(?:years?|yrs?)\s*(?:of\s*)?(?:experience|exp)', full_text.lower()) | |
if exp_match: | |
summary_parts.append(f"with {exp_match.group(1)}+ years of experience") | |
return '. '.join(summary_parts) + '.' if summary_parts else "Professional resume with career details." | |
def extract_name(self, text: str) -> str: | |
"""Extract name from document using multiple strategies""" | |
# Strategy 1: Look for name patterns at the beginning | |
lines = text.split('\n')[:10] # First 10 lines | |
for line in lines: | |
line = line.strip() | |
if len(line) < 50 and len(line) > 3: # Likely a header line | |
# Check if it looks like a name | |
name_match = re.match(r'^([A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)(?:\s|$)', line) | |
if name_match: | |
return name_match.group(1) | |
# Strategy 2: Look for "Name:" pattern | |
name_patterns = [ | |
r'(?:name|full name):\s*([A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)', | |
r'^([A-Z][a-z]+\s+[A-Z][a-z]+)(?:\s*\n|\s*email|\s*phone|\s*linkedin)', | |
] | |
for pattern in name_patterns: | |
match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE) | |
if match: | |
return match.group(1) | |
return "" | |
def extract_research_summary(self, sentences: List[str]) -> str: | |
"""Extract research paper summary""" | |
# Look for abstract or introduction | |
for sentence in sentences[:5]: | |
if any(word in sentence.lower() for word in ['abstract', 'study', 'research', 'paper']): | |
return sentence[:200] + ('...' if len(sentence) > 200 else '') | |
return "Research document with academic content." | |
def extract_business_summary(self, sentences: List[str]) -> str: | |
"""Extract business document summary""" | |
for sentence in sentences[:3]: | |
if any(word in sentence.lower() for word in ['company', 'business', 'organization']): | |
return sentence[:200] + ('...' if len(sentence) > 200 else '') | |
return "Business document with organizational information." | |
def extract_general_summary(self, sentences: List[str]) -> str: | |
"""Extract general document summary""" | |
return sentences[0][:200] + ('...' if len(sentences[0]) > 200 else '') if sentences else "General document." | |
def extract_text_from_file(self, file_path: str) -> str: | |
"""Enhanced text extraction""" | |
try: | |
file_extension = os.path.splitext(file_path)[1].lower() | |
if file_extension == '.pdf': | |
return self.extract_from_pdf(file_path) | |
elif file_extension == '.docx': | |
return self.extract_from_docx(file_path) | |
elif file_extension == '.txt': | |
return self.extract_from_txt(file_path) | |
else: | |
return f"Unsupported file format: {file_extension}" | |
except Exception as e: | |
return f"Error reading file: {str(e)}" | |
def extract_from_pdf(self, file_path: str) -> str: | |
"""Enhanced PDF extraction""" | |
text = "" | |
try: | |
with open(file_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
for page in pdf_reader.pages: | |
page_text = page.extract_text() | |
if page_text.strip(): | |
# Better text cleaning | |
page_text = re.sub(r'\s+', ' ', page_text) | |
page_text = re.sub(r'([a-z])([A-Z])', r'\1 \2', page_text) # Fix merged words | |
text += f"{page_text}\n" | |
except Exception as e: | |
text = f"Error reading PDF: {str(e)}" | |
return text.strip() | |
def extract_from_docx(self, file_path: str) -> str: | |
"""Enhanced DOCX extraction""" | |
try: | |
doc = docx.Document(file_path) | |
text = "" | |
for paragraph in doc.paragraphs: | |
if paragraph.text.strip(): | |
text += paragraph.text.strip() + "\n" | |
return text.strip() | |
except Exception as e: | |
return f"Error reading DOCX: {str(e)}" | |
def extract_from_txt(self, file_path: str) -> str: | |
"""Enhanced TXT extraction""" | |
encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1'] | |
for encoding in encodings: | |
try: | |
with open(file_path, 'r', encoding=encoding) as file: | |
return file.read().strip() | |
except UnicodeDecodeError: | |
continue | |
except Exception as e: | |
return f"Error reading TXT: {str(e)}" | |
return "Error: Could not decode file" | |
def enhanced_chunk_text(self, text: str, max_chunk_size: int = 300, overlap: int = 50) -> list[str]: | |
""" | |
Splits text into smaller overlapping chunks for better semantic search. | |
Args: | |
text (str): The full text to chunk. | |
max_chunk_size (int): Maximum tokens/words per chunk. | |
overlap (int): Number of words overlapping between consecutive chunks. | |
Returns: | |
list[str]: List of text chunks. | |
""" | |
import re | |
# Clean and normalize whitespace | |
text = re.sub(r'\s+', ' ', text).strip() | |
words = text.split() | |
chunks = [] | |
start = 0 | |
text_len = len(words) | |
while start < text_len: | |
end = min(start + max_chunk_size, text_len) | |
chunk_words = words[start:end] | |
chunk = ' '.join(chunk_words) | |
chunks.append(chunk) | |
# Move start forward by chunk size minus overlap to create overlap | |
start += max_chunk_size - overlap | |
return chunks | |
def process_documents(self, files) -> str: | |
"""Enhanced document processing""" | |
if not files: | |
return "β No files uploaded!" | |
try: | |
all_text = "" | |
processed_files = [] | |
for file in files: | |
if file is None: | |
continue | |
file_text = self.extract_text_from_file(file.name) | |
if not file_text.startswith("Error") and not file_text.startswith("Unsupported"): | |
all_text += f"\n{file_text}" | |
processed_files.append(os.path.basename(file.name)) | |
else: | |
return f"β {file_text}" | |
if not all_text.strip(): | |
return "β No text extracted from files!" | |
# Store and analyze | |
self.raw_text = all_text | |
self.document_type = self.detect_document_type(all_text) | |
self.document_summary = self.create_document_summary(all_text) | |
# Enhanced chunking | |
chunk_data = self.enhanced_chunk_text(all_text) | |
if not chunk_data: | |
return "β No valid text chunks created!" | |
self.documents = [chunk['text'] for chunk in chunk_data] | |
self.document_metadata = chunk_data | |
# Create embeddings | |
print(f"π Creating embeddings for {len(self.documents)} chunks...") | |
embeddings = self.embedder.encode(self.documents, show_progress_bar=False) | |
# Build FAISS index | |
dimension = embeddings.shape[1] | |
self.index = faiss.IndexFlatIP(dimension) | |
# Normalize for cosine similarity | |
faiss.normalize_L2(embeddings) | |
self.index.add(embeddings.astype('float32')) | |
self.is_indexed = True | |
return f"β Successfully processed {len(processed_files)} files:\n" + \ | |
f"π Files: {', '.join(processed_files)}\n" + \ | |
f"π Document Type: {self.document_type.title()}\n" + \ | |
f"π Created {len(self.documents)} chunks\n" + \ | |
f"π Summary: {self.document_summary}\n" + \ | |
f"π Ready for Q&A!" | |
except Exception as e: | |
return f"β Error processing documents: {str(e)}" | |
def find_relevant_content(self, query: str, k: int = 3) -> str: | |
"""Improved content retrieval with stricter relevance filter""" | |
if not self.is_indexed: | |
return "" | |
try: | |
# Semantic search | |
query_embedding = self.embedder.encode([query]) | |
faiss.normalize_L2(query_embedding) | |
scores, indices = self.index.search(query_embedding.astype('float32'), min(k, len(self.documents))) | |
relevant_chunks = [] | |
for i, idx in enumerate(indices[0]): | |
score = scores[0][i] | |
if idx < len(self.documents) and score > 0.4: # β stricter similarity filter | |
relevant_chunks.append(self.documents[idx]) | |
return ' '.join(relevant_chunks) | |
except Exception as e: | |
print(f"Error in content retrieval: {e}") | |
return "" | |
def answer_question(self, query: str) -> str: | |
"""Enhanced question answering with better model usage and hallucination reduction.""" | |
if not query.strip(): | |
return "β Please ask a question!" | |
if not self.is_indexed: | |
return "π Please upload and process documents first!" | |
try: | |
query_lower = query.lower() | |
# Handle summary requests explicitly | |
if any(word in query_lower for word in ['summary', 'summarize', 'about', 'overview']): | |
return f"π **Document Summary:**\n\n{self.document_summary}" | |
# Retrieve relevant content chunks via semantic search | |
context = self.find_relevant_content(query, k=3) | |
if not context: | |
return "π No relevant information found. Try rephrasing your question." | |
# If no QA pipeline, fall back to direct extraction | |
if self.qa_pipeline is None: | |
return self.extract_direct_answer(query, context) | |
try: | |
if self.model_type in ["distilbert-qa", "fallback"]: | |
# Use extractive Q&A pipeline | |
result = self.qa_pipeline(question=query, context=context) | |
answer = result.get('answer', '').strip() | |
confidence = result.get('score', 0) | |
if confidence > 0.1 and answer: | |
return f"**Answer:** {answer}\n\n**Context:** {context[:200]}..." | |
else: | |
return self.extract_direct_answer(query, context) | |
elif self.model_type == "flan-t5": | |
# Use generative model with improved prompt to reduce hallucination | |
prompt = ( | |
f"Answer concisely and strictly based on the following context.\n\n" | |
f"Context:\n{context}\n\n" | |
f"Question:\n{query}\n\n" | |
f"If the answer is not contained in the context, reply with 'Not found in document.'\n" | |
f"Answer:" | |
) | |
result = self.qa_pipeline(prompt, max_length=256, num_return_sequences=1) | |
generated_text = result[0].get('generated_text', '') | |
answer = generated_text.replace(prompt, '').strip() | |
if answer.lower() in ["not found in document.", "no answer", "unknown", ""]: | |
return "π Sorry, the answer was not found in the documents." | |
else: | |
return f"**Answer:** {answer}" | |
else: | |
# Default fallback extraction | |
return self.extract_direct_answer(query, context) | |
except Exception as e: | |
print(f"Model inference error: {e}") | |
return self.extract_direct_answer(query, context) | |
except Exception as e: | |
return f"β Error processing question: {str(e)}" | |
def extract_direct_answer(self, query: str, context: str) -> str: | |
"""Direct answer extraction as fallback""" | |
query_lower = query.lower() | |
# Name extraction | |
if any(word in query_lower for word in ['name', 'who is', 'who']): | |
names = re.findall(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', context) | |
if names: | |
return f"**Name:** {names[0]}" | |
# Experience extraction | |
if any(word in query_lower for word in ['experience', 'years']): | |
exp_matches = re.findall(r'(\d+)[\+\-\s]*(?:years?|yrs?)', context.lower()) | |
if exp_matches: | |
return f"**Experience:** {exp_matches[0]} years" | |
# Skills extraction | |
if any(word in query_lower for word in ['skill', 'technology', 'tech']): | |
# Common tech skills | |
tech_patterns = [ | |
r'\b(?:Python|Java|JavaScript|React|Node|SQL|AWS|Docker|Kubernetes|Git)\b', | |
r'\b(?:HTML|CSS|Angular|Vue|Spring|Django|Flask|MongoDB|PostgreSQL)\b' | |
] | |
skills = [] | |
for pattern in tech_patterns: | |
skills.extend(re.findall(pattern, context, re.IGNORECASE)) | |
if skills: | |
return f"**Skills mentioned:** {', '.join(set(skills))}" | |
# Education extraction | |
if any(word in query_lower for word in ['education', 'degree', 'university']): | |
edu_matches = re.findall(r'(?:Bachelor|Master|PhD|B\.?S\.?|M\.?S\.?|B\.?A\.?|M\.?A\.?).*?(?:in|of)\s+([^.]+)', context) | |
if edu_matches: | |
return f"**Education:** {edu_matches[0]}" | |
# Return first relevant sentence | |
sentences = [s.strip() for s in context.split('.') if s.strip()] | |
if sentences: | |
return f"**Answer:** {sentences[0]}" | |
return "I found relevant content but couldn't extract a specific answer." | |
def clean_text(self, text: str) -> str: | |
""" | |
Clean and normalize raw text by: | |
- Removing excessive whitespace | |
- Fixing merged words (camel case separation) | |
- Removing unwanted characters (optional) | |
- Lowercasing or preserving case (optional) | |
""" | |
import re | |
# Replace multiple whitespace/newlines/tabs with single space | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Fix merged words like 'wordAnotherWord' -> 'word Another Word' | |
text = re.sub(r'([a-z])([A-Z])', r'\1 \2', text) | |
# Optional: remove special characters except basic punctuation | |
# text = re.sub(r'[^a-zA-Z0-9,.!?;:\'\"()\-\s]', '', text) | |
return text | |
# Initialize the system | |
print("Initializing Enhanced Smart RAG System...") | |
rag_system = SmartDocumentRAG() | |
# Create the interface | |
def create_interface(): | |
with gr.Blocks(title="π§ Enhanced Document Q&A", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π§ Enhanced Document Q&A System | |
**Optimized with Better Models & Quantization!** | |
**Features:** | |
- π― Flan-T5 or DistilBERT for accurate Q&A | |
- β‘ 4-bit quantization for GPU efficiency | |
- π Direct answer extraction | |
- π Enhanced semantic search | |
""") | |
with gr.Tab("π€ Upload & Process"): | |
with gr.Row(): | |
with gr.Column(): | |
file_upload = gr.File( | |
label="π Upload Documents", | |
file_count="multiple", | |
file_types=[".pdf", ".docx", ".txt"], | |
height=150 | |
) | |
process_btn = gr.Button("π Process Documents", variant="primary", size="lg") | |
with gr.Column(): | |
process_status = gr.Textbox( | |
label="π Processing Status", | |
lines=10, | |
interactive=False | |
) | |
process_btn.click( | |
fn=rag_system.process_documents, | |
inputs=[file_upload], | |
outputs=[process_status] | |
) | |
with gr.Tab("β Q&A"): | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox( | |
label="π€ Ask Your Question", | |
placeholder="What is the person's name? / How many years of experience? / What skills do they have?", | |
lines=3 | |
) | |
with gr.Row(): | |
ask_btn = gr.Button("π§ Get Answer", variant="primary") | |
summary_btn = gr.Button("π Get Summary", variant="secondary") | |
with gr.Column(): | |
answer_output = gr.Textbox( | |
label="π‘ Answer", | |
lines=8, | |
interactive=False | |
) | |
ask_btn.click( | |
fn=rag_system.answer_question, | |
inputs=[question_input], | |
outputs=[answer_output] | |
) | |
summary_btn.click( | |
fn=lambda: rag_system.answer_question("summary"), | |
inputs=[], | |
outputs=[answer_output] | |
) | |
return demo | |
# Launch the app | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |