File size: 5,801 Bytes
5ee0a10
 
7ccde22
dd6b309
7ccde22
801d9f2
 
dd6b309
265c29d
801d9f2
 
dd6b309
 
7f04a94
 
801d9f2
7ccde22
801d9f2
7ccde22
801d9f2
 
 
 
7f04a94
801d9f2
 
 
 
 
 
 
d286a45
801d9f2
 
 
 
7f04a94
801d9f2
 
 
7f04a94
801d9f2
 
 
 
 
 
 
4e62c61
801d9f2
 
7f04a94
4e62c61
801d9f2
 
7f04a94
dd6b309
801d9f2
 
7f04a94
f9e4fd2
 
801d9f2
f9e4fd2
801d9f2
 
 
7f04a94
801d9f2
 
 
 
 
 
 
 
 
 
7f04a94
 
801d9f2
 
 
 
 
7f04a94
801d9f2
 
 
 
 
 
 
 
 
 
 
 
017ee94
801d9f2
 
 
b2bcde5
801d9f2
7f04a94
801d9f2
 
5ee0a10
801d9f2
7f04a94
801d9f2
 
dd6b309
7f04a94
 
dd6b309
801d9f2
 
dd6b309
7f04a94
801d9f2
 
 
7f04a94
801d9f2
7f04a94
801d9f2
 
 
 
 
 
7f04a94
 
801d9f2
 
 
 
 
 
7f04a94
 
801d9f2
 
 
7f04a94
801d9f2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import numpy as np
import faiss
import zipfile
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer
import concurrent.futures
import os
import requests
from functools import lru_cache
from typing import List, Dict

# Configure logging
logging.basicConfig(level=logging.WARNING,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("OptimizedSearch")

class OptimizedMetadataManager:
    def __init__(self):
        self._init_metadata()
        self._init_url_resolver()

    def _init_metadata(self):
        """Memory-mapped metadata loading"""
        self.metadata_dir = Path("unzipped_cache/metadata_shards")
        self.metadata = {}
        
        # Preload all metadata into memory
        for parquet_file in self.metadata_dir.glob("*.parquet"):
            df = pd.read_parquet(parquet_file, columns=["title", "summary"])
            self.metadata.update(df.to_dict(orient="index"))
            
        self.total_docs = len(self.metadata)
        logger.info(f"Loaded {self.total_docs} metadata entries into memory")

    def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
        """Batch retrieval of metadata"""
        return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]

    def _init_url_resolver(self):
        """Initialize API session and cache"""
        self.session = requests.Session()
        adapter = requests.adapters.HTTPAdapter(
            pool_connections=10,
            pool_maxsize=10,
            max_retries=3
        )
        self.session.mount("https://", adapter)

    @lru_cache(maxsize=10_000)
    def resolve_url(self, title: str) -> str:
        """Optimized URL resolution with fail-fast"""
        try:
            # Try arXiv first
            arxiv_url = self._get_arxiv_url(title)
            if arxiv_url: return arxiv_url
            
            # Fallback to Semantic Scholar
            semantic_url = self._get_semantic_url(title)
            if semantic_url: return semantic_url
            
        except Exception as e:
            logger.warning(f"URL resolution failed: {str(e)}")
            
        return f"https://scholar.google.com/scholar?q={quote(title)}"

    def _get_arxiv_url(self, title: str) -> str:
        """Fast arXiv lookup with timeout"""
        with self.session.get(
            "http://export.arxiv.org/api/query",
            params={"search_query": f'ti:"{title}"', "max_results": 1},
            timeout=2
        ) as response:
            if response.ok:
                return self._parse_arxiv_response(response.text)
        return ""

    def _parse_arxiv_response(self, xml: str) -> str:
        """Fast XML parsing using string operations"""
        if "<entry>" not in xml: return ""
        start = xml.find("<id>") + 4
        end = xml.find("</id>", start)
        return xml[start:end].replace("http:", "https:") if start > 3 else ""

    def _get_semantic_url(self, title: str) -> str:
        """Batch-friendly Semantic Scholar lookup"""
        with self.session.get(
            "https://api.semanticscholar.org/graph/v1/paper/search",
            params={"query": title[:200], "limit": 1},
            timeout=2
        ) as response:
            if response.ok:
                data = response.json()
                if data.get("data"):
                    return data["data"][0].get("url", "")
        return ""

class OptimizedSemanticSearch:
    def __init__(self):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self._load_faiss_indexes()
        self.metadata_mgr = OptimizedMetadataManager()
        
    def _load_faiss_indexes(self):
        """Load indexes with memory mapping"""
        self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
        logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")

    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        """Optimized search pipeline"""
        # Batch encode query
        query_embedding = self.model.encode([query], convert_to_numpy=True)
        
        # FAISS search
        distances, indices = self.index.search(query_embedding, top_k*2)  # Search extra for dedup
        
        # Batch metadata retrieval
        results = self.metadata_mgr.get_metadata_batch(indices[0])
        
        # Process results
        return self._process_results(results, distances[0], top_k)

    def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
        """Parallel result processing"""
        with concurrent.futures.ThreadPoolExecutor() as executor:
            # Parallel URL resolution
            futures = {
                executor.submit(
                    self.metadata_mgr.resolve_url, 
                    res["title"]
                ): idx for idx, res in enumerate(results)
            }
            
            # Update results as URLs resolve
            for future in concurrent.futures.as_completed(futures):
                idx = futures[future]
                try:
                    results[idx]["source"] = future.result()
                except Exception as e:
                    results[idx]["source"] = ""

        # Add similarity scores
        for idx, dist in enumerate(distances[:len(results)]):
            results[idx]["similarity"] = 1 - (dist / 2)

        # Deduplicate and sort
        seen = set()
        final_results = []
        for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
            if res["title"] not in seen and len(final_results) < top_k:
                seen.add(res["title"])
                final_results.append(res)
        
        return final_results