Testys commited on
Commit
0880e2f
·
1 Parent(s): 2906fee

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +24 -8
search_utils.py CHANGED
@@ -36,7 +36,6 @@ class MetadataManager:
36
  self._ensure_directories()
37
  self._unzip_if_needed()
38
  self._build_shard_map()
39
- self._init_url_resolver()
40
  logger.info(f"Total documents indexed: {self.total_docs}")
41
  logger.info(f"Total shards found: {len(self.shard_map)}")
42
 
@@ -145,12 +144,15 @@ class MetadataManager:
145
  shard_path = self.shard_dir / shard
146
  if not shard_path.exists():
147
  logger.error(f"Shard file not found: {shard_path}")
148
- return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
 
149
  file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
150
  logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
 
151
  try:
152
- self.loaded_shards[shard] = pd.read_parquet(shard_path, columns=["title", "summary"])
153
  logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
 
154
  except Exception as e:
155
  logger.error(f"Failed to read parquet file {shard}: {str(e)}")
156
  try:
@@ -158,7 +160,7 @@ class MetadataManager:
158
  logger.info(f"Parquet schema: {schema}")
159
  except Exception:
160
  pass
161
- return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
162
  df = self.loaded_shards[shard]
163
  df_len = len(df)
164
  valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
@@ -168,9 +170,10 @@ class MetadataManager:
168
  chunk = df.iloc[valid_local_indices]
169
  logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
170
  return chunk
 
171
  except Exception as e:
172
  logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
173
- return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
174
 
175
  def get_metadata(self, global_indices):
176
  """Retrieve metadata for a batch of global indices using parallel shard processing."""
@@ -328,14 +331,17 @@ class SemanticSearch:
328
  if index.ntotal == 0:
329
  self.logger.warning(f"Skipping empty shard {shard_idx}")
330
  return None
 
331
  try:
332
  shard_start = time.time()
333
  distances, indices = index.search(query_embedding, top_k)
334
  valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
335
  valid_indices = indices[0][valid_mask].tolist()
336
  valid_distances = distances[0][valid_mask].tolist()
 
337
  if len(valid_indices) != top_k:
338
  self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
 
339
  global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
340
  self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
341
  return valid_distances, global_indices
@@ -348,21 +354,23 @@ class SemanticSearch:
348
  process_start = time.time()
349
  if global_indices.size == 0 or distances.size == 0:
350
  self.logger.warning("No search results to process")
351
- return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
352
  try:
353
  self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
354
  metadata_start = time.time()
355
  results = self.metadata_mgr.get_metadata(global_indices)
356
  self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
 
357
  if len(results) == 0:
358
  self.logger.warning("No metadata found for indices")
359
- return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
360
  if len(results) != len(distances):
361
  self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
362
  if len(results) < len(distances):
363
  distances = distances[:len(results)]
364
  else:
365
  distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
 
366
  self.logger.debug("Calculating similarity scores")
367
  results['similarity'] = 1 - (distances / 2)
368
  if not results.empty:
@@ -370,13 +378,21 @@ class SemanticSearch:
370
  f"max={results['similarity'].max():.3f}, " +
371
  f"mean={results['similarity'].mean():.3f}")
372
  results['source'] = results["source"]
 
 
 
 
 
 
373
 
374
  pre_dedup = len(results)
375
- results = results.drop_duplicates(subset=["title", "source"]).sort_values("similarity", ascending=False).head(top_k)
 
376
  post_dedup = len(results)
377
  if pre_dedup > post_dedup:
378
  self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
379
  self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
 
380
  return results.reset_index(drop=True)
381
  except Exception as e:
382
  self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
 
36
  self._ensure_directories()
37
  self._unzip_if_needed()
38
  self._build_shard_map()
 
39
  logger.info(f"Total documents indexed: {self.total_docs}")
40
  logger.info(f"Total shards found: {len(self.shard_map)}")
41
 
 
144
  shard_path = self.shard_dir / shard
145
  if not shard_path.exists():
146
  logger.error(f"Shard file not found: {shard_path}")
147
+ return pd.DataFrame(columns=["title", "summary", "similarity","authors", "source"])
148
+
149
  file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
150
  logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
151
+
152
  try:
153
+ self.loaded_shards[shard] = pd.read_parquet(shard_path, columns=["title", "summary", "source", "authors"])
154
  logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
155
+
156
  except Exception as e:
157
  logger.error(f"Failed to read parquet file {shard}: {str(e)}")
158
  try:
 
160
  logger.info(f"Parquet schema: {schema}")
161
  except Exception:
162
  pass
163
+ return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
164
  df = self.loaded_shards[shard]
165
  df_len = len(df)
166
  valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
 
170
  chunk = df.iloc[valid_local_indices]
171
  logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
172
  return chunk
173
+
174
  except Exception as e:
175
  logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
176
+ return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
177
 
178
  def get_metadata(self, global_indices):
179
  """Retrieve metadata for a batch of global indices using parallel shard processing."""
 
331
  if index.ntotal == 0:
332
  self.logger.warning(f"Skipping empty shard {shard_idx}")
333
  return None
334
+
335
  try:
336
  shard_start = time.time()
337
  distances, indices = index.search(query_embedding, top_k)
338
  valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
339
  valid_indices = indices[0][valid_mask].tolist()
340
  valid_distances = distances[0][valid_mask].tolist()
341
+
342
  if len(valid_indices) != top_k:
343
  self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
344
+
345
  global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
346
  self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
347
  return valid_distances, global_indices
 
354
  process_start = time.time()
355
  if global_indices.size == 0 or distances.size == 0:
356
  self.logger.warning("No search results to process")
357
+ return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
358
  try:
359
  self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
360
  metadata_start = time.time()
361
  results = self.metadata_mgr.get_metadata(global_indices)
362
  self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
363
+
364
  if len(results) == 0:
365
  self.logger.warning("No metadata found for indices")
366
+ return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
367
  if len(results) != len(distances):
368
  self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
369
  if len(results) < len(distances):
370
  distances = distances[:len(results)]
371
  else:
372
  distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
373
+
374
  self.logger.debug("Calculating similarity scores")
375
  results['similarity'] = 1 - (distances / 2)
376
  if not results.empty:
 
378
  f"max={results['similarity'].max():.3f}, " +
379
  f"mean={results['similarity'].mean():.3f}")
380
  results['source'] = results["source"]
381
+
382
+ # Ensure we have all required columns
383
+ required_columns = ["title", "summary", "authors", "source", "similarity"]
384
+ for col in required_columns:
385
+ if col not in results.columns:
386
+ results[col] = None # Fill missing columns with None
387
 
388
  pre_dedup = len(results)
389
+ results = results.drop_duplicates(subset=["title","authors", "source"]).sort_values("similarity", ascending=False).head(top_k)
390
+
391
  post_dedup = len(results)
392
  if pre_dedup > post_dedup:
393
  self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
394
  self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
395
+
396
  return results.reset_index(drop=True)
397
  except Exception as e:
398
  self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)