import numpy as np import faiss import zipfile import logging from pathlib import Path from sentence_transformers import SentenceTransformer import concurrent.futures import os import requests from functools import lru_cache from typing import List, Dict import pandas as pd from urllib.parse import quote # Configure logging logging.basicConfig( level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("OptimizedSearch") class OptimizedMetadataManager: def __init__(self): self._init_metadata() self._init_url_resolver() def _init_metadata(self): """Memory-mapped metadata loading. Preloads all metadata (title and summary) into memory from parquet files. """ self.metadata_dir = Path("unzipped_cache/metadata_shards") self.metadata = {} # Preload all metadata into memory for parquet_file in self.metadata_dir.glob("*.parquet"): df = pd.read_parquet(parquet_file, columns=["title", "summary"]) # Using the dataframe index as key (assumes unique indices across files) self.metadata.update(df.to_dict(orient="index")) self.total_docs = len(self.metadata) logger.info(f"Loaded {self.total_docs} metadata entries into memory") def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]: """Batch retrieval of metadata entries for a list of indices.""" return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices] def _init_url_resolver(self): """Initialize API session and adapter for faster URL resolution.""" self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( pool_connections=10, pool_maxsize=10, max_retries=3 ) self.session.mount("https://", adapter) @lru_cache(maxsize=10_000) def resolve_url(self, title: str) -> str: """Optimized URL resolution with caching and a fail-fast approach.""" try: # Try arXiv first arxiv_url = self._get_arxiv_url(title) if arxiv_url: return arxiv_url # Fallback to Semantic Scholar semantic_url = self._get_semantic_url(title) if semantic_url: return semantic_url except Exception as e: logger.warning(f"URL resolution failed: {str(e)}") # Default fallback to Google Scholar search return f"https://scholar.google.com/scholar?q={quote(title)}" def _get_arxiv_url(self, title: str) -> str: """Fast arXiv lookup with a short timeout.""" with self.session.get( "http://export.arxiv.org/api/query", params={"search_query": f'ti:"{title}"', "max_results": 1}, timeout=2 ) as response: if response.ok: return self._parse_arxiv_response(response.text) return "" def _parse_arxiv_response(self, xml: str) -> str: """Fast XML parsing using simple string operations.""" if "" not in xml: return "" start = xml.find("") + 4 end = xml.find("", start) return xml[start:end].replace("http:", "https:") if start > 3 else "" def _get_semantic_url(self, title: str) -> str: """Semantic Scholar lookup with a short timeout.""" with self.session.get( "https://api.semanticscholar.org/graph/v1/paper/search", params={"query": title[:200], "limit": 1}, timeout=2 ) as response: if response.ok: data = response.json() if data.get("data"): return data["data"][0].get("url", "") return "" class OptimizedSemanticSearch: def __init__(self): # Load the sentence transformer model self.model = SentenceTransformer('all-MiniLM-L6-v2') self._load_faiss_indexes() self.metadata_mgr = OptimizedMetadataManager() def _load_faiss_indexes(self): """Load the FAISS index with memory mapping for read-only access.""" # Here we assume the FAISS index has been combined into one file. self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors") def search(self, query: str, top_k: int = 5) -> List[Dict]: """Optimized search pipeline: - Encodes the query. - Performs FAISS search (fetching extra results for deduplication). - Retrieves metadata and processes results. """ # Batch encode query query_embedding = self.model.encode([query], convert_to_numpy=True) # FAISS search: we search for more than top_k to allow for deduplication. distances, indices = self.index.search(query_embedding, top_k * 2) # Batch metadata retrieval results = self.metadata_mgr.get_metadata_batch(indices[0]) # Process and return the final results return self._process_results(results, distances[0], top_k) def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]: """Parallel processing of search results: - Resolve source URLs in parallel. - Add similarity scores. - Deduplicate and sort the results. """ with concurrent.futures.ThreadPoolExecutor() as executor: # Parallel URL resolution for each result futures = { executor.submit( self.metadata_mgr.resolve_url, res["title"] ): idx for idx, res in enumerate(results) } # Update each result as URLs resolve for future in concurrent.futures.as_completed(futures): idx = futures[future] try: results[idx]["source"] = future.result() except Exception as e: results[idx]["source"] = "" # Add similarity scores based on distances for idx, dist in enumerate(distances[:len(results)]): results[idx]["similarity"] = 1 - (dist / 2) # Deduplicate by title and sort by similarity score (descending) seen = set() final_results = [] for res in sorted(results, key=lambda x: x["similarity"], reverse=True): if res["title"] not in seen and len(final_results) < top_k: seen.add(res["title"]) final_results.append(res) return final_results