Spaces:
Running
Running
import json | |
from typing import List, Dict, Any, Optional, Tuple | |
from document_processor import DocumentProcessor, DocumentChunk | |
from vector_store import VectorStore, SearchResult | |
import os | |
import tempfile | |
from pathlib import Path | |
class RAGTool: | |
"""RAG tool for integrating document search with chat""" | |
def __init__(self): | |
self.processor = DocumentProcessor(chunk_size=800, chunk_overlap=100) | |
self.vector_store = VectorStore() | |
self.processed_files = [] | |
self.total_chunks = 0 | |
def process_uploaded_files(self, file_paths: List[str]) -> Dict[str, Any]: | |
"""Process uploaded files and build vector index""" | |
# Validate files | |
valid_files = [] | |
errors = [] | |
for file_path in file_paths: | |
try: | |
# Check file size (10MB limit) | |
size_mb = os.path.getsize(file_path) / (1024 * 1024) | |
if size_mb > 10: | |
errors.append({ | |
'file': Path(file_path).name, | |
'error': f'File too large ({size_mb:.1f}MB). Maximum size is 10MB.' | |
}) | |
continue | |
valid_files.append(file_path) | |
except Exception as e: | |
errors.append({ | |
'file': Path(file_path).name, | |
'error': str(e) | |
}) | |
if not valid_files: | |
return { | |
'success': False, | |
'message': 'No valid files to process', | |
'errors': errors | |
} | |
# Process files | |
all_chunks, summary = self.processor.process_multiple_files(valid_files) | |
if not all_chunks: | |
return { | |
'success': False, | |
'message': 'No content extracted from files', | |
'summary': summary | |
} | |
# Build vector index | |
chunk_dicts = [chunk.to_dict() for chunk in all_chunks] | |
self.vector_store.build_index(chunk_dicts, show_progress=False) | |
# Update stats | |
self.processed_files = summary['files_processed'] | |
self.total_chunks = len(all_chunks) | |
# Calculate index size | |
index_stats = self.vector_store.get_stats() | |
return { | |
'success': True, | |
'message': f'Successfully processed {len(valid_files)} files into {self.total_chunks} chunks', | |
'summary': summary, | |
'index_stats': index_stats, | |
'errors': errors | |
} | |
def get_relevant_context(self, query: str, max_chunks: int = 3) -> str: | |
"""Get relevant context for a query""" | |
if not self.vector_store.index: | |
return "" | |
# Search for relevant chunks | |
results = self.vector_store.search( | |
query=query, | |
top_k=max_chunks, | |
score_threshold=0.3 | |
) | |
if not results: | |
return "" | |
# Format context | |
context_parts = [] | |
for i, result in enumerate(results, 1): | |
file_name = result.metadata.get('file_name', 'Unknown') | |
context_parts.append( | |
f"[Document: {file_name} - Relevance: {result.score:.2f}]\n{result.text}" | |
) | |
return "\n\n".join(context_parts) | |
def get_serialized_data(self) -> Dict[str, Any]: | |
"""Get serialized data for deployment""" | |
if not self.vector_store.index: | |
return None | |
return self.vector_store.serialize() | |
def get_deployment_info(self) -> Dict[str, Any]: | |
"""Get information for deployment package""" | |
if not self.vector_store.index: | |
return { | |
'enabled': False, | |
'message': 'No documents processed' | |
} | |
# Estimate package size increase | |
index_stats = self.vector_store.get_stats() | |
estimated_size_mb = ( | |
# Index size estimation | |
(index_stats['total_chunks'] * index_stats['dimension'] * 4) / (1024 * 1024) + | |
# Chunks text size estimation | |
(sum(len(chunk['text']) for chunk in self.vector_store.chunks.values()) / (1024 * 1024)) | |
) * 1.5 # Add overhead for base64 encoding | |
return { | |
'enabled': True, | |
'total_files': len(self.processed_files), | |
'total_chunks': self.total_chunks, | |
'estimated_size_mb': round(estimated_size_mb, 2), | |
'files': [f['name'] for f in self.processed_files] | |
} | |
def create_rag_module_for_space(serialized_data: Dict[str, Any]) -> str: | |
"""Create a minimal RAG module for the deployed space""" | |
return '''# RAG Module for deployed space | |
import numpy as np | |
import faiss | |
import base64 | |
import json | |
class RAGContext: | |
def __init__(self, serialized_data): | |
# Deserialize FAISS index | |
index_bytes = base64.b64decode(serialized_data['index_base64']) | |
self.index = faiss.deserialize_index(index_bytes) | |
# Restore chunks and mappings | |
self.chunks = serialized_data['chunks'] | |
self.chunk_ids = serialized_data['chunk_ids'] | |
def get_context(self, query_embedding, max_chunks=3): | |
"""Get relevant context using pre-computed embedding""" | |
if not self.index: | |
return "" | |
# Normalize and search | |
faiss.normalize_L2(query_embedding) | |
scores, indices = self.index.search(query_embedding, max_chunks) | |
# Format results | |
context_parts = [] | |
for score, idx in zip(scores[0], indices[0]): | |
if idx < 0 or score < 0.3: | |
continue | |
chunk = self.chunks[self.chunk_ids[idx]] | |
file_name = chunk.get('metadata', {}).get('file_name', 'Document') | |
context_parts.append( | |
f"[{file_name} - Relevance: {score:.2f}]\\n{chunk['text']}" | |
) | |
return "\\n\\n".join(context_parts) if context_parts else "" | |
# Initialize RAG context | |
RAG_DATA = json.loads(\'\'\'{{rag_data_json}}\'\'\') | |
rag_context = RAGContext(RAG_DATA) if RAG_DATA else None | |
def get_rag_context(query): | |
"""Get relevant context for a query""" | |
if not rag_context: | |
return "" | |
# In production, you'd compute query embedding here | |
# For now, return empty (would need embedding service) | |
return "" | |
''' | |
def format_context_for_prompt(context: str, query: str) -> str: | |
"""Format RAG context for inclusion in prompt""" | |
if not context: | |
return "" | |
return f"""Relevant context from uploaded documents: | |
{context} | |
Please use the above context to help answer the user's question: {query}""" |