Spaces:
Running
Running
File size: 5,801 Bytes
5ee0a10 7ccde22 dd6b309 7ccde22 801d9f2 dd6b309 265c29d 801d9f2 dd6b309 7f04a94 801d9f2 7ccde22 801d9f2 7ccde22 801d9f2 7f04a94 801d9f2 d286a45 801d9f2 7f04a94 801d9f2 7f04a94 801d9f2 4e62c61 801d9f2 7f04a94 4e62c61 801d9f2 7f04a94 dd6b309 801d9f2 7f04a94 f9e4fd2 801d9f2 f9e4fd2 801d9f2 7f04a94 801d9f2 7f04a94 801d9f2 7f04a94 801d9f2 017ee94 801d9f2 b2bcde5 801d9f2 7f04a94 801d9f2 5ee0a10 801d9f2 7f04a94 801d9f2 dd6b309 7f04a94 dd6b309 801d9f2 dd6b309 7f04a94 801d9f2 7f04a94 801d9f2 7f04a94 801d9f2 7f04a94 801d9f2 7f04a94 801d9f2 7f04a94 801d9f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import numpy as np
import faiss
import zipfile
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer
import concurrent.futures
import os
import requests
from functools import lru_cache
from typing import List, Dict
# Configure logging
logging.basicConfig(level=logging.WARNING,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("OptimizedSearch")
class OptimizedMetadataManager:
def __init__(self):
self._init_metadata()
self._init_url_resolver()
def _init_metadata(self):
"""Memory-mapped metadata loading"""
self.metadata_dir = Path("unzipped_cache/metadata_shards")
self.metadata = {}
# Preload all metadata into memory
for parquet_file in self.metadata_dir.glob("*.parquet"):
df = pd.read_parquet(parquet_file, columns=["title", "summary"])
self.metadata.update(df.to_dict(orient="index"))
self.total_docs = len(self.metadata)
logger.info(f"Loaded {self.total_docs} metadata entries into memory")
def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
"""Batch retrieval of metadata"""
return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]
def _init_url_resolver(self):
"""Initialize API session and cache"""
self.session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
pool_connections=10,
pool_maxsize=10,
max_retries=3
)
self.session.mount("https://", adapter)
@lru_cache(maxsize=10_000)
def resolve_url(self, title: str) -> str:
"""Optimized URL resolution with fail-fast"""
try:
# Try arXiv first
arxiv_url = self._get_arxiv_url(title)
if arxiv_url: return arxiv_url
# Fallback to Semantic Scholar
semantic_url = self._get_semantic_url(title)
if semantic_url: return semantic_url
except Exception as e:
logger.warning(f"URL resolution failed: {str(e)}")
return f"https://scholar.google.com/scholar?q={quote(title)}"
def _get_arxiv_url(self, title: str) -> str:
"""Fast arXiv lookup with timeout"""
with self.session.get(
"http://export.arxiv.org/api/query",
params={"search_query": f'ti:"{title}"', "max_results": 1},
timeout=2
) as response:
if response.ok:
return self._parse_arxiv_response(response.text)
return ""
def _parse_arxiv_response(self, xml: str) -> str:
"""Fast XML parsing using string operations"""
if "<entry>" not in xml: return ""
start = xml.find("<id>") + 4
end = xml.find("</id>", start)
return xml[start:end].replace("http:", "https:") if start > 3 else ""
def _get_semantic_url(self, title: str) -> str:
"""Batch-friendly Semantic Scholar lookup"""
with self.session.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
params={"query": title[:200], "limit": 1},
timeout=2
) as response:
if response.ok:
data = response.json()
if data.get("data"):
return data["data"][0].get("url", "")
return ""
class OptimizedSemanticSearch:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self._load_faiss_indexes()
self.metadata_mgr = OptimizedMetadataManager()
def _load_faiss_indexes(self):
"""Load indexes with memory mapping"""
self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""Optimized search pipeline"""
# Batch encode query
query_embedding = self.model.encode([query], convert_to_numpy=True)
# FAISS search
distances, indices = self.index.search(query_embedding, top_k*2) # Search extra for dedup
# Batch metadata retrieval
results = self.metadata_mgr.get_metadata_batch(indices[0])
# Process results
return self._process_results(results, distances[0], top_k)
def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
"""Parallel result processing"""
with concurrent.futures.ThreadPoolExecutor() as executor:
# Parallel URL resolution
futures = {
executor.submit(
self.metadata_mgr.resolve_url,
res["title"]
): idx for idx, res in enumerate(results)
}
# Update results as URLs resolve
for future in concurrent.futures.as_completed(futures):
idx = futures[future]
try:
results[idx]["source"] = future.result()
except Exception as e:
results[idx]["source"] = ""
# Add similarity scores
for idx, dist in enumerate(distances[:len(results)]):
results[idx]["similarity"] = 1 - (dist / 2)
# Deduplicate and sort
seen = set()
final_results = []
for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
if res["title"] not in seen and len(final_results) < top_k:
seen.add(res["title"])
final_results.append(res)
return final_results
|