chatui-helper / rag_tool.py
milwright's picture
Add vector RAG functionality as modular tool
7f85357
raw
history blame
6.87 kB
import json
from typing import List, Dict, Any, Optional, Tuple
from document_processor import DocumentProcessor, DocumentChunk
from vector_store import VectorStore, SearchResult
import os
import tempfile
from pathlib import Path
class RAGTool:
"""RAG tool for integrating document search with chat"""
def __init__(self):
self.processor = DocumentProcessor(chunk_size=800, chunk_overlap=100)
self.vector_store = VectorStore()
self.processed_files = []
self.total_chunks = 0
def process_uploaded_files(self, file_paths: List[str]) -> Dict[str, Any]:
"""Process uploaded files and build vector index"""
# Validate files
valid_files = []
errors = []
for file_path in file_paths:
try:
# Check file size (10MB limit)
size_mb = os.path.getsize(file_path) / (1024 * 1024)
if size_mb > 10:
errors.append({
'file': Path(file_path).name,
'error': f'File too large ({size_mb:.1f}MB). Maximum size is 10MB.'
})
continue
valid_files.append(file_path)
except Exception as e:
errors.append({
'file': Path(file_path).name,
'error': str(e)
})
if not valid_files:
return {
'success': False,
'message': 'No valid files to process',
'errors': errors
}
# Process files
all_chunks, summary = self.processor.process_multiple_files(valid_files)
if not all_chunks:
return {
'success': False,
'message': 'No content extracted from files',
'summary': summary
}
# Build vector index
chunk_dicts = [chunk.to_dict() for chunk in all_chunks]
self.vector_store.build_index(chunk_dicts, show_progress=False)
# Update stats
self.processed_files = summary['files_processed']
self.total_chunks = len(all_chunks)
# Calculate index size
index_stats = self.vector_store.get_stats()
return {
'success': True,
'message': f'Successfully processed {len(valid_files)} files into {self.total_chunks} chunks',
'summary': summary,
'index_stats': index_stats,
'errors': errors
}
def get_relevant_context(self, query: str, max_chunks: int = 3) -> str:
"""Get relevant context for a query"""
if not self.vector_store.index:
return ""
# Search for relevant chunks
results = self.vector_store.search(
query=query,
top_k=max_chunks,
score_threshold=0.3
)
if not results:
return ""
# Format context
context_parts = []
for i, result in enumerate(results, 1):
file_name = result.metadata.get('file_name', 'Unknown')
context_parts.append(
f"[Document: {file_name} - Relevance: {result.score:.2f}]\n{result.text}"
)
return "\n\n".join(context_parts)
def get_serialized_data(self) -> Dict[str, Any]:
"""Get serialized data for deployment"""
if not self.vector_store.index:
return None
return self.vector_store.serialize()
def get_deployment_info(self) -> Dict[str, Any]:
"""Get information for deployment package"""
if not self.vector_store.index:
return {
'enabled': False,
'message': 'No documents processed'
}
# Estimate package size increase
index_stats = self.vector_store.get_stats()
estimated_size_mb = (
# Index size estimation
(index_stats['total_chunks'] * index_stats['dimension'] * 4) / (1024 * 1024) +
# Chunks text size estimation
(sum(len(chunk['text']) for chunk in self.vector_store.chunks.values()) / (1024 * 1024))
) * 1.5 # Add overhead for base64 encoding
return {
'enabled': True,
'total_files': len(self.processed_files),
'total_chunks': self.total_chunks,
'estimated_size_mb': round(estimated_size_mb, 2),
'files': [f['name'] for f in self.processed_files]
}
def create_rag_module_for_space(serialized_data: Dict[str, Any]) -> str:
"""Create a minimal RAG module for the deployed space"""
return '''# RAG Module for deployed space
import numpy as np
import faiss
import base64
import json
class RAGContext:
def __init__(self, serialized_data):
# Deserialize FAISS index
index_bytes = base64.b64decode(serialized_data['index_base64'])
self.index = faiss.deserialize_index(index_bytes)
# Restore chunks and mappings
self.chunks = serialized_data['chunks']
self.chunk_ids = serialized_data['chunk_ids']
def get_context(self, query_embedding, max_chunks=3):
"""Get relevant context using pre-computed embedding"""
if not self.index:
return ""
# Normalize and search
faiss.normalize_L2(query_embedding)
scores, indices = self.index.search(query_embedding, max_chunks)
# Format results
context_parts = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0 or score < 0.3:
continue
chunk = self.chunks[self.chunk_ids[idx]]
file_name = chunk.get('metadata', {}).get('file_name', 'Document')
context_parts.append(
f"[{file_name} - Relevance: {score:.2f}]\\n{chunk['text']}"
)
return "\\n\\n".join(context_parts) if context_parts else ""
# Initialize RAG context
RAG_DATA = json.loads(\'\'\'{{rag_data_json}}\'\'\')
rag_context = RAGContext(RAG_DATA) if RAG_DATA else None
def get_rag_context(query):
"""Get relevant context for a query"""
if not rag_context:
return ""
# In production, you'd compute query embedding here
# For now, return empty (would need embedding service)
return ""
'''
def format_context_for_prompt(context: str, query: str) -> str:
"""Format RAG context for inclusion in prompt"""
if not context:
return ""
return f"""Relevant context from uploaded documents:
{context}
Please use the above context to help answer the user's question: {query}"""