Spaces:
Running
Running
import numpy as np | |
import faiss | |
import zipfile | |
import logging | |
from pathlib import Path | |
from sentence_transformers import SentenceTransformer | |
import concurrent.futures | |
import os | |
import requests | |
from functools import lru_cache | |
from typing import List, Dict | |
import pandas as pd | |
from urllib.parse import quote | |
# Configure logging | |
logging.basicConfig( | |
level=logging.WARNING, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger("OptimizedSearch") | |
class OptimizedMetadataManager: | |
def __init__(self): | |
self._init_metadata() | |
self._init_url_resolver() | |
def _init_metadata(self): | |
"""Memory-mapped metadata loading. | |
Preloads all metadata (title and summary) into memory from parquet files. | |
""" | |
self.metadata_dir = Path("unzipped_cache/metadata_shards") | |
self.metadata = {} | |
# Preload all metadata into memory | |
for parquet_file in self.metadata_dir.glob("*.parquet"): | |
df = pd.read_parquet(parquet_file, columns=["title", "summary"]) | |
# Using the dataframe index as key (assumes unique indices across files) | |
self.metadata.update(df.to_dict(orient="index")) | |
self.total_docs = len(self.metadata) | |
logger.info(f"Loaded {self.total_docs} metadata entries into memory") | |
def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]: | |
"""Batch retrieval of metadata entries for a list of indices.""" | |
return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices] | |
def _init_url_resolver(self): | |
"""Initialize API session and adapter for faster URL resolution.""" | |
self.session = requests.Session() | |
adapter = requests.adapters.HTTPAdapter( | |
pool_connections=10, | |
pool_maxsize=10, | |
max_retries=3 | |
) | |
self.session.mount("https://", adapter) | |
def resolve_url(self, title: str) -> str: | |
"""Optimized URL resolution with caching and a fail-fast approach.""" | |
try: | |
# Try arXiv first | |
arxiv_url = self._get_arxiv_url(title) | |
if arxiv_url: | |
return arxiv_url | |
# Fallback to Semantic Scholar | |
semantic_url = self._get_semantic_url(title) | |
if semantic_url: | |
return semantic_url | |
except Exception as e: | |
logger.warning(f"URL resolution failed: {str(e)}") | |
# Default fallback to Google Scholar search | |
return f"https://scholar.google.com/scholar?q={quote(title)}" | |
def _get_arxiv_url(self, title: str) -> str: | |
"""Fast arXiv lookup with a short timeout.""" | |
with self.session.get( | |
"http://export.arxiv.org/api/query", | |
params={"search_query": f'ti:"{title}"', "max_results": 1}, | |
timeout=2 | |
) as response: | |
if response.ok: | |
return self._parse_arxiv_response(response.text) | |
return "" | |
def _parse_arxiv_response(self, xml: str) -> str: | |
"""Fast XML parsing using simple string operations.""" | |
if "<entry>" not in xml: | |
return "" | |
start = xml.find("<id>") + 4 | |
end = xml.find("</id>", start) | |
return xml[start:end].replace("http:", "https:") if start > 3 else "" | |
def _get_semantic_url(self, title: str) -> str: | |
"""Semantic Scholar lookup with a short timeout.""" | |
with self.session.get( | |
"https://api.semanticscholar.org/graph/v1/paper/search", | |
params={"query": title[:200], "limit": 1}, | |
timeout=2 | |
) as response: | |
if response.ok: | |
data = response.json() | |
if data.get("data"): | |
return data["data"][0].get("url", "") | |
return "" | |
class OptimizedSemanticSearch: | |
def __init__(self): | |
# Load the sentence transformer model | |
self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
self._load_faiss_indexes() | |
self.metadata_mgr = OptimizedMetadataManager() | |
def _load_faiss_indexes(self): | |
"""Load the FAISS index with memory mapping for read-only access.""" | |
# Here we assume the FAISS index has been combined into one file. | |
self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) | |
logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors") | |
def search(self, query: str, top_k: int = 5) -> List[Dict]: | |
"""Optimized search pipeline: | |
- Encodes the query. | |
- Performs FAISS search (fetching extra results for deduplication). | |
- Retrieves metadata and processes results. | |
""" | |
# Batch encode query | |
query_embedding = self.model.encode([query], convert_to_numpy=True) | |
# FAISS search: we search for more than top_k to allow for deduplication. | |
distances, indices = self.index.search(query_embedding, top_k * 2) | |
# Batch metadata retrieval | |
results = self.metadata_mgr.get_metadata_batch(indices[0]) | |
# Process and return the final results | |
return self._process_results(results, distances[0], top_k) | |
def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]: | |
"""Parallel processing of search results: | |
- Resolve source URLs in parallel. | |
- Add similarity scores. | |
- Deduplicate and sort the results. | |
""" | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
# Parallel URL resolution for each result | |
futures = { | |
executor.submit( | |
self.metadata_mgr.resolve_url, | |
res["title"] | |
): idx for idx, res in enumerate(results) | |
} | |
# Update each result as URLs resolve | |
for future in concurrent.futures.as_completed(futures): | |
idx = futures[future] | |
try: | |
results[idx]["source"] = future.result() | |
except Exception as e: | |
results[idx]["source"] = "" | |
# Add similarity scores based on distances | |
for idx, dist in enumerate(distances[:len(results)]): | |
results[idx]["similarity"] = 1 - (dist / 2) | |
# Deduplicate by title and sort by similarity score (descending) | |
seen = set() | |
final_results = [] | |
for res in sorted(results, key=lambda x: x["similarity"], reverse=True): | |
if res["title"] not in seen and len(final_results) < top_k: | |
seen.add(res["title"]) | |
final_results.append(res) | |
return final_results | |