Spaces:
Running
Running
import numpy as np | |
import pandas as pd | |
import faiss | |
from pathlib import Path | |
from sentence_transformers import SentenceTransformer, util | |
import streamlit as st | |
import zipfile | |
import pandas as pd | |
from pathlib import Path | |
import streamlit as st | |
class MetadataManager: | |
def __init__(self): | |
self.shard_dir = Path("metadata_shards") | |
self.shard_map = {} | |
self.loaded_shards = {} | |
self._ensure_unzipped() | |
self._build_shard_map() | |
def _ensure_unzipped(self): | |
"""Extract metadata shards from zip if needed""" | |
if not self.shard_dir.exists(): | |
zip_path = Path("metadata_shards.zip") | |
if zip_path.exists(): | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(self.shard_dir) | |
st.toast("β Successfully extracted metadata shards!", icon="π¦") | |
else: | |
raise FileNotFoundError("No metadata shards found!") | |
def _build_shard_map(self): | |
"""Map index ranges to shard files""" | |
for f in self.shard_dir.glob("*.parquet"): | |
parts = f.stem.split("_") | |
self.shard_map[(int(parts[1]), int(parts[2]))] = f.name | |
def get_metadata(self, indices): | |
"""Retrieve metadata for specific indices""" | |
results = [] | |
shard_groups = {} | |
# Group indices by shard | |
for idx in indices: | |
for (start, end), shard in self.shard_map.items(): | |
if start <= idx <= end: | |
if shard not in shard_groups: | |
shard_groups[shard] = [] | |
shard_groups[shard].append(idx - start) | |
break | |
# Load required shards | |
for shard, local_indices in shard_groups.items(): | |
if shard not in self.loaded_shards: | |
self.loaded_shards[shard] = pd.read_parquet( | |
self.shard_dir / shard, | |
columns=["title", "summary", "source"] | |
) | |
results.append(self.loaded_shards[shard].iloc[local_indices]) | |
return pd.concat(results).reset_index(drop=True) | |
class SemanticSearch: | |
def __init__(self, shard_dir="compressed_shards"): | |
self.shard_dir = Path(shard_dir) | |
self.shard_dir.mkdir(exist_ok=True, parents=True) | |
self.model = None | |
self.index_shards = [] | |
self.metadata_mgr = MetadataManager() | |
def load_model(_self): | |
return SentenceTransformer('all-MiniLM-L6-v2') | |
def initialize_system(self): | |
self.model = self.load_model() | |
self._load_index_shards() | |
def _load_index_shards(self): | |
"""Load FAISS shards directly from local directory""" | |
for shard_path in sorted(self.shard_dir.glob("*.index")): | |
self.index_shards.append(faiss.read_index(str(shard_path))) | |
def search(self, query, top_k=5): | |
"""Search across all shards""" | |
query_embedding = self.model.encode([query], convert_to_numpy=True) | |
all_scores = [] | |
all_indices = [] | |
for shard_idx, index in enumerate(self.index_shards): | |
distances, indices = index.search(query_embedding, top_k) | |
# Convert local indices to global shard offsets | |
global_indices = [ | |
self._calculate_global_index(shard_idx, idx) | |
for idx in indices[0] | |
] | |
all_scores.extend(distances[0]) | |
all_indices.extend(global_indices) | |
return self._process_results(np.array(all_scores), np.array(all_indices), top_k) | |
def _calculate_global_index(self, shard_idx, local_idx): | |
"""Convert shard-local index to global index""" | |
# Implement your specific shard indexing logic here | |
# Example: return f"{shard_idx}-{local_idx}" | |
return local_idx # Simple version if using unique IDs | |
def _process_results(self, distances, indices, top_k): | |
"""Format search results""" | |
results = pd.DataFrame({ | |
'global_index': indices, | |
'similarity': 1 - (distances / 2) # L2 to cosine approximation | |
}) | |
return results.sort_values('similarity', ascending=False).head(top_k) | |
def search_with_threshold(self, query, top_k=5, similarity_threshold=0.6): | |
"""Threshold-filtered search""" | |
results = self.search(query, top_k*2) | |
filtered = results[results['similarity'] > similarity_threshold].head(top_k) | |
return filtered.reset_index(drop=True) | |
class MetadataManager: | |
def __init__(self, repo_id, shard_dir="metadata_shards"): | |
self.repo_id = repo_id | |
self.shard_dir = Path(shard_dir) | |
self.shard_map = {} | |
self.loaded_shards = {} | |
self._build_shard_map() | |
def _build_shard_map(self): | |
"""Map index ranges to shard files""" | |
for f in self.shard_dir.glob("*.parquet"): | |
parts = f.stem.split("_") | |
self.shard_map[(int(parts[1]), int(parts[2]))] = f.name | |
def _download_shard(self, shard_name): | |
"""Download missing shards on demand""" | |
if not (self.shard_dir/shard_name).exists(): | |
hf_hub_download( | |
repo_id=self.repo_id, | |
filename=f"metadata_shards/{shard_name}", | |
local_dir=self.shard_dir, | |
cache_dir="metadata_cache" | |
) | |
def get_metadata(self, indices): | |
"""Retrieve metadata for specific indices""" | |
results = [] | |
# Group indices by shard | |
shard_groups = {} | |
for idx in indices: | |
for (start, end), shard in self.shard_map.items(): | |
if start <= idx <= end: | |
if shard not in shard_groups: | |
shard_groups[shard] = [] | |
shard_groups[shard].append(idx - start) | |
break | |
# Process each required shard | |
for shard, local_indices in shard_groups.items(): | |
if shard not in self.loaded_shards: | |
self._download_shard(shard) | |
self.loaded_shards[shard] = pd.read_parquet(self.shard_dir/shard) | |
results.append(self.loaded_shards[shard].iloc[local_indices]) | |
return pd.concat(results).reset_index(drop=True) |