import numpy as np import pandas as pd import faiss import zipfile from pathlib import Path from sentence_transformers import SentenceTransformer, util import streamlit as st class MetadataManager: def __init__(self): self.shard_dir = Path("metadata_shards") self.shard_map = {} self.loaded_shards = {} self.total_docs = 0 self._ensure_unzipped() # Removed Streamlit elements from here self._build_shard_map() def _ensure_unzipped(self): """Handle ZIP extraction without Streamlit elements""" if not self.shard_dir.exists(): zip_path = Path("metadata_shards.zip") if zip_path.exists(): with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(self.shard_dir) else: raise FileNotFoundError("Metadata ZIP file not found") def _build_shard_map(self): """Create index range to shard mapping""" self.total_docs = 0 for f in sorted(self.shard_dir.glob("*.parquet")): parts = f.stem.split("_") start = int(parts[1]) end = int(parts[2]) self.shard_map[(start, end)] = f.name self.total_docs = max(self.total_docs, end + 1) def get_metadata(self, global_indices): """Retrieve metadata with validation""" # Check for empty numpy array properly if isinstance(global_indices, np.ndarray) and global_indices.size == 0: return pd.DataFrame(columns=["title", "summary", "source", "similarity"]) # Convert numpy array to list for processing indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices # Filter valid indices valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs] if not valid_indices: return pd.DataFrame(columns=["title", "summary", "source", "similarity"]) # Group indices by shard with boundary check shard_groups = {} for idx in valid_indices: found = False for (start, end), shard in self.shard_map.items(): if start <= idx <= end: if shard not in shard_groups: shard_groups[shard] = [] shard_groups[shard].append(idx - start) found = True break if not found: st.warning(f"Index {idx} out of shard range (0-{self.total_docs-1})") # Load and process shards results = [] for shard, local_indices in shard_groups.items(): try: if shard not in self.loaded_shards: self.loaded_shards[shard] = pd.read_parquet( self.shard_dir / shard, columns=["title", "summary", "source"] ) if local_indices: results.append(self.loaded_shards[shard].iloc[local_indices]) except Exception as e: st.error(f"Error loading shard {shard}: {str(e)}") continue return pd.concat(results).reset_index(drop=True) if results else pd.DataFrame() class SemanticSearch: def __init__(self): self.shard_dir = Path("compressed_shards") self.model = None self.index_shards = [] self.metadata_mgr = MetadataManager() # No Streamlit elements in constructor self.shard_sizes = [] @st.cache_resource def load_model(_self): return SentenceTransformer('all-MiniLM-L6-v2') def initialize_system(self): self.model = self.load_model() self._load_faiss_shards() def _load_faiss_shards(self): """Load all FAISS index shards""" self.shard_sizes = [] for shard_path in sorted(self.shard_dir.glob("*.index")): index = faiss.read_index(str(shard_path)) self.index_shards.append(index) self.shard_sizes.append(index.ntotal) def _global_index(self, shard_idx, local_idx): """Convert local index to global index""" return sum(self.shard_sizes[:shard_idx]) + local_idx def search(self, query, top_k=5): """Search with validation""" if not query or not self.index_shards: return pd.DataFrame() try: query_embedding = self.model.encode([query], convert_to_numpy=True) except Exception as e: st.error(f"Query encoding failed: {str(e)}") return pd.DataFrame() all_distances = [] all_global_indices = [] # Search with index validation for shard_idx, index in enumerate(self.index_shards): if index.ntotal == 0: continue try: distances, indices = index.search(query_embedding, top_k) valid_indices = [idx for idx in indices[0] if 0 <= idx < index.ntotal] global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices] all_distances.extend(distances[0][:len(valid_indices)]) all_global_indices.extend(global_indices) except Exception as e: st.error(f"Search failed in shard {shard_idx}: {str(e)}") continue # Ensure equal array lengths min_length = min(len(all_distances), len(all_global_indices)) return self._process_results( np.array(all_distances[:min_length]), np.array(all_global_indices[:min_length]), top_k ) def _process_results(self, distances, global_indices, top_k): """Process raw search results into formatted DataFrame""" # Proper numpy array emptiness checks if global_indices.size == 0 or distances.size == 0: return pd.DataFrame(columns=["title", "summary", "source", "similarity"]) try: # Convert numpy indices to Python list for metadata retrieval indices_list = global_indices.tolist() # Get metadata for matched indices results = self.metadata_mgr.get_metadata(indices_list) # Ensure distances match results length if len(results) != len(distances): distances = distances[:len(results)] # Calculate similarity scores results['similarity'] = 1 - (distances / 2) # Deduplicate and sort results results = results.drop_duplicates(subset=["title", "source"]) .sort_values("similarity", ascending=False) .head(top_k) return results.reset_index(drop=True) except Exception as e: st.error(f"Result processing failed: {str(e)}") return pd.DataFrame(columns=["title", "summary", "source", "similarity"])