File size: 6,377 Bytes
5ee0a10
 
 
 
 
 
7ccde22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ee0a10
 
 
 
 
 
 
7ccde22
5ee0a10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ccde22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import numpy as np
import pandas as pd
import faiss
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
import zipfile
import pandas as pd
from pathlib import Path
import streamlit as st

class MetadataManager:
    def __init__(self):
        self.shard_dir = Path("metadata_shards")
        self.shard_map = {}
        self.loaded_shards = {}
        self._ensure_unzipped()
        self._build_shard_map()

    def _ensure_unzipped(self):
        """Extract metadata shards from zip if needed"""
        if not self.shard_dir.exists():
            zip_path = Path("metadata_shards.zip")
            if zip_path.exists():
                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(self.shard_dir)
                st.toast("βœ… Successfully extracted metadata shards!", icon="πŸ“¦")
            else:
                raise FileNotFoundError("No metadata shards found!")

    def _build_shard_map(self):
        """Map index ranges to shard files"""
        for f in self.shard_dir.glob("*.parquet"):
            parts = f.stem.split("_")
            self.shard_map[(int(parts[1]), int(parts[2]))] = f.name

    def get_metadata(self, indices):
        """Retrieve metadata for specific indices"""
        results = []
        shard_groups = {}
        
        # Group indices by shard
        for idx in indices:
            for (start, end), shard in self.shard_map.items():
                if start <= idx <= end:
                    if shard not in shard_groups:
                        shard_groups[shard] = []
                    shard_groups[shard].append(idx - start)
                    break
        
        # Load required shards
        for shard, local_indices in shard_groups.items():
            if shard not in self.loaded_shards:
                self.loaded_shards[shard] = pd.read_parquet(
                    self.shard_dir / shard,
                    columns=["title", "summary", "source"]
                )
            
            results.append(self.loaded_shards[shard].iloc[local_indices])
        
        return pd.concat(results).reset_index(drop=True)


class SemanticSearch:
    def __init__(self, shard_dir="compressed_shards"):
        self.shard_dir = Path(shard_dir)
        self.shard_dir.mkdir(exist_ok=True, parents=True)
        self.model = None
        self.index_shards = []
        self.metadata_mgr = MetadataManager()
        
    @st.cache_resource
    def load_model(_self):
        return SentenceTransformer('all-MiniLM-L6-v2')
        
    def initialize_system(self):
        self.model = self.load_model()
        self._load_index_shards()

    def _load_index_shards(self):
        """Load FAISS shards directly from local directory"""
        for shard_path in sorted(self.shard_dir.glob("*.index")):
            self.index_shards.append(faiss.read_index(str(shard_path)))

    def search(self, query, top_k=5):
        """Search across all shards"""
        query_embedding = self.model.encode([query], convert_to_numpy=True)
        all_scores = []
        all_indices = []
        
        for shard_idx, index in enumerate(self.index_shards):
            distances, indices = index.search(query_embedding, top_k)
            # Convert local indices to global shard offsets
            global_indices = [
                self._calculate_global_index(shard_idx, idx)
                for idx in indices[0]
            ]
            all_scores.extend(distances[0])
            all_indices.extend(global_indices)
            
        return self._process_results(np.array(all_scores), np.array(all_indices), top_k)

    def _calculate_global_index(self, shard_idx, local_idx):
        """Convert shard-local index to global index"""
        # Implement your specific shard indexing logic here
        # Example: return f"{shard_idx}-{local_idx}"
        return local_idx  # Simple version if using unique IDs

    def _process_results(self, distances, indices, top_k):
        """Format search results"""
        results = pd.DataFrame({
            'global_index': indices,
            'similarity': 1 - (distances / 2)  # L2 to cosine approximation
        })
        return results.sort_values('similarity', ascending=False).head(top_k)

    def search_with_threshold(self, query, top_k=5, similarity_threshold=0.6):
        """Threshold-filtered search"""
        results = self.search(query, top_k*2)
        filtered = results[results['similarity'] > similarity_threshold].head(top_k)
        return filtered.reset_index(drop=True)




class MetadataManager:
    def __init__(self, repo_id, shard_dir="metadata_shards"):
        self.repo_id = repo_id
        self.shard_dir = Path(shard_dir)
        self.shard_map = {}
        self.loaded_shards = {}
        self._build_shard_map()

    def _build_shard_map(self):
        """Map index ranges to shard files"""
        for f in self.shard_dir.glob("*.parquet"):
            parts = f.stem.split("_")
            self.shard_map[(int(parts[1]), int(parts[2]))] = f.name

    def _download_shard(self, shard_name):
        """Download missing shards on demand"""
        if not (self.shard_dir/shard_name).exists():
            hf_hub_download(
                repo_id=self.repo_id,
                filename=f"metadata_shards/{shard_name}",
                local_dir=self.shard_dir,
                cache_dir="metadata_cache"
            )

    def get_metadata(self, indices):
        """Retrieve metadata for specific indices"""
        results = []
        
        # Group indices by shard
        shard_groups = {}
        for idx in indices:
            for (start, end), shard in self.shard_map.items():
                if start <= idx <= end:
                    if shard not in shard_groups:
                        shard_groups[shard] = []
                    shard_groups[shard].append(idx - start)
                    break
        
        # Process each required shard
        for shard, local_indices in shard_groups.items():
            if shard not in self.loaded_shards:
                self._download_shard(shard)
                self.loaded_shards[shard] = pd.read_parquet(self.shard_dir/shard)
            
            results.append(self.loaded_shards[shard].iloc[local_indices])
        
        return pd.concat(results).reset_index(drop=True)