Spaces:
Running
Running
Update search_utils.py
Browse files- search_utils.py +24 -44
search_utils.py
CHANGED
@@ -9,14 +9,10 @@ import os
|
|
9 |
import requests
|
10 |
from functools import lru_cache
|
11 |
from typing import List, Dict
|
12 |
-
import pandas as pd
|
13 |
-
from urllib.parse import quote
|
14 |
|
15 |
# Configure logging
|
16 |
-
logging.basicConfig(
|
17 |
-
|
18 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
19 |
-
)
|
20 |
logger = logging.getLogger("OptimizedSearch")
|
21 |
|
22 |
class OptimizedMetadataManager:
|
@@ -25,27 +21,24 @@ class OptimizedMetadataManager:
|
|
25 |
self._init_url_resolver()
|
26 |
|
27 |
def _init_metadata(self):
|
28 |
-
"""Memory-mapped metadata loading
|
29 |
-
Preloads all metadata (title and summary) into memory from parquet files.
|
30 |
-
"""
|
31 |
self.metadata_dir = Path("unzipped_cache/metadata_shards")
|
32 |
self.metadata = {}
|
33 |
|
34 |
# Preload all metadata into memory
|
35 |
for parquet_file in self.metadata_dir.glob("*.parquet"):
|
36 |
df = pd.read_parquet(parquet_file, columns=["title", "summary"])
|
37 |
-
# Using the dataframe index as key (assumes unique indices across files)
|
38 |
self.metadata.update(df.to_dict(orient="index"))
|
39 |
|
40 |
self.total_docs = len(self.metadata)
|
41 |
logger.info(f"Loaded {self.total_docs} metadata entries into memory")
|
42 |
|
43 |
def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
|
44 |
-
"""Batch retrieval of metadata
|
45 |
return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
|
46 |
|
47 |
def _init_url_resolver(self):
|
48 |
-
"""Initialize API session and
|
49 |
self.session = requests.Session()
|
50 |
adapter = requests.adapters.HTTPAdapter(
|
51 |
pool_connections=10,
|
@@ -56,26 +49,23 @@ class OptimizedMetadataManager:
|
|
56 |
|
57 |
@lru_cache(maxsize=10_000)
|
58 |
def resolve_url(self, title: str) -> str:
|
59 |
-
"""Optimized URL resolution with
|
60 |
try:
|
61 |
# Try arXiv first
|
62 |
arxiv_url = self._get_arxiv_url(title)
|
63 |
-
if arxiv_url:
|
64 |
-
return arxiv_url
|
65 |
|
66 |
# Fallback to Semantic Scholar
|
67 |
semantic_url = self._get_semantic_url(title)
|
68 |
-
if semantic_url:
|
69 |
-
return semantic_url
|
70 |
|
71 |
except Exception as e:
|
72 |
logger.warning(f"URL resolution failed: {str(e)}")
|
73 |
|
74 |
-
# Default fallback to Google Scholar search
|
75 |
return f"https://scholar.google.com/scholar?q={quote(title)}"
|
76 |
|
77 |
def _get_arxiv_url(self, title: str) -> str:
|
78 |
-
"""Fast arXiv lookup with
|
79 |
with self.session.get(
|
80 |
"http://export.arxiv.org/api/query",
|
81 |
params={"search_query": f'ti:"{title}"', "max_results": 1},
|
@@ -86,15 +76,14 @@ class OptimizedMetadataManager:
|
|
86 |
return ""
|
87 |
|
88 |
def _parse_arxiv_response(self, xml: str) -> str:
|
89 |
-
"""Fast XML parsing using
|
90 |
-
if "<entry>" not in xml:
|
91 |
-
return ""
|
92 |
start = xml.find("<id>") + 4
|
93 |
end = xml.find("</id>", start)
|
94 |
return xml[start:end].replace("http:", "https:") if start > 3 else ""
|
95 |
|
96 |
def _get_semantic_url(self, title: str) -> str:
|
97 |
-
"""Semantic Scholar lookup
|
98 |
with self.session.get(
|
99 |
"https://api.semanticscholar.org/graph/v1/paper/search",
|
100 |
params={"query": title[:200], "limit": 1},
|
@@ -108,62 +97,53 @@ class OptimizedMetadataManager:
|
|
108 |
|
109 |
class OptimizedSemanticSearch:
|
110 |
def __init__(self):
|
111 |
-
# Load the sentence transformer model
|
112 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
113 |
self._load_faiss_indexes()
|
114 |
self.metadata_mgr = OptimizedMetadataManager()
|
115 |
|
116 |
def _load_faiss_indexes(self):
|
117 |
-
"""Load
|
118 |
-
# Here we assume the FAISS index has been combined into one file.
|
119 |
self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
|
120 |
logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
|
121 |
|
122 |
def search(self, query: str, top_k: int = 5) -> List[Dict]:
|
123 |
-
"""Optimized search pipeline
|
124 |
-
- Encodes the query.
|
125 |
-
- Performs FAISS search (fetching extra results for deduplication).
|
126 |
-
- Retrieves metadata and processes results.
|
127 |
-
"""
|
128 |
# Batch encode query
|
129 |
query_embedding = self.model.encode([query], convert_to_numpy=True)
|
130 |
|
131 |
-
# FAISS search
|
132 |
-
distances, indices = self.index.search(query_embedding, top_k
|
133 |
|
134 |
# Batch metadata retrieval
|
135 |
results = self.metadata_mgr.get_metadata_batch(indices[0])
|
136 |
|
137 |
-
# Process
|
138 |
return self._process_results(results, distances[0], top_k)
|
139 |
|
140 |
def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
|
141 |
-
"""Parallel processing
|
142 |
-
- Resolve source URLs in parallel.
|
143 |
-
- Add similarity scores.
|
144 |
-
- Deduplicate and sort the results.
|
145 |
-
"""
|
146 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
147 |
-
# Parallel URL resolution
|
148 |
futures = {
|
149 |
executor.submit(
|
150 |
self.metadata_mgr.resolve_url,
|
151 |
res["title"]
|
152 |
): idx for idx, res in enumerate(results)
|
153 |
}
|
154 |
-
|
|
|
155 |
for future in concurrent.futures.as_completed(futures):
|
156 |
idx = futures[future]
|
157 |
try:
|
158 |
results[idx]["source"] = future.result()
|
159 |
except Exception as e:
|
160 |
results[idx]["source"] = ""
|
161 |
-
|
162 |
-
# Add similarity scores
|
163 |
for idx, dist in enumerate(distances[:len(results)]):
|
164 |
results[idx]["similarity"] = 1 - (dist / 2)
|
165 |
|
166 |
-
# Deduplicate
|
167 |
seen = set()
|
168 |
final_results = []
|
169 |
for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
|
|
|
9 |
import requests
|
10 |
from functools import lru_cache
|
11 |
from typing import List, Dict
|
|
|
|
|
12 |
|
13 |
# Configure logging
|
14 |
+
logging.basicConfig(level=logging.WARNING,
|
15 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
|
16 |
logger = logging.getLogger("OptimizedSearch")
|
17 |
|
18 |
class OptimizedMetadataManager:
|
|
|
21 |
self._init_url_resolver()
|
22 |
|
23 |
def _init_metadata(self):
|
24 |
+
"""Memory-mapped metadata loading"""
|
|
|
|
|
25 |
self.metadata_dir = Path("unzipped_cache/metadata_shards")
|
26 |
self.metadata = {}
|
27 |
|
28 |
# Preload all metadata into memory
|
29 |
for parquet_file in self.metadata_dir.glob("*.parquet"):
|
30 |
df = pd.read_parquet(parquet_file, columns=["title", "summary"])
|
|
|
31 |
self.metadata.update(df.to_dict(orient="index"))
|
32 |
|
33 |
self.total_docs = len(self.metadata)
|
34 |
logger.info(f"Loaded {self.total_docs} metadata entries into memory")
|
35 |
|
36 |
def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
|
37 |
+
"""Batch retrieval of metadata"""
|
38 |
return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
|
39 |
|
40 |
def _init_url_resolver(self):
|
41 |
+
"""Initialize API session and cache"""
|
42 |
self.session = requests.Session()
|
43 |
adapter = requests.adapters.HTTPAdapter(
|
44 |
pool_connections=10,
|
|
|
49 |
|
50 |
@lru_cache(maxsize=10_000)
|
51 |
def resolve_url(self, title: str) -> str:
|
52 |
+
"""Optimized URL resolution with fail-fast"""
|
53 |
try:
|
54 |
# Try arXiv first
|
55 |
arxiv_url = self._get_arxiv_url(title)
|
56 |
+
if arxiv_url: return arxiv_url
|
|
|
57 |
|
58 |
# Fallback to Semantic Scholar
|
59 |
semantic_url = self._get_semantic_url(title)
|
60 |
+
if semantic_url: return semantic_url
|
|
|
61 |
|
62 |
except Exception as e:
|
63 |
logger.warning(f"URL resolution failed: {str(e)}")
|
64 |
|
|
|
65 |
return f"https://scholar.google.com/scholar?q={quote(title)}"
|
66 |
|
67 |
def _get_arxiv_url(self, title: str) -> str:
|
68 |
+
"""Fast arXiv lookup with timeout"""
|
69 |
with self.session.get(
|
70 |
"http://export.arxiv.org/api/query",
|
71 |
params={"search_query": f'ti:"{title}"', "max_results": 1},
|
|
|
76 |
return ""
|
77 |
|
78 |
def _parse_arxiv_response(self, xml: str) -> str:
|
79 |
+
"""Fast XML parsing using string operations"""
|
80 |
+
if "<entry>" not in xml: return ""
|
|
|
81 |
start = xml.find("<id>") + 4
|
82 |
end = xml.find("</id>", start)
|
83 |
return xml[start:end].replace("http:", "https:") if start > 3 else ""
|
84 |
|
85 |
def _get_semantic_url(self, title: str) -> str:
|
86 |
+
"""Batch-friendly Semantic Scholar lookup"""
|
87 |
with self.session.get(
|
88 |
"https://api.semanticscholar.org/graph/v1/paper/search",
|
89 |
params={"query": title[:200], "limit": 1},
|
|
|
97 |
|
98 |
class OptimizedSemanticSearch:
|
99 |
def __init__(self):
|
|
|
100 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
101 |
self._load_faiss_indexes()
|
102 |
self.metadata_mgr = OptimizedMetadataManager()
|
103 |
|
104 |
def _load_faiss_indexes(self):
|
105 |
+
"""Load indexes with memory mapping"""
|
|
|
106 |
self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
|
107 |
logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
|
108 |
|
109 |
def search(self, query: str, top_k: int = 5) -> List[Dict]:
|
110 |
+
"""Optimized search pipeline"""
|
|
|
|
|
|
|
|
|
111 |
# Batch encode query
|
112 |
query_embedding = self.model.encode([query], convert_to_numpy=True)
|
113 |
|
114 |
+
# FAISS search
|
115 |
+
distances, indices = self.index.search(query_embedding, top_k*2) # Search extra for dedup
|
116 |
|
117 |
# Batch metadata retrieval
|
118 |
results = self.metadata_mgr.get_metadata_batch(indices[0])
|
119 |
|
120 |
+
# Process results
|
121 |
return self._process_results(results, distances[0], top_k)
|
122 |
|
123 |
def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
|
124 |
+
"""Parallel result processing"""
|
|
|
|
|
|
|
|
|
125 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
126 |
+
# Parallel URL resolution
|
127 |
futures = {
|
128 |
executor.submit(
|
129 |
self.metadata_mgr.resolve_url,
|
130 |
res["title"]
|
131 |
): idx for idx, res in enumerate(results)
|
132 |
}
|
133 |
+
|
134 |
+
# Update results as URLs resolve
|
135 |
for future in concurrent.futures.as_completed(futures):
|
136 |
idx = futures[future]
|
137 |
try:
|
138 |
results[idx]["source"] = future.result()
|
139 |
except Exception as e:
|
140 |
results[idx]["source"] = ""
|
141 |
+
|
142 |
+
# Add similarity scores
|
143 |
for idx, dist in enumerate(distances[:len(results)]):
|
144 |
results[idx]["similarity"] = 1 - (dist / 2)
|
145 |
|
146 |
+
# Deduplicate and sort
|
147 |
seen = set()
|
148 |
final_results = []
|
149 |
for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
|