Spaces:
Running
Running
Update search_utils.py
Browse files- search_utils.py +24 -8
search_utils.py
CHANGED
@@ -36,7 +36,6 @@ class MetadataManager:
|
|
36 |
self._ensure_directories()
|
37 |
self._unzip_if_needed()
|
38 |
self._build_shard_map()
|
39 |
-
self._init_url_resolver()
|
40 |
logger.info(f"Total documents indexed: {self.total_docs}")
|
41 |
logger.info(f"Total shards found: {len(self.shard_map)}")
|
42 |
|
@@ -145,12 +144,15 @@ class MetadataManager:
|
|
145 |
shard_path = self.shard_dir / shard
|
146 |
if not shard_path.exists():
|
147 |
logger.error(f"Shard file not found: {shard_path}")
|
148 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
|
|
149 |
file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
|
150 |
logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
|
|
|
151 |
try:
|
152 |
-
self.loaded_shards[shard] = pd.read_parquet(shard_path, columns=["title", "summary"])
|
153 |
logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
|
|
|
154 |
except Exception as e:
|
155 |
logger.error(f"Failed to read parquet file {shard}: {str(e)}")
|
156 |
try:
|
@@ -158,7 +160,7 @@ class MetadataManager:
|
|
158 |
logger.info(f"Parquet schema: {schema}")
|
159 |
except Exception:
|
160 |
pass
|
161 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
162 |
df = self.loaded_shards[shard]
|
163 |
df_len = len(df)
|
164 |
valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
|
@@ -168,9 +170,10 @@ class MetadataManager:
|
|
168 |
chunk = df.iloc[valid_local_indices]
|
169 |
logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
|
170 |
return chunk
|
|
|
171 |
except Exception as e:
|
172 |
logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
|
173 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
174 |
|
175 |
def get_metadata(self, global_indices):
|
176 |
"""Retrieve metadata for a batch of global indices using parallel shard processing."""
|
@@ -328,14 +331,17 @@ class SemanticSearch:
|
|
328 |
if index.ntotal == 0:
|
329 |
self.logger.warning(f"Skipping empty shard {shard_idx}")
|
330 |
return None
|
|
|
331 |
try:
|
332 |
shard_start = time.time()
|
333 |
distances, indices = index.search(query_embedding, top_k)
|
334 |
valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
|
335 |
valid_indices = indices[0][valid_mask].tolist()
|
336 |
valid_distances = distances[0][valid_mask].tolist()
|
|
|
337 |
if len(valid_indices) != top_k:
|
338 |
self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
|
|
|
339 |
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
|
340 |
self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
|
341 |
return valid_distances, global_indices
|
@@ -348,21 +354,23 @@ class SemanticSearch:
|
|
348 |
process_start = time.time()
|
349 |
if global_indices.size == 0 or distances.size == 0:
|
350 |
self.logger.warning("No search results to process")
|
351 |
-
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
|
352 |
try:
|
353 |
self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
|
354 |
metadata_start = time.time()
|
355 |
results = self.metadata_mgr.get_metadata(global_indices)
|
356 |
self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
|
|
|
357 |
if len(results) == 0:
|
358 |
self.logger.warning("No metadata found for indices")
|
359 |
-
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
|
360 |
if len(results) != len(distances):
|
361 |
self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
|
362 |
if len(results) < len(distances):
|
363 |
distances = distances[:len(results)]
|
364 |
else:
|
365 |
distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
|
|
|
366 |
self.logger.debug("Calculating similarity scores")
|
367 |
results['similarity'] = 1 - (distances / 2)
|
368 |
if not results.empty:
|
@@ -370,13 +378,21 @@ class SemanticSearch:
|
|
370 |
f"max={results['similarity'].max():.3f}, " +
|
371 |
f"mean={results['similarity'].mean():.3f}")
|
372 |
results['source'] = results["source"]
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
pre_dedup = len(results)
|
375 |
-
results = results.drop_duplicates(subset=["title", "source"]).sort_values("similarity", ascending=False).head(top_k)
|
|
|
376 |
post_dedup = len(results)
|
377 |
if pre_dedup > post_dedup:
|
378 |
self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
|
379 |
self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
|
|
|
380 |
return results.reset_index(drop=True)
|
381 |
except Exception as e:
|
382 |
self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
|
|
|
36 |
self._ensure_directories()
|
37 |
self._unzip_if_needed()
|
38 |
self._build_shard_map()
|
|
|
39 |
logger.info(f"Total documents indexed: {self.total_docs}")
|
40 |
logger.info(f"Total shards found: {len(self.shard_map)}")
|
41 |
|
|
|
144 |
shard_path = self.shard_dir / shard
|
145 |
if not shard_path.exists():
|
146 |
logger.error(f"Shard file not found: {shard_path}")
|
147 |
+
return pd.DataFrame(columns=["title", "summary", "similarity","authors", "source"])
|
148 |
+
|
149 |
file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
|
150 |
logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
|
151 |
+
|
152 |
try:
|
153 |
+
self.loaded_shards[shard] = pd.read_parquet(shard_path, columns=["title", "summary", "source", "authors"])
|
154 |
logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
|
155 |
+
|
156 |
except Exception as e:
|
157 |
logger.error(f"Failed to read parquet file {shard}: {str(e)}")
|
158 |
try:
|
|
|
160 |
logger.info(f"Parquet schema: {schema}")
|
161 |
except Exception:
|
162 |
pass
|
163 |
+
return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
|
164 |
df = self.loaded_shards[shard]
|
165 |
df_len = len(df)
|
166 |
valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
|
|
|
170 |
chunk = df.iloc[valid_local_indices]
|
171 |
logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
|
172 |
return chunk
|
173 |
+
|
174 |
except Exception as e:
|
175 |
logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
|
176 |
+
return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
|
177 |
|
178 |
def get_metadata(self, global_indices):
|
179 |
"""Retrieve metadata for a batch of global indices using parallel shard processing."""
|
|
|
331 |
if index.ntotal == 0:
|
332 |
self.logger.warning(f"Skipping empty shard {shard_idx}")
|
333 |
return None
|
334 |
+
|
335 |
try:
|
336 |
shard_start = time.time()
|
337 |
distances, indices = index.search(query_embedding, top_k)
|
338 |
valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
|
339 |
valid_indices = indices[0][valid_mask].tolist()
|
340 |
valid_distances = distances[0][valid_mask].tolist()
|
341 |
+
|
342 |
if len(valid_indices) != top_k:
|
343 |
self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
|
344 |
+
|
345 |
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
|
346 |
self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
|
347 |
return valid_distances, global_indices
|
|
|
354 |
process_start = time.time()
|
355 |
if global_indices.size == 0 or distances.size == 0:
|
356 |
self.logger.warning("No search results to process")
|
357 |
+
return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
|
358 |
try:
|
359 |
self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
|
360 |
metadata_start = time.time()
|
361 |
results = self.metadata_mgr.get_metadata(global_indices)
|
362 |
self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
|
363 |
+
|
364 |
if len(results) == 0:
|
365 |
self.logger.warning("No metadata found for indices")
|
366 |
+
return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
|
367 |
if len(results) != len(distances):
|
368 |
self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
|
369 |
if len(results) < len(distances):
|
370 |
distances = distances[:len(results)]
|
371 |
else:
|
372 |
distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
|
373 |
+
|
374 |
self.logger.debug("Calculating similarity scores")
|
375 |
results['similarity'] = 1 - (distances / 2)
|
376 |
if not results.empty:
|
|
|
378 |
f"max={results['similarity'].max():.3f}, " +
|
379 |
f"mean={results['similarity'].mean():.3f}")
|
380 |
results['source'] = results["source"]
|
381 |
+
|
382 |
+
# Ensure we have all required columns
|
383 |
+
required_columns = ["title", "summary", "authors", "source", "similarity"]
|
384 |
+
for col in required_columns:
|
385 |
+
if col not in results.columns:
|
386 |
+
results[col] = None # Fill missing columns with None
|
387 |
|
388 |
pre_dedup = len(results)
|
389 |
+
results = results.drop_duplicates(subset=["title","authors", "source"]).sort_values("similarity", ascending=False).head(top_k)
|
390 |
+
|
391 |
post_dedup = len(results)
|
392 |
if pre_dedup > post_dedup:
|
393 |
self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
|
394 |
self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
|
395 |
+
|
396 |
return results.reset_index(drop=True)
|
397 |
except Exception as e:
|
398 |
self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
|