Spaces:
Running
Running
File size: 4,157 Bytes
5ee0a10 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 7ccde22 5ee0a10 017ee94 5ee0a10 7ccde22 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import numpy as np
import pandas as pd
import faiss
import zipfile
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
class MetadataManager:
def __init__(self):
self.shard_dir = Path("metadata_shards")
self.shard_map = {}
self.loaded_shards = {}
self.total_docs = 0
self._ensure_unzipped()
self._build_shard_map()
def _ensure_unzipped(self):
"""Handle ZIP extraction automatically"""
if not self.shard_dir.exists():
zip_path = Path("metadata_shards.zip")
if zip_path.exists():
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.shard_dir)
st.toast("π¦ Metadata shards extracted successfully!", icon="β
")
else:
st.error("β Missing metadata_shards.zip file!")
raise FileNotFoundError("Metadata ZIP file not found")
def _build_shard_map(self):
"""Create index range to shard mapping"""
self.total_docs = 0
for f in sorted(self.shard_dir.glob("*.parquet")):
parts = f.stem.split("_")
start = int(parts[1])
end = int(parts[2])
self.shard_map[(start, end)] = f.name
self.total_docs = max(self.total_docs, end + 1)
def get_metadata(self, global_indices):
"""Retrieve metadata for global indices"""
results = []
shard_groups = {}
# Organize indices by their respective shards
for idx in global_indices:
for (start, end), shard in self.shard_map.items():
if start <= idx <= end:
if shard not in shard_groups:
shard_groups[shard] = []
shard_groups[shard].append(idx - start)
break
# Load and process required shards
for shard, local_indices in shard_groups.items():
if shard not in self.loaded_shards:
self.loaded_shards[shard] = pd.read_parquet(
self.shard_dir / shard,
columns=["title", "summary", "source"]
)
results.append(self.loaded_shards[shard].iloc[local_indices])
return pd.concat(results).reset_index(drop=True)
class SemanticSearch:
def __init__(self):
self.shard_dir = Path("compressed_shards")
self.model = None
self.index_shards = []
self.metadata_mgr = MetadataManager()
self.shard_sizes = []
@st.cache_resource
def load_model(_self):
return SentenceTransformer('all-MiniLM-L6-v2')
def initialize_system(self):
self.model = self.load_model()
self._load_faiss_shards()
def _load_faiss_shards(self):
"""Load all FAISS index shards"""
self.shard_sizes = []
for shard_path in sorted(self.shard_dir.glob("*.index")):
index = faiss.read_index(str(shard_path))
self.index_shards.append(index)
self.shard_sizes.append(index.ntotal)
def _global_index(self, shard_idx, local_idx):
"""Convert local index to global index"""
return sum(self.shard_sizes[:shard_idx]) + local_idx
def search(self, query, top_k=5):
"""Main search functionality"""
query_embedding = self.model.encode([query], convert_to_numpy=True)
all_distances = []
all_global_indices = []
# Search across all shards
for shard_idx, index in enumerate(self.index_shards):
distances, indices = index.search(query_embedding, top_k)
global_indices = [self._global_index(shard_idx, idx) for idx in indices[0]]
all_distances.extend(distances[0])
all_global_indices.extend(global_indices)
# Process and format results
results = self.metadata_mgr.get_metadata(all_global_indices)
results['similarity'] = 1 - (np.array(all_distances) / 2) # Convert L2 to cosine
return results.sort_values('similarity', ascending=False).head(top_k) |