Testys commited on
Commit
5cefd40
·
1 Parent(s): 801d9f2

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +44 -24
search_utils.py CHANGED
@@ -9,10 +9,14 @@ import os
9
  import requests
10
  from functools import lru_cache
11
  from typing import List, Dict
 
 
12
 
13
  # Configure logging
14
- logging.basicConfig(level=logging.WARNING,
15
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
16
  logger = logging.getLogger("OptimizedSearch")
17
 
18
  class OptimizedMetadataManager:
@@ -21,24 +25,27 @@ class OptimizedMetadataManager:
21
  self._init_url_resolver()
22
 
23
  def _init_metadata(self):
24
- """Memory-mapped metadata loading"""
 
 
25
  self.metadata_dir = Path("unzipped_cache/metadata_shards")
26
  self.metadata = {}
27
 
28
  # Preload all metadata into memory
29
  for parquet_file in self.metadata_dir.glob("*.parquet"):
30
  df = pd.read_parquet(parquet_file, columns=["title", "summary"])
 
31
  self.metadata.update(df.to_dict(orient="index"))
32
 
33
  self.total_docs = len(self.metadata)
34
  logger.info(f"Loaded {self.total_docs} metadata entries into memory")
35
 
36
  def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
37
- """Batch retrieval of metadata"""
38
  return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
39
 
40
  def _init_url_resolver(self):
41
- """Initialize API session and cache"""
42
  self.session = requests.Session()
43
  adapter = requests.adapters.HTTPAdapter(
44
  pool_connections=10,
@@ -49,23 +56,26 @@ class OptimizedMetadataManager:
49
 
50
  @lru_cache(maxsize=10_000)
51
  def resolve_url(self, title: str) -> str:
52
- """Optimized URL resolution with fail-fast"""
53
  try:
54
  # Try arXiv first
55
  arxiv_url = self._get_arxiv_url(title)
56
- if arxiv_url: return arxiv_url
 
57
 
58
  # Fallback to Semantic Scholar
59
  semantic_url = self._get_semantic_url(title)
60
- if semantic_url: return semantic_url
 
61
 
62
  except Exception as e:
63
  logger.warning(f"URL resolution failed: {str(e)}")
64
 
 
65
  return f"https://scholar.google.com/scholar?q={quote(title)}"
66
 
67
  def _get_arxiv_url(self, title: str) -> str:
68
- """Fast arXiv lookup with timeout"""
69
  with self.session.get(
70
  "http://export.arxiv.org/api/query",
71
  params={"search_query": f'ti:"{title}"', "max_results": 1},
@@ -76,14 +86,15 @@ class OptimizedMetadataManager:
76
  return ""
77
 
78
  def _parse_arxiv_response(self, xml: str) -> str:
79
- """Fast XML parsing using string operations"""
80
- if "<entry>" not in xml: return ""
 
81
  start = xml.find("<id>") + 4
82
  end = xml.find("</id>", start)
83
  return xml[start:end].replace("http:", "https:") if start > 3 else ""
84
 
85
  def _get_semantic_url(self, title: str) -> str:
86
- """Batch-friendly Semantic Scholar lookup"""
87
  with self.session.get(
88
  "https://api.semanticscholar.org/graph/v1/paper/search",
89
  params={"query": title[:200], "limit": 1},
@@ -97,53 +108,62 @@ class OptimizedMetadataManager:
97
 
98
  class OptimizedSemanticSearch:
99
  def __init__(self):
 
100
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
101
  self._load_faiss_indexes()
102
  self.metadata_mgr = OptimizedMetadataManager()
103
 
104
  def _load_faiss_indexes(self):
105
- """Load indexes with memory mapping"""
 
106
  self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
107
  logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
108
 
109
  def search(self, query: str, top_k: int = 5) -> List[Dict]:
110
- """Optimized search pipeline"""
 
 
 
 
111
  # Batch encode query
112
  query_embedding = self.model.encode([query], convert_to_numpy=True)
113
 
114
- # FAISS search
115
- distances, indices = self.index.search(query_embedding, top_k*2) # Search extra for dedup
116
 
117
  # Batch metadata retrieval
118
  results = self.metadata_mgr.get_metadata_batch(indices[0])
119
 
120
- # Process results
121
  return self._process_results(results, distances[0], top_k)
122
 
123
  def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
124
- """Parallel result processing"""
 
 
 
 
125
  with concurrent.futures.ThreadPoolExecutor() as executor:
126
- # Parallel URL resolution
127
  futures = {
128
  executor.submit(
129
  self.metadata_mgr.resolve_url,
130
  res["title"]
131
  ): idx for idx, res in enumerate(results)
132
  }
133
-
134
- # Update results as URLs resolve
135
  for future in concurrent.futures.as_completed(futures):
136
  idx = futures[future]
137
  try:
138
  results[idx]["source"] = future.result()
139
  except Exception as e:
140
  results[idx]["source"] = ""
141
-
142
- # Add similarity scores
143
  for idx, dist in enumerate(distances[:len(results)]):
144
  results[idx]["similarity"] = 1 - (dist / 2)
145
 
146
- # Deduplicate and sort
147
  seen = set()
148
  final_results = []
149
  for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
 
9
  import requests
10
  from functools import lru_cache
11
  from typing import List, Dict
12
+ import pandas as pd
13
+ from urllib.parse import quote
14
 
15
  # Configure logging
16
+ logging.basicConfig(
17
+ level=logging.WARNING,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
+ )
20
  logger = logging.getLogger("OptimizedSearch")
21
 
22
  class OptimizedMetadataManager:
 
25
  self._init_url_resolver()
26
 
27
  def _init_metadata(self):
28
+ """Memory-mapped metadata loading.
29
+ Preloads all metadata (title and summary) into memory from parquet files.
30
+ """
31
  self.metadata_dir = Path("unzipped_cache/metadata_shards")
32
  self.metadata = {}
33
 
34
  # Preload all metadata into memory
35
  for parquet_file in self.metadata_dir.glob("*.parquet"):
36
  df = pd.read_parquet(parquet_file, columns=["title", "summary"])
37
+ # Using the dataframe index as key (assumes unique indices across files)
38
  self.metadata.update(df.to_dict(orient="index"))
39
 
40
  self.total_docs = len(self.metadata)
41
  logger.info(f"Loaded {self.total_docs} metadata entries into memory")
42
 
43
  def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
44
+ """Batch retrieval of metadata entries for a list of indices."""
45
  return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
46
 
47
  def _init_url_resolver(self):
48
+ """Initialize API session and adapter for faster URL resolution."""
49
  self.session = requests.Session()
50
  adapter = requests.adapters.HTTPAdapter(
51
  pool_connections=10,
 
56
 
57
  @lru_cache(maxsize=10_000)
58
  def resolve_url(self, title: str) -> str:
59
+ """Optimized URL resolution with caching and a fail-fast approach."""
60
  try:
61
  # Try arXiv first
62
  arxiv_url = self._get_arxiv_url(title)
63
+ if arxiv_url:
64
+ return arxiv_url
65
 
66
  # Fallback to Semantic Scholar
67
  semantic_url = self._get_semantic_url(title)
68
+ if semantic_url:
69
+ return semantic_url
70
 
71
  except Exception as e:
72
  logger.warning(f"URL resolution failed: {str(e)}")
73
 
74
+ # Default fallback to Google Scholar search
75
  return f"https://scholar.google.com/scholar?q={quote(title)}"
76
 
77
  def _get_arxiv_url(self, title: str) -> str:
78
+ """Fast arXiv lookup with a short timeout."""
79
  with self.session.get(
80
  "http://export.arxiv.org/api/query",
81
  params={"search_query": f'ti:"{title}"', "max_results": 1},
 
86
  return ""
87
 
88
  def _parse_arxiv_response(self, xml: str) -> str:
89
+ """Fast XML parsing using simple string operations."""
90
+ if "<entry>" not in xml:
91
+ return ""
92
  start = xml.find("<id>") + 4
93
  end = xml.find("</id>", start)
94
  return xml[start:end].replace("http:", "https:") if start > 3 else ""
95
 
96
  def _get_semantic_url(self, title: str) -> str:
97
+ """Semantic Scholar lookup with a short timeout."""
98
  with self.session.get(
99
  "https://api.semanticscholar.org/graph/v1/paper/search",
100
  params={"query": title[:200], "limit": 1},
 
108
 
109
  class OptimizedSemanticSearch:
110
  def __init__(self):
111
+ # Load the sentence transformer model
112
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
113
  self._load_faiss_indexes()
114
  self.metadata_mgr = OptimizedMetadataManager()
115
 
116
  def _load_faiss_indexes(self):
117
+ """Load the FAISS index with memory mapping for read-only access."""
118
+ # Here we assume the FAISS index has been combined into one file.
119
  self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
120
  logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
121
 
122
  def search(self, query: str, top_k: int = 5) -> List[Dict]:
123
+ """Optimized search pipeline:
124
+ - Encodes the query.
125
+ - Performs FAISS search (fetching extra results for deduplication).
126
+ - Retrieves metadata and processes results.
127
+ """
128
  # Batch encode query
129
  query_embedding = self.model.encode([query], convert_to_numpy=True)
130
 
131
+ # FAISS search: we search for more than top_k to allow for deduplication.
132
+ distances, indices = self.index.search(query_embedding, top_k * 2)
133
 
134
  # Batch metadata retrieval
135
  results = self.metadata_mgr.get_metadata_batch(indices[0])
136
 
137
+ # Process and return the final results
138
  return self._process_results(results, distances[0], top_k)
139
 
140
  def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
141
+ """Parallel processing of search results:
142
+ - Resolve source URLs in parallel.
143
+ - Add similarity scores.
144
+ - Deduplicate and sort the results.
145
+ """
146
  with concurrent.futures.ThreadPoolExecutor() as executor:
147
+ # Parallel URL resolution for each result
148
  futures = {
149
  executor.submit(
150
  self.metadata_mgr.resolve_url,
151
  res["title"]
152
  ): idx for idx, res in enumerate(results)
153
  }
154
+ # Update each result as URLs resolve
 
155
  for future in concurrent.futures.as_completed(futures):
156
  idx = futures[future]
157
  try:
158
  results[idx]["source"] = future.result()
159
  except Exception as e:
160
  results[idx]["source"] = ""
161
+
162
+ # Add similarity scores based on distances
163
  for idx, dist in enumerate(distances[:len(results)]):
164
  results[idx]["similarity"] = 1 - (dist / 2)
165
 
166
+ # Deduplicate by title and sort by similarity score (descending)
167
  seen = set()
168
  final_results = []
169
  for res in sorted(results, key=lambda x: x["similarity"], reverse=True):