Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import PyPDF2 | |
import docx | |
import io | |
import os | |
from typing import List, Optional | |
class DocumentRAG: | |
def __init__(self): | |
print("π Initializing RAG System...") | |
# Initialize embedding model (lightweight) | |
self.embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
print("β Embedding model loaded") | |
# Initialize quantized LLM | |
self.setup_llm() | |
# Document storage | |
self.documents = [] | |
self.index = None | |
self.is_indexed = False | |
def setup_llm(self): | |
"""Setup quantized Mistral model""" | |
try: | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
model_name = "mistralai/Mistral-7B-Instruct-v0.1" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True | |
) | |
print("β Quantized Mistral model loaded") | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
# Fallback to a smaller model if Mistral fails | |
self.setup_fallback_model() | |
def setup_fallback_model(self): | |
"""Fallback to smaller model if Mistral fails""" | |
try: | |
model_name = "microsoft/DialoGPT-small" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name) | |
print("β Fallback model loaded") | |
except Exception as e: | |
print(f"β Fallback model failed: {e}") | |
self.model = None | |
self.tokenizer = None | |
def extract_text_from_file(self, file_path: str) -> str: | |
"""Extract text from various file formats""" | |
try: | |
file_extension = os.path.splitext(file_path)[1].lower() | |
if file_extension == '.pdf': | |
return self.extract_from_pdf(file_path) | |
elif file_extension == '.docx': | |
return self.extract_from_docx(file_path) | |
elif file_extension == '.txt': | |
return self.extract_from_txt(file_path) | |
else: | |
return f"Unsupported file format: {file_extension}" | |
except Exception as e: | |
return f"Error reading file: {str(e)}" | |
def extract_from_pdf(self, file_path: str) -> str: | |
"""Extract text from PDF""" | |
text = "" | |
try: | |
with open(file_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
for page in pdf_reader.pages: | |
text += page.extract_text() + "\n" | |
except Exception as e: | |
text = f"Error reading PDF: {str(e)}" | |
return text | |
def extract_from_docx(self, file_path: str) -> str: | |
"""Extract text from DOCX""" | |
try: | |
doc = docx.Document(file_path) | |
text = "" | |
for paragraph in doc.paragraphs: | |
text += paragraph.text + "\n" | |
return text | |
except Exception as e: | |
return f"Error reading DOCX: {str(e)}" | |
def extract_from_txt(self, file_path: str) -> str: | |
"""Extract text from TXT""" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
return file.read() | |
except Exception as e: | |
try: | |
with open(file_path, 'r', encoding='latin-1') as file: | |
return file.read() | |
except Exception as e2: | |
return f"Error reading TXT: {str(e2)}" | |
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
"""Split text into overlapping chunks""" | |
if not text.strip(): | |
return [] | |
words = text.split() | |
chunks = [] | |
for i in range(0, len(words), chunk_size - overlap): | |
chunk = ' '.join(words[i:i + chunk_size]) | |
if chunk.strip(): | |
chunks.append(chunk.strip()) | |
if i + chunk_size >= len(words): | |
break | |
return chunks | |
def process_documents(self, files) -> str: | |
"""Process uploaded files and create embeddings""" | |
if not files: | |
return "β No files uploaded!" | |
try: | |
all_text = "" | |
processed_files = [] | |
# Extract text from all files | |
for file in files: | |
if file is None: | |
continue | |
file_text = self.extract_text_from_file(file.name) | |
if not file_text.startswith("Error") and not file_text.startswith("Unsupported"): | |
all_text += f"\n\n--- {os.path.basename(file.name)} ---\n\n{file_text}" | |
processed_files.append(os.path.basename(file.name)) | |
else: | |
return f"β {file_text}" | |
if not all_text.strip(): | |
return "β No text extracted from files!" | |
# Chunk the text | |
self.documents = self.chunk_text(all_text) | |
if not self.documents: | |
return "β No valid text chunks created!" | |
# Create embeddings | |
print(f"π Creating embeddings for {len(self.documents)} chunks...") | |
embeddings = self.embedder.encode(self.documents, show_progress_bar=True) | |
# Build FAISS index | |
dimension = embeddings.shape[1] | |
self.index = faiss.IndexFlatIP(dimension) | |
# Normalize embeddings for cosine similarity | |
faiss.normalize_L2(embeddings) | |
self.index.add(embeddings.astype('float32')) | |
self.is_indexed = True | |
return f"β Successfully processed {len(processed_files)} files:\n" + \ | |
f"π Files: {', '.join(processed_files)}\n" + \ | |
f"π Created {len(self.documents)} text chunks\n" + \ | |
f"π Ready for Q&A!" | |
except Exception as e: | |
return f"β Error processing documents: {str(e)}" | |
def retrieve_context(self, query: str, k: int = 3) -> str: | |
"""Retrieve relevant context for the query""" | |
if not self.is_indexed: | |
return "" | |
try: | |
# Get query embedding | |
query_embedding = self.embedder.encode([query]) | |
faiss.normalize_L2(query_embedding) | |
# Search for similar chunks | |
scores, indices = self.index.search(query_embedding.astype('float32'), k) | |
# Get relevant documents | |
relevant_docs = [] | |
for i, idx in enumerate(indices[0]): | |
if idx < len(self.documents) and scores[0][i] > 0.1: # Similarity threshold | |
relevant_docs.append(self.documents[idx]) | |
return "\n\n".join(relevant_docs) | |
except Exception as e: | |
print(f"Error in retrieval: {e}") | |
return "" | |
def generate_answer(self, query: str, context: str) -> str: | |
"""Generate answer using the LLM""" | |
if self.model is None or self.tokenizer is None: | |
return "β Model not available. Please try again." | |
try: | |
# Create prompt | |
prompt = f"""<s>[INST] Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question." | |
Context: | |
{context[:2000]} # Limit context length | |
Question: {query} | |
Answer: [/INST]""" | |
# Tokenize | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True, | |
padding=True | |
) | |
# Generate | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=self.tokenizer.eos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# Decode response | |
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract answer (remove the prompt part) | |
if "[/INST]" in full_response: | |
answer = full_response.split("[/INST]")[-1].strip() | |
else: | |
answer = full_response[len(prompt):].strip() | |
return answer if answer else "I couldn't generate a proper response." | |
except Exception as e: | |
return f"β Error generating answer: {str(e)}" | |
def answer_question(self, query: str) -> str: | |
"""Main function to answer questions""" | |
if not query.strip(): | |
return "β Please ask a question!" | |
if not self.is_indexed: | |
return "π Please upload and process documents first!" | |
try: | |
# Retrieve relevant context | |
context = self.retrieve_context(query) | |
if not context: | |
return "π No relevant information found in the uploaded documents." | |
# Generate answer | |
answer = self.generate_answer(query, context) | |
return f"π‘ **Answer:** {answer}\n\nπ **Source Context:** {context[:500]}..." | |
except Exception as e: | |
return f"β Error answering question: {str(e)}" | |
# Initialize the RAG system | |
print("Initializing Document RAG System...") | |
rag_system = DocumentRAG() | |
# Gradio Interface | |
def create_interface(): | |
with gr.Blocks(title="π Document Q&A with RAG", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π Document Q&A System | |
Upload your documents and ask questions about them! | |
**Supported formats:** PDF, DOCX, TXT | |
""") | |
with gr.Tab("π€ Upload Documents"): | |
with gr.Row(): | |
with gr.Column(): | |
file_upload = gr.File( | |
label="Upload Documents", | |
file_count="multiple", | |
file_types=[".pdf", ".docx", ".txt"] | |
) | |
process_btn = gr.Button("π Process Documents", variant="primary") | |
with gr.Column(): | |
process_status = gr.Textbox( | |
label="Processing Status", | |
lines=8, | |
interactive=False | |
) | |
process_btn.click( | |
fn=rag_system.process_documents, | |
inputs=[file_upload], | |
outputs=[process_status] | |
) | |
with gr.Tab("β Ask Questions"): | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox( | |
label="Your Question", | |
placeholder="What would you like to know about your documents?", | |
lines=3 | |
) | |
ask_btn = gr.Button("π Get Answer", variant="primary") | |
with gr.Column(): | |
answer_output = gr.Textbox( | |
label="Answer", | |
lines=10, | |
interactive=False | |
) | |
ask_btn.click( | |
fn=rag_system.answer_question, | |
inputs=[question_input], | |
outputs=[answer_output] | |
) | |
# Example questions | |
gr.Markdown(""" | |
### π‘ Example Questions: | |
- What is the main topic of the document? | |
- Can you summarize the key points? | |
- What are the conclusions mentioned? | |
- Are there any specific numbers or statistics? | |
""") | |
return demo | |
# Launch the app | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |