Spaces:
Running
Running
File size: 6,377 Bytes
5ee0a10 7ccde22 5ee0a10 7ccde22 5ee0a10 7ccde22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import numpy as np
import pandas as pd
import faiss
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
import zipfile
import pandas as pd
from pathlib import Path
import streamlit as st
class MetadataManager:
def __init__(self):
self.shard_dir = Path("metadata_shards")
self.shard_map = {}
self.loaded_shards = {}
self._ensure_unzipped()
self._build_shard_map()
def _ensure_unzipped(self):
"""Extract metadata shards from zip if needed"""
if not self.shard_dir.exists():
zip_path = Path("metadata_shards.zip")
if zip_path.exists():
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.shard_dir)
st.toast("β
Successfully extracted metadata shards!", icon="π¦")
else:
raise FileNotFoundError("No metadata shards found!")
def _build_shard_map(self):
"""Map index ranges to shard files"""
for f in self.shard_dir.glob("*.parquet"):
parts = f.stem.split("_")
self.shard_map[(int(parts[1]), int(parts[2]))] = f.name
def get_metadata(self, indices):
"""Retrieve metadata for specific indices"""
results = []
shard_groups = {}
# Group indices by shard
for idx in indices:
for (start, end), shard in self.shard_map.items():
if start <= idx <= end:
if shard not in shard_groups:
shard_groups[shard] = []
shard_groups[shard].append(idx - start)
break
# Load required shards
for shard, local_indices in shard_groups.items():
if shard not in self.loaded_shards:
self.loaded_shards[shard] = pd.read_parquet(
self.shard_dir / shard,
columns=["title", "summary", "source"]
)
results.append(self.loaded_shards[shard].iloc[local_indices])
return pd.concat(results).reset_index(drop=True)
class SemanticSearch:
def __init__(self, shard_dir="compressed_shards"):
self.shard_dir = Path(shard_dir)
self.shard_dir.mkdir(exist_ok=True, parents=True)
self.model = None
self.index_shards = []
self.metadata_mgr = MetadataManager()
@st.cache_resource
def load_model(_self):
return SentenceTransformer('all-MiniLM-L6-v2')
def initialize_system(self):
self.model = self.load_model()
self._load_index_shards()
def _load_index_shards(self):
"""Load FAISS shards directly from local directory"""
for shard_path in sorted(self.shard_dir.glob("*.index")):
self.index_shards.append(faiss.read_index(str(shard_path)))
def search(self, query, top_k=5):
"""Search across all shards"""
query_embedding = self.model.encode([query], convert_to_numpy=True)
all_scores = []
all_indices = []
for shard_idx, index in enumerate(self.index_shards):
distances, indices = index.search(query_embedding, top_k)
# Convert local indices to global shard offsets
global_indices = [
self._calculate_global_index(shard_idx, idx)
for idx in indices[0]
]
all_scores.extend(distances[0])
all_indices.extend(global_indices)
return self._process_results(np.array(all_scores), np.array(all_indices), top_k)
def _calculate_global_index(self, shard_idx, local_idx):
"""Convert shard-local index to global index"""
# Implement your specific shard indexing logic here
# Example: return f"{shard_idx}-{local_idx}"
return local_idx # Simple version if using unique IDs
def _process_results(self, distances, indices, top_k):
"""Format search results"""
results = pd.DataFrame({
'global_index': indices,
'similarity': 1 - (distances / 2) # L2 to cosine approximation
})
return results.sort_values('similarity', ascending=False).head(top_k)
def search_with_threshold(self, query, top_k=5, similarity_threshold=0.6):
"""Threshold-filtered search"""
results = self.search(query, top_k*2)
filtered = results[results['similarity'] > similarity_threshold].head(top_k)
return filtered.reset_index(drop=True)
class MetadataManager:
def __init__(self, repo_id, shard_dir="metadata_shards"):
self.repo_id = repo_id
self.shard_dir = Path(shard_dir)
self.shard_map = {}
self.loaded_shards = {}
self._build_shard_map()
def _build_shard_map(self):
"""Map index ranges to shard files"""
for f in self.shard_dir.glob("*.parquet"):
parts = f.stem.split("_")
self.shard_map[(int(parts[1]), int(parts[2]))] = f.name
def _download_shard(self, shard_name):
"""Download missing shards on demand"""
if not (self.shard_dir/shard_name).exists():
hf_hub_download(
repo_id=self.repo_id,
filename=f"metadata_shards/{shard_name}",
local_dir=self.shard_dir,
cache_dir="metadata_cache"
)
def get_metadata(self, indices):
"""Retrieve metadata for specific indices"""
results = []
# Group indices by shard
shard_groups = {}
for idx in indices:
for (start, end), shard in self.shard_map.items():
if start <= idx <= end:
if shard not in shard_groups:
shard_groups[shard] = []
shard_groups[shard].append(idx - start)
break
# Process each required shard
for shard, local_indices in shard_groups.items():
if shard not in self.loaded_shards:
self._download_shard(shard)
self.loaded_shards[shard] = pd.read_parquet(self.shard_dir/shard)
results.append(self.loaded_shards[shard].iloc[local_indices])
return pd.concat(results).reset_index(drop=True) |