Spaces:
Running
Running
import numpy as np | |
import pandas as pd | |
import faiss | |
import zipfile | |
from pathlib import Path | |
from sentence_transformers import SentenceTransformer, util | |
import streamlit as st | |
class MetadataManager: | |
def __init__(self): | |
self.shard_dir = Path("metadata_shards") | |
self.shard_map = {} | |
self.loaded_shards = {} | |
self.total_docs = 0 | |
self._ensure_unzipped() # Removed Streamlit elements from here | |
self._build_shard_map() | |
def _ensure_unzipped(self): | |
"""Handle ZIP extraction without Streamlit elements""" | |
if not self.shard_dir.exists(): | |
zip_path = Path("metadata_shards.zip") | |
if zip_path.exists(): | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(self.shard_dir) | |
else: | |
raise FileNotFoundError("Metadata ZIP file not found") | |
def _build_shard_map(self): | |
"""Create index range to shard mapping""" | |
self.total_docs = 0 | |
for f in sorted(self.shard_dir.glob("*.parquet")): | |
parts = f.stem.split("_") | |
start = int(parts[1]) | |
end = int(parts[2]) | |
self.shard_map[(start, end)] = f.name | |
self.total_docs = max(self.total_docs, end + 1) | |
def get_metadata(self, global_indices): | |
"""Retrieve metadata with validation""" | |
# Check for empty numpy array properly | |
if isinstance(global_indices, np.ndarray) and global_indices.size == 0: | |
return pd.DataFrame(columns=["title", "summary", "source", "similarity"]) | |
# Convert numpy array to list for processing | |
indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices | |
# Filter valid indices | |
valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs] | |
if not valid_indices: | |
return pd.DataFrame(columns=["title", "summary", "source", "similarity"]) | |
# Group indices by shard with boundary check | |
shard_groups = {} | |
for idx in valid_indices: | |
found = False | |
for (start, end), shard in self.shard_map.items(): | |
if start <= idx <= end: | |
if shard not in shard_groups: | |
shard_groups[shard] = [] | |
shard_groups[shard].append(idx - start) | |
found = True | |
break | |
if not found: | |
st.warning(f"Index {idx} out of shard range (0-{self.total_docs-1})") | |
# Load and process shards | |
results = [] | |
for shard, local_indices in shard_groups.items(): | |
try: | |
if shard not in self.loaded_shards: | |
self.loaded_shards[shard] = pd.read_parquet( | |
self.shard_dir / shard, | |
columns=["title", "summary", "source"] | |
) | |
if local_indices: | |
results.append(self.loaded_shards[shard].iloc[local_indices]) | |
except Exception as e: | |
st.error(f"Error loading shard {shard}: {str(e)}") | |
continue | |
return pd.concat(results).reset_index(drop=True) if results else pd.DataFrame() | |
class SemanticSearch: | |
def __init__(self): | |
self.shard_dir = Path("compressed_shards") | |
self.model = None | |
self.index_shards = [] | |
self.metadata_mgr = MetadataManager() # No Streamlit elements in constructor | |
self.shard_sizes = [] | |
def load_model(_self): | |
return SentenceTransformer('all-MiniLM-L6-v2') | |
def initialize_system(self): | |
self.model = self.load_model() | |
self._load_faiss_shards() | |
def _load_faiss_shards(self): | |
"""Load all FAISS index shards""" | |
self.shard_sizes = [] | |
for shard_path in sorted(self.shard_dir.glob("*.index")): | |
index = faiss.read_index(str(shard_path)) | |
self.index_shards.append(index) | |
self.shard_sizes.append(index.ntotal) | |
def _global_index(self, shard_idx, local_idx): | |
"""Convert local index to global index""" | |
return sum(self.shard_sizes[:shard_idx]) + local_idx | |
def search(self, query, top_k=5): | |
"""Search with validation""" | |
if not query or not self.index_shards: | |
return pd.DataFrame() | |
try: | |
query_embedding = self.model.encode([query], convert_to_numpy=True) | |
except Exception as e: | |
st.error(f"Query encoding failed: {str(e)}") | |
return pd.DataFrame() | |
all_distances = [] | |
all_global_indices = [] | |
# Search with index validation | |
for shard_idx, index in enumerate(self.index_shards): | |
if index.ntotal == 0: | |
continue | |
try: | |
distances, indices = index.search(query_embedding, top_k) | |
valid_indices = [idx for idx in indices[0] if 0 <= idx < index.ntotal] | |
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices] | |
all_distances.extend(distances[0][:len(valid_indices)]) | |
all_global_indices.extend(global_indices) | |
except Exception as e: | |
st.error(f"Search failed in shard {shard_idx}: {str(e)}") | |
continue | |
# Ensure equal array lengths | |
min_length = min(len(all_distances), len(all_global_indices)) | |
return self._process_results( | |
np.array(all_distances[:min_length]), | |
np.array(all_global_indices[:min_length]), | |
top_k | |
) | |
def _process_results(self, distances, global_indices, top_k): | |
"""Process raw search results into formatted DataFrame""" | |
# Proper numpy array emptiness checks | |
if global_indices.size == 0 or distances.size == 0: | |
return pd.DataFrame(columns=["title", "summary", "source", "similarity"]) | |
try: | |
# Convert numpy indices to Python list for metadata retrieval | |
indices_list = global_indices.tolist() | |
# Get metadata for matched indices | |
results = self.metadata_mgr.get_metadata(indices_list) | |
# Ensure distances match results length | |
if len(results) != len(distances): | |
distances = distances[:len(results)] | |
# Calculate similarity scores | |
results['similarity'] = 1 - (distances / 2) | |
# Deduplicate and sort results | |
results = results.drop_duplicates(subset=["title", "source"]) | |
.sort_values("similarity", ascending=False) | |
.head(top_k) | |
return results.reset_index(drop=True) | |
except Exception as e: | |
st.error(f"Result processing failed: {str(e)}") | |
return pd.DataFrame(columns=["title", "summary", "source", "similarity"]) |