Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import numpy as np | |
| import faiss | |
| import zipfile | |
| import logging | |
| from pathlib import Path | |
| from sentence_transformers import SentenceTransformer | |
| import concurrent.futures | |
| import os | |
| import requests | |
| from functools import lru_cache | |
| from typing import List, Dict | |
| import pandas as pd | |
| from urllib.parse import quote | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.WARNING, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("OptimizedSearch") | |
| class OptimizedMetadataManager: | |
| def __init__(self): | |
| self._init_metadata() | |
| self._init_url_resolver() | |
| def _init_metadata(self): | |
| """Memory-mapped metadata loading. | |
| Preloads all metadata (title and summary) into memory from parquet files. | |
| """ | |
| self.metadata_dir = Path("unzipped_cache/metadata_shards") | |
| self.metadata = {} | |
| # Preload all metadata into memory | |
| for parquet_file in self.metadata_dir.glob("*.parquet"): | |
| df = pd.read_parquet(parquet_file, columns=["title", "summary"]) | |
| # Using the dataframe index as key (assumes unique indices across files) | |
| self.metadata.update(df.to_dict(orient="index")) | |
| self.total_docs = len(self.metadata) | |
| logger.info(f"Loaded {self.total_docs} metadata entries into memory") | |
| def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]: | |
| """Batch retrieval of metadata entries for a list of indices.""" | |
| return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices] | |
| def _init_url_resolver(self): | |
| """Initialize API session and adapter for faster URL resolution.""" | |
| self.session = requests.Session() | |
| adapter = requests.adapters.HTTPAdapter( | |
| pool_connections=10, | |
| pool_maxsize=10, | |
| max_retries=3 | |
| ) | |
| self.session.mount("https://", adapter) | |
| def resolve_url(self, title: str) -> str: | |
| """Optimized URL resolution with caching and a fail-fast approach.""" | |
| try: | |
| # Try arXiv first | |
| arxiv_url = self._get_arxiv_url(title) | |
| if arxiv_url: | |
| return arxiv_url | |
| # Fallback to Semantic Scholar | |
| semantic_url = self._get_semantic_url(title) | |
| if semantic_url: | |
| return semantic_url | |
| except Exception as e: | |
| logger.warning(f"URL resolution failed: {str(e)}") | |
| # Default fallback to Google Scholar search | |
| return f"https://scholar.google.com/scholar?q={quote(title)}" | |
| def _get_arxiv_url(self, title: str) -> str: | |
| """Fast arXiv lookup with a short timeout.""" | |
| with self.session.get( | |
| "http://export.arxiv.org/api/query", | |
| params={"search_query": f'ti:"{title}"', "max_results": 1}, | |
| timeout=2 | |
| ) as response: | |
| if response.ok: | |
| return self._parse_arxiv_response(response.text) | |
| return "" | |
| def _parse_arxiv_response(self, xml: str) -> str: | |
| """Fast XML parsing using simple string operations.""" | |
| if "<entry>" not in xml: | |
| return "" | |
| start = xml.find("<id>") + 4 | |
| end = xml.find("</id>", start) | |
| return xml[start:end].replace("http:", "https:") if start > 3 else "" | |
| def _get_semantic_url(self, title: str) -> str: | |
| """Semantic Scholar lookup with a short timeout.""" | |
| with self.session.get( | |
| "https://api.semanticscholar.org/graph/v1/paper/search", | |
| params={"query": title[:200], "limit": 1}, | |
| timeout=2 | |
| ) as response: | |
| if response.ok: | |
| data = response.json() | |
| if data.get("data"): | |
| return data["data"][0].get("url", "") | |
| return "" | |
| class OptimizedSemanticSearch: | |
| def __init__(self): | |
| # Load the sentence transformer model | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self._load_faiss_indexes() | |
| self.metadata_mgr = OptimizedMetadataManager() | |
| def _load_faiss_indexes(self): | |
| """Load the FAISS index with memory mapping for read-only access.""" | |
| # Here we assume the FAISS index has been combined into one file. | |
| self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) | |
| logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors") | |
| def search(self, query: str, top_k: int = 5) -> List[Dict]: | |
| """Optimized search pipeline: | |
| - Encodes the query. | |
| - Performs FAISS search (fetching extra results for deduplication). | |
| - Retrieves metadata and processes results. | |
| """ | |
| # Batch encode query | |
| query_embedding = self.model.encode([query], convert_to_numpy=True) | |
| # FAISS search: we search for more than top_k to allow for deduplication. | |
| distances, indices = self.index.search(query_embedding, top_k * 2) | |
| # Batch metadata retrieval | |
| results = self.metadata_mgr.get_metadata_batch(indices[0]) | |
| # Process and return the final results | |
| return self._process_results(results, distances[0], top_k) | |
| def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]: | |
| """Parallel processing of search results: | |
| - Resolve source URLs in parallel. | |
| - Add similarity scores. | |
| - Deduplicate and sort the results. | |
| """ | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| # Parallel URL resolution for each result | |
| futures = { | |
| executor.submit( | |
| self.metadata_mgr.resolve_url, | |
| res["title"] | |
| ): idx for idx, res in enumerate(results) | |
| } | |
| # Update each result as URLs resolve | |
| for future in concurrent.futures.as_completed(futures): | |
| idx = futures[future] | |
| try: | |
| results[idx]["source"] = future.result() | |
| except Exception as e: | |
| results[idx]["source"] = "" | |
| # Add similarity scores based on distances | |
| for idx, dist in enumerate(distances[:len(results)]): | |
| results[idx]["similarity"] = 1 - (dist / 2) | |
| # Deduplicate by title and sort by similarity score (descending) | |
| seen = set() | |
| final_results = [] | |
| for res in sorted(results, key=lambda x: x["similarity"], reverse=True): | |
| if res["title"] not in seen and len(final_results) < top_k: | |
| seen.add(res["title"]) | |
| final_results.append(res) | |
| return final_results | |
