Testys commited on
Commit
801d9f2
·
1 Parent(s): 391339a

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +128 -567
search_utils.py CHANGED
@@ -1,593 +1,154 @@
1
  import numpy as np
2
- import pandas as pd
3
  import faiss
4
  import zipfile
5
  import logging
6
  from pathlib import Path
7
- from sentence_transformers import SentenceTransformer, util
8
- import streamlit as st
9
- import time
10
  import os
11
- from urllib.parse import quote
12
  import requests
13
-
 
14
 
15
  # Configure logging
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
- handlers=[
20
- logging.StreamHandler()
21
- ]
22
- )
23
- logger = logging.getLogger("MetadataManager")
24
 
25
- class MetadataManager:
26
  def __init__(self):
27
- self.cache_dir = Path("unzipped_cache")
28
- self.shard_dir = self.cache_dir / "metadata_shards"
29
- self.shard_map = {}
30
- self.loaded_shards = {}
31
- self.total_docs = 0
32
- self.api_cache = {}
33
-
34
- logger.info("Initializing MetadataManager")
35
- self._ensure_directories()
36
- self._unzip_if_needed()
37
- self._build_shard_map()
38
- logger.info(f"Total documents indexed: {self.total_docs}")
39
- logger.info(f"Total shards found: {len(self.shard_map)}")
40
-
41
- def _ensure_directories(self):
42
- """Create necessary directories if they don't exist"""
43
- self.cache_dir.mkdir(parents=True, exist_ok=True)
44
- self.shard_dir.mkdir(parents=True, exist_ok=True)
45
-
46
- def _unzip_if_needed(self):
47
- """Handle ZIP extraction with nested directory handling"""
48
- zip_path = Path("metadata_shards.zip")
49
-
50
- # Check if we need to unzip by looking for parquet files in any subdirectory
51
- if not any(self.shard_dir.rglob("*.parquet")):
52
- logger.info("No parquet files found, checking for zip archive")
53
 
54
- if not zip_path.exists():
55
- raise FileNotFoundError(f"Metadata ZIP file not found at {zip_path}")
56
-
57
- logger.info(f"Extracting {zip_path} to {self.shard_dir}")
58
- try:
59
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
60
- # Check for nested directory structure in zip
61
- zip_root = self._get_zip_root(zip_ref)
62
-
63
- # Extract while preserving structure
64
- zip_ref.extractall(self.shard_dir)
65
-
66
- # Handle nested directory if exists
67
- if zip_root:
68
- nested_dir = self.shard_dir / zip_root
69
- if nested_dir.exists():
70
- # Move files up from nested directory
71
- self._flatten_directory(nested_dir, self.shard_dir)
72
- nested_dir.rmdir()
73
-
74
- # Verify extraction
75
- parquet_files = list(self.shard_dir.rglob("*.parquet"))
76
- if not parquet_files:
77
- raise RuntimeError("Extraction completed but no parquet files found")
78
-
79
- logger.info(f"Found {len(parquet_files)} parquet files after extraction")
80
-
81
- except Exception as e:
82
- logger.error(f"Failed to extract zip file: {str(e)}")
83
- self._clean_failed_extraction()
84
- raise
85
 
86
- def _get_zip_root(self, zip_ref):
87
- """Identify common root directory in zip file"""
 
88
  try:
89
- first_file = zip_ref.namelist()[0]
90
- if '/' in first_file:
91
- return first_file.split('/')[0]
92
- return ""
93
- except Exception as e:
94
- logger.warning(f"Error detecting zip root: {str(e)}")
95
- return ""
96
-
97
- def _flatten_directory(self, src_dir, dest_dir):
98
- """Move files from nested directory to destination"""
99
- for item in src_dir.iterdir():
100
- if item.is_dir():
101
- self._flatten_directory(item, dest_dir)
102
- item.rmdir()
103
- else:
104
- target = dest_dir / item.name
105
- if target.exists():
106
- target.unlink()
107
- item.rename(target)
108
-
109
- def _clean_failed_extraction(self):
110
- """Remove any extracted files after failed attempt"""
111
- logger.info("Cleaning up failed extraction")
112
- for item in self.shard_dir.iterdir():
113
- if item.is_dir():
114
- shutil.rmtree(item)
115
- else:
116
- item.unlink()
117
-
118
- def _build_shard_map(self):
119
- """Create validated index range to shard mapping"""
120
- logger.info("Building shard map from parquet files")
121
- parquet_files = list(self.shard_dir.glob("*.parquet"))
122
-
123
- if not parquet_files:
124
- raise FileNotFoundError("No parquet files found after extraction")
125
 
126
- # Sort files by numerical order
127
- parquet_files = sorted(parquet_files, key=lambda x: int(x.stem.split("_")[1]))
128
-
129
- # Track expected next index
130
- expected_start = 0
131
-
132
- for f in parquet_files:
133
- try:
134
- parts = f.stem.split("_")
135
- if len(parts) != 3:
136
- raise ValueError("Invalid filename format")
137
-
138
- start = int(parts[1])
139
- end = int(parts[2])
140
-
141
- # Validate continuity
142
- if start != expected_start:
143
- raise ValueError(f"Non-contiguous shard start: expected {expected_start}, got {start}")
144
-
145
- # Validate range
146
- if end <= start:
147
- raise ValueError(f"Invalid shard range: {start}-{end}")
148
-
149
- self.shard_map[(start, end)] = f.name
150
- self.total_docs = end + 1
151
- expected_start = end + 1
152
-
153
- logger.debug(f"Mapped shard {f.name}: indices {start}-{end}")
154
-
155
- except Exception as e:
156
- logger.error(f"Error processing shard {f.name}: {str(e)}")
157
- raise RuntimeError("Invalid shard structure") from e
158
-
159
- logger.info(f"Validated {len(self.shard_map)} continuous shards")
160
- logger.info(f"Total document count: {self.total_docs}")
161
-
162
- # Log shard statistics
163
- logger.info(f"Shard map built with {len(self.shard_map)} shards")
164
- logger.info(f"Total document count: {self.total_docs}")
165
-
166
- # Validate shard boundaries for gaps or overlaps
167
- sorted_ranges = sorted(self.shard_map.keys())
168
- for i in range(1, len(sorted_ranges)):
169
- prev_end = sorted_ranges[i-1][1]
170
- curr_start = sorted_ranges[i][0]
171
- if curr_start != prev_end + 1:
172
- logger.warning(f"Gap or overlap detected between shards: {prev_end} to {curr_start}")
173
-
174
- def get_metadata(self, global_indices):
175
- """Retrieve metadata with validation"""
176
- # Check for empty numpy array properly
177
- if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
178
- logger.warning("Empty indices array passed to get_metadata")
179
- return pd.DataFrame(columns=["title", "summary", "similarity"])
180
-
181
- # Convert numpy array to list for processing
182
- indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
183
- logger.info(f"Retrieving metadata for {len(indices_list)} indices")
184
-
185
- # Filter valid indices
186
- valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
187
- invalid_count = len(indices_list) - len(valid_indices)
188
- if invalid_count > 0:
189
- logger.warning(f"Filtered out {invalid_count} invalid indices")
190
-
191
- if not valid_indices:
192
- logger.warning("No valid indices remain after filtering")
193
- return pd.DataFrame(columns=["title", "summary", "similarity"])
194
-
195
- # Group indices by shard with boundary check
196
- shard_groups = {}
197
- unassigned_indices = []
198
-
199
- for idx in valid_indices:
200
- found = False
201
- for (start, end), shard in self.shard_map.items():
202
- if start <= idx <= end:
203
- if shard not in shard_groups:
204
- shard_groups[shard] = []
205
- shard_groups[shard].append(idx - start)
206
- found = True
207
- break
208
- if not found:
209
- unassigned_indices.append(idx)
210
- logger.warning(f"Index {idx} not found in any shard range")
211
-
212
- if unassigned_indices:
213
- logger.warning(f"Could not assign {len(unassigned_indices)} indices to any shard")
214
-
215
- # Load and process shards
216
- results = []
217
- for shard, local_indices in shard_groups.items():
218
- try:
219
- logger.info(f"Processing shard {shard} with {len(local_indices)} indices")
220
- start_time = time.time()
221
-
222
- if shard not in self.loaded_shards:
223
- logger.info(f"Loading shard file: {shard}")
224
- shard_path = self.shard_dir / shard
225
-
226
- # Verify file exists
227
- if not shard_path.exists():
228
- logger.error(f"Shard file not found: {shard_path}")
229
- continue
230
-
231
- # Log file size
232
- file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
233
- logger.info(f"Shard file size: {file_size_mb:.2f} MB")
234
-
235
- # Attempt to read the parquet file
236
- try:
237
- self.loaded_shards[shard] = pd.read_parquet(
238
- shard_path,
239
- columns=["title", "summary"]
240
- )
241
- logger.info(f"Successfully loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
242
- except Exception as e:
243
- logger.error(f"Failed to read parquet file {shard}: {str(e)}")
244
-
245
- # Try to read file schema for debugging
246
- try:
247
- schema = pd.read_parquet(shard_path, engine='pyarrow').dtypes
248
- logger.info(f"Parquet schema: {schema}")
249
- except:
250
- pass
251
- continue
252
-
253
- if local_indices:
254
- # Validate indices are within dataframe bounds
255
- df_len = len(self.loaded_shards[shard])
256
- valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
257
-
258
- if len(valid_local_indices) != len(local_indices):
259
- logger.warning(f"Filtered {len(local_indices) - len(valid_local_indices)} out-of-bounds indices")
260
-
261
- if valid_local_indices:
262
- logger.debug(f"Retrieving rows at indices: {valid_local_indices}")
263
- chunk = self.loaded_shards[shard].iloc[valid_local_indices]
264
- results.append(chunk)
265
- logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
266
-
267
- logger.info(f"Shard processing completed in {time.time() - start_time:.2f} seconds")
268
-
269
- except Exception as e:
270
- logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
271
- continue
272
-
273
- # Combine results
274
- if results:
275
- combined = pd.concat(results).reset_index(drop=True)
276
- logger.info(f"Combined metadata: {len(combined)} records from {len(results)} shards")
277
- return combined
278
- else:
279
- logger.warning("No metadata records retrieved")
280
- return pd.DataFrame(columns=["title", "summary", "similarity"])
281
-
282
-
283
- def _resolve_paper_url(self, title):
284
- """Find paper URL using multiple strategies"""
285
- # Check cache first
286
- if title in self.api_cache:
287
- return self.api_cache[title]
288
-
289
- links = {}
290
-
291
- # Try arXiv first
292
- arxiv_url = self._get_arxiv_url(title)
293
- if arxiv_url:
294
- links["arxiv"] = arxiv_url
295
-
296
- # Attempt to get a direct link using Semantic Scholar's API
297
- semantic_url = self._get_semantic_scholar_url(title)
298
- if semantic_url:
299
- links["semantic_search"] = semantic_url
300
-
301
-
302
- # Fallback to Google Scholar search
303
- scholar_url = f"https://scholar.google.com/scholar?q={quote(title)}"
304
- links["google"] = scholar_url
305
-
306
- self.api_cache[title] = links
307
-
308
- return links
309
-
310
-
311
- def _get_arxiv_url(self, title):
312
- """Search arXiv API for paper"""
313
- try:
314
- response = requests.get(
315
- "http://export.arxiv.org/api/query",
316
- params={
317
- "search_query": f'ti:"{title}"',
318
- "max_results": 1,
319
- "sortBy": "relevance"
320
- },
321
- timeout=5
322
- )
323
- response.raise_for_status()
324
 
325
- # Parse XML response
326
- from xml.etree import ElementTree as ET
327
- root = ET.fromstring(response.content)
328
- entry = root.find('{http://www.w3.org/2005/Atom}entry')
329
- if entry is not None:
330
- arxiv_id = entry.find('{http://www.w3.org/2005/Atom}id').text
331
- return arxiv_id.replace('http:', 'https:') # Force HTTPS
332
  except Exception as e:
333
- logger.error(f"arXiv API failed for '{title}': {str(e)}")
334
- return None
335
-
336
-
337
-
338
- def _get_semantic_scholar_url(self, title):
339
- """Search Semantic Scholar API for a paper by title and return its URL."""
340
- try:
341
- response = requests.get(
342
- "https://api.semanticscholar.org/graph/v1/paper/search",
343
- params={
344
- "query": title,
345
- "limit": 1,
346
- "fields": "paperId,url,title"
347
- },
348
- timeout=5
349
- )
350
- response.raise_for_status() # This raises for 429 or other errors
351
- data = response.json()
352
 
353
- if "data" in data and len(data["data"]) > 0:
354
- paper = data["data"][0]
355
- if paper.get("url"):
356
- return paper["url"]
357
- elif paper.get("paperId"):
358
- return f"https://www.semanticscholar.org/paper/{paper['paperId']}"
359
- except requests.exceptions.HTTPError as http_err:
360
- if response.status_code == 429:
361
-
362
- time.sleep(1) # simple backoff delay; consider exponential backoff
363
- except Exception as e:
364
- logger.error(f"Semantic Scholar API failed for '{title}': {e}")
365
-
366
- return None
367
-
368
-
369
-
370
- class SemanticSearch:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  def __init__(self):
372
- self.shard_dir = Path("compressed_shards")
373
- self.model = None
374
- self.index_shards = []
375
- self.metadata_mgr = MetadataManager()
376
- self.shard_sizes = []
377
 
378
- # Configure search logger
379
- self.logger = logging.getLogger("SemanticSearch")
380
- self.logger.info("Initializing SemanticSearch")
381
-
382
- @st.cache_resource
383
- def load_model(_self):
384
- return SentenceTransformer('all-MiniLM-L6-v2')
385
-
386
- def initialize_system(self):
387
- self.logger.info("Loading sentence transformer model")
388
- start_time = time.time()
389
- self.model = self.load_model()
390
- self.logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
391
-
392
- self.logger.info("Loading FAISS indices")
393
- self._load_faiss_shards()
394
 
395
- def _load_faiss_shards(self):
396
- """Load all FAISS index shards"""
397
- self.logger.info(f"Searching for index files in {self.shard_dir}")
398
-
399
- if not self.shard_dir.exists():
400
- self.logger.error(f"Shard directory not found: {self.shard_dir}")
401
- return
402
-
403
- index_files = list(self.shard_dir.glob("*.index"))
404
- self.logger.info(f"Found {len(index_files)} index files")
405
 
406
- self.shard_sizes = []
407
- self.index_shards = []
408
 
409
- for shard_path in sorted(index_files):
410
- try:
411
- self.logger.info(f"Loading index: {shard_path}")
412
- start_time = time.time()
413
-
414
- # Log file size
415
- file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
416
- self.logger.info(f"Index file size: {file_size_mb:.2f} MB")
417
-
418
- index = faiss.read_index(str(shard_path))
419
- self.index_shards.append(index)
420
- self.shard_sizes.append(index.ntotal)
421
-
422
- self.logger.info(f"Loaded index with {index.ntotal} vectors in {time.time() - start_time:.2f} seconds")
423
- except Exception as e:
424
- self.logger.error(f"Failed to load index {shard_path}: {str(e)}")
425
-
426
- total_vectors = sum(self.shard_sizes)
427
- self.logger.info(f"Total loaded vectors: {total_vectors} across {len(self.index_shards)} shards")
428
-
429
- def _global_index(self, shard_idx, local_idx):
430
- """Convert local index to global index"""
431
- return sum(self.shard_sizes[:shard_idx]) + local_idx
432
-
433
- def search(self, query, top_k=5):
434
- """Search with validation"""
435
- self.logger.info(f"Searching for query: '{query}' (top_k={top_k})")
436
- start_time = time.time()
437
-
438
- if not query:
439
- self.logger.warning("Empty query provided")
440
- return pd.DataFrame()
441
-
442
- if not self.index_shards:
443
- self.logger.error("No index shards loaded")
444
- return pd.DataFrame()
445
-
446
- try:
447
- self.logger.info("Encoding query")
448
- query_embedding = self.model.encode([query], convert_to_numpy=True)
449
- self.logger.debug(f"Query encoded to shape {query_embedding.shape}")
450
- except Exception as e:
451
- self.logger.error(f"Query encoding failed: {str(e)}")
452
- return pd.DataFrame()
453
-
454
- all_distances = []
455
- all_global_indices = []
456
-
457
- # Search with index validation
458
- self.logger.info(f"Searching across {len(self.index_shards)} shards")
459
- for shard_idx, index in enumerate(self.index_shards):
460
- if index.ntotal == 0:
461
- self.logger.warning(f"Skipping empty shard {shard_idx}")
462
- continue
463
-
464
- try:
465
- shard_start = time.time()
466
- distances, indices = index.search(query_embedding, top_k)
467
-
468
- valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
469
- valid_indices = indices[0][valid_mask].tolist()
470
- valid_distances = distances[0][valid_mask].tolist()
471
-
472
- if len(valid_indices) != top_k:
473
- self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
474
-
475
- global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
476
-
477
- all_distances.extend(valid_distances)
478
- all_global_indices.extend(global_indices)
479
-
480
- self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
481
- except Exception as e:
482
- self.logger.error(f"Search failed in shard {shard_idx}: {str(e)}")
483
- continue
484
-
485
- self.logger.info(f"Search found {len(all_global_indices)} results across all shards")
486
 
487
  # Process results
488
- results = self._process_results(
489
- np.array(all_distances),
490
- np.array(all_global_indices),
491
- top_k
492
- )
493
-
494
- self.logger.info(f"Search completed in {time.time() - start_time:.2f} seconds with {len(results)} final results")
495
- return results
496
-
497
- def _process_results(self, distances, global_indices, top_k):
498
- """Process raw search results into formatted DataFrame"""
499
- process_start = time.time()
500
-
501
- # Proper numpy array emptiness checks
502
- if global_indices.size == 0 or distances.size == 0:
503
- self.logger.warning("No search results to process")
504
- return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
505
-
506
- try:
507
- # Get metadata for matched indices
508
- self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
509
- metadata_start = time.time()
510
- results = self.metadata_mgr.get_metadata(global_indices)
511
- self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
512
-
513
- # Empty results check
514
- if len(results) == 0:
515
- self.logger.warning("No metadata found for indices")
516
- return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
517
-
518
- # Ensure distances match results length
519
- if len(results) != len(distances):
520
- self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
521
-
522
- if len(results) < len(distances):
523
- self.logger.info("Truncating distances array to match results length")
524
- distances = distances[:len(results)]
525
- else:
526
- # Should not happen but handle it anyway
527
- self.logger.error("More results than distances - this shouldn't happen")
528
- distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
529
-
530
- # Calculate similarity scores
531
- self.logger.debug("Calculating similarity scores")
532
- results['similarity'] = 1 - (distances / 2)
533
-
534
- # Log similarity statistics
535
- if not results.empty:
536
- self.logger.debug(f"Similarity stats: min={results['similarity'].min():.3f}, " +
537
- f"max={results['similarity'].max():.3f}, " +
538
- f"mean={results['similarity'].mean():.3f}")
539
-
540
-
541
- results['source'] = results['title'].apply(
542
- lambda title: self._format_source_links(
543
- self.metadata_mgr._resolve_paper_url(title)
544
- )
545
- )
546
 
547
- # Deduplicate and sort results
548
- pre_dedup = len(results)
549
- results = results.drop_duplicates(subset=["title", "source"]).sort_values("similarity", ascending=False).head(top_k)
550
- post_dedup = len(results)
551
-
552
- if pre_dedup > post_dedup:
553
- self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
554
-
555
- self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
556
- return results.reset_index(drop=True)
557
-
558
- # Add URL resolution for final results only
559
- final_results = results.sort_values("similarity", ascending=False).head(top_k)
560
-
561
- # Resolve URLs for top results only
562
- final_results['source'] = final_results['title'].apply(
563
- lambda title: self._format_source_links(
564
- self.metadata_mgr._resolve_paper_url(title)
565
- )
566
- )
567
-
568
- # Deduplicate based on title only
569
- final_results = final_results.drop_duplicates(subset=["title"]).head(top_k)
570
-
571
- return final_results.reset_index(drop=True)
572
-
573
- except Exception as e:
574
- self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
575
- return pd.DataFrame(columns=["title", "summary", "similarity"])
576
-
577
-
578
- def _format_source_links(self, links):
579
- """Generate an HTML snippet for the available source links."""
580
- html_parts = []
581
- if "arxiv" in links:
582
- html_parts.append(
583
- f"<a class='source-link' href='{links['arxiv']}' target='_blank' rel='noopener noreferrer'> 📜 arXiv</a>"
584
- )
585
- if "semantic" in links:
586
- html_parts.append(
587
- f"<a class='source-link' href='{links['semantic']}' target='_blank' rel='noopener noreferrer'> 🌐 Semantic Scholar</a>"
588
- )
589
- if "google" in links:
590
- html_parts.append(
591
- f"<a class='source-link' href='{links['google']}' target='_blank' rel='noopener noreferrer'> 🔍 Google Scholar</a>"
592
- )
593
- return " | ".join(html_parts)
 
1
  import numpy as np
 
2
  import faiss
3
  import zipfile
4
  import logging
5
  from pathlib import Path
6
+ from sentence_transformers import SentenceTransformer
7
+ import concurrent.futures
 
8
  import os
 
9
  import requests
10
+ from functools import lru_cache
11
+ from typing import List, Dict
12
 
13
  # Configure logging
14
+ logging.basicConfig(level=logging.WARNING,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16
+ logger = logging.getLogger("OptimizedSearch")
 
 
 
 
 
17
 
18
+ class OptimizedMetadataManager:
19
  def __init__(self):
20
+ self._init_metadata()
21
+ self._init_url_resolver()
22
+
23
+ def _init_metadata(self):
24
+ """Memory-mapped metadata loading"""
25
+ self.metadata_dir = Path("unzipped_cache/metadata_shards")
26
+ self.metadata = {}
27
+
28
+ # Preload all metadata into memory
29
+ for parquet_file in self.metadata_dir.glob("*.parquet"):
30
+ df = pd.read_parquet(parquet_file, columns=["title", "summary"])
31
+ self.metadata.update(df.to_dict(orient="index"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ self.total_docs = len(self.metadata)
34
+ logger.info(f"Loaded {self.total_docs} metadata entries into memory")
35
+
36
+ def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
37
+ """Batch retrieval of metadata"""
38
+ return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
39
+
40
+ def _init_url_resolver(self):
41
+ """Initialize API session and cache"""
42
+ self.session = requests.Session()
43
+ adapter = requests.adapters.HTTPAdapter(
44
+ pool_connections=10,
45
+ pool_maxsize=10,
46
+ max_retries=3
47
+ )
48
+ self.session.mount("https://", adapter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ @lru_cache(maxsize=10_000)
51
+ def resolve_url(self, title: str) -> str:
52
+ """Optimized URL resolution with fail-fast"""
53
  try:
54
+ # Try arXiv first
55
+ arxiv_url = self._get_arxiv_url(title)
56
+ if arxiv_url: return arxiv_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Fallback to Semantic Scholar
59
+ semantic_url = self._get_semantic_url(title)
60
+ if semantic_url: return semantic_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
+ logger.warning(f"URL resolution failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ return f"https://scholar.google.com/scholar?q={quote(title)}"
66
+
67
+ def _get_arxiv_url(self, title: str) -> str:
68
+ """Fast arXiv lookup with timeout"""
69
+ with self.session.get(
70
+ "http://export.arxiv.org/api/query",
71
+ params={"search_query": f'ti:"{title}"', "max_results": 1},
72
+ timeout=2
73
+ ) as response:
74
+ if response.ok:
75
+ return self._parse_arxiv_response(response.text)
76
+ return ""
77
+
78
+ def _parse_arxiv_response(self, xml: str) -> str:
79
+ """Fast XML parsing using string operations"""
80
+ if "<entry>" not in xml: return ""
81
+ start = xml.find("<id>") + 4
82
+ end = xml.find("</id>", start)
83
+ return xml[start:end].replace("http:", "https:") if start > 3 else ""
84
+
85
+ def _get_semantic_url(self, title: str) -> str:
86
+ """Batch-friendly Semantic Scholar lookup"""
87
+ with self.session.get(
88
+ "https://api.semanticscholar.org/graph/v1/paper/search",
89
+ params={"query": title[:200], "limit": 1},
90
+ timeout=2
91
+ ) as response:
92
+ if response.ok:
93
+ data = response.json()
94
+ if data.get("data"):
95
+ return data["data"][0].get("url", "")
96
+ return ""
97
+
98
+ class OptimizedSemanticSearch:
99
  def __init__(self):
100
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
101
+ self._load_faiss_indexes()
102
+ self.metadata_mgr = OptimizedMetadataManager()
 
 
103
 
104
+ def _load_faiss_indexes(self):
105
+ """Load indexes with memory mapping"""
106
+ self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
107
+ logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ def search(self, query: str, top_k: int = 5) -> List[Dict]:
110
+ """Optimized search pipeline"""
111
+ # Batch encode query
112
+ query_embedding = self.model.encode([query], convert_to_numpy=True)
 
 
 
 
 
 
113
 
114
+ # FAISS search
115
+ distances, indices = self.index.search(query_embedding, top_k*2) # Search extra for dedup
116
 
117
+ # Batch metadata retrieval
118
+ results = self.metadata_mgr.get_metadata_batch(indices[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Process results
121
+ return self._process_results(results, distances[0], top_k)
122
+
123
+ def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
124
+ """Parallel result processing"""
125
+ with concurrent.futures.ThreadPoolExecutor() as executor:
126
+ # Parallel URL resolution
127
+ futures = {
128
+ executor.submit(
129
+ self.metadata_mgr.resolve_url,
130
+ res["title"]
131
+ ): idx for idx, res in enumerate(results)
132
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Update results as URLs resolve
135
+ for future in concurrent.futures.as_completed(futures):
136
+ idx = futures[future]
137
+ try:
138
+ results[idx]["source"] = future.result()
139
+ except Exception as e:
140
+ results[idx]["source"] = ""
141
+
142
+ # Add similarity scores
143
+ for idx, dist in enumerate(distances[:len(results)]):
144
+ results[idx]["similarity"] = 1 - (dist / 2)
145
+
146
+ # Deduplicate and sort
147
+ seen = set()
148
+ final_results = []
149
+ for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
150
+ if res["title"] not in seen and len(final_results) < top_k:
151
+ seen.add(res["title"])
152
+ final_results.append(res)
153
+
154
+ return final_results