import numpy as np import pickle import base64 from typing import List, Dict, Any, Tuple, Optional import json from dataclasses import dataclass try: from sentence_transformers import SentenceTransformer HAS_SENTENCE_TRANSFORMERS = True except ImportError: HAS_SENTENCE_TRANSFORMERS = False try: import faiss HAS_FAISS = True except ImportError: HAS_FAISS = False @dataclass class SearchResult: chunk_id: str text: str score: float metadata: Dict[str, Any] class VectorStore: def __init__(self, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"): self.embedding_model_name = embedding_model self.embedding_model = None self.index = None self.chunks = {} # chunk_id -> chunk data self.chunk_ids = [] # Ordered list for FAISS index mapping self.dimension = 384 # Default for all-MiniLM-L6-v2 if HAS_SENTENCE_TRANSFORMERS: self._initialize_model() def _initialize_model(self): """Initialize the embedding model""" if not HAS_SENTENCE_TRANSFORMERS: raise ImportError("sentence-transformers not installed") self.embedding_model = SentenceTransformer(self.embedding_model_name) # Update dimension based on model self.dimension = self.embedding_model.get_sentence_embedding_dimension() def create_embeddings(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Create embeddings for a list of texts""" if not self.embedding_model: self._initialize_model() # Process in batches for efficiency embeddings = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] batch_embeddings = self.embedding_model.encode( batch, convert_to_numpy=True, show_progress_bar=False ) embeddings.append(batch_embeddings) return np.vstack(embeddings) if embeddings else np.array([]) def build_index(self, chunks: List[Dict[str, Any]], show_progress: bool = True): """Build FAISS index from chunks""" if not HAS_FAISS: raise ImportError("faiss-cpu not installed") # Extract texts and build embeddings texts = [chunk['text'] for chunk in chunks] if show_progress: print(f"Creating embeddings for {len(texts)} chunks...") embeddings = self.create_embeddings(texts) # Build FAISS index if show_progress: print("Building FAISS index...") # Use IndexFlatIP for inner product (cosine similarity with normalized vectors) self.index = faiss.IndexFlatIP(self.dimension) # Normalize embeddings for cosine similarity faiss.normalize_L2(embeddings) # Add to index self.index.add(embeddings) # Store chunks and maintain mapping self.chunks = {} self.chunk_ids = [] for chunk in chunks: chunk_id = chunk['chunk_id'] self.chunks[chunk_id] = chunk self.chunk_ids.append(chunk_id) if show_progress: print(f"Index built with {len(chunks)} chunks") def search(self, query: str, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]: """Search for similar chunks""" if not self.index or not self.chunks: return [] # Create query embedding query_embedding = self.create_embeddings([query]) # Normalize for cosine similarity faiss.normalize_L2(query_embedding) # Search scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks))) # Convert to results results = [] for score, idx in zip(scores[0], indices[0]): if idx < 0 or score < score_threshold: continue chunk_id = self.chunk_ids[idx] chunk = self.chunks[chunk_id] result = SearchResult( chunk_id=chunk_id, text=chunk['text'], score=float(score), metadata=chunk.get('metadata', {}) ) results.append(result) return results def serialize(self) -> Dict[str, Any]: """Serialize the vector store for deployment""" if not self.index: raise ValueError("No index to serialize") # Serialize FAISS index index_bytes = faiss.serialize_index(self.index) index_base64 = base64.b64encode(index_bytes).decode('utf-8') return { 'index_base64': index_base64, 'chunks': self.chunks, 'chunk_ids': self.chunk_ids, 'dimension': self.dimension, 'model_name': self.embedding_model_name } @classmethod def deserialize(cls, data: Dict[str, Any]) -> 'VectorStore': """Deserialize a vector store from deployment data""" if not HAS_FAISS: raise ImportError("faiss-cpu not installed") store = cls(embedding_model=data['model_name']) # Deserialize FAISS index index_bytes = base64.b64decode(data['index_base64']) store.index = faiss.deserialize_index(index_bytes) # Restore chunks and mappings store.chunks = data['chunks'] store.chunk_ids = data['chunk_ids'] store.dimension = data['dimension'] return store def get_stats(self) -> Dict[str, Any]: """Get statistics about the vector store""" return { 'total_chunks': len(self.chunks), 'index_size': self.index.ntotal if self.index else 0, 'dimension': self.dimension, 'model': self.embedding_model_name } class LightweightVectorStore: """Lightweight version for deployed spaces without embedding model""" def __init__(self, serialized_data: Dict[str, Any]): if not HAS_FAISS: raise ImportError("faiss-cpu not installed") # Deserialize FAISS index index_bytes = base64.b64decode(serialized_data['index_base64']) self.index = faiss.deserialize_index(index_bytes) # Restore chunks and mappings self.chunks = serialized_data['chunks'] self.chunk_ids = serialized_data['chunk_ids'] self.dimension = serialized_data['dimension'] # For query embedding, we'll need to include pre-computed embeddings # or use a lightweight embedding service self.query_embeddings_cache = serialized_data.get('query_embeddings_cache', {}) def search_with_embedding(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]: """Search using pre-computed query embedding""" if not self.index or not self.chunks: return [] # Normalize for cosine similarity faiss.normalize_L2(query_embedding) # Search scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks))) # Convert to results results = [] for score, idx in zip(scores[0], indices[0]): if idx < 0 or score < score_threshold: continue chunk_id = self.chunk_ids[idx] chunk = self.chunks[chunk_id] result = SearchResult( chunk_id=chunk_id, text=chunk['text'], score=float(score), metadata=chunk.get('metadata', {}) ) results.append(result) return results # Utility functions def estimate_index_size(num_chunks: int, dimension: int = 384) -> float: """Estimate the size of the index in MB""" # Rough estimation: 4 bytes per float * dimension * num_chunks bytes_size = 4 * dimension * num_chunks # Add overhead for index structure and metadata overhead = 1.2 return (bytes_size * overhead) / (1024 * 1024)