File size: 6,834 Bytes
5ee0a10
 
7ccde22
dd6b309
7ccde22
801d9f2
 
dd6b309
265c29d
801d9f2
 
5cefd40
 
dd6b309
 
5cefd40
 
 
 
801d9f2
7ccde22
801d9f2
7ccde22
801d9f2
 
 
 
5cefd40
 
 
801d9f2
 
 
 
 
 
5cefd40
801d9f2
d286a45
801d9f2
 
 
 
5cefd40
801d9f2
 
 
5cefd40
801d9f2
 
 
 
 
 
 
4e62c61
801d9f2
 
5cefd40
4e62c61
801d9f2
 
5cefd40
 
dd6b309
801d9f2
 
5cefd40
 
f9e4fd2
 
801d9f2
f9e4fd2
5cefd40
801d9f2
 
 
5cefd40
801d9f2
 
 
 
 
 
 
 
 
 
5cefd40
 
 
801d9f2
 
 
 
 
5cefd40
801d9f2
 
 
 
 
 
 
 
 
 
 
 
017ee94
5cefd40
801d9f2
 
 
b2bcde5
801d9f2
5cefd40
 
801d9f2
 
5ee0a10
801d9f2
5cefd40
 
 
 
 
801d9f2
 
dd6b309
5cefd40
 
dd6b309
801d9f2
 
dd6b309
5cefd40
801d9f2
 
 
5cefd40
 
 
 
 
801d9f2
5cefd40
801d9f2
 
 
 
 
 
5cefd40
801d9f2
 
 
 
 
 
5cefd40
 
801d9f2
 
 
5cefd40
801d9f2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import numpy as np
import faiss
import zipfile
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer
import concurrent.futures
import os
import requests
from functools import lru_cache
from typing import List, Dict
import pandas as pd
from urllib.parse import quote

# Configure logging
logging.basicConfig(
    level=logging.WARNING,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("OptimizedSearch")

class OptimizedMetadataManager:
    def __init__(self):
        self._init_metadata()
        self._init_url_resolver()

    def _init_metadata(self):
        """Memory-mapped metadata loading.
           Preloads all metadata (title and summary) into memory from parquet files.
        """
        self.metadata_dir = Path("unzipped_cache/metadata_shards")
        self.metadata = {}
        
        # Preload all metadata into memory
        for parquet_file in self.metadata_dir.glob("*.parquet"):
            df = pd.read_parquet(parquet_file, columns=["title", "summary"])
            # Using the dataframe index as key (assumes unique indices across files)
            self.metadata.update(df.to_dict(orient="index"))
            
        self.total_docs = len(self.metadata)
        logger.info(f"Loaded {self.total_docs} metadata entries into memory")

    def get_metadata_batch(self, indices: np.ndarray) -> List[Dict]:
        """Batch retrieval of metadata entries for a list of indices."""
        return [self.metadata.get(idx, {"title": "", "summary": ""}) for idx in indices]

    def _init_url_resolver(self):
        """Initialize API session and adapter for faster URL resolution."""
        self.session = requests.Session()
        adapter = requests.adapters.HTTPAdapter(
            pool_connections=10,
            pool_maxsize=10,
            max_retries=3
        )
        self.session.mount("https://", adapter)

    @lru_cache(maxsize=10_000)
    def resolve_url(self, title: str) -> str:
        """Optimized URL resolution with caching and a fail-fast approach."""
        try:
            # Try arXiv first
            arxiv_url = self._get_arxiv_url(title)
            if arxiv_url:
                return arxiv_url
            
            # Fallback to Semantic Scholar
            semantic_url = self._get_semantic_url(title)
            if semantic_url:
                return semantic_url
            
        except Exception as e:
            logger.warning(f"URL resolution failed: {str(e)}")
            
        # Default fallback to Google Scholar search
        return f"https://scholar.google.com/scholar?q={quote(title)}"

    def _get_arxiv_url(self, title: str) -> str:
        """Fast arXiv lookup with a short timeout."""
        with self.session.get(
            "http://export.arxiv.org/api/query",
            params={"search_query": f'ti:"{title}"', "max_results": 1},
            timeout=2
        ) as response:
            if response.ok:
                return self._parse_arxiv_response(response.text)
        return ""

    def _parse_arxiv_response(self, xml: str) -> str:
        """Fast XML parsing using simple string operations."""
        if "<entry>" not in xml:
            return ""
        start = xml.find("<id>") + 4
        end = xml.find("</id>", start)
        return xml[start:end].replace("http:", "https:") if start > 3 else ""

    def _get_semantic_url(self, title: str) -> str:
        """Semantic Scholar lookup with a short timeout."""
        with self.session.get(
            "https://api.semanticscholar.org/graph/v1/paper/search",
            params={"query": title[:200], "limit": 1},
            timeout=2
        ) as response:
            if response.ok:
                data = response.json()
                if data.get("data"):
                    return data["data"][0].get("url", "")
        return ""

class OptimizedSemanticSearch:
    def __init__(self):
        # Load the sentence transformer model
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self._load_faiss_indexes()
        self.metadata_mgr = OptimizedMetadataManager()
        
    def _load_faiss_indexes(self):
        """Load the FAISS index with memory mapping for read-only access."""
        # Here we assume the FAISS index has been combined into one file.
        self.index = faiss.read_index("combined_index.faiss", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
        logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors")

    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        """Optimized search pipeline:
           - Encodes the query.
           - Performs FAISS search (fetching extra results for deduplication).
           - Retrieves metadata and processes results.
        """
        # Batch encode query
        query_embedding = self.model.encode([query], convert_to_numpy=True)
        
        # FAISS search: we search for more than top_k to allow for deduplication.
        distances, indices = self.index.search(query_embedding, top_k * 2)
        
        # Batch metadata retrieval
        results = self.metadata_mgr.get_metadata_batch(indices[0])
        
        # Process and return the final results
        return self._process_results(results, distances[0], top_k)

    def _process_results(self, results: List[Dict], distances: np.ndarray, top_k: int) -> List[Dict]:
        """Parallel processing of search results:
           - Resolve source URLs in parallel.
           - Add similarity scores.
           - Deduplicate and sort the results.
        """
        with concurrent.futures.ThreadPoolExecutor() as executor:
            # Parallel URL resolution for each result
            futures = {
                executor.submit(
                    self.metadata_mgr.resolve_url, 
                    res["title"]
                ): idx for idx, res in enumerate(results)
            }
            # Update each result as URLs resolve
            for future in concurrent.futures.as_completed(futures):
                idx = futures[future]
                try:
                    results[idx]["source"] = future.result()
                except Exception as e:
                    results[idx]["source"] = ""
        
        # Add similarity scores based on distances
        for idx, dist in enumerate(distances[:len(results)]):
            results[idx]["similarity"] = 1 - (dist / 2)

        # Deduplicate by title and sort by similarity score (descending)
        seen = set()
        final_results = []
        for res in sorted(results, key=lambda x: x["similarity"], reverse=True):
            if res["title"] not in seen and len(final_results) < top_k:
                seen.add(res["title"])
                final_results.append(res)
        
        return final_results