Spaces:
Running
Running
File size: 7,117 Bytes
5ee0a10 7ccde22 017ee94 7ccde22 017ee94 b2bcde5 7ccde22 b2bcde5 7ccde22 017ee94 b2bcde5 7ccde22 017ee94 7ccde22 017ee94 7ccde22 017ee94 a223079 2dec497 a223079 7ccde22 2dec497 a223079 2dec497 a223079 7ccde22 a223079 7ccde22 a223079 7ccde22 a223079 7ccde22 5ee0a10 017ee94 5ee0a10 b2bcde5 017ee94 b2bcde5 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 017ee94 5ee0a10 a223079 017ee94 a223079 5ee0a10 a223079 b73a811 2dec497 b73a811 2dec497 b73a811 2dec497 b73a811 2dec497 b73a811 2dec497 b73a811 2dec497 b73a811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import numpy as np
import pandas as pd
import faiss
import zipfile
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
class MetadataManager:
def __init__(self):
self.shard_dir = Path("metadata_shards")
self.shard_map = {}
self.loaded_shards = {}
self.total_docs = 0
self._ensure_unzipped() # Removed Streamlit elements from here
self._build_shard_map()
def _ensure_unzipped(self):
"""Handle ZIP extraction without Streamlit elements"""
if not self.shard_dir.exists():
zip_path = Path("metadata_shards.zip")
if zip_path.exists():
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.shard_dir)
else:
raise FileNotFoundError("Metadata ZIP file not found")
def _build_shard_map(self):
"""Create index range to shard mapping"""
self.total_docs = 0
for f in sorted(self.shard_dir.glob("*.parquet")):
parts = f.stem.split("_")
start = int(parts[1])
end = int(parts[2])
self.shard_map[(start, end)] = f.name
self.total_docs = max(self.total_docs, end + 1)
def get_metadata(self, global_indices):
"""Retrieve metadata with validation"""
# Check for empty numpy array properly
if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
# Convert numpy array to list for processing
indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
# Filter valid indices
valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
if not valid_indices:
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
# Group indices by shard with boundary check
shard_groups = {}
for idx in valid_indices:
found = False
for (start, end), shard in self.shard_map.items():
if start <= idx <= end:
if shard not in shard_groups:
shard_groups[shard] = []
shard_groups[shard].append(idx - start)
found = True
break
if not found:
st.warning(f"Index {idx} out of shard range (0-{self.total_docs-1})")
# Load and process shards
results = []
for shard, local_indices in shard_groups.items():
try:
if shard not in self.loaded_shards:
self.loaded_shards[shard] = pd.read_parquet(
self.shard_dir / shard,
columns=["title", "summary", "source"]
)
if local_indices:
results.append(self.loaded_shards[shard].iloc[local_indices])
except Exception as e:
st.error(f"Error loading shard {shard}: {str(e)}")
continue
return pd.concat(results).reset_index(drop=True) if results else pd.DataFrame()
class SemanticSearch:
def __init__(self):
self.shard_dir = Path("compressed_shards")
self.model = None
self.index_shards = []
self.metadata_mgr = MetadataManager() # No Streamlit elements in constructor
self.shard_sizes = []
@st.cache_resource
def load_model(_self):
return SentenceTransformer('all-MiniLM-L6-v2')
def initialize_system(self):
self.model = self.load_model()
self._load_faiss_shards()
def _load_faiss_shards(self):
"""Load all FAISS index shards"""
self.shard_sizes = []
for shard_path in sorted(self.shard_dir.glob("*.index")):
index = faiss.read_index(str(shard_path))
self.index_shards.append(index)
self.shard_sizes.append(index.ntotal)
def _global_index(self, shard_idx, local_idx):
"""Convert local index to global index"""
return sum(self.shard_sizes[:shard_idx]) + local_idx
def search(self, query, top_k=5):
"""Search with validation"""
if not query or not self.index_shards:
return pd.DataFrame()
try:
query_embedding = self.model.encode([query], convert_to_numpy=True)
except Exception as e:
st.error(f"Query encoding failed: {str(e)}")
return pd.DataFrame()
all_distances = []
all_global_indices = []
# Search with index validation
for shard_idx, index in enumerate(self.index_shards):
if index.ntotal == 0:
continue
try:
distances, indices = index.search(query_embedding, top_k)
valid_indices = [idx for idx in indices[0] if 0 <= idx < index.ntotal]
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
all_distances.extend(distances[0][:len(valid_indices)])
all_global_indices.extend(global_indices)
except Exception as e:
st.error(f"Search failed in shard {shard_idx}: {str(e)}")
continue
# Ensure equal array lengths
min_length = min(len(all_distances), len(all_global_indices))
return self._process_results(
np.array(all_distances[:min_length]),
np.array(all_global_indices[:min_length]),
top_k
)
def _process_results(self, distances, global_indices, top_k):
"""Process raw search results into formatted DataFrame"""
# Proper numpy array emptiness checks
if global_indices.size == 0 or distances.size == 0:
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
try:
# Convert numpy indices to Python list for metadata retrieval
indices_list = global_indices.tolist()
# Get metadata for matched indices
results = self.metadata_mgr.get_metadata(indices_list)
# Ensure distances match results length
if len(results) != len(distances):
distances = distances[:len(results)]
# Calculate similarity scores
results['similarity'] = 1 - (distances / 2)
# Deduplicate and sort results
results = results.drop_duplicates(subset=["title", "source"])
.sort_values("similarity", ascending=False)
.head(top_k)
return results.reset_index(drop=True)
except Exception as e:
st.error(f"Result processing failed: {str(e)}")
return pd.DataFrame(columns=["title", "summary", "source", "similarity"]) |