Spaces:
Running
Running
File size: 6,834 Bytes
5ee0a10 7ccde22 dd6b309 7ccde22 801d9f2 dd6b309 265c29d 801d9f2 5cefd40 dd6b309 5cefd40 801d9f2 7ccde22 801d9f2 7ccde22 801d9f2 5cefd40 801d9f2 5cefd40 801d9f2 d286a45 801d9f2 5cefd40 801d9f2 5cefd40 801d9f2 4e62c61 801d9f2 5cefd40 4e62c61 801d9f2 5cefd40 dd6b309 801d9f2 5cefd40 f9e4fd2 801d9f2 f9e4fd2 5cefd40 801d9f2 5cefd40 801d9f2 5cefd40 801d9f2 5cefd40 801d9f2 017ee94 5cefd40 801d9f2 b2bcde5 801d9f2 5cefd40 801d9f2 5ee0a10 801d9f2 5cefd40 801d9f2 dd6b309 5cefd40 dd6b309 801d9f2 dd6b309 5cefd40 801d9f2 5cefd40 801d9f2 5cefd40 801d9f2 5cefd40 801d9f2 5cefd40 801d9f2 5cefd40 801d9f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import numpy as np
import faiss
import zipfile
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer
import concurrent.futures
import os
import requests
from functools import lru_cache
from typing import List, Dict
import pandas as pd
from urllib.parse import quote
# Configure logging
logging.basicConfig(
level=logging.WARNING,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("OptimizedSearch")
class OptimizedMetadataManager:
def __init__(self):
self._init_metadata()
self._init_url_resolver()
def _init_metadata(self):
"""Memory-mapped metadata loading.
Preloads all metadata (title and summary) into memory from parquet files.
"""
self.metadata_dir = Path("unzipped_cache/metadata_shards")
self.metadata = {}
# Preload all metadata into memory
for parquet_file in self.metadata_dir.glob("*.parquet"):
df = pd.read_parquet(parquet_file, columns=["title", "summary"])
# Using the dataframe index as key (assumes unique indices across files)
self.metadata.update(df.to_dict(orient="index"))
self.total_docs = len(self.metadata)
logger.info(f"Loaded {self.total_docs} metadata entries into memory")
def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
"""Batch retrieval of metadata entries for a list of indices."""
return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
def _init_url_resolver(self):
"""Initialize API session and adapter for faster URL resolution."""
self.session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
pool_connections=10,
pool_maxsize=10,
max_retries=3
)
self.session.mount("https://", adapter)
@lru_cache(maxsize=10_000)
def resolve_url(self, title: str) -> str:
"""Optimized URL resolution with caching and a fail-fast approach."""
try:
# Try arXiv first
arxiv_url = self._get_arxiv_url(title)
if arxiv_url:
return arxiv_url
# Fallback to Semantic Scholar
semantic_url = self._get_semantic_url(title)
if semantic_url:
return semantic_url
except Exception as e:
logger.warning(f"URL resolution failed: {str(e)}")
# Default fallback to Google Scholar search
return f"https://scholar.google.com/scholar?q={quote(title)}"
def _get_arxiv_url(self, title: str) -> str:
"""Fast arXiv lookup with a short timeout."""
with self.session.get(
"http://export.arxiv.org/api/query",
params={"search_query": f'ti:"{title}"', "max_results": 1},
timeout=2
) as response:
if response.ok:
return self._parse_arxiv_response(response.text)
return ""
def _parse_arxiv_response(self, xml: str) -> str:
"""Fast XML parsing using simple string operations."""
if "<entry>" not in xml:
return ""
start = xml.find("<id>") + 4
end = xml.find("</id>", start)
return xml[start:end].replace("http:", "https:") if start > 3 else ""
def _get_semantic_url(self, title: str) -> str:
"""Semantic Scholar lookup with a short timeout."""
with self.session.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
params={"query": title[:200], "limit": 1},
timeout=2
) as response:
if response.ok:
data = response.json()
if data.get("data"):
return data["data"][0].get("url", "")
return ""
class OptimizedSemanticSearch:
def __init__(self):
# Load the sentence transformer model
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self._load_faiss_indexes()
self.metadata_mgr = OptimizedMetadataManager()
def _load_faiss_indexes(self):
"""Load the FAISS index with memory mapping for read-only access."""
# Here we assume the FAISS index has been combined into one file.
self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""Optimized search pipeline:
- Encodes the query.
- Performs FAISS search (fetching extra results for deduplication).
- Retrieves metadata and processes results.
"""
# Batch encode query
query_embedding = self.model.encode([query], convert_to_numpy=True)
# FAISS search: we search for more than top_k to allow for deduplication.
distances, indices = self.index.search(query_embedding, top_k * 2)
# Batch metadata retrieval
results = self.metadata_mgr.get_metadata_batch(indices[0])
# Process and return the final results
return self._process_results(results, distances[0], top_k)
def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
"""Parallel processing of search results:
- Resolve source URLs in parallel.
- Add similarity scores.
- Deduplicate and sort the results.
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
# Parallel URL resolution for each result
futures = {
executor.submit(
self.metadata_mgr.resolve_url,
res["title"]
): idx for idx, res in enumerate(results)
}
# Update each result as URLs resolve
for future in concurrent.futures.as_completed(futures):
idx = futures[future]
try:
results[idx]["source"] = future.result()
except Exception as e:
results[idx]["source"] = ""
# Add similarity scores based on distances
for idx, dist in enumerate(distances[:len(results)]):
results[idx]["similarity"] = 1 - (dist / 2)
# Deduplicate by title and sort by similarity score (descending)
seen = set()
final_results = []
for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
if res["title"] not in seen and len(final_results) < top_k:
seen.add(res["title"])
final_results.append(res)
return final_results
|