Spaces:
Running
Running
File size: 4,123 Bytes
5ee0a10 7ccde22 017ee94 7ccde22 017ee94 b2bcde5 7ccde22 b2bcde5 7ccde22 017ee94 b2bcde5 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 5ee0a10 017ee94 5ee0a10 b2bcde5 017ee94 b2bcde5 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import numpy as np
import pandas as pd
import faiss
import zipfile
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
class MetadataManager:
def __init__(self):
self.shard_dir = Path("metadata_shards")
self.shard_map = {}
self.loaded_shards = {}
self.total_docs = 0
self._ensure_unzipped() # Removed Streamlit elements from here
self._build_shard_map()
def _ensure_unzipped(self):
"""Handle ZIP extraction without Streamlit elements"""
if not self.shard_dir.exists():
zip_path = Path("metadata_shards.zip")
if zip_path.exists():
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.shard_dir)
else:
raise FileNotFoundError("Metadata ZIP file not found")
def _build_shard_map(self):
"""Create index range to shard mapping"""
self.total_docs = 0
for f in sorted(self.shard_dir.glob("*.parquet")):
parts = f.stem.split("_")
start = int(parts[1])
end = int(parts[2])
self.shard_map[(start, end)] = f.name
self.total_docs = max(self.total_docs, end + 1)
def get_metadata(self, global_indices):
"""Retrieve metadata for global indices"""
results = []
shard_groups = {}
# Organize indices by their respective shards
for idx in global_indices:
for (start, end), shard in self.shard_map.items():
if start <= idx <= end:
if shard not in shard_groups:
shard_groups[shard] = []
shard_groups[shard].append(idx - start)
break
# Load and process required shards
for shard, local_indices in shard_groups.items():
if shard not in self.loaded_shards:
self.loaded_shards[shard] = pd.read_parquet(
self.shard_dir / shard,
columns=["title", "summary", "source"]
)
results.append(self.loaded_shards[shard].iloc[local_indices])
return pd.concat(results).reset_index(drop=True)
class SemanticSearch:
def __init__(self):
self.shard_dir = Path("compressed_shards")
self.model = None
self.index_shards = []
self.metadata_mgr = MetadataManager() # No Streamlit elements in constructor
self.shard_sizes = []
@st.cache_resource
def load_model(_self):
return SentenceTransformer('all-MiniLM-L6-v2')
def initialize_system(self):
self.model = self.load_model()
self._load_faiss_shards()
def _load_faiss_shards(self):
"""Load all FAISS index shards"""
self.shard_sizes = []
for shard_path in sorted(self.shard_dir.glob("*.index")):
index = faiss.read_index(str(shard_path))
self.index_shards.append(index)
self.shard_sizes.append(index.ntotal)
def _global_index(self, shard_idx, local_idx):
"""Convert local index to global index"""
return sum(self.shard_sizes[:shard_idx]) + local_idx
def search(self, query, top_k=5):
"""Main search functionality"""
query_embedding = self.model.encode([query], convert_to_numpy=True)
all_distances = []
all_global_indices = []
# Search across all shards
for shard_idx, index in enumerate(self.index_shards):
distances, indices = index.search(query_embedding, top_k)
global_indices = [self._global_index(shard_idx, idx) for idx in indices[0]]
all_distances.extend(distances[0])
all_global_indices.extend(global_indices)
# Process and format results
results = self.metadata_mgr.get_metadata(all_global_indices)
results['similarity'] = 1 - (np.array(all_distances) / 2) # Convert L2 to cosine
return results.sort_values('similarity', ascending=False).head(top_k) |