Testys commited on
Commit
7f04a94
·
1 Parent(s): c056209

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +24 -44
search_utils.py CHANGED
@@ -9,14 +9,10 @@ import os
9
  import requests
10
  from functools import lru_cache
11
  from typing import List, Dict
12
- import pandas as pd
13
- from urllib.parse import quote
14
 
15
  # Configure logging
16
- logging.basicConfig(
17
- level=logging.WARNING,
18
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
- )
20
  logger = logging.getLogger("OptimizedSearch")
21
 
22
  class OptimizedMetadataManager:
@@ -25,27 +21,24 @@ class OptimizedMetadataManager:
25
  self._init_url_resolver()
26
 
27
  def _init_metadata(self):
28
- """Memory-mapped metadata loading.
29
- Preloads all metadata (title and summary) into memory from parquet files.
30
- """
31
  self.metadata_dir = Path("unzipped_cache/metadata_shards")
32
  self.metadata = {}
33
 
34
  # Preload all metadata into memory
35
  for parquet_file in self.metadata_dir.glob("*.parquet"):
36
  df = pd.read_parquet(parquet_file, columns=["title", "summary"])
37
- # Using the dataframe index as key (assumes unique indices across files)
38
  self.metadata.update(df.to_dict(orient="index"))
39
 
40
  self.total_docs = len(self.metadata)
41
  logger.info(f"Loaded {self.total_docs} metadata entries into memory")
42
 
43
  def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
44
- """Batch retrieval of metadata entries for a list of indices."""
45
  return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
46
 
47
  def _init_url_resolver(self):
48
- """Initialize API session and adapter for faster URL resolution."""
49
  self.session = requests.Session()
50
  adapter = requests.adapters.HTTPAdapter(
51
  pool_connections=10,
@@ -56,26 +49,23 @@ class OptimizedMetadataManager:
56
 
57
  @lru_cache(maxsize=10_000)
58
  def resolve_url(self, title: str) -> str:
59
- """Optimized URL resolution with caching and a fail-fast approach."""
60
  try:
61
  # Try arXiv first
62
  arxiv_url = self._get_arxiv_url(title)
63
- if arxiv_url:
64
- return arxiv_url
65
 
66
  # Fallback to Semantic Scholar
67
  semantic_url = self._get_semantic_url(title)
68
- if semantic_url:
69
- return semantic_url
70
 
71
  except Exception as e:
72
  logger.warning(f"URL resolution failed: {str(e)}")
73
 
74
- # Default fallback to Google Scholar search
75
  return f"https://scholar.google.com/scholar?q={quote(title)}"
76
 
77
  def _get_arxiv_url(self, title: str) -> str:
78
- """Fast arXiv lookup with a short timeout."""
79
  with self.session.get(
80
  "http://export.arxiv.org/api/query",
81
  params={"search_query": f'ti:"{title}"', "max_results": 1},
@@ -86,15 +76,14 @@ class OptimizedMetadataManager:
86
  return ""
87
 
88
  def _parse_arxiv_response(self, xml: str) -> str:
89
- """Fast XML parsing using simple string operations."""
90
- if "<entry>" not in xml:
91
- return ""
92
  start = xml.find("<id>") + 4
93
  end = xml.find("</id>", start)
94
  return xml[start:end].replace("http:", "https:") if start > 3 else ""
95
 
96
  def _get_semantic_url(self, title: str) -> str:
97
- """Semantic Scholar lookup with a short timeout."""
98
  with self.session.get(
99
  "https://api.semanticscholar.org/graph/v1/paper/search",
100
  params={"query": title[:200], "limit": 1},
@@ -108,62 +97,53 @@ class OptimizedMetadataManager:
108
 
109
  class OptimizedSemanticSearch:
110
  def __init__(self):
111
- # Load the sentence transformer model
112
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
113
  self._load_faiss_indexes()
114
  self.metadata_mgr = OptimizedMetadataManager()
115
 
116
  def _load_faiss_indexes(self):
117
- """Load the FAISS index with memory mapping for read-only access."""
118
- # Here we assume the FAISS index has been combined into one file.
119
  self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
120
  logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
121
 
122
  def search(self, query: str, top_k: int = 5) -> List[Dict]:
123
- """Optimized search pipeline:
124
- - Encodes the query.
125
- - Performs FAISS search (fetching extra results for deduplication).
126
- - Retrieves metadata and processes results.
127
- """
128
  # Batch encode query
129
  query_embedding = self.model.encode([query], convert_to_numpy=True)
130
 
131
- # FAISS search: we search for more than top_k to allow for deduplication.
132
- distances, indices = self.index.search(query_embedding, top_k * 2)
133
 
134
  # Batch metadata retrieval
135
  results = self.metadata_mgr.get_metadata_batch(indices[0])
136
 
137
- # Process and return the final results
138
  return self._process_results(results, distances[0], top_k)
139
 
140
  def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
141
- """Parallel processing of search results:
142
- - Resolve source URLs in parallel.
143
- - Add similarity scores.
144
- - Deduplicate and sort the results.
145
- """
146
  with concurrent.futures.ThreadPoolExecutor() as executor:
147
- # Parallel URL resolution for each result
148
  futures = {
149
  executor.submit(
150
  self.metadata_mgr.resolve_url,
151
  res["title"]
152
  ): idx for idx, res in enumerate(results)
153
  }
154
- # Update each result as URLs resolve
 
155
  for future in concurrent.futures.as_completed(futures):
156
  idx = futures[future]
157
  try:
158
  results[idx]["source"] = future.result()
159
  except Exception as e:
160
  results[idx]["source"] = ""
161
-
162
- # Add similarity scores based on distances
163
  for idx, dist in enumerate(distances[:len(results)]):
164
  results[idx]["similarity"] = 1 - (dist / 2)
165
 
166
- # Deduplicate by title and sort by similarity score (descending)
167
  seen = set()
168
  final_results = []
169
  for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
 
9
  import requests
10
  from functools import lru_cache
11
  from typing import List, Dict
 
 
12
 
13
  # Configure logging
14
+ logging.basicConfig(level=logging.WARNING,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
16
  logger = logging.getLogger("OptimizedSearch")
17
 
18
  class OptimizedMetadataManager:
 
21
  self._init_url_resolver()
22
 
23
  def _init_metadata(self):
24
+ """Memory-mapped metadata loading"""
 
 
25
  self.metadata_dir = Path("unzipped_cache/metadata_shards")
26
  self.metadata = {}
27
 
28
  # Preload all metadata into memory
29
  for parquet_file in self.metadata_dir.glob("*.parquet"):
30
  df = pd.read_parquet(parquet_file, columns=["title", "summary"])
 
31
  self.metadata.update(df.to_dict(orient="index"))
32
 
33
  self.total_docs = len(self.metadata)
34
  logger.info(f"Loaded {self.total_docs} metadata entries into memory")
35
 
36
  def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
37
+ """Batch retrieval of metadata"""
38
  return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
39
 
40
  def _init_url_resolver(self):
41
+ """Initialize API session and cache"""
42
  self.session = requests.Session()
43
  adapter = requests.adapters.HTTPAdapter(
44
  pool_connections=10,
 
49
 
50
  @lru_cache(maxsize=10_000)
51
  def resolve_url(self, title: str) -> str:
52
+ """Optimized URL resolution with fail-fast"""
53
  try:
54
  # Try arXiv first
55
  arxiv_url = self._get_arxiv_url(title)
56
+ if arxiv_url: return arxiv_url
 
57
 
58
  # Fallback to Semantic Scholar
59
  semantic_url = self._get_semantic_url(title)
60
+ if semantic_url: return semantic_url
 
61
 
62
  except Exception as e:
63
  logger.warning(f"URL resolution failed: {str(e)}")
64
 
 
65
  return f"https://scholar.google.com/scholar?q={quote(title)}"
66
 
67
  def _get_arxiv_url(self, title: str) -> str:
68
+ """Fast arXiv lookup with timeout"""
69
  with self.session.get(
70
  "http://export.arxiv.org/api/query",
71
  params={"search_query": f'ti:"{title}"', "max_results": 1},
 
76
  return ""
77
 
78
  def _parse_arxiv_response(self, xml: str) -> str:
79
+ """Fast XML parsing using string operations"""
80
+ if "<entry>" not in xml: return ""
 
81
  start = xml.find("<id>") + 4
82
  end = xml.find("</id>", start)
83
  return xml[start:end].replace("http:", "https:") if start > 3 else ""
84
 
85
  def _get_semantic_url(self, title: str) -> str:
86
+ """Batch-friendly Semantic Scholar lookup"""
87
  with self.session.get(
88
  "https://api.semanticscholar.org/graph/v1/paper/search",
89
  params={"query": title[:200], "limit": 1},
 
97
 
98
  class OptimizedSemanticSearch:
99
  def __init__(self):
 
100
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
101
  self._load_faiss_indexes()
102
  self.metadata_mgr = OptimizedMetadataManager()
103
 
104
  def _load_faiss_indexes(self):
105
+ """Load indexes with memory mapping"""
 
106
  self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
107
  logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
108
 
109
  def search(self, query: str, top_k: int = 5) -> List[Dict]:
110
+ """Optimized search pipeline"""
 
 
 
 
111
  # Batch encode query
112
  query_embedding = self.model.encode([query], convert_to_numpy=True)
113
 
114
+ # FAISS search
115
+ distances, indices = self.index.search(query_embedding, top_k*2) # Search extra for dedup
116
 
117
  # Batch metadata retrieval
118
  results = self.metadata_mgr.get_metadata_batch(indices[0])
119
 
120
+ # Process results
121
  return self._process_results(results, distances[0], top_k)
122
 
123
  def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
124
+ """Parallel result processing"""
 
 
 
 
125
  with concurrent.futures.ThreadPoolExecutor() as executor:
126
+ # Parallel URL resolution
127
  futures = {
128
  executor.submit(
129
  self.metadata_mgr.resolve_url,
130
  res["title"]
131
  ): idx for idx, res in enumerate(results)
132
  }
133
+
134
+ # Update results as URLs resolve
135
  for future in concurrent.futures.as_completed(futures):
136
  idx = futures[future]
137
  try:
138
  results[idx]["source"] = future.result()
139
  except Exception as e:
140
  results[idx]["source"] = ""
141
+
142
+ # Add similarity scores
143
  for idx, dist in enumerate(distances[:len(results)]):
144
  results[idx]["similarity"] = 1 - (dist / 2)
145
 
146
+ # Deduplicate and sort
147
  seen = set()
148
  final_results = []
149
  for res in sorted(results, key=lambda x: x["similarity"], reverse=True):