File size: 17,069 Bytes
5ee0a10
 
 
7ccde22
dd6b309
7ccde22
017ee94
7ccde22
dd6b309
 
 
 
 
 
 
 
 
 
 
 
7ccde22
 
 
 
 
 
017ee94
dd6b309
 
 
7ccde22
dd6b309
 
7ccde22
 
b2bcde5
dd6b309
7ccde22
 
dd6b309
7ccde22
dd6b309
 
7ccde22
 
dd6b309
7ccde22
dd6b309
 
 
 
 
b2bcde5
7ccde22
017ee94
dd6b309
017ee94
dd6b309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ccde22
017ee94
a223079
2dec497
 
dd6b309
a223079
7ccde22
2dec497
 
dd6b309
2dec497
a223079
2dec497
dd6b309
 
 
 
a223079
dd6b309
a223079
 
 
 
dd6b309
 
a223079
 
7ccde22
 
 
 
 
a223079
7ccde22
a223079
dd6b309
 
 
 
 
a223079
 
 
7ccde22
a223079
dd6b309
 
 
a223079
dd6b309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a223079
 
dd6b309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a223079
dd6b309
a223079
 
dd6b309
 
 
 
 
 
 
 
7ccde22
5ee0a10
017ee94
 
5ee0a10
 
dd6b309
017ee94
b2bcde5
dd6b309
 
 
 
5ee0a10
 
 
017ee94
5ee0a10
dd6b309
 
5ee0a10
dd6b309
 
 
017ee94
5ee0a10
017ee94
 
dd6b309
 
 
 
 
 
 
 
 
017ee94
dd6b309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
017ee94
 
 
 
5ee0a10
 
a223079
dd6b309
 
 
 
 
 
 
 
 
a223079
 
 
dd6b309
a223079
dd6b309
a223079
dd6b309
a223079
 
017ee94
 
a223079
 
dd6b309
5ee0a10
a223079
dd6b309
a223079
 
 
dd6b309
a223079
dd6b309
 
 
 
 
 
 
 
a223079
 
dd6b309
a223079
dd6b309
 
a223079
dd6b309
a223079
 
dd6b309
 
 
 
 
 
a223079
b73a811
dd6b309
 
 
b73a811
 
 
dd6b309
 
2dec497
 
dd6b309
b73a811
 
 
2dec497
dd6b309
 
 
 
b73a811
dd6b309
 
 
 
 
2dec497
 
dd6b309
 
 
 
 
 
 
 
 
2dec497
 
dd6b309
b73a811
 
dd6b309
 
 
 
 
 
2dec497
dd6b309
d7bc2ed
dd6b309
b73a811
dd6b309
 
 
 
b73a811
 
 
dd6b309
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import numpy as np
import pandas as pd
import faiss
import zipfile
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
import time
import os

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("MetadataManager")

class MetadataManager:
    def __init__(self):
        self.shard_dir = Path("metadata_shards")
        self.shard_map = {}
        self.loaded_shards = {}
        self.total_docs = 0
        
        logger.info("Initializing MetadataManager")
        self._ensure_unzipped()
        self._build_shard_map()
        logger.info(f"Total documents indexed: {self.total_docs}")
        logger.info(f"Total shards found: {len(self.shard_map)}")

    def _ensure_unzipped(self):
        """Handle ZIP extraction without Streamlit elements"""
        logger.info(f"Checking for shard directory: {self.shard_dir}")
        if not self.shard_dir.exists():
            zip_path = Path("metadata_shards.zip")
            logger.info(f"Shard directory not found, looking for zip file: {zip_path}")
            if zip_path.exists():
                logger.info(f"Extracting from zip file: {zip_path}")
                start_time = time.time()
                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(self.shard_dir)
                logger.info(f"Extraction completed in {time.time() - start_time:.2f} seconds")
            else:
                error_msg = "Metadata ZIP file not found"
                logger.error(error_msg)
                raise FileNotFoundError(error_msg)
        else:
            logger.info("Shard directory exists, skipping extraction")
                
    def _build_shard_map(self):
        """Create index range to shard mapping"""
        logger.info("Building shard map from parquet files")
        self.total_docs = 0
        shard_files = list(self.shard_dir.glob("*.parquet"))
        logger.info(f"Found {len(shard_files)} parquet files")
        
        if not shard_files:
            logger.warning("No parquet files found in shard directory")
            
        for f in sorted(shard_files):
            try:
                parts = f.stem.split("_")
                if len(parts) < 3:
                    logger.warning(f"Skipping file with invalid name format: {f}")
                    continue
                    
                start = int(parts[1])
                end = int(parts[2])
                self.shard_map[(start, end)] = f.name
                self.total_docs = max(self.total_docs, end + 1)
                logger.debug(f"Mapped shard {f.name}: indices {start}-{end}")
            except Exception as e:
                logger.error(f"Error parsing shard filename {f}: {str(e)}")
        
        # Log shard statistics
        logger.info(f"Shard map built with {len(self.shard_map)} shards")
        logger.info(f"Total document count: {self.total_docs}")
        
        # Validate shard boundaries for gaps or overlaps
        sorted_ranges = sorted(self.shard_map.keys())
        for i in range(1, len(sorted_ranges)):
            prev_end = sorted_ranges[i-1][1]
            curr_start = sorted_ranges[i][0]
            if curr_start != prev_end + 1:
                logger.warning(f"Gap or overlap detected between shards: {prev_end} to {curr_start}")

    def get_metadata(self, global_indices):
        """Retrieve metadata with validation"""
        # Check for empty numpy array properly
        if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
            logger.warning("Empty indices array passed to get_metadata")
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
        
        # Convert numpy array to list for processing
        indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
        logger.info(f"Retrieving metadata for {len(indices_list)} indices")
        
        # Filter valid indices
        valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
        invalid_count = len(indices_list) - len(valid_indices)
        if invalid_count > 0:
            logger.warning(f"Filtered out {invalid_count} invalid indices")
        
        if not valid_indices:
            logger.warning("No valid indices remain after filtering")
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
    
        # Group indices by shard with boundary check
        shard_groups = {}
        unassigned_indices = []
        
        for idx in valid_indices:
            found = False
            for (start, end), shard in self.shard_map.items():
                if start <= idx <= end:
                    if shard not in shard_groups:
                        shard_groups[shard] = []
                    shard_groups[shard].append(idx - start)
                    found = True
                    break
            if not found:
                unassigned_indices.append(idx)
                logger.warning(f"Index {idx} not found in any shard range")
        
        if unassigned_indices:
            logger.warning(f"Could not assign {len(unassigned_indices)} indices to any shard")
    
        # Load and process shards
        results = []
        for shard, local_indices in shard_groups.items():
            try:
                logger.info(f"Processing shard {shard} with {len(local_indices)} indices")
                start_time = time.time()
                
                if shard not in self.loaded_shards:
                    logger.info(f"Loading shard file: {shard}")
                    shard_path = self.shard_dir / shard
                    
                    # Verify file exists
                    if not shard_path.exists():
                        logger.error(f"Shard file not found: {shard_path}")
                        continue
                        
                    # Log file size
                    file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
                    logger.info(f"Shard file size: {file_size_mb:.2f} MB")
                    
                    # Attempt to read the parquet file
                    try:
                        self.loaded_shards[shard] = pd.read_parquet(
                            shard_path,
                            columns=["title", "summary", "source"]
                        )
                        logger.info(f"Successfully loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
                    except Exception as e:
                        logger.error(f"Failed to read parquet file {shard}: {str(e)}")
                        
                        # Try to read file schema for debugging
                        try:
                            schema = pd.read_parquet(shard_path, engine='pyarrow').dtypes
                            logger.info(f"Parquet schema: {schema}")
                        except:
                            pass
                        continue
                
                if local_indices:
                    # Validate indices are within dataframe bounds
                    df_len = len(self.loaded_shards[shard])
                    valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
                    
                    if len(valid_local_indices) != len(local_indices):
                        logger.warning(f"Filtered {len(local_indices) - len(valid_local_indices)} out-of-bounds indices")
                    
                    if valid_local_indices:
                        logger.debug(f"Retrieving rows at indices: {valid_local_indices}")
                        chunk = self.loaded_shards[shard].iloc[valid_local_indices]
                        results.append(chunk)
                        logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
                
                logger.info(f"Shard processing completed in {time.time() - start_time:.2f} seconds")
                    
            except Exception as e:
                logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
                continue
    
        # Combine results
        if results:
            combined = pd.concat(results).reset_index(drop=True)
            logger.info(f"Combined metadata: {len(combined)} records from {len(results)} shards")
            return combined
        else:
            logger.warning("No metadata records retrieved")
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])

class SemanticSearch:
    def __init__(self):
        self.shard_dir = Path("compressed_shards")
        self.model = None
        self.index_shards = []
        self.metadata_mgr = MetadataManager()
        self.shard_sizes = []
        
        # Configure search logger
        self.logger = logging.getLogger("SemanticSearch")
        self.logger.info("Initializing SemanticSearch")
        
    @st.cache_resource
    def load_model(_self):
        return SentenceTransformer('all-MiniLM-L6-v2')

    def initialize_system(self):
        self.logger.info("Loading sentence transformer model")
        start_time = time.time()
        self.model = self.load_model()
        self.logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
        
        self.logger.info("Loading FAISS indices")
        self._load_faiss_shards()

    def _load_faiss_shards(self):
        """Load all FAISS index shards"""
        self.logger.info(f"Searching for index files in {self.shard_dir}")
        
        if not self.shard_dir.exists():
            self.logger.error(f"Shard directory not found: {self.shard_dir}")
            return
            
        index_files = list(self.shard_dir.glob("*.index"))
        self.logger.info(f"Found {len(index_files)} index files")
        
        self.shard_sizes = []
        self.index_shards = []
        
        for shard_path in sorted(index_files):
            try:
                self.logger.info(f"Loading index: {shard_path}")
                start_time = time.time()
                
                # Log file size
                file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
                self.logger.info(f"Index file size: {file_size_mb:.2f} MB")
                
                index = faiss.read_index(str(shard_path))
                self.index_shards.append(index)
                self.shard_sizes.append(index.ntotal)
                
                self.logger.info(f"Loaded index with {index.ntotal} vectors in {time.time() - start_time:.2f} seconds")
            except Exception as e:
                self.logger.error(f"Failed to load index {shard_path}: {str(e)}")
        
        total_vectors = sum(self.shard_sizes)
        self.logger.info(f"Total loaded vectors: {total_vectors} across {len(self.index_shards)} shards")

    def _global_index(self, shard_idx, local_idx):
        """Convert local index to global index"""
        return sum(self.shard_sizes[:shard_idx]) + local_idx

    def search(self, query, top_k=5):
        """Search with validation"""
        self.logger.info(f"Searching for query: '{query}' (top_k={top_k})")
        start_time = time.time()
        
        if not query:
            self.logger.warning("Empty query provided")
            return pd.DataFrame()
            
        if not self.index_shards:
            self.logger.error("No index shards loaded")
            return pd.DataFrame()
        
        try:
            self.logger.info("Encoding query")
            query_embedding = self.model.encode([query], convert_to_numpy=True)
            self.logger.debug(f"Query encoded to shape {query_embedding.shape}")
        except Exception as e:
            self.logger.error(f"Query encoding failed: {str(e)}")
            return pd.DataFrame()
    
        all_distances = []
        all_global_indices = []
    
        # Search with index validation
        self.logger.info(f"Searching across {len(self.index_shards)} shards")
        for shard_idx, index in enumerate(self.index_shards):
            if index.ntotal == 0:
                self.logger.warning(f"Skipping empty shard {shard_idx}")
                continue
                
            try:
                shard_start = time.time()
                distances, indices = index.search(query_embedding, top_k)
                
                valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
                valid_indices = indices[0][valid_mask].tolist()
                valid_distances = distances[0][valid_mask].tolist()
                
                if len(valid_indices) != top_k:
                    self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
                
                global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
                
                all_distances.extend(valid_distances)
                all_global_indices.extend(global_indices)
                
                self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
            except Exception as e:
                self.logger.error(f"Search failed in shard {shard_idx}: {str(e)}")
                continue
    
        self.logger.info(f"Search found {len(all_global_indices)} results across all shards")
        
        # Process results
        results = self._process_results(
            np.array(all_distances), 
            np.array(all_global_indices), 
            top_k
        )
        
        self.logger.info(f"Search completed in {time.time() - start_time:.2f} seconds with {len(results)} final results")
        return results

    def _process_results(self, distances, global_indices, top_k):
        """Process raw search results into formatted DataFrame"""
        process_start = time.time()
        
        # Proper numpy array emptiness checks
        if global_indices.size == 0 or distances.size == 0:
            self.logger.warning("No search results to process")
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
        
        try:
            # Get metadata for matched indices
            self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
            metadata_start = time.time()
            results = self.metadata_mgr.get_metadata(global_indices)
            self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
            
            # Empty results check
            if len(results) == 0:
                self.logger.warning("No metadata found for indices")
                return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
                
            # Ensure distances match results length
            if len(results) != len(distances):
                self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
                
                if len(results) < len(distances):
                    self.logger.info("Truncating distances array to match results length")
                    distances = distances[:len(results)]
                else:
                    # Should not happen but handle it anyway
                    self.logger.error("More results than distances - this shouldn't happen")
                    distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
                
            # Calculate similarity scores
            self.logger.debug("Calculating similarity scores")
            results['similarity'] = 1 - (distances / 2)
            
            # Log similarity statistics
            if not results.empty:
                self.logger.debug(f"Similarity stats: min={results['similarity'].min():.3f}, " +
                                 f"max={results['similarity'].max():.3f}, " +
                                 f"mean={results['similarity'].mean():.3f}")
            
            # Deduplicate and sort results
            pre_dedup = len(results)
            results = results.drop_duplicates(subset=["title", "source"]).sort_values("similarity", ascending=False).head(top_k)
            post_dedup = len(results)
            
            if pre_dedup > post_dedup:
                self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
            
            self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
            return results.reset_index(drop=True)
            
        except Exception as e:
            self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
            return pd.DataFrame(columns=["title", "summary", "source", "similarity"])