Spaces:
Running
Running
File size: 13,150 Bytes
5ee0a10 eb5ebce 5ee0a10 7ccde22 dd6b309 7ccde22 eb5ebce dd6b309 eb5ebce 265c29d eb5ebce 801d9f2 dd6b309 eb5ebce 7ccde22 eb5ebce 7ccde22 74dd725 eb5ebce 74dd725 eb5ebce 74dd725 eb5ebce 74dd725 4e58098 74dd725 4e58098 74dd725 dd2e007 4e58098 74dd725 70d8022 74dd725 eb5ebce 74dd725 eb5ebce ff6741a eb5ebce 74dd725 eb5ebce 74dd725 ff6741a 74dd725 a138102 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 3c95d1f 29bdbcf eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce bc90a96 eb5ebce ff6741a bc90a96 ff6741a eb5ebce ff6741a eb5ebce 0880e2f eb5ebce 0880e2f ff6741a eb5ebce 0880e2f eb5ebce ff6741a eb5ebce ff6741a eb5ebce 3c95d1f bc90a96 ff6741a bc90a96 ff6741a bc90a96 ff6741a bc90a96 8b5c2a1 bc90a96 70d8022 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
import numpy as np
import pandas as pd
import faiss
import zipfile
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
import time
import os
from urllib.parse import quote
import requests
import shutil
import concurrent.futures
from functools import lru_cache
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("MetadataManager")
class MetadataManager:
def __init__(self):
self.metadata_path = Path("combined.parquet")
self.df = None
self.total_docs = 0
logger.info("Initializing MetadataManager")
self._load_metadata()
logger.info(f"Total documents indexed: {self.total_docs}")
def _load_metadata(self):
"""Load the combined parquet file directly"""
logger.info("Loading metadata from combined.parquet")
try:
# Load the parquet file
self.df = pd.read_parquet(self.metadata_path)
# Clean and format the data
self.df['source'] = self.df['source'].apply(
lambda x: [
url.strip()
for url in str(x).split(';')
if url.strip() and url.startswith('http')
]
)
self.total_docs = len(self.df)
logger.info(f"Successfully loaded {self.total_docs} documents")
except Exception as e:
logger.error(f"Failed to load metadata: {str(e)}")
raise
def get_metadata(self, global_indices):
"""Retrieve metadata for given indices with deduplication by title"""
if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
return pd.DataFrame(columns=["title", "summary", 'authors', "similarity", "source"])
try:
# Directly index the DataFrame
results = self.df.iloc[global_indices].copy()
# Deduplicate by title to avoid near-duplicate results
if len(results) > 1:
results = results.drop_duplicates(subset=["title"])
return results
except Exception as e:
logger.error(f"Metadata retrieval failed: {str(e)}")
return pd.DataFrame(columns=["title", "summary", "similarity", "source", 'authors'])
class SemanticSearch:
def __init__(self):
self.shard_dir = Path("compressed_shards")
self.model = None
self.index_shards = []
self.metadata_mgr = MetadataManager()
self.shard_sizes = []
# Configure search logger
self.logger = logging.getLogger("SemanticSearch")
self.logger.info("Initializing SemanticSearch")
@st.cache_resource
def load_model(_self):
return SentenceTransformer('all-MiniLM-L6-v2')
def initialize_system(self):
self.logger.info("Loading sentence transformer model")
start_time = time.time()
self.model = self.load_model()
self.logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
self.logger.info("Loading FAISS indices")
self._load_faiss_shards()
def _load_faiss_shards(self):
"""Load all FAISS index shards"""
self.logger.info(f"Searching for index files in {self.shard_dir}")
if not self.shard_dir.exists():
self.logger.error(f"Shard directory not found: {self.shard_dir}")
return
index_files = list(self.shard_dir.glob("*.index"))
self.logger.info(f"Found {len(index_files)} index files")
self.shard_sizes = []
self.index_shards = []
for shard_path in sorted(index_files):
try:
self.logger.info(f"Loading index: {shard_path}")
start_time = time.time()
# Log file size
file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
self.logger.info(f"Index file size: {file_size_mb:.2f} MB")
index = faiss.read_index(str(shard_path))
self.index_shards.append(index)
self.shard_sizes.append(index.ntotal)
self.logger.info(f"Loaded index with {index.ntotal} vectors in {time.time() - start_time:.2f} seconds")
except Exception as e:
self.logger.error(f"Failed to load index {shard_path}: {str(e)}")
self.total_vectors = sum(self.shard_sizes)
self.logger.info(f"Total loaded vectors: {self.total_vectors} across {len(self.index_shards)} shards")
def _global_index(self, shard_idx, local_idx):
"""Convert local index to global index"""
return sum(self.shard_sizes[:shard_idx]) + local_idx
def search(self, query, top_k=5):
"""Search with validation"""
self.logger.info(f"Searching for query: '{query}' (top_k={top_k})")
start_time = time.time()
if not query:
self.logger.warning("Empty query provided")
return pd.DataFrame()
if not self.index_shards:
self.logger.error("No index shards loaded")
return pd.DataFrame()
try:
self.logger.info("Encoding query")
query_embedding = self.model.encode([query], convert_to_numpy=True)
self.logger.debug(f"Query encoded to shape {query_embedding.shape}")
except Exception as e:
self.logger.error(f"Query encoding failed: {str(e)}")
return pd.DataFrame()
all_distances = []
all_global_indices = []
# Search with index validation
self.logger.info(f"Searching across {len(self.index_shards)} shards")
for shard_idx, index in enumerate(self.index_shards):
if index.ntotal == 0:
self.logger.warning(f"Skipping empty shard {shard_idx}")
continue
try:
shard_start = time.time()
distances, indices = index.search(query_embedding, top_k)
valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
valid_indices = indices[0][valid_mask].tolist()
valid_distances = distances[0][valid_mask].tolist()
if len(valid_indices) != top_k:
self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
all_distances.extend(valid_distances)
all_global_indices.extend(global_indices)
self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
except Exception as e:
self.logger.error(f"Search failed in shard {shard_idx}: {str(e)}")
continue
self.logger.info(f"Search found {len(all_global_indices)} results across all shards")
# Process results
results = self._process_results(
np.array(all_distances),
np.array(all_global_indices),
top_k
)
self.logger.info(f"Search completed in {time.time() - start_time:.2f} seconds with {len(results)} final results")
return results
def _search_shard(self, shard_idx, index, query_embedding, top_k):
"""Search a single FAISS shard for the query embedding with proper error handling."""
if index.ntotal == 0:
self.logger.warning(f"Skipping empty shard {shard_idx}")
return None
try:
shard_start = time.time()
distances, indices = index.search(query_embedding, top_k)
# Filter out invalid indices (-1 is returned by FAISS for insufficient results)
valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
valid_indices = indices[0][valid_mask]
valid_distances = distances[0][valid_mask]
if len(valid_indices) == 0:
self.logger.debug(f"Shard {shard_idx}: No valid results found")
return None
if len(valid_indices) != top_k:
self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
# Filter out any invalid global indices (could happen if _global_index validation fails)
valid_global = [(d, i) for d, i in zip(valid_distances, global_indices) if i >= 0]
if not valid_global:
return None
final_distances, final_indices = zip(*valid_global)
self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
return final_distances, final_indices
except Exception as e:
self.logger.error(f"Search failed in shard {shard_idx}: {str(e)}")
return None
def _process_results(self, distances, global_indices, top_k):
"""Process raw search results into formatted DataFrame"""
process_start = time.time()
# Proper numpy array emptiness checks
if global_indices.size == 0 or distances.size == 0:
self.logger.warning("No search results to process")
return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
try:
# Get metadata for matched indices
self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
metadata_start = time.time()
results = self.metadata_mgr.get_metadata(global_indices)
self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
# Empty results check
if len(results) == 0:
self.logger.warning("No metadata found for indices")
return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
# Ensure distances match results length
if len(results) != len(distances):
self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
if len(results) < len(distances):
self.logger.info("Truncating distances array to match results length")
distances = distances[:len(results)]
else:
# Should not happen but handle it anyway
self.logger.error("More results than distances - this shouldn't happen")
distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
# Calculate similarity scores
self.logger.debug("Calculating similarity scores")
results['similarity'] = 1 - (distances / 2)
# Log similarity statistics
if not results.empty:
self.logger.debug(f"Similarity stats: min={results['similarity'].min():.3f}, " +
f"max={results['similarity'].max():.3f}, " +
f"mean={results['similarity'].mean():.3f}")
# Deduplicate and sort results
pre_dedup = len(results)
results = results.drop_duplicates(subset=["title"]).sort_values("similarity", ascending=False).head(top_k)
post_dedup = len(results)
if pre_dedup > post_dedup:
self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
return results.reset_index(drop=True)
# Add URL resolution for final results only
final_results = results.sort_values("similarity", ascending=False).head(top_k)
# Resolve URLs for top results only
# final_results['source'] =
# Deduplicate based on title only
final_results = final_results.drop_duplicates(subset=["title"]).head(top_k)
return final_results.reset_index(drop=True)
except Exception as e:
self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
return pd.DataFrame(columns=["title", "summary", "similarity", 'authors']) |