Testys commited on
Commit
eb5ebce
Β·
1 Parent(s): 7f04a94

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +412 -100
search_utils.py CHANGED
@@ -1,44 +1,225 @@
1
  import numpy as np
 
2
  import faiss
3
  import zipfile
4
  import logging
5
  from pathlib import Path
6
- from sentence_transformers import SentenceTransformer
7
- import concurrent.futures
 
8
  import os
 
9
  import requests
 
 
 
10
  from functools import lru_cache
11
- from typing import List, Dict
12
 
13
  # Configure logging
14
- logging.basicConfig(level=logging.WARNING,
15
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16
- logger = logging.getLogger("OptimizedSearch")
 
 
 
17
 
18
- class OptimizedMetadataManager:
19
  def __init__(self):
20
- self._init_metadata()
 
 
 
 
 
 
 
 
 
 
21
  self._init_url_resolver()
22
-
23
- def _init_metadata(self):
24
- """Memory-mapped metadata loading"""
25
- self.metadata_dir = Path("unzipped_cache/metadata_shards")
26
- self.metadata = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Preload all metadata into memory
29
- for parquet_file in self.metadata_dir.glob("*.parquet"):
30
- df = pd.read_parquet(parquet_file, columns=["title", "summary"])
31
- self.metadata.update(df.to_dict(orient="index"))
32
-
33
- self.total_docs = len(self.metadata)
34
- logger.info(f"Loaded {self.total_docs} metadata entries into memory")
35
-
36
- def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
37
- """Batch retrieval of metadata"""
38
- return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def _init_url_resolver(self):
41
- """Initialize API session and cache"""
42
  self.session = requests.Session()
43
  adapter = requests.adapters.HTTPAdapter(
44
  pool_connections=10,
@@ -47,43 +228,45 @@ class OptimizedMetadataManager:
47
  )
48
  self.session.mount("https://", adapter)
49
 
50
- @lru_cache(maxsize=10_000)
51
  def resolve_url(self, title: str) -> str:
52
- """Optimized URL resolution with fail-fast"""
53
- try:
54
- # Try arXiv first
55
- arxiv_url = self._get_arxiv_url(title)
56
- if arxiv_url: return arxiv_url
57
-
58
- # Fallback to Semantic Scholar
59
- semantic_url = self._get_semantic_url(title)
60
- if semantic_url: return semantic_url
61
-
62
- except Exception as e:
63
- logger.warning(f"URL resolution failed: {str(e)}")
64
-
65
- return f"https://scholar.google.com/scholar?q={quote(title)}"
66
-
 
 
67
  def _get_arxiv_url(self, title: str) -> str:
68
- """Fast arXiv lookup with timeout"""
69
  with self.session.get(
70
  "http://export.arxiv.org/api/query",
71
- params={"search_query": f'ti:"{title}"', "max_results": 1},
72
  timeout=2
73
  ) as response:
74
  if response.ok:
75
  return self._parse_arxiv_response(response.text)
76
  return ""
77
-
78
  def _parse_arxiv_response(self, xml: str) -> str:
79
- """Fast XML parsing using string operations"""
80
- if "<entry>" not in xml: return ""
 
81
  start = xml.find("<id>") + 4
82
  end = xml.find("</id>", start)
83
  return xml[start:end].replace("http:", "https:") if start > 3 else ""
84
-
85
  def _get_semantic_url(self, title: str) -> str:
86
- """Batch-friendly Semantic Scholar lookup"""
87
  with self.session.get(
88
  "https://api.semanticscholar.org/graph/v1/paper/search",
89
  params={"query": title[:200], "limit": 1},
@@ -94,61 +277,190 @@ class OptimizedMetadataManager:
94
  if data.get("data"):
95
  return data["data"][0].get("url", "")
96
  return ""
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- class OptimizedSemanticSearch:
99
- def __init__(self):
100
- self.model = SentenceTransformer('all-MiniLM-L6-v2')
101
- self._load_faiss_indexes()
102
- self.metadata_mgr = OptimizedMetadataManager()
103
-
104
- def _load_faiss_indexes(self):
105
- """Load indexes with memory mapping"""
106
- self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
107
- logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
108
 
109
- def search(self, query: str, top_k: int = 5) -> List[Dict]:
110
- """Optimized search pipeline"""
111
- # Batch encode query
112
- query_embedding = self.model.encode([query], convert_to_numpy=True)
113
-
114
- # FAISS search
115
- distances, indices = self.index.search(query_embedding, top_k*2) # Search extra for dedup
116
-
117
- # Batch metadata retrieval
118
- results = self.metadata_mgr.get_metadata_batch(indices[0])
119
-
120
- # Process results
121
- return self._process_results(results, distances[0], top_k)
122
-
123
- def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
124
- """Parallel result processing"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  with concurrent.futures.ThreadPoolExecutor() as executor:
126
- # Parallel URL resolution
127
- futures = {
128
- executor.submit(
129
- self.metadata_mgr.resolve_url,
130
- res["title"]
131
- ): idx for idx, res in enumerate(results)
132
  }
133
-
134
- # Update results as URLs resolve
135
- for future in concurrent.futures.as_completed(futures):
136
- idx = futures[future]
137
  try:
138
- results[idx]["source"] = future.result()
 
 
 
 
139
  except Exception as e:
140
- results[idx]["source"] = ""
141
-
142
- # Add similarity scores
143
- for idx, dist in enumerate(distances[:len(results)]):
144
- results[idx]["similarity"] = 1 - (dist / 2)
145
-
146
- # Deduplicate and sort
147
- seen = set()
148
- final_results = []
149
- for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
150
- if res["title"] not in seen and len(final_results) < top_k:
151
- seen.add(res["title"])
152
- final_results.append(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- return final_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ import pandas as pd
3
  import faiss
4
  import zipfile
5
  import logging
6
  from pathlib import Path
7
+ from sentence_transformers import SentenceTransformer, util
8
+ import streamlit as st
9
+ import time
10
  import os
11
+ from urllib.parse import quote
12
  import requests
13
+ import shutil
14
+ import concurrent.futures
15
+ # Optional: Uncomment if you want to use lru_cache for instance methods
16
  from functools import lru_cache
 
17
 
18
  # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
22
+ handlers=[logging.StreamHandler()]
23
+ )
24
+ logger = logging.getLogger("MetadataManager")
25
 
26
+ class MetadataManager:
27
  def __init__(self):
28
+ self.cache_dir = Path("unzipped_cache")
29
+ self.shard_dir = self.cache_dir / "metadata_shards"
30
+ self.shard_map = {}
31
+ self.loaded_shards = {}
32
+ self.total_docs = 0
33
+ self.api_cache = {}
34
+
35
+ logger.info("Initializing MetadataManager")
36
+ self._ensure_directories()
37
+ self._unzip_if_needed()
38
+ self._build_shard_map()
39
  self._init_url_resolver()
40
+ logger.info(f"Total documents indexed: {self.total_docs}")
41
+ logger.info(f"Total shards found: {len(self.shard_map)}")
42
+
43
+ def _ensure_directories(self):
44
+ """Create necessary directories if they don't exist."""
45
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
46
+ self.shard_dir.mkdir(parents=True, exist_ok=True)
47
+
48
+ def _unzip_if_needed(self):
49
+ """Extract the ZIP archive if no parquet files are found."""
50
+ zip_path = Path("metadata_shards.zip")
51
+ if not any(self.shard_dir.rglob("*.parquet")):
52
+ logger.info("No parquet files found, checking for zip archive")
53
+ if not zip_path.exists():
54
+ raise FileNotFoundError(f"Metadata ZIP file not found at {zip_path}")
55
+ logger.info(f"Extracting {zip_path} to {self.shard_dir}")
56
+ try:
57
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
58
+ zip_root = self._get_zip_root(zip_ref)
59
+ zip_ref.extractall(self.shard_dir)
60
+ if zip_root:
61
+ nested_dir = self.shard_dir / zip_root
62
+ if nested_dir.exists():
63
+ self._flatten_directory(nested_dir, self.shard_dir)
64
+ nested_dir.rmdir()
65
+ parquet_files = list(self.shard_dir.rglob("*.parquet"))
66
+ if not parquet_files:
67
+ raise RuntimeError("Extraction completed but no parquet files found")
68
+ logger.info(f"Found {len(parquet_files)} parquet files after extraction")
69
+ except Exception as e:
70
+ logger.error(f"Failed to extract zip file: {str(e)}")
71
+ self._clean_failed_extraction()
72
+ raise
73
+
74
+ def _get_zip_root(self, zip_ref):
75
+ """Identify the common root directory within the ZIP file."""
76
+ try:
77
+ first_file = zip_ref.namelist()[0]
78
+ if '/' in first_file:
79
+ return first_file.split('/')[0]
80
+ return ""
81
+ except Exception as e:
82
+ logger.warning(f"Error detecting zip root: {str(e)}")
83
+ return ""
84
+
85
+ def _flatten_directory(self, src_dir, dest_dir):
86
+ """Move files from a nested directory up to the destination."""
87
+ for item in src_dir.iterdir():
88
+ if item.is_dir():
89
+ self._flatten_directory(item, dest_dir)
90
+ item.rmdir()
91
+ else:
92
+ target = dest_dir / item.name
93
+ if target.exists():
94
+ target.unlink()
95
+ item.rename(target)
96
+
97
+ def _clean_failed_extraction(self):
98
+ """Clean up files from a failed extraction attempt."""
99
+ logger.info("Cleaning up failed extraction")
100
+ for item in self.shard_dir.iterdir():
101
+ if item.is_dir():
102
+ shutil.rmtree(item)
103
+ else:
104
+ item.unlink()
105
+
106
+ def _build_shard_map(self):
107
+ """Build a map from global index ranges to shard filenames."""
108
+ logger.info("Building shard map from parquet files")
109
+ parquet_files = list(self.shard_dir.glob("*.parquet"))
110
+ if not parquet_files:
111
+ raise FileNotFoundError("No parquet files found after extraction")
112
+ parquet_files = sorted(parquet_files, key=lambda x: int(x.stem.split("_")[1]))
113
+ expected_start = 0
114
+ for f in parquet_files:
115
+ try:
116
+ parts = f.stem.split("_")
117
+ if len(parts) != 3:
118
+ raise ValueError("Invalid filename format")
119
+ start = int(parts[1])
120
+ end = int(parts[2])
121
+ if start != expected_start:
122
+ raise ValueError(f"Non-contiguous shard start: expected {expected_start}, got {start}")
123
+ if end <= start:
124
+ raise ValueError(f"Invalid shard range: {start}-{end}")
125
+ self.shard_map[(start, end)] = f.name
126
+ self.total_docs = end + 1
127
+ expected_start = end + 1
128
+ logger.debug(f"Mapped shard {f.name}: indices {start}-{end}")
129
+ except Exception as e:
130
+ logger.error(f"Error processing shard {f.name}: {str(e)}")
131
+ raise RuntimeError("Invalid shard structure") from e
132
+ logger.info(f"Validated {len(self.shard_map)} continuous shards")
133
+ logger.info(f"Total document count: {self.total_docs}")
134
+ sorted_ranges = sorted(self.shard_map.keys())
135
+ for i in range(1, len(sorted_ranges)):
136
+ prev_end = sorted_ranges[i-1][1]
137
+ curr_start = sorted_ranges[i][0]
138
+ if curr_start != prev_end + 1:
139
+ logger.warning(f"Gap or overlap detected between shards: {prev_end} to {curr_start}")
140
+
141
+ def _process_shard(self, shard, local_indices):
142
+ """Load a shard (if not already loaded) and retrieve the specified rows."""
143
+ try:
144
+ if shard not in self.loaded_shards:
145
+ shard_path = self.shard_dir / shard
146
+ if not shard_path.exists():
147
+ logger.error(f"Shard file not found: {shard_path}")
148
+ return pd.DataFrame(columns=["title", "summary", "similarity"])
149
+ file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
150
+ logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
151
+ try:
152
+ self.loaded_shards[shard] = pd.read_parquet(shard_path, columns=["title", "summary"])
153
+ logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
154
+ except Exception as e:
155
+ logger.error(f"Failed to read parquet file {shard}: {str(e)}")
156
+ try:
157
+ schema = pd.read_parquet(shard_path, engine='pyarrow').dtypes
158
+ logger.info(f"Parquet schema: {schema}")
159
+ except Exception:
160
+ pass
161
+ return pd.DataFrame(columns=["title", "summary", "similarity"])
162
+ df = self.loaded_shards[shard]
163
+ df_len = len(df)
164
+ valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
165
+ if len(valid_local_indices) != len(local_indices):
166
+ logger.warning(f"Filtered {len(local_indices) - len(valid_local_indices)} out-of-bounds indices in shard {shard}")
167
+ if valid_local_indices:
168
+ chunk = df.iloc[valid_local_indices]
169
+ logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
170
+ return chunk
171
+ except Exception as e:
172
+ logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
173
+ return pd.DataFrame(columns=["title", "summary", "similarity"])
174
+
175
+ def get_metadata(self, global_indices):
176
+ """Retrieve metadata for a batch of global indices using parallel shard processing."""
177
+ if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
178
+ logger.warning("Empty indices array passed to get_metadata")
179
+ return pd.DataFrame(columns=["title", "summary", "similarity"])
180
 
181
+ indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
182
+ logger.info(f"Retrieving metadata for {len(indices_list)} indices")
183
+ valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
184
+ invalid_count = len(indices_list) - len(valid_indices)
185
+ if invalid_count > 0:
186
+ logger.warning(f"Filtered out {invalid_count} invalid indices")
187
+ if not valid_indices:
188
+ logger.warning("No valid indices remain after filtering")
189
+ return pd.DataFrame(columns=["title", "summary", "similarity"])
190
+
191
+ # Group indices by shard
192
+ shard_groups = {}
193
+ for idx in valid_indices:
194
+ found = False
195
+ for (start, end), shard in self.shard_map.items():
196
+ if start <= idx <= end:
197
+ shard_groups.setdefault(shard, []).append(idx - start)
198
+ found = True
199
+ break
200
+ if not found:
201
+ logger.warning(f"Index {idx} not found in any shard range")
202
+
203
+ # Process shards concurrently
204
+ results = []
205
+ with concurrent.futures.ThreadPoolExecutor() as executor:
206
+ futures = [executor.submit(self._process_shard, shard, local_indices)
207
+ for shard, local_indices in shard_groups.items()]
208
+ for future in concurrent.futures.as_completed(futures):
209
+ df_chunk = future.result()
210
+ if not df_chunk.empty:
211
+ results.append(df_chunk)
212
+
213
+ if results:
214
+ combined = pd.concat(results).reset_index(drop=True)
215
+ logger.info(f"Combined metadata: {len(combined)} records from {len(results)} shards")
216
+ return combined
217
+ else:
218
+ logger.warning("No metadata records retrieved")
219
+ return pd.DataFrame(columns=["title", "summary", "similarity"])
220
+
221
  def _init_url_resolver(self):
222
+ """Initialize API session and cache."""
223
  self.session = requests.Session()
224
  adapter = requests.adapters.HTTPAdapter(
225
  pool_connections=10,
 
228
  )
229
  self.session.mount("https://", adapter)
230
 
 
231
  def resolve_url(self, title: str) -> str:
232
+ """Optimized URL resolution with fail-fast."""
233
+ if title in self.api_cache:
234
+ return self.api_cache[title]
235
+
236
+ links = {}
237
+ arxiv_url = self._get_arxiv_url(title)
238
+ if arxiv_url:
239
+ links["arxiv"] = arxiv_url
240
+ semantic_url = self._get_semantic_url(title)
241
+ if semantic_url:
242
+ links["semantic"] = semantic_url
243
+ scholar_url = f"https://scholar.google.com/scholar?q={quote(title)}"
244
+ links["google"] = scholar_url
245
+
246
+ self.api_cache[title] = links
247
+ return links
248
+
249
  def _get_arxiv_url(self, title: str) -> str:
250
+ """Fast arXiv lookup with timeout."""
251
  with self.session.get(
252
  "http://export.arxiv.org/api/query",
253
+ params={"search_query": f'ti:"{title}"', "max_results": 1, "sortBy": "relevance"},
254
  timeout=2
255
  ) as response:
256
  if response.ok:
257
  return self._parse_arxiv_response(response.text)
258
  return ""
259
+
260
  def _parse_arxiv_response(self, xml: str) -> str:
261
+ """Fast XML parsing using string operations."""
262
+ if "<entry>" not in xml:
263
+ return ""
264
  start = xml.find("<id>") + 4
265
  end = xml.find("</id>", start)
266
  return xml[start:end].replace("http:", "https:") if start > 3 else ""
267
+
268
  def _get_semantic_url(self, title: str) -> str:
269
+ """Batch-friendly Semantic Scholar lookup."""
270
  with self.session.get(
271
  "https://api.semanticscholar.org/graph/v1/paper/search",
272
  params={"query": title[:200], "limit": 1},
 
277
  if data.get("data"):
278
  return data["data"][0].get("url", "")
279
  return ""
280
+
281
+ def _format_source_links(self, links):
282
+ """Generate an HTML snippet for the available source links."""
283
+ html_parts = []
284
+ if "arxiv" in links:
285
+ html_parts.append(f"<a class='source-link' href='{links['arxiv']}' target='_blank' rel='noopener noreferrer'> πŸ“œ arXiv</a>")
286
+ if "semantic" in links:
287
+ html_parts.append(f"<a class='source-link' href='{links['semantic']}' target='_blank' rel='noopener noreferrer'> 🌐 Semantic Scholar</a>")
288
+ if "google" in links:
289
+ html_parts.append(f"<a class='source-link' href='{links['google']}' target='_blank' rel='noopener noreferrer'> πŸ” Google Scholar</a>")
290
+ return " | ".join(html_parts)
291
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ class SemanticSearch:
294
+ def __init__(self):
295
+ self.shard_dir = Path("compressed_shards")
296
+ self.model = None
297
+ self.index_shards = []
298
+ self.metadata_mgr = MetadataManager()
299
+ self.shard_sizes = []
300
+ self.cumulative_offsets = None
301
+ self.logger = logging.getLogger("SemanticSearch")
302
+ self.logger.info("Initializing SemanticSearch")
303
+
304
+ @st.cache_resource
305
+ def load_model(_self):
306
+ return SentenceTransformer('all-MiniLM-L6-v2')
307
+
308
+ def initialize_system(self):
309
+ self.logger.info("Loading sentence transformer model")
310
+ start_time = time.time()
311
+ self.model = self.load_model()
312
+ self.logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
313
+ self.logger.info("Loading FAISS indices")
314
+ self._load_faiss_shards()
315
+
316
+ def _load_faiss_shards(self):
317
+ """Load FAISS shards concurrently and precompute cumulative offsets for global indexing."""
318
+ self.logger.info(f"Searching for index files in {self.shard_dir}")
319
+ if not self.shard_dir.exists():
320
+ self.logger.error(f"Shard directory not found: {self.shard_dir}")
321
+ return
322
+ index_files = sorted(self.shard_dir.glob("*.index"))
323
+ self.logger.info(f"Found {len(index_files)} index files")
324
+ self.index_shards = []
325
+ self.shard_sizes = []
326
  with concurrent.futures.ThreadPoolExecutor() as executor:
327
+ future_to_file = {
328
+ executor.submit(self._load_single_index, shard_path): shard_path
329
+ for shard_path in index_files
 
 
 
330
  }
331
+ for future in concurrent.futures.as_completed(future_to_file):
332
+ shard_path = future_to_file[future]
 
 
333
  try:
334
+ index, size = future.result()
335
+ if index is not None:
336
+ self.index_shards.append(index)
337
+ self.shard_sizes.append(size)
338
+ self.logger.info(f"Loaded index {shard_path.name} with {size} vectors")
339
  except Exception as e:
340
+ self.logger.error(f"Error loading index {shard_path}: {str(e)}")
341
+ total_vectors = sum(self.shard_sizes)
342
+ self.logger.info(f"Total loaded vectors: {total_vectors} across {len(self.index_shards)} shards")
343
+ self.cumulative_offsets = np.cumsum([0] + self.shard_sizes)
344
+
345
+ def _load_single_index(self, shard_path):
346
+ """Load a single FAISS index shard."""
347
+ self.logger.info(f"Loading index: {shard_path}")
348
+ start_time = time.time()
349
+ file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
350
+ self.logger.info(f"Index file size: {file_size_mb:.2f} MB")
351
+ index = faiss.read_index(str(shard_path))
352
+ size = index.ntotal
353
+ self.logger.info(f"Index loaded in {time.time() - start_time:.2f} seconds")
354
+ return index, size
355
+
356
+ def _global_index(self, shard_idx, local_idx):
357
+ """Convert a local index (within a shard) to a global index using precomputed offsets."""
358
+ return int(self.cumulative_offsets[shard_idx] + local_idx)
359
+
360
+ def search(self, query, top_k=5):
361
+ """Search for a query using parallel FAISS shard search."""
362
+ self.logger.info(f"Searching for query: '{query}' (top_k={top_k})")
363
+ start_time = time.time()
364
+ if not query:
365
+ self.logger.warning("Empty query provided")
366
+ return pd.DataFrame()
367
+ if not self.index_shards:
368
+ self.logger.error("No index shards loaded")
369
+ return pd.DataFrame()
370
+ try:
371
+ self.logger.info("Encoding query")
372
+ query_embedding = self.model.encode([query], convert_to_numpy=True)
373
+ self.logger.debug(f"Query encoded to shape {query_embedding.shape}")
374
+ except Exception as e:
375
+ self.logger.error(f"Query encoding failed: {str(e)}")
376
+ return pd.DataFrame()
377
 
378
+ all_distances = []
379
+ all_global_indices = []
380
+ # Run shard searches in parallel
381
+ with concurrent.futures.ThreadPoolExecutor() as executor:
382
+ futures = {
383
+ executor.submit(self._search_shard, shard_idx, index, query_embedding, top_k): shard_idx
384
+ for shard_idx, index in enumerate(self.index_shards)
385
+ }
386
+ for future in concurrent.futures.as_completed(futures):
387
+ result = future.result()
388
+ if result is not None:
389
+ distances_part, global_indices_part = result
390
+ all_distances.extend(distances_part)
391
+ all_global_indices.extend(global_indices_part)
392
+ self.logger.info(f"Search found {len(all_global_indices)} results across all shards")
393
+ results = self._process_results(np.array(all_distances), np.array(all_global_indices), top_k)
394
+ self.logger.info(f"Search completed in {time.time() - start_time:.2f} seconds with {len(results)} final results")
395
+ return results
396
+
397
+ def _search_shard(self, shard_idx, index, query_embedding, top_k):
398
+ """Search a single FAISS shard for the query embedding."""
399
+ if index.ntotal == 0:
400
+ self.logger.warning(f"Skipping empty shard {shard_idx}")
401
+ return None
402
+ try:
403
+ shard_start = time.time()
404
+ distances, indices = index.search(query_embedding, top_k)
405
+ valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
406
+ valid_indices = indices[0][valid_mask].tolist()
407
+ valid_distances = distances[0][valid_mask].tolist()
408
+ if len(valid_indices) != top_k:
409
+ self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
410
+ global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
411
+ self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
412
+ return valid_distances, global_indices
413
+ except Exception as e:
414
+ self.logger.error(f"Search failed in shard {shard_idx}: {str(e)}")
415
+ return None
416
+
417
+ def _process_results(self, distances, global_indices, top_k):
418
+ """Process raw search results: retrieve metadata, calculate similarity, and deduplicate."""
419
+ process_start = time.time()
420
+ if global_indices.size == 0 or distances.size == 0:
421
+ self.logger.warning("No search results to process")
422
+ return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
423
+ try:
424
+ self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
425
+ metadata_start = time.time()
426
+ results = self.metadata_mgr.get_metadata(global_indices)
427
+ self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
428
+ if len(results) == 0:
429
+ self.logger.warning("No metadata found for indices")
430
+ return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
431
+ if len(results) != len(distances):
432
+ self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
433
+ if len(results) < len(distances):
434
+ distances = distances[:len(results)]
435
+ else:
436
+ distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
437
+ self.logger.debug("Calculating similarity scores")
438
+ results['similarity'] = 1 - (distances / 2)
439
+ if not results.empty:
440
+ self.logger.debug(f"Similarity stats: min={results['similarity'].min():.3f}, " +
441
+ f"max={results['similarity'].max():.3f}, " +
442
+ f"mean={results['similarity'].mean():.3f}")
443
+ results['source'] = results['title'].apply(
444
+ lambda title: self._format_source_links(self.metadata_mgr._resolve_paper_url(title))
445
+ )
446
+ pre_dedup = len(results)
447
+ results = results.drop_duplicates(subset=["title", "source"]).sort_values("similarity", ascending=False).head(top_k)
448
+ post_dedup = len(results)
449
+ if pre_dedup > post_dedup:
450
+ self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
451
+ self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
452
+ return results.reset_index(drop=True)
453
+ except Exception as e:
454
+ self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
455
+ return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
456
+
457
+ def _format_source_links(self, links):
458
+ """Generate an HTML snippet for the available source links."""
459
+ html_parts = []
460
+ if "arxiv" in links:
461
+ html_parts.append(f"<a class='source-link' href='{links['arxiv']}' target='_blank' rel='noopener noreferrer'> πŸ“œ arXiv</a>")
462
+ if "semantic" in links:
463
+ html_parts.append(f"<a class='source-link' href='{links['semantic']}' target='_blank' rel='noopener noreferrer'> 🌐 Semantic Scholar</a>")
464
+ if "google" in links:
465
+ html_parts.append(f"<a class='source-link' href='{links['google']}' target='_blank' rel='noopener noreferrer'> πŸ” Google Scholar</a>")
466
+ return " | ".join(html_parts)