Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import PyPDF2 | |
import docx | |
import io | |
import os | |
from typing import List, Optional | |
class DocumentRAG: | |
def __init__(self): | |
print("π Initializing RAG System...") | |
# Initialize embedding model (lightweight) | |
self.embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
print("β Embedding model loaded") | |
# Initialize quantized LLM | |
self.setup_llm() | |
# Document storage | |
self.documents = [] | |
self.index = None | |
self.is_indexed = False | |
def setup_llm(self): | |
"""Setup quantized Mistral model""" | |
try: | |
# Check if CUDA is available | |
if not torch.cuda.is_available(): | |
print("β οΈ CUDA not available, falling back to CPU or alternative model") | |
self.setup_fallback_model() | |
return | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
model_name = "mistralai/Mistral-7B-Instruct-v0.1" | |
# Load tokenizer first | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
# Fix padding token issue | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Load model with quantization | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
print("β Quantized Mistral model loaded successfully") | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
print("π Falling back to alternative model...") | |
self.setup_fallback_model() | |
def setup_fallback_model(self): | |
"""Fallback to smaller model if Mistral fails""" | |
try: | |
# Use a model that's better for factual Q&A and less prone to hallucination | |
model_name = "microsoft/DialoGPT-small" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Fix padding token for fallback model too | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
print("β Fallback model loaded") | |
except Exception as e: | |
print(f"β Fallback model failed: {e}") | |
# Try an even simpler approach - return context-based answers without generation | |
self.model = None | |
self.tokenizer = None | |
print("β οΈ Using context-only mode (no text generation)") | |
def simple_context_answer(self, query: str, context: str) -> str: | |
"""Improved context-based answering when model is not available""" | |
if not context: | |
return "No relevant information found in the documents." | |
# Improved keyword matching approach | |
query_words = set(query.lower().split()) | |
context_sentences = context.split('.') | |
# Find sentences that contain query keywords | |
relevant_sentences = [] | |
for sentence in context_sentences: | |
sentence = sentence.strip() | |
if len(sentence) < 10: # Skip very short sentences | |
continue | |
sentence_words = set(sentence.lower().split()) | |
# Check if sentence contains query keywords | |
common_words = query_words.intersection(sentence_words) | |
if len(common_words) >= 1: # Lowered threshold | |
relevant_sentences.append(sentence) | |
if relevant_sentences: | |
# Return the most relevant sentences | |
return '. '.join(relevant_sentences[:3]) + '.' | |
else: | |
# If no exact matches, return first few sentences of context | |
first_sentences = context_sentences[:2] | |
if first_sentences: | |
return '. '.join([s.strip() for s in first_sentences if s.strip()]) + '.' | |
return "Based on the document content, I found some information but cannot provide a specific answer to your question." | |
def extract_text_from_file(self, file_path: str) -> str: | |
"""Extract text from various file formats""" | |
try: | |
file_extension = os.path.splitext(file_path)[1].lower() | |
if file_extension == '.pdf': | |
return self.extract_from_pdf(file_path) | |
elif file_extension == '.docx': | |
return self.extract_from_docx(file_path) | |
elif file_extension == '.txt': | |
return self.extract_from_txt(file_path) | |
else: | |
return f"Unsupported file format: {file_extension}" | |
except Exception as e: | |
return f"Error reading file: {str(e)}" | |
def extract_from_pdf(self, file_path: str) -> str: | |
"""Extract text from PDF""" | |
text = "" | |
try: | |
with open(file_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
for page in pdf_reader.pages: | |
text += page.extract_text() + "\n" | |
except Exception as e: | |
text = f"Error reading PDF: {str(e)}" | |
return text | |
def extract_from_docx(self, file_path: str) -> str: | |
"""Extract text from DOCX""" | |
try: | |
doc = docx.Document(file_path) | |
text = "" | |
for paragraph in doc.paragraphs: | |
text += paragraph.text + "\n" | |
return text | |
except Exception as e: | |
return f"Error reading DOCX: {str(e)}" | |
def extract_from_txt(self, file_path: str) -> str: | |
"""Extract text from TXT""" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
return file.read() | |
except Exception as e: | |
try: | |
with open(file_path, 'r', encoding='latin-1') as file: | |
return file.read() | |
except Exception as e2: | |
return f"Error reading TXT: {str(e2)}" | |
def chunk_text(self, text: str, chunk_size: int = 200, overlap: int = 30) -> List[str]: | |
"""Split text into overlapping chunks with better sentence preservation""" | |
if not text.strip(): | |
return [] | |
# Split by sentences first, then group into chunks | |
sentences = text.replace('\n', ' ').split('. ') | |
chunks = [] | |
current_chunk = "" | |
for sentence in sentences: | |
sentence = sentence.strip() | |
if not sentence: | |
continue | |
# Add sentence to current chunk | |
test_chunk = current_chunk + ". " + sentence if current_chunk else sentence | |
# If chunk gets too long, save it and start new one | |
if len(test_chunk.split()) > chunk_size: | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence | |
else: | |
current_chunk = test_chunk | |
# Add the last chunk | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
def process_documents(self, files) -> str: | |
"""Process uploaded files and create embeddings""" | |
if not files: | |
return "β No files uploaded!" | |
try: | |
all_text = "" | |
processed_files = [] | |
# Extract text from all files | |
for file in files: | |
if file is None: | |
continue | |
file_text = self.extract_text_from_file(file.name) | |
if not file_text.startswith("Error") and not file_text.startswith("Unsupported"): | |
all_text += f"\n\n--- {os.path.basename(file.name)} ---\n\n{file_text}" | |
processed_files.append(os.path.basename(file.name)) | |
else: | |
return f"β {file_text}" | |
if not all_text.strip(): | |
return "β No text extracted from files!" | |
# Chunk the text | |
self.documents = self.chunk_text(all_text) | |
if not self.documents: | |
return "β No valid text chunks created!" | |
# Create embeddings | |
print(f"π Creating embeddings for {len(self.documents)} chunks...") | |
embeddings = self.embedder.encode(self.documents, show_progress_bar=True) | |
# Build FAISS index | |
dimension = embeddings.shape[1] | |
self.index = faiss.IndexFlatIP(dimension) | |
# Normalize embeddings for cosine similarity | |
faiss.normalize_L2(embeddings) | |
self.index.add(embeddings.astype('float32')) | |
self.is_indexed = True | |
return f"β Successfully processed {len(processed_files)} files:\n" + \ | |
f"π Files: {', '.join(processed_files)}\n" + \ | |
f"π Created {len(self.documents)} text chunks\n" + \ | |
f"π Ready for Q&A!" | |
except Exception as e: | |
return f"β Error processing documents: {str(e)}" | |
def retrieve_context(self, query: str, k: int = 5) -> str: | |
"""Retrieve relevant context for the query with improved retrieval""" | |
if not self.is_indexed: | |
return "" | |
try: | |
# Get query embedding | |
query_embedding = self.embedder.encode([query]) | |
faiss.normalize_L2(query_embedding) | |
# Search for similar chunks | |
scores, indices = self.index.search(query_embedding.astype('float32'), k) | |
# Get relevant documents with MUCH LOWER threshold | |
relevant_docs = [] | |
for i, idx in enumerate(indices[0]): | |
if idx < len(self.documents) and scores[0][i] > 0.05: # Much lower threshold | |
relevant_docs.append(self.documents[idx]) | |
# If no high-similarity matches, take the top results anyway | |
if not relevant_docs: | |
for i, idx in enumerate(indices[0]): | |
if idx < len(self.documents): | |
relevant_docs.append(self.documents[idx]) | |
if len(relevant_docs) >= 3: # Take at least 3 chunks | |
break | |
return "\n\n".join(relevant_docs) | |
except Exception as e: | |
print(f"Error in retrieval: {e}") | |
return "" | |
def generate_answer(self, query: str, context: str) -> str: | |
"""Generate answer using the LLM with improved prompting""" | |
if self.model is None or self.tokenizer is None: | |
return self.simple_context_answer(query, context) | |
try: | |
# Check if using Mistral (has specific prompt format) or fallback model | |
model_name = getattr(self.model.config, '_name_or_path', '').lower() | |
is_mistral = 'mistral' in model_name | |
if is_mistral: | |
# Improved prompt for Mistral - more flexible | |
prompt = f"""<s>[INST] You are a helpful document assistant. Answer the question based on the provided context. If the exact answer isn't in the context, provide the most relevant information available. | |
Context: | |
{context[:1500]} | |
Question: {query} | |
Please provide a helpful answer based on the available information. [/INST]""" | |
else: | |
# Improved prompt for fallback models | |
prompt = f"""Based on the following information, please answer the question: | |
Context: | |
{context[:1000]} | |
Question: {query} | |
Answer:""" | |
# Tokenize with proper handling | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=800, | |
truncation=True, | |
padding=True | |
) | |
# Move to same device as model | |
if torch.cuda.is_available() and next(self.model.parameters()).is_cuda: | |
inputs = {k: v.cuda() for k, v in inputs.items()} | |
# Generate with more flexible parameters | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=150, | |
temperature=0.3, # Slightly higher for more natural responses | |
do_sample=True, | |
top_p=0.9, | |
num_beams=2, | |
early_stopping=True, | |
repetition_penalty=1.1, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# Decode response | |
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract answer based on model type | |
if is_mistral and "[/INST]" in full_response: | |
answer = full_response.split("[/INST]")[-1].strip() | |
else: | |
# For other models, remove the prompt | |
if "Answer:" in full_response: | |
answer = full_response.split("Answer:")[-1].strip() | |
else: | |
answer = full_response[len(prompt):].strip() | |
# Clean up the answer | |
answer = self.clean_answer(answer) | |
return answer if answer else self.simple_context_answer(query, context) | |
except Exception as e: | |
print(f"Error in generation: {e}") | |
return self.simple_context_answer(query, context) | |
def clean_answer(self, answer: str) -> str: | |
"""Clean up the generated answer""" | |
if not answer or len(answer) < 5: | |
return "" | |
# Remove obvious problematic patterns | |
lines = answer.split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
line = line.strip() | |
if line and not any(pattern in line.lower() for pattern in [ | |
'what are you doing', 'what do you think', 'how are you', | |
'i am an ai', 'i cannot', 'i don\'t know' | |
]): | |
cleaned_lines.append(line) | |
cleaned_answer = ' '.join(cleaned_lines) | |
# Limit length to prevent rambling | |
if len(cleaned_answer) > 500: | |
sentences = cleaned_answer.split('.') | |
cleaned_answer = '. '.join(sentences[:3]) + '.' | |
return cleaned_answer.strip() | |
def answer_question(self, query: str) -> str: | |
"""Main function to answer questions with improved handling""" | |
if not query.strip(): | |
return "β Please ask a question!" | |
if not self.is_indexed: | |
return "π Please upload and process documents first!" | |
try: | |
# Retrieve relevant context | |
context = self.retrieve_context(query, k=7) # Get more chunks | |
if not context: | |
return "π No relevant information found in the uploaded documents for your question." | |
# Generate answer | |
answer = self.generate_answer(query, context) | |
if answer and len(answer) > 10: | |
return f"π‘ **Answer:** {answer}\n\nπ **Source Context:**\n{context[:300]}..." | |
else: | |
# Fallback to simple context display | |
return f"π **Based on the document content:**\n{context[:500]}..." | |
except Exception as e: | |
return f"β Error answering question: {str(e)}" | |
# Initialize the RAG system | |
print("Initializing Document RAG System...") | |
rag_system = DocumentRAG() | |
# Gradio Interface | |
def create_interface(): | |
with gr.Blocks(title="π Document Q&A with RAG", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π Document Q&A System | |
Upload your documents and ask questions about them! | |
**Supported formats:** PDF, DOCX, TXT | |
""") | |
with gr.Tab("π€ Upload Documents"): | |
with gr.Row(): | |
with gr.Column(): | |
file_upload = gr.File( | |
label="Upload Documents", | |
file_count="multiple", | |
file_types=[".pdf", ".docx", ".txt"] | |
) | |
process_btn = gr.Button("π Process Documents", variant="primary") | |
with gr.Column(): | |
process_status = gr.Textbox( | |
label="Processing Status", | |
lines=8, | |
interactive=False | |
) | |
process_btn.click( | |
fn=rag_system.process_documents, | |
inputs=[file_upload], | |
outputs=[process_status] | |
) | |
with gr.Tab("β Ask Questions"): | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox( | |
label="Your Question", | |
placeholder="What would you like to know about your documents?", | |
lines=3 | |
) | |
ask_btn = gr.Button("π Get Answer", variant="primary") | |
with gr.Column(): | |
answer_output = gr.Textbox( | |
label="Answer", | |
lines=12, | |
interactive=False | |
) | |
ask_btn.click( | |
fn=rag_system.answer_question, | |
inputs=[question_input], | |
outputs=[answer_output] | |
) | |
# Example questions | |
gr.Markdown(""" | |
### π‘ Example Questions: | |
- What is the main topic of the document? | |
- Can you summarize the key points? | |
- What are the conclusions mentioned? | |
- Are there any specific numbers or statistics? | |
- Who are the main people or organizations mentioned? | |
""") | |
return demo | |
# Launch the app | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |