import numpy as np import pandas as pd import faiss from pathlib import Path from sentence_transformers import SentenceTransformer, util import streamlit as st import zipfile import pandas as pd from pathlib import Path import streamlit as st class MetadataManager: def __init__(self): self.shard_dir = Path("metadata_shards") self.shard_map = {} self.loaded_shards = {} self._ensure_unzipped() self._build_shard_map() def _ensure_unzipped(self): """Extract metadata shards from zip if needed""" if not self.shard_dir.exists(): zip_path = Path("metadata_shards.zip") if zip_path.exists(): with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(self.shard_dir) st.toast("✅ Successfully extracted metadata shards!", icon="📦") else: raise FileNotFoundError("No metadata shards found!") def _build_shard_map(self): """Map index ranges to shard files""" for f in self.shard_dir.glob("*.parquet"): parts = f.stem.split("_") self.shard_map[(int(parts[1]), int(parts[2]))] = f.name def get_metadata(self, indices): """Retrieve metadata for specific indices""" results = [] shard_groups = {} # Group indices by shard for idx in indices: for (start, end), shard in self.shard_map.items(): if start <= idx <= end: if shard not in shard_groups: shard_groups[shard] = [] shard_groups[shard].append(idx - start) break # Load required shards for shard, local_indices in shard_groups.items(): if shard not in self.loaded_shards: self.loaded_shards[shard] = pd.read_parquet( self.shard_dir / shard, columns=["title", "summary", "source"] ) results.append(self.loaded_shards[shard].iloc[local_indices]) return pd.concat(results).reset_index(drop=True) class SemanticSearch: def __init__(self, shard_dir="compressed_shards"): self.shard_dir = Path(shard_dir) self.shard_dir.mkdir(exist_ok=True, parents=True) self.model = None self.index_shards = [] self.metadata_mgr = MetadataManager() @st.cache_resource def load_model(_self): return SentenceTransformer('all-MiniLM-L6-v2') def initialize_system(self): self.model = self.load_model() self._load_index_shards() def _load_index_shards(self): """Load FAISS shards directly from local directory""" for shard_path in sorted(self.shard_dir.glob("*.index")): self.index_shards.append(faiss.read_index(str(shard_path))) def search(self, query, top_k=5): """Search across all shards""" query_embedding = self.model.encode([query], convert_to_numpy=True) all_scores = [] all_indices = [] for shard_idx, index in enumerate(self.index_shards): distances, indices = index.search(query_embedding, top_k) # Convert local indices to global shard offsets global_indices = [ self._calculate_global_index(shard_idx, idx) for idx in indices[0] ] all_scores.extend(distances[0]) all_indices.extend(global_indices) return self._process_results(np.array(all_scores), np.array(all_indices), top_k) def _calculate_global_index(self, shard_idx, local_idx): """Convert shard-local index to global index""" # Implement your specific shard indexing logic here # Example: return f"{shard_idx}-{local_idx}" return local_idx # Simple version if using unique IDs def _process_results(self, distances, indices, top_k): """Format search results""" results = pd.DataFrame({ 'global_index': indices, 'similarity': 1 - (distances / 2) # L2 to cosine approximation }) return results.sort_values('similarity', ascending=False).head(top_k) def search_with_threshold(self, query, top_k=5, similarity_threshold=0.6): """Threshold-filtered search""" results = self.search(query, top_k*2) filtered = results[results['similarity'] > similarity_threshold].head(top_k) return filtered.reset_index(drop=True) class MetadataManager: def __init__(self, repo_id, shard_dir="metadata_shards"): self.repo_id = repo_id self.shard_dir = Path(shard_dir) self.shard_map = {} self.loaded_shards = {} self._build_shard_map() def _build_shard_map(self): """Map index ranges to shard files""" for f in self.shard_dir.glob("*.parquet"): parts = f.stem.split("_") self.shard_map[(int(parts[1]), int(parts[2]))] = f.name def _download_shard(self, shard_name): """Download missing shards on demand""" if not (self.shard_dir/shard_name).exists(): hf_hub_download( repo_id=self.repo_id, filename=f"metadata_shards/{shard_name}", local_dir=self.shard_dir, cache_dir="metadata_cache" ) def get_metadata(self, indices): """Retrieve metadata for specific indices""" results = [] # Group indices by shard shard_groups = {} for idx in indices: for (start, end), shard in self.shard_map.items(): if start <= idx <= end: if shard not in shard_groups: shard_groups[shard] = [] shard_groups[shard].append(idx - start) break # Process each required shard for shard, local_indices in shard_groups.items(): if shard not in self.loaded_shards: self._download_shard(shard) self.loaded_shards[shard] = pd.read_parquet(self.shard_dir/shard) results.append(self.loaded_shards[shard].iloc[local_indices]) return pd.concat(results).reset_index(drop=True)