Spaces:
Build error
Build error
import os | |
import re | |
import numpy as np | |
import gc | |
import torch | |
import time | |
import shutil | |
import hashlib | |
import pickle | |
import traceback | |
from typing import List, Dict, Any, Tuple, Optional, Union, Generator | |
from dataclasses import dataclass | |
import gradio as gr | |
# Import dependencies (no need for pip install commands) | |
import fitz # PyMuPDF | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from llama_cpp import Llama | |
from rank_bm25 import BM25Okapi | |
import nltk | |
from nltk.tokenize import word_tokenize | |
from nltk.corpus import stopwords | |
from huggingface_hub import hf_hub_download | |
# Download nltk resources | |
try: | |
nltk.download('punkt', quiet=True) | |
nltk.download('stopwords', quiet=True) | |
except: | |
print("Failed to download NLTK resources, continuing without them") | |
# Setup directories for Spaces | |
os.makedirs("pdfs", exist_ok=True) | |
os.makedirs("models", exist_ok=True) | |
os.makedirs("pdf_cache", exist_ok=True) | |
# Download nltk resources | |
try: | |
nltk.download('punkt', quiet=True) | |
nltk.download('stopwords', quiet=True) | |
except: | |
print("Failed to download NLTK resources, continuing without them") | |
# Download model from Hugging Face Hub | |
model_path = hf_hub_download( | |
repo_id="TheBloke/phi-2-GGUF", | |
filename="phi-2.Q8_0.gguf", | |
repo_type="model", | |
local_dir="models" | |
) | |
# === MEMORY MANAGEMENT UTILITIES === | |
def clear_memory(): | |
"""Clear memory to prevent OOM errors""" | |
gc.collect() | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
# === PDF PROCESSING === | |
class PDFChunk: | |
"""Class to represent a chunk of text extracted from a PDF""" | |
text: str | |
source: str | |
page_num: int | |
chunk_id: int | |
class PDFProcessor: | |
def __init__(self, pdf_dir: str = "pdfs"): | |
"""Initialize PDF processor | |
Args: | |
pdf_dir: Directory containing PDF files | |
""" | |
self.pdf_dir = pdf_dir | |
# Smaller chunk size with more overlap for better retrieval | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=384, | |
chunk_overlap=288, # 75% overlap for better context preservation | |
length_function=len, | |
is_separator_regex=False, | |
) | |
# Create cache directory | |
self.cache_dir = os.path.join(os.getcwd(), "pdf_cache") | |
os.makedirs(self.cache_dir, exist_ok=True) | |
def list_pdfs(self) -> List[str]: | |
"""List all PDF files in the directory""" | |
if not os.path.exists(self.pdf_dir): | |
return [] | |
return [f for f in os.listdir(self.pdf_dir) if f.lower().endswith('.pdf')] | |
def _get_cache_path(self, pdf_path: str) -> str: | |
"""Get the cache file path for a PDF""" | |
pdf_hash = hashlib.md5(open(pdf_path, 'rb').read(8192)).hexdigest() | |
return os.path.join(self.cache_dir, f"{os.path.basename(pdf_path)}_{pdf_hash}.pkl") | |
def _is_cached(self, pdf_path: str) -> bool: | |
"""Check if a PDF is cached""" | |
cache_path = self._get_cache_path(pdf_path) | |
return os.path.exists(cache_path) | |
def _load_from_cache(self, pdf_path: str) -> List[PDFChunk]: | |
"""Load chunks from cache""" | |
cache_path = self._get_cache_path(pdf_path) | |
try: | |
with open(cache_path, 'rb') as f: | |
return pickle.load(f) | |
except: | |
return None | |
def _save_to_cache(self, pdf_path: str, chunks: List[PDFChunk]) -> None: | |
"""Save chunks to cache""" | |
cache_path = self._get_cache_path(pdf_path) | |
try: | |
with open(cache_path, 'wb') as f: | |
pickle.dump(chunks, f) | |
except Exception as e: | |
print(f"Warning: Failed to cache PDF {pdf_path}: {str(e)}") | |
def clean_text(self, text: str) -> str: | |
"""Clean extracted text""" | |
# Remove excessive whitespace | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Remove header/footer patterns (common in PDFs) | |
text = re.sub(r'(?<!\w)page \d+(?!\w)', '', text, flags=re.IGNORECASE) | |
return text | |
def extract_text_from_pdf(self, pdf_path: str) -> List[PDFChunk]: | |
"""Extract text content from a PDF file with improved extraction | |
Args: | |
pdf_path: Path to the PDF file | |
Returns: | |
List of PDFChunk objects extracted from the PDF | |
""" | |
# Check cache first | |
if self._is_cached(pdf_path): | |
cached_chunks = self._load_from_cache(pdf_path) | |
if cached_chunks: | |
print(f"Loaded {len(cached_chunks)} chunks from cache for {os.path.basename(pdf_path)}") | |
return cached_chunks | |
try: | |
doc = fitz.open(pdf_path) | |
pdf_chunks = [] | |
pdf_name = os.path.basename(pdf_path) | |
for page_num in range(len(doc)): | |
page = doc.load_page(page_num) | |
# Extract text with more options for better quality | |
page_text = page.get_text("text", sort=True) | |
# Try to extract text with alternative layout analysis if the text is too short | |
if len(page_text) < 100: | |
try: | |
page_text = page.get_text("dict", sort=True) | |
# Convert dict to text | |
if isinstance(page_text, dict) and "blocks" in page_text: | |
extracted_text = "" | |
for block in page_text["blocks"]: | |
if "lines" in block: | |
for line in block["lines"]: | |
if "spans" in line: | |
for span in line["spans"]: | |
if "text" in span: | |
extracted_text += span["text"] + " " | |
page_text = extracted_text | |
except: | |
# Fallback to default extraction | |
page_text = page.get_text("text") | |
# Clean the text | |
page_text = self.clean_text(page_text) | |
# Extract tables | |
try: | |
tables = page.find_tables() | |
if tables and hasattr(tables, "tables"): | |
for table in tables.tables: | |
table_text = "" | |
for i, row in enumerate(table.rows): | |
row_cells = [] | |
for cell in row.cells: | |
if hasattr(cell, "rect"): | |
cell_text = page.get_text("text", clip=cell.rect) | |
cell_text = self.clean_text(cell_text) | |
row_cells.append(cell_text) | |
if row_cells: | |
table_text += " | ".join(row_cells) + "\n" | |
# Add table text to page text | |
if table_text.strip(): | |
page_text += "\n\nTABLE:\n" + table_text | |
except Exception as table_err: | |
print(f"Warning: Skipping table extraction for page {page_num}: {str(table_err)}") | |
# Split the page text into chunks | |
if page_text.strip(): | |
page_chunks = self.text_splitter.split_text(page_text) | |
# Create PDFChunk objects | |
for i, chunk_text in enumerate(page_chunks): | |
pdf_chunks.append(PDFChunk( | |
text=chunk_text, | |
source=pdf_name, | |
page_num=page_num + 1, # 1-based page numbering for humans | |
chunk_id=i | |
)) | |
# Clear memory periodically | |
if page_num % 10 == 0: | |
clear_memory() | |
doc.close() | |
# Cache the results | |
self._save_to_cache(pdf_path, pdf_chunks) | |
return pdf_chunks | |
except Exception as e: | |
print(f"Error extracting text from {pdf_path}: {str(e)}") | |
return [] | |
def process_pdf(self, pdf_name: str) -> List[PDFChunk]: | |
"""Process a single PDF file and extract chunks | |
Args: | |
pdf_name: Name of the PDF file in the pdf_dir | |
Returns: | |
List of PDFChunk objects from the PDF | |
""" | |
pdf_path = os.path.join(self.pdf_dir, pdf_name) | |
return self.extract_text_from_pdf(pdf_path) | |
def process_all_pdfs(self, batch_size: int = 2) -> List[PDFChunk]: | |
"""Process all PDFs in batches to manage memory | |
Args: | |
batch_size: Number of PDFs to process in each batch | |
Returns: | |
List of all PDFChunk objects from all PDFs | |
""" | |
all_chunks = [] | |
pdf_files = self.list_pdfs() | |
if not pdf_files: | |
print("No PDF files found in the directory.") | |
return [] | |
# Process PDFs in batches | |
for i in range(0, len(pdf_files), batch_size): | |
batch = pdf_files[i:i+batch_size] | |
print(f"Processing batch {i//batch_size + 1}/{(len(pdf_files)-1)//batch_size + 1}") | |
for pdf_name in batch: | |
print(f"Processing {pdf_name}") | |
chunks = self.process_pdf(pdf_name) | |
all_chunks.extend(chunks) | |
print(f"Extracted {len(chunks)} chunks from {pdf_name}") | |
# Clear memory after each batch | |
clear_memory() | |
return all_chunks | |
# === VECTOR DATABASE SETUP === | |
class VectorDBManager: | |
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
"""Initialize vector database manager | |
Args: | |
model_name: Name of the embedding model | |
""" | |
# Initialize embedding model with normalization | |
try: | |
self.embedding_model = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": True} | |
) | |
except Exception as e: | |
print(f"Error initializing embedding model {model_name}: {str(e)}") | |
print("Falling back to all-MiniLM-L6-v2 model") | |
self.embedding_model = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": True} | |
) | |
self.vectordb = None | |
# BM25 index for hybrid search | |
self.bm25_index = None | |
self.chunks = [] | |
self.tokenized_chunks = [] | |
def _prepare_bm25(self, chunks: List[PDFChunk]): | |
"""Prepare BM25 index for hybrid search""" | |
# Tokenize chunks for BM25 | |
try: | |
tokenized_chunks = [] | |
for chunk in chunks: | |
# Tokenize and remove stopwords | |
tokens = word_tokenize(chunk.text.lower()) | |
stop_words = set(stopwords.words('english')) | |
filtered_tokens = [w for w in tokens if w.isalnum() and w not in stop_words] | |
tokenized_chunks.append(filtered_tokens) | |
# Create BM25 index | |
self.bm25_index = BM25Okapi(tokenized_chunks) | |
except Exception as e: | |
print(f"Error creating BM25 index: {str(e)}") | |
print(traceback.format_exc()) | |
self.bm25_index = None | |
def create_vector_db(self, chunks: List[PDFChunk]) -> None: | |
"""Create vector database from text chunks | |
Args: | |
chunks: List of PDFChunk objects | |
""" | |
try: | |
if not chunks or len(chunks) == 0: | |
print("ERROR: No chunks provided to create vector database") | |
return | |
print(f"Creating vector DB with {len(chunks)} chunks") | |
# Store chunks for hybrid search | |
self.chunks = chunks | |
# Prepare data for vector DB | |
chunk_texts = [chunk.text for chunk in chunks] | |
# Create BM25 index for hybrid search | |
print("Creating BM25 index for hybrid search") | |
self._prepare_bm25(chunks) | |
# Process in smaller batches to manage memory | |
batch_size = 16 # Reduced for Spaces | |
all_embeddings = [] | |
for i in range(0, len(chunk_texts), batch_size): | |
batch = chunk_texts[i:i+batch_size] | |
print(f"Embedding batch {i//batch_size + 1}/{(len(chunk_texts)-1)//batch_size + 1}") | |
# Generate embeddings for the batch | |
batch_embeddings = self.embedding_model.embed_documents(batch) | |
all_embeddings.extend(batch_embeddings) | |
# Clear memory after each batch | |
clear_memory() | |
# Create FAISS index | |
print(f"Creating FAISS index with {len(all_embeddings)} embeddings") | |
self.vectordb = FAISS.from_embeddings( | |
text_embeddings=list(zip(chunk_texts, all_embeddings)), | |
embedding=self.embedding_model | |
) | |
print(f"Vector database created with {len(chunks)} documents") | |
except Exception as e: | |
print(f"Error creating vector database: {str(e)}") | |
print(traceback.format_exc()) | |
raise | |
def _format_chunk_with_metadata(self, chunk: PDFChunk) -> str: | |
"""Format a chunk with its metadata for better context""" | |
return f"Source: {chunk.source} | Page: {chunk.page_num}\n\n{chunk.text}" | |
def hybrid_search(self, query: str, k: int = 5, alpha: float = 0.7) -> List[str]: | |
"""Hybrid search combining vector search and BM25 | |
Args: | |
query: Query text | |
k: Number of results to return | |
alpha: Weight for vector search (1-alpha for BM25) | |
Returns: | |
List of formatted documents | |
""" | |
if self.vectordb is None: | |
print("Vector database not initialized") | |
return [] | |
try: | |
# Get vector search results | |
vector_results = self.vectordb.similarity_search(query, k=k*2) | |
vector_texts = [doc.page_content for doc in vector_results] | |
final_results = [] | |
# Combine with BM25 if available | |
if self.bm25_index is not None: | |
try: | |
# Tokenize query for BM25 | |
query_tokens = word_tokenize(query.lower()) | |
stop_words = set(stopwords.words('english')) | |
filtered_query = [w for w in query_tokens if w.isalnum() and w not in stop_words] | |
# Get BM25 scores | |
bm25_scores = self.bm25_index.get_scores(filtered_query) | |
# Combine scores (normalized) | |
combined_results = [] | |
seen_texts = set() | |
# First add vector results with their positions as scores | |
for i, text in enumerate(vector_texts): | |
if text not in seen_texts: | |
seen_texts.add(text) | |
# Find corresponding chunk | |
for j, chunk in enumerate(self.chunks): | |
if chunk.text == text: | |
# Combine scores: alpha * vector_score + (1-alpha) * bm25_score | |
# For vector, use inverse of position as score (normalized) | |
vector_score = 1.0 - (i / len(vector_texts)) | |
# Normalize BM25 score | |
bm25_score = bm25_scores[j] / max(bm25_scores) if max(bm25_scores) > 0 else 0 | |
combined_score = alpha * vector_score + (1-alpha) * bm25_score | |
combined_results.append((chunk, combined_score)) | |
break | |
# Sort by combined score | |
combined_results.sort(key=lambda x: x[1], reverse=True) | |
# Get top k results | |
top_chunks = [item[0] for item in combined_results[:k]] | |
# Format results with metadata | |
final_results = [self._format_chunk_with_metadata(chunk) for chunk in top_chunks] | |
except Exception as e: | |
print(f"Error in BM25 scoring: {str(e)}") | |
# Fallback to vector search results | |
final_results = vector_texts[:k] | |
else: | |
# Just use vector search results if BM25 is not available | |
final_results = vector_texts[:k] | |
return final_results | |
except Exception as e: | |
print(f"Error during hybrid search: {str(e)}") | |
return [] | |
# === QUERY EXPANSION === | |
class QueryExpander: | |
def __init__(self, llm_model): | |
"""Initialize query expander | |
Args: | |
llm_model: LLM model for query expansion | |
""" | |
self.llm = llm_model | |
def expand_query(self, query: str) -> str: | |
"""Expand the query using the LLM to improve retrieval | |
Args: | |
query: Original query | |
Returns: | |
Expanded query | |
""" | |
try: | |
prompt = f"""I need to search for documents related to this question: "{query}" | |
Please help me expand this query by identifying key concepts, synonyms, and related terms that might be used in the documents. | |
Return only the expanded search query, without any explanations or additional text. | |
Expanded query:""" | |
expanded = self.llm.generate(prompt, max_tokens=100, temperature=0.3) | |
# Combine original and expanded | |
combined = f"{query} {expanded}" | |
# Limit length | |
if len(combined) > 300: | |
combined = combined[:300] | |
return combined | |
except: | |
# Return original query if expansion fails | |
return query | |
# === LLM SETUP === | |
class Phi2Model: | |
def __init__(self, model_path: str = model_path): | |
"""Initialize Phi-2 model | |
Args: | |
model_path: Path to the model file | |
""" | |
try: | |
# Initialize Phi-2 with llama.cpp - optimized for Spaces | |
self.llm = Llama( | |
model_path=model_path, | |
n_ctx=1024, # Reduced context window for Spaces | |
n_batch=64, # Reduced batch size | |
n_gpu_layers=0, # Run on CPU for compatibility | |
verbose=False | |
) | |
except Exception as e: | |
print(f"Error initializing Phi-2 model: {str(e)}") | |
raise | |
def generate(self, prompt: str, | |
max_tokens: int = 512, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
stream: bool = False) -> Union[str, Generator[str, None, None]]: | |
"""Generate text using Phi-2 | |
Args: | |
prompt: Input prompt | |
max_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature | |
top_p: Top-p sampling parameter | |
stream: Whether to stream the output | |
Returns: | |
Generated text or generator if streaming | |
""" | |
try: | |
if stream: | |
return self._generate_stream(prompt, max_tokens, temperature, top_p) | |
else: | |
output = self.llm( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
echo=False | |
) | |
return output["choices"][0]["text"] | |
except Exception as e: | |
print(f"Error generating text: {str(e)}") | |
return "Error: Could not generate response." | |
def _generate_stream(self, prompt: str, | |
max_tokens: int = 512, | |
temperature: float = 0.7, | |
top_p: float = 0.9) -> Generator[str, None, None]: | |
"""Stream text generation using Phi-2 | |
Args: | |
prompt: Input prompt | |
max_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature | |
top_p: Top-p sampling parameter | |
Yields: | |
Generated text tokens | |
""" | |
response = "" | |
for output in self.llm( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
echo=False, | |
stream=True | |
): | |
token = output["choices"][0]["text"] | |
response += token | |
yield response | |
# === RAG SYSTEM === | |
class RAGSystem: | |
def __init__(self, pdf_processor: PDFProcessor, | |
vector_db: VectorDBManager, | |
model: Phi2Model): | |
"""Initialize RAG system | |
Args: | |
pdf_processor: PDF processor instance | |
vector_db: Vector database manager instance | |
model: LLM model instance | |
""" | |
self.pdf_processor = pdf_processor | |
self.vector_db = vector_db | |
self.model = model | |
self.query_expander = QueryExpander(model) | |
self.is_initialized = False | |
def process_documents(self) -> bool: | |
"""Process all documents and create vector database | |
Returns: | |
True if successful, False otherwise | |
""" | |
try: | |
# Process PDFs | |
chunks = self.pdf_processor.process_all_pdfs() | |
if not chunks: | |
print("No chunks were extracted from PDFs") | |
return False | |
print(f"Total chunks extracted: {len(chunks)}") | |
# Create vector database | |
print("Creating vector database...") | |
self.vector_db.create_vector_db(chunks) | |
# Verify success | |
if self.vector_db.vectordb is None: | |
print("Failed to create vector database") | |
return False | |
# Set initialization flag | |
self.is_initialized = True | |
return True | |
except Exception as e: | |
print(f"Error processing documents: {str(e)}") | |
print(traceback.format_exc()) | |
return False | |
def generate_prompt(self, query: str, contexts: List[str]) -> str: | |
"""Generate prompt for the LLM with better instructions | |
Args: | |
query: User query | |
contexts: Retrieved contexts | |
Returns: | |
Formatted prompt | |
""" | |
# Format contexts with numbering for better reference | |
formatted_contexts = "" | |
for i, context in enumerate(contexts): | |
formatted_contexts += f"[CONTEXT {i+1}]\n{context}\n\n" | |
# Create prompt with better instructions | |
prompt = f"""You are an AI assistant that answers questions based on the provided context information. | |
User Query: {query} | |
Below are relevant passages from documents that might help answer the query: | |
{formatted_contexts} | |
Using ONLY the information provided in the context above, provide a comprehensive answer to the user's query. | |
If the provided context doesn't contain relevant information to answer the query, clearly state: "I don't have enough information in the provided context to answer this question." | |
Do not use any prior knowledge that is not contained in the provided context. | |
If quoting from the context, mention the source document and page number. | |
Organize your answer in a clear, coherent manner. | |
Answer:""" | |
return prompt | |
def answer_query(self, query: str, k: int = 5, max_tokens: int = 512, | |
temperature: float = 0.7, stream: bool = False) -> Union[str, Generator[str, None, None]]: | |
"""Answer a query using RAG with query expansion | |
Args: | |
query: User query | |
k: Number of contexts to retrieve | |
max_tokens: Maximum number of tokens to generate | |
temperature: Temperature for generation | |
stream: Whether to stream the output | |
Returns: | |
Answer text or generator if streaming | |
""" | |
# Check if system is initialized | |
if not self.is_initialized or self.vector_db.vectordb is None: | |
return "Error: Documents have not been processed yet. Please process documents first." | |
try: | |
# Expand query for better retrieval | |
expanded_query = self.query_expander.expand_query(query) | |
print(f"Expanded query: {expanded_query}") | |
# Retrieve relevant contexts using hybrid search | |
contexts = self.vector_db.hybrid_search(expanded_query, k=k) | |
if not contexts: | |
return "No relevant information found in the documents. Please try a different query or check if documents were processed correctly." | |
# Generate prompt with improved instructions | |
prompt = self.generate_prompt(query, contexts) | |
# Generate answer | |
return self.model.generate( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
stream=stream | |
) | |
except Exception as e: | |
print(f"Error answering query: {str(e)}") | |
print(traceback.format_exc()) | |
return f"Error processing your query: {str(e)}" | |
# === GRADIO INTERFACE === | |
class RAGInterface: | |
def __init__(self, rag_system: RAGSystem): | |
"""Initialize Gradio interface | |
Args: | |
rag_system: RAG system instance | |
""" | |
self.rag_system = rag_system | |
self.interface = None | |
self.is_processing = False | |
def upload_file(self, files): | |
"""Upload PDF files""" | |
try: | |
os.makedirs("pdfs", exist_ok=True) | |
uploaded_files = [] | |
for file in files: | |
destination = os.path.join("pdfs", os.path.basename(file.name)) | |
shutil.copy(file.name, destination) | |
uploaded_files.append(os.path.basename(file.name)) | |
# Verify files exist in the directory | |
pdf_files = [f for f in os.listdir("pdfs") if f.lower().endswith('.pdf')] | |
if not pdf_files: | |
return "No PDF files were uploaded successfully." | |
return f"Successfully uploaded {len(uploaded_files)} files: {', '.join(uploaded_files)}" | |
except Exception as e: | |
return f"Error uploading files: {str(e)}" | |
def process_documents(self): | |
"""Process all documents | |
Returns: | |
Status message | |
""" | |
if self.is_processing: | |
return "Document processing is already in progress. Please wait." | |
try: | |
self.is_processing = True | |
start_time = time.time() | |
success = self.rag_system.process_documents() | |
elapsed = time.time() - start_time | |
self.is_processing = False | |
if success: | |
return f"Documents processed successfully in {elapsed:.2f} seconds." | |
else: | |
return "Failed to process documents. Check the logs for more information." | |
except Exception as e: | |
self.is_processing = False | |
return f"Error processing documents: {str(e)}" | |
def answer_query(self, query, k, max_tokens, temperature): | |
"""Answer a query | |
Args: | |
query: User query | |
k: Number of contexts to retrieve | |
max_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature | |
Returns: | |
Answer | |
""" | |
if not query.strip(): | |
return "Please enter a question." | |
try: | |
return self.rag_system.answer_query( | |
query, | |
k=k, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
stream=False | |
) | |
except Exception as e: | |
return f"Error answering query: {str(e)}" | |
def answer_query_stream(self, query, k, max_tokens, temperature): | |
"""Stream answer to a query | |
Args: | |
query: User query | |
k: Number of contexts to retrieve | |
max_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature | |
Yields: | |
Generated text | |
""" | |
if not query.strip(): | |
yield "Please enter a question." | |
return | |
try: | |
yield from self.rag_system.answer_query( | |
query, | |
k=k, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
stream=True | |
) | |
except Exception as e: | |
yield f"Error answering query: {str(e)}" | |
def create_interface(self): | |
"""Create Gradio interface""" | |
with gr.Blocks(title="PDF RAG System") as interface: | |
gr.Markdown("# PDF RAG System with Phi-2") | |
gr.Markdown("Upload your PDF documents, process them, and ask questions to get answers based on the content.") | |
with gr.Tab("Upload & Process"): | |
with gr.Row(): | |
pdf_files = gr.File( | |
file_count="multiple", | |
label="Upload PDF Files", | |
file_types=[".pdf"] | |
) | |
upload_button = gr.Button("Upload", variant="primary") | |
upload_output = gr.Textbox(label="Upload Status", lines=2) | |
upload_button.click(self.upload_file, inputs=[pdf_files], outputs=upload_output) | |
process_button = gr.Button("Process Documents", variant="primary") | |
process_output = gr.Textbox(label="Processing Status", lines=2) | |
process_button.click(self.process_documents, inputs=[], outputs=process_output) | |
with gr.Tab("Query"): | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox( | |
label="Question", | |
lines=3, | |
placeholder="Ask a question about your documents..." | |
) | |
with gr.Row(): | |
k_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of Contexts" | |
) | |
max_tokens_slider = gr.Slider( | |
minimum=100, | |
maximum=800, | |
value=400, | |
step=50, | |
label="Max Tokens" | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0,value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
submit_button = gr.Button("Submit", variant="primary") | |
answer_output = gr.Textbox(label="Answer", lines=10) | |
submit_button.click( | |
self.answer_query, | |
inputs=[query_input, k_slider, max_tokens_slider, temperature_slider], | |
outputs=answer_output | |
) | |
# Add streaming capability | |
stream_button = gr.Button("Submit (Streaming)", variant="secondary") | |
stream_button.click( | |
self.answer_query_stream, | |
inputs=[query_input, k_slider, max_tokens_slider, temperature_slider], | |
outputs=answer_output | |
) | |
gr.Markdown(""" | |
## Instructions | |
1. Upload PDF files in the 'Upload & Process' tab. | |
2. Click the 'Process Documents' button to extract and index content. | |
3. Switch to the 'Query' tab to ask questions about your documents. | |
4. Adjust parameters as needed: | |
- Number of Contexts: More contexts provide more information but may be less focused. | |
- Max Tokens: Controls the length of the response. | |
- Temperature: Lower values (0.1-0.5) give more focused answers, higher values (0.6-1.0) give more creative answers. | |
""") | |
self.interface = interface | |
return interface | |
def launch(self, **kwargs): | |
"""Launch the Gradio interface""" | |
if self.interface is None: | |
self.create_interface() | |
self.interface.launch(**kwargs) | |
# === MAIN APPLICATION === | |
def main(): | |
"""Main function to set up and launch the application""" | |
try: | |
# Initialize components | |
pdf_processor = PDFProcessor(pdf_dir="pdfs") | |
vector_db = VectorDBManager() | |
phi2_model = Phi2Model() | |
# Initialize RAG system | |
rag_system = RAGSystem(pdf_processor, vector_db, phi2_model) | |
# Create interface | |
interface = RAGInterface(rag_system) | |
# Launch application | |
interface.launch(share=True) | |
except Exception as e: | |
print(f"Error initializing application: {str(e)}") | |
print(traceback.format_exc()) | |
if __name__ == "__main__": | |
main() |