semantic-search / search_utils.py
Testys's picture
Update search_utils.py
7f04a94
raw
history blame
5.8 kB
import numpy as np
import faiss
import zipfile
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer
import concurrent.futures
import os
import requests
from functools import lru_cache
from typing import List, Dict
# Configure logging
logging.basicConfig(level=logging.WARNING,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("OptimizedSearch")
class OptimizedMetadataManager:
def __init__(self):
self._init_metadata()
self._init_url_resolver()
def _init_metadata(self):
"""Memory-mapped metadata loading"""
self.metadata_dir = Path("unzipped_cache/metadata_shards")
self.metadata = {}
# Preload all metadata into memory
for parquet_file in self.metadata_dir.glob("*.parquet"):
df = pd.read_parquet(parquet_file, columns=["title", "summary"])
self.metadata.update(df.to_dict(orient="index"))
self.total_docs = len(self.metadata)
logger.info(f"Loaded {self.total_docs} metadata entries into memory")
def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
"""Batch retrieval of metadata"""
return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
def _init_url_resolver(self):
"""Initialize API session and cache"""
self.session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
pool_connections=10,
pool_maxsize=10,
max_retries=3
)
self.session.mount("https://", adapter)
@lru_cache(maxsize=10_000)
def resolve_url(self, title: str) -> str:
"""Optimized URL resolution with fail-fast"""
try:
# Try arXiv first
arxiv_url = self._get_arxiv_url(title)
if arxiv_url: return arxiv_url
# Fallback to Semantic Scholar
semantic_url = self._get_semantic_url(title)
if semantic_url: return semantic_url
except Exception as e:
logger.warning(f"URL resolution failed: {str(e)}")
return f"https://scholar.google.com/scholar?q={quote(title)}"
def _get_arxiv_url(self, title: str) -> str:
"""Fast arXiv lookup with timeout"""
with self.session.get(
"http://export.arxiv.org/api/query",
params={"search_query": f'ti:"{title}"', "max_results": 1},
timeout=2
) as response:
if response.ok:
return self._parse_arxiv_response(response.text)
return ""
def _parse_arxiv_response(self, xml: str) -> str:
"""Fast XML parsing using string operations"""
if "<entry>" not in xml: return ""
start = xml.find("<id>") + 4
end = xml.find("</id>", start)
return xml[start:end].replace("http:", "https:") if start > 3 else ""
def _get_semantic_url(self, title: str) -> str:
"""Batch-friendly Semantic Scholar lookup"""
with self.session.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
params={"query": title[:200], "limit": 1},
timeout=2
) as response:
if response.ok:
data = response.json()
if data.get("data"):
return data["data"][0].get("url", "")
return ""
class OptimizedSemanticSearch:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self._load_faiss_indexes()
self.metadata_mgr = OptimizedMetadataManager()
def _load_faiss_indexes(self):
"""Load indexes with memory mapping"""
self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""Optimized search pipeline"""
# Batch encode query
query_embedding = self.model.encode([query], convert_to_numpy=True)
# FAISS search
distances, indices = self.index.search(query_embedding, top_k*2) # Search extra for dedup
# Batch metadata retrieval
results = self.metadata_mgr.get_metadata_batch(indices[0])
# Process results
return self._process_results(results, distances[0], top_k)
def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
"""Parallel result processing"""
with concurrent.futures.ThreadPoolExecutor() as executor:
# Parallel URL resolution
futures = {
executor.submit(
self.metadata_mgr.resolve_url,
res["title"]
): idx for idx, res in enumerate(results)
}
# Update results as URLs resolve
for future in concurrent.futures.as_completed(futures):
idx = futures[future]
try:
results[idx]["source"] = future.result()
except Exception as e:
results[idx]["source"] = ""
# Add similarity scores
for idx, dist in enumerate(distances[:len(results)]):
results[idx]["similarity"] = 1 - (dist / 2)
# Deduplicate and sort
seen = set()
final_results = []
for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
if res["title"] not in seen and len(final_results) < top_k:
seen.add(res["title"])
final_results.append(res)
return final_results