semantic-search / search_utils.py
Testys's picture
Update search_utils.py
5cefd40
raw
history blame
6.83 kB
import numpy as np
import faiss
import zipfile
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer
import concurrent.futures
import os
import requests
from functools import lru_cache
from typing import List, Dict
import pandas as pd
from urllib.parse import quote
# Configure logging
logging.basicConfig(
level=logging.WARNING,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("OptimizedSearch")
class OptimizedMetadataManager:
def __init__(self):
self._init_metadata()
self._init_url_resolver()
def _init_metadata(self):
"""Memory-mapped metadata loading.
Preloads all metadata (title and summary) into memory from parquet files.
"""
self.metadata_dir = Path("unzipped_cache/metadata_shards")
self.metadata = {}
# Preload all metadata into memory
for parquet_file in self.metadata_dir.glob("*.parquet"):
df = pd.read_parquet(parquet_file, columns=["title", "summary"])
# Using the dataframe index as key (assumes unique indices across files)
self.metadata.update(df.to_dict(orient="index"))
self.total_docs = len(self.metadata)
logger.info(f"Loaded {self.total_docs} metadata entries into memory")
def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
"""Batch retrieval of metadata entries for a list of indices."""
return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
def _init_url_resolver(self):
"""Initialize API session and adapter for faster URL resolution."""
self.session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
pool_connections=10,
pool_maxsize=10,
max_retries=3
)
self.session.mount("https://", adapter)
@lru_cache(maxsize=10_000)
def resolve_url(self, title: str) -> str:
"""Optimized URL resolution with caching and a fail-fast approach."""
try:
# Try arXiv first
arxiv_url = self._get_arxiv_url(title)
if arxiv_url:
return arxiv_url
# Fallback to Semantic Scholar
semantic_url = self._get_semantic_url(title)
if semantic_url:
return semantic_url
except Exception as e:
logger.warning(f"URL resolution failed: {str(e)}")
# Default fallback to Google Scholar search
return f"https://scholar.google.com/scholar?q={quote(title)}"
def _get_arxiv_url(self, title: str) -> str:
"""Fast arXiv lookup with a short timeout."""
with self.session.get(
"http://export.arxiv.org/api/query",
params={"search_query": f'ti:"{title}"', "max_results": 1},
timeout=2
) as response:
if response.ok:
return self._parse_arxiv_response(response.text)
return ""
def _parse_arxiv_response(self, xml: str) -> str:
"""Fast XML parsing using simple string operations."""
if "<entry>" not in xml:
return ""
start = xml.find("<id>") + 4
end = xml.find("</id>", start)
return xml[start:end].replace("http:", "https:") if start > 3 else ""
def _get_semantic_url(self, title: str) -> str:
"""Semantic Scholar lookup with a short timeout."""
with self.session.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
params={"query": title[:200], "limit": 1},
timeout=2
) as response:
if response.ok:
data = response.json()
if data.get("data"):
return data["data"][0].get("url", "")
return ""
class OptimizedSemanticSearch:
def __init__(self):
# Load the sentence transformer model
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self._load_faiss_indexes()
self.metadata_mgr = OptimizedMetadataManager()
def _load_faiss_indexes(self):
"""Load the FAISS index with memory mapping for read-only access."""
# Here we assume the FAISS index has been combined into one file.
self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""Optimized search pipeline:
- Encodes the query.
- Performs FAISS search (fetching extra results for deduplication).
- Retrieves metadata and processes results.
"""
# Batch encode query
query_embedding = self.model.encode([query], convert_to_numpy=True)
# FAISS search: we search for more than top_k to allow for deduplication.
distances, indices = self.index.search(query_embedding, top_k * 2)
# Batch metadata retrieval
results = self.metadata_mgr.get_metadata_batch(indices[0])
# Process and return the final results
return self._process_results(results, distances[0], top_k)
def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
"""Parallel processing of search results:
- Resolve source URLs in parallel.
- Add similarity scores.
- Deduplicate and sort the results.
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
# Parallel URL resolution for each result
futures = {
executor.submit(
self.metadata_mgr.resolve_url,
res["title"]
): idx for idx, res in enumerate(results)
}
# Update each result as URLs resolve
for future in concurrent.futures.as_completed(futures):
idx = futures[future]
try:
results[idx]["source"] = future.result()
except Exception as e:
results[idx]["source"] = ""
# Add similarity scores based on distances
for idx, dist in enumerate(distances[:len(results)]):
results[idx]["similarity"] = 1 - (dist / 2)
# Deduplicate by title and sort by similarity score (descending)
seen = set()
final_results = []
for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
if res["title"] not in seen and len(final_results) < top_k:
seen.add(res["title"])
final_results.append(res)
return final_results