"""Embedding Manager for Starfish This module provides embedding functionality using FAISS and SentenceTransformers for semantic similarity search and data deduplication. """ import numpy as np from typing import List, Dict, Any, Optional, Tuple, Union import faiss from sentence_transformers import SentenceTransformer import pickle import os from pathlib import Path from starfish.common.logger import get_logger logger = get_logger(__name__) class EmbeddingManager: """ Manages embeddings using SentenceTransformers and FAISS for efficient similarity search. Features: - Text embedding using pre-trained SentenceTransformers models - Fast similarity search using FAISS indexing - Persistent storage and loading of embeddings - Configurable similarity thresholds - Support for both exact and approximate nearest neighbor search """ def __init__( self, model_name: str = "all-MiniLM-L6-v2", index_type: str = "flat", similarity_threshold: float = 0.85, cache_dir: Optional[str] = None, device: Optional[str] = None, ): """ Initialize the EmbeddingManager. Args: model_name: SentenceTransformers model name or path index_type: Type of FAISS index ('flat', 'ivf', 'hnsw') similarity_threshold: Threshold for determining similar items (0-1) cache_dir: Directory to cache embeddings and models device: Device to run model on ('cpu', 'cuda', 'mps') """ self.model_name = model_name self.index_type = index_type self.similarity_threshold = similarity_threshold self.cache_dir = Path(cache_dir) if cache_dir else Path.home() / ".starfish" / "embeddings" self.cache_dir.mkdir(parents=True, exist_ok=True) # Initialize SentenceTransformer model logger.info(f"Loading SentenceTransformer model: {model_name}") self.model = SentenceTransformer(model_name, device=device) self.embedding_dim = self.model.get_sentence_embedding_dimension() # Initialize FAISS index self.index = None self.metadata = [] # Store original texts and metadata self.id_to_index = {} # Map custom IDs to FAISS indices logger.info(f"EmbeddingManager initialized with {model_name}, dim={self.embedding_dim}") def _create_index(self, dimension: int) -> faiss.Index: """Create a FAISS index based on the specified type.""" if self.index_type == "flat": # L2 distance (Euclidean) index = faiss.IndexFlatL2(dimension) elif self.index_type == "ivf": # Inverted file index for faster approximate search quantizer = faiss.IndexFlatL2(dimension) index = faiss.IndexIVFFlat(quantizer, dimension, 100) # 100 clusters elif self.index_type == "hnsw": # Hierarchical Navigable Small World for very fast approximate search index = faiss.IndexHNSWFlat(dimension, 32) else: raise ValueError(f"Unsupported index type: {self.index_type}") return index def embed_texts(self, texts: List[str], show_progress: bool = True) -> np.ndarray: """ Embed a list of texts using SentenceTransformers. Args: texts: List of texts to embed show_progress: Whether to show progress bar Returns: numpy array of embeddings with shape (len(texts), embedding_dim) """ if not texts: return np.array([]).reshape(0, self.embedding_dim) logger.info(f"Embedding {len(texts)} texts...") embeddings = self.model.encode( texts, convert_to_numpy=True, show_progress_bar=show_progress, normalize_embeddings=True, # Normalize for cosine similarity ) return embeddings.astype(np.float32) def add_texts(self, texts: List[str], metadata: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None) -> List[int]: """ Add texts to the embedding index. Args: texts: List of texts to add metadata: Optional metadata for each text ids: Optional custom IDs for each text Returns: List of internal indices assigned to the texts """ if not texts: return [] # Generate embeddings embeddings = self.embed_texts(texts) # Initialize index if needed if self.index is None: self.index = self._create_index(self.embedding_dim) if self.index_type == "ivf": # Train the IVF index if len(embeddings) >= 100: # Need at least as many points as clusters self.index.train(embeddings) else: logger.warning("Not enough data to train IVF index, using flat index instead") self.index = faiss.IndexFlatL2(self.embedding_dim) # Add to index start_idx = self.index.ntotal self.index.add(embeddings) # Store metadata if metadata is None: metadata = [{"text": text} for text in texts] else: # Ensure metadata includes the original text for i, meta in enumerate(metadata): if "text" not in meta: meta["text"] = texts[i] self.metadata.extend(metadata) # Handle custom IDs indices = list(range(start_idx, start_idx + len(texts))) if ids: for i, custom_id in enumerate(ids): self.id_to_index[custom_id] = indices[i] logger.info(f"Added {len(texts)} texts to index. Total: {self.index.ntotal}") return indices def search_similar(self, query_text: str, k: int = 5, threshold: Optional[float] = None) -> List[Dict[str, Any]]: """ Search for similar texts in the index. Args: query_text: Text to search for k: Number of similar items to return threshold: Similarity threshold (overrides default) Returns: List of dictionaries containing similar items with scores and metadata """ if self.index is None or self.index.ntotal == 0: logger.warning("Index is empty or not initialized") return [] # Embed query query_embedding = self.embed_texts([query_text], show_progress=False) # Search if self.index_type == "ivf" and hasattr(self.index, "nprobe"): self.index.nprobe = min(10, self.index.nlist) # Search in 10 clusters scores, indices = self.index.search(query_embedding, k) # Convert L2 distances to cosine similarities # Since embeddings are normalized, L2 distance relates to cosine similarity similarities = 1 - (scores[0] / 2) # Convert L2 to cosine similarity # Filter by threshold threshold = threshold or self.similarity_threshold results = [] for idx, similarity in zip(indices[0], similarities): if idx != -1 and similarity >= threshold: # -1 indicates no match found result = { "index": int(idx), "similarity": float(similarity), "metadata": self.metadata[idx].copy() if idx < len(self.metadata) else {}, "text": self.metadata[idx].get("text", "") if idx < len(self.metadata) else "", } results.append(result) logger.debug(f"Found {len(results)} similar items for query (threshold={threshold})") return results def find_duplicates(self, texts: List[str], threshold: Optional[float] = None) -> List[List[int]]: """ Find groups of duplicate/similar texts. Args: texts: List of texts to check for duplicates threshold: Similarity threshold for considering items duplicates Returns: List of lists, where each inner list contains indices of similar texts """ threshold = threshold or self.similarity_threshold if not texts: return [] # Embed all texts embeddings = self.embed_texts(texts, show_progress=True) # Create temporary index for comparison temp_index = faiss.IndexFlatL2(self.embedding_dim) temp_index.add(embeddings) # Find similar items duplicate_groups = [] processed = set() for i, embedding in enumerate(embeddings): if i in processed: continue # Search for similar items query_embedding = embedding.reshape(1, -1) scores, indices = temp_index.search(query_embedding, len(texts)) # Convert to similarities and filter similarities = 1 - (scores[0] / 2) similar_indices = [] for idx, similarity in zip(indices[0], similarities): if similarity >= threshold and idx not in processed: similar_indices.append(idx) processed.add(idx) if len(similar_indices) > 1: duplicate_groups.append(similar_indices) logger.info(f"Found {len(duplicate_groups)} groups of duplicates") return duplicate_groups def save_index(self, filepath: str) -> None: """Save the FAISS index and metadata to disk.""" if self.index is None: logger.warning("No index to save") return filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) # Save FAISS index faiss.write_index(self.index, str(filepath.with_suffix(".faiss"))) # Save metadata and configuration metadata_file = filepath.with_suffix(".pkl") with open(metadata_file, "wb") as f: pickle.dump( { "metadata": self.metadata, "id_to_index": self.id_to_index, "model_name": self.model_name, "index_type": self.index_type, "similarity_threshold": self.similarity_threshold, "embedding_dim": self.embedding_dim, }, f, ) logger.info(f"Saved index to {filepath}") def load_index(self, filepath: str) -> None: """Load a FAISS index and metadata from disk.""" filepath = Path(filepath) # Load FAISS index index_file = filepath.with_suffix(".faiss") if not index_file.exists(): raise FileNotFoundError(f"Index file not found: {index_file}") self.index = faiss.read_index(str(index_file)) # Load metadata and configuration metadata_file = filepath.with_suffix(".pkl") if metadata_file.exists(): with open(metadata_file, "rb") as f: data = pickle.load(f) self.metadata = data.get("metadata", []) self.id_to_index = data.get("id_to_index", {}) # Verify model compatibility saved_model = data.get("model_name", self.model_name) if saved_model != self.model_name: logger.warning(f"Model mismatch: saved={saved_model}, current={self.model_name}") logger.info(f"Loaded index from {filepath} ({self.index.ntotal} items)") def get_embedding_by_id(self, custom_id: str) -> Optional[np.ndarray]: """Get embedding vector by custom ID.""" if custom_id not in self.id_to_index: return None idx = self.id_to_index[custom_id] if self.index is None or idx >= self.index.ntotal: return None return self.index.reconstruct(idx) def get_stats(self) -> Dict[str, Any]: """Get statistics about the current index.""" return { "model_name": self.model_name, "index_type": self.index_type, "embedding_dimension": self.embedding_dim, "total_items": self.index.ntotal if self.index else 0, "similarity_threshold": self.similarity_threshold, "metadata_count": len(self.metadata), "custom_ids_count": len(self.id_to_index), }