Testys commited on
Commit
dd6b309
·
1 Parent(s): d7bc2ed

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +243 -38
search_utils.py CHANGED
@@ -2,9 +2,22 @@ import numpy as np
2
  import pandas as pd
3
  import faiss
4
  import zipfile
 
5
  from pathlib import Path
6
  from sentence_transformers import SentenceTransformer, util
7
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class MetadataManager:
10
  def __init__(self):
@@ -12,45 +25,94 @@ class MetadataManager:
12
  self.shard_map = {}
13
  self.loaded_shards = {}
14
  self.total_docs = 0
15
- self._ensure_unzipped() # Removed Streamlit elements from here
 
 
16
  self._build_shard_map()
 
 
17
 
18
  def _ensure_unzipped(self):
19
  """Handle ZIP extraction without Streamlit elements"""
 
20
  if not self.shard_dir.exists():
21
  zip_path = Path("metadata_shards.zip")
 
22
  if zip_path.exists():
 
 
23
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
24
  zip_ref.extractall(self.shard_dir)
 
25
  else:
26
- raise FileNotFoundError("Metadata ZIP file not found")
 
 
 
 
27
 
28
  def _build_shard_map(self):
29
  """Create index range to shard mapping"""
 
30
  self.total_docs = 0
31
- for f in sorted(self.shard_dir.glob("*.parquet")):
32
- parts = f.stem.split("_")
33
- start = int(parts[1])
34
- end = int(parts[2])
35
- self.shard_map[(start, end)] = f.name
36
- self.total_docs = max(self.total_docs, end + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def get_metadata(self, global_indices):
39
  """Retrieve metadata with validation"""
40
  # Check for empty numpy array properly
41
  if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
 
42
  return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
43
 
44
  # Convert numpy array to list for processing
45
  indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
 
46
 
47
  # Filter valid indices
48
  valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
 
 
 
 
49
  if not valid_indices:
 
50
  return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
51
 
52
  # Group indices by shard with boundary check
53
  shard_groups = {}
 
 
54
  for idx in valid_indices:
55
  found = False
56
  for (start, end), shard in self.shard_map.items():
@@ -61,49 +123,137 @@ class MetadataManager:
61
  found = True
62
  break
63
  if not found:
64
- st.warning(f"Index {idx} out of shard range (0-{self.total_docs-1})")
 
 
 
 
65
 
66
  # Load and process shards
67
  results = []
68
  for shard, local_indices in shard_groups.items():
69
  try:
 
 
 
70
  if shard not in self.loaded_shards:
71
- self.loaded_shards[shard] = pd.read_parquet(
72
- self.shard_dir / shard,
73
- columns=["title", "summary", "source"]
74
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  if local_indices:
77
- results.append(self.loaded_shards[shard].iloc[local_indices])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  except Exception as e:
79
- st.error(f"Error loading shard {shard}: {str(e)}")
80
  continue
81
 
82
- return pd.concat(results).reset_index(drop=True) if results else pd.DataFrame()
 
 
 
 
 
 
 
83
 
84
  class SemanticSearch:
85
  def __init__(self):
86
  self.shard_dir = Path("compressed_shards")
87
  self.model = None
88
  self.index_shards = []
89
- self.metadata_mgr = MetadataManager() # No Streamlit elements in constructor
90
  self.shard_sizes = []
91
 
 
 
 
 
92
  @st.cache_resource
93
  def load_model(_self):
94
  return SentenceTransformer('all-MiniLM-L6-v2')
95
 
96
  def initialize_system(self):
 
 
97
  self.model = self.load_model()
 
 
 
98
  self._load_faiss_shards()
99
 
100
  def _load_faiss_shards(self):
101
  """Load all FAISS index shards"""
 
 
 
 
 
 
 
 
 
102
  self.shard_sizes = []
103
- for shard_path in sorted(self.shard_dir.glob("*.index")):
104
- index = faiss.read_index(str(shard_path))
105
- self.index_shards.append(index)
106
- self.shard_sizes.append(index.ntotal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def _global_index(self, shard_idx, local_idx):
109
  """Convert local index to global index"""
@@ -111,67 +261,122 @@ class SemanticSearch:
111
 
112
  def search(self, query, top_k=5):
113
  """Search with validation"""
114
- if not query or not self.index_shards:
 
 
 
 
 
 
 
 
115
  return pd.DataFrame()
116
 
117
  try:
 
118
  query_embedding = self.model.encode([query], convert_to_numpy=True)
 
119
  except Exception as e:
120
- st.error(f"Query encoding failed: {str(e)}")
121
  return pd.DataFrame()
122
 
123
  all_distances = []
124
  all_global_indices = []
125
 
126
  # Search with index validation
 
127
  for shard_idx, index in enumerate(self.index_shards):
128
  if index.ntotal == 0:
 
129
  continue
130
 
131
  try:
 
132
  distances, indices = index.search(query_embedding, top_k)
133
- valid_indices = [idx for idx in indices[0] if 0 <= idx < index.ntotal]
 
 
 
 
 
 
 
134
  global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
135
 
136
- all_distances.extend(distances[0][:len(valid_indices)])
137
  all_global_indices.extend(global_indices)
 
 
138
  except Exception as e:
139
- st.error(f"Search failed in shard {shard_idx}: {str(e)}")
140
  continue
141
 
142
- # Ensure equal array lengths
143
- min_length = min(len(all_distances), len(all_global_indices))
144
- return self._process_results(
145
- np.array(all_distances[:min_length]),
146
- np.array(all_global_indices[:min_length]),
 
147
  top_k
148
  )
 
 
 
149
 
150
  def _process_results(self, distances, global_indices, top_k):
151
  """Process raw search results into formatted DataFrame"""
 
 
152
  # Proper numpy array emptiness checks
153
  if global_indices.size == 0 or distances.size == 0:
 
154
  return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
155
 
156
  try:
157
- # Convert numpy indices to Python list for metadata retrieval
158
- indices_list = global_indices.tolist()
159
-
160
  # Get metadata for matched indices
161
- results = self.metadata_mgr.get_metadata(indices_list)
 
 
 
162
 
 
 
 
 
 
163
  # Ensure distances match results length
164
  if len(results) != len(distances):
165
- distances = distances[:len(results)]
 
 
 
 
 
 
 
 
166
 
167
  # Calculate similarity scores
 
168
  results['similarity'] = 1 - (distances / 2)
169
 
 
 
 
 
 
 
170
  # Deduplicate and sort results
 
171
  results = results.drop_duplicates(subset=["title", "source"]).sort_values("similarity", ascending=False).head(top_k)
 
172
 
 
 
 
 
173
  return results.reset_index(drop=True)
174
 
175
  except Exception as e:
176
- st.error(f"Result processing failed: {str(e)}")
177
- return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
 
2
  import pandas as pd
3
  import faiss
4
  import zipfile
5
+ import logging
6
  from pathlib import Path
7
  from sentence_transformers import SentenceTransformer, util
8
  import streamlit as st
9
+ import time
10
+ import os
11
+
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
16
+ handlers=[
17
+ logging.StreamHandler()
18
+ ]
19
+ )
20
+ logger = logging.getLogger("MetadataManager")
21
 
22
  class MetadataManager:
23
  def __init__(self):
 
25
  self.shard_map = {}
26
  self.loaded_shards = {}
27
  self.total_docs = 0
28
+
29
+ logger.info("Initializing MetadataManager")
30
+ self._ensure_unzipped()
31
  self._build_shard_map()
32
+ logger.info(f"Total documents indexed: {self.total_docs}")
33
+ logger.info(f"Total shards found: {len(self.shard_map)}")
34
 
35
  def _ensure_unzipped(self):
36
  """Handle ZIP extraction without Streamlit elements"""
37
+ logger.info(f"Checking for shard directory: {self.shard_dir}")
38
  if not self.shard_dir.exists():
39
  zip_path = Path("metadata_shards.zip")
40
+ logger.info(f"Shard directory not found, looking for zip file: {zip_path}")
41
  if zip_path.exists():
42
+ logger.info(f"Extracting from zip file: {zip_path}")
43
+ start_time = time.time()
44
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
45
  zip_ref.extractall(self.shard_dir)
46
+ logger.info(f"Extraction completed in {time.time() - start_time:.2f} seconds")
47
  else:
48
+ error_msg = "Metadata ZIP file not found"
49
+ logger.error(error_msg)
50
+ raise FileNotFoundError(error_msg)
51
+ else:
52
+ logger.info("Shard directory exists, skipping extraction")
53
 
54
  def _build_shard_map(self):
55
  """Create index range to shard mapping"""
56
+ logger.info("Building shard map from parquet files")
57
  self.total_docs = 0
58
+ shard_files = list(self.shard_dir.glob("*.parquet"))
59
+ logger.info(f"Found {len(shard_files)} parquet files")
60
+
61
+ if not shard_files:
62
+ logger.warning("No parquet files found in shard directory")
63
+
64
+ for f in sorted(shard_files):
65
+ try:
66
+ parts = f.stem.split("_")
67
+ if len(parts) < 3:
68
+ logger.warning(f"Skipping file with invalid name format: {f}")
69
+ continue
70
+
71
+ start = int(parts[1])
72
+ end = int(parts[2])
73
+ self.shard_map[(start, end)] = f.name
74
+ self.total_docs = max(self.total_docs, end + 1)
75
+ logger.debug(f"Mapped shard {f.name}: indices {start}-{end}")
76
+ except Exception as e:
77
+ logger.error(f"Error parsing shard filename {f}: {str(e)}")
78
+
79
+ # Log shard statistics
80
+ logger.info(f"Shard map built with {len(self.shard_map)} shards")
81
+ logger.info(f"Total document count: {self.total_docs}")
82
+
83
+ # Validate shard boundaries for gaps or overlaps
84
+ sorted_ranges = sorted(self.shard_map.keys())
85
+ for i in range(1, len(sorted_ranges)):
86
+ prev_end = sorted_ranges[i-1][1]
87
+ curr_start = sorted_ranges[i][0]
88
+ if curr_start != prev_end + 1:
89
+ logger.warning(f"Gap or overlap detected between shards: {prev_end} to {curr_start}")
90
 
91
  def get_metadata(self, global_indices):
92
  """Retrieve metadata with validation"""
93
  # Check for empty numpy array properly
94
  if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
95
+ logger.warning("Empty indices array passed to get_metadata")
96
  return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
97
 
98
  # Convert numpy array to list for processing
99
  indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
100
+ logger.info(f"Retrieving metadata for {len(indices_list)} indices")
101
 
102
  # Filter valid indices
103
  valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
104
+ invalid_count = len(indices_list) - len(valid_indices)
105
+ if invalid_count > 0:
106
+ logger.warning(f"Filtered out {invalid_count} invalid indices")
107
+
108
  if not valid_indices:
109
+ logger.warning("No valid indices remain after filtering")
110
  return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
111
 
112
  # Group indices by shard with boundary check
113
  shard_groups = {}
114
+ unassigned_indices = []
115
+
116
  for idx in valid_indices:
117
  found = False
118
  for (start, end), shard in self.shard_map.items():
 
123
  found = True
124
  break
125
  if not found:
126
+ unassigned_indices.append(idx)
127
+ logger.warning(f"Index {idx} not found in any shard range")
128
+
129
+ if unassigned_indices:
130
+ logger.warning(f"Could not assign {len(unassigned_indices)} indices to any shard")
131
 
132
  # Load and process shards
133
  results = []
134
  for shard, local_indices in shard_groups.items():
135
  try:
136
+ logger.info(f"Processing shard {shard} with {len(local_indices)} indices")
137
+ start_time = time.time()
138
+
139
  if shard not in self.loaded_shards:
140
+ logger.info(f"Loading shard file: {shard}")
141
+ shard_path = self.shard_dir / shard
142
+
143
+ # Verify file exists
144
+ if not shard_path.exists():
145
+ logger.error(f"Shard file not found: {shard_path}")
146
+ continue
147
+
148
+ # Log file size
149
+ file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
150
+ logger.info(f"Shard file size: {file_size_mb:.2f} MB")
151
+
152
+ # Attempt to read the parquet file
153
+ try:
154
+ self.loaded_shards[shard] = pd.read_parquet(
155
+ shard_path,
156
+ columns=["title", "summary", "source"]
157
+ )
158
+ logger.info(f"Successfully loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
159
+ except Exception as e:
160
+ logger.error(f"Failed to read parquet file {shard}: {str(e)}")
161
+
162
+ # Try to read file schema for debugging
163
+ try:
164
+ schema = pd.read_parquet(shard_path, engine='pyarrow').dtypes
165
+ logger.info(f"Parquet schema: {schema}")
166
+ except:
167
+ pass
168
+ continue
169
 
170
  if local_indices:
171
+ # Validate indices are within dataframe bounds
172
+ df_len = len(self.loaded_shards[shard])
173
+ valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
174
+
175
+ if len(valid_local_indices) != len(local_indices):
176
+ logger.warning(f"Filtered {len(local_indices) - len(valid_local_indices)} out-of-bounds indices")
177
+
178
+ if valid_local_indices:
179
+ logger.debug(f"Retrieving rows at indices: {valid_local_indices}")
180
+ chunk = self.loaded_shards[shard].iloc[valid_local_indices]
181
+ results.append(chunk)
182
+ logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
183
+
184
+ logger.info(f"Shard processing completed in {time.time() - start_time:.2f} seconds")
185
+
186
  except Exception as e:
187
+ logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
188
  continue
189
 
190
+ # Combine results
191
+ if results:
192
+ combined = pd.concat(results).reset_index(drop=True)
193
+ logger.info(f"Combined metadata: {len(combined)} records from {len(results)} shards")
194
+ return combined
195
+ else:
196
+ logger.warning("No metadata records retrieved")
197
+ return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
198
 
199
  class SemanticSearch:
200
  def __init__(self):
201
  self.shard_dir = Path("compressed_shards")
202
  self.model = None
203
  self.index_shards = []
204
+ self.metadata_mgr = MetadataManager()
205
  self.shard_sizes = []
206
 
207
+ # Configure search logger
208
+ self.logger = logging.getLogger("SemanticSearch")
209
+ self.logger.info("Initializing SemanticSearch")
210
+
211
  @st.cache_resource
212
  def load_model(_self):
213
  return SentenceTransformer('all-MiniLM-L6-v2')
214
 
215
  def initialize_system(self):
216
+ self.logger.info("Loading sentence transformer model")
217
+ start_time = time.time()
218
  self.model = self.load_model()
219
+ self.logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
220
+
221
+ self.logger.info("Loading FAISS indices")
222
  self._load_faiss_shards()
223
 
224
  def _load_faiss_shards(self):
225
  """Load all FAISS index shards"""
226
+ self.logger.info(f"Searching for index files in {self.shard_dir}")
227
+
228
+ if not self.shard_dir.exists():
229
+ self.logger.error(f"Shard directory not found: {self.shard_dir}")
230
+ return
231
+
232
+ index_files = list(self.shard_dir.glob("*.index"))
233
+ self.logger.info(f"Found {len(index_files)} index files")
234
+
235
  self.shard_sizes = []
236
+ self.index_shards = []
237
+
238
+ for shard_path in sorted(index_files):
239
+ try:
240
+ self.logger.info(f"Loading index: {shard_path}")
241
+ start_time = time.time()
242
+
243
+ # Log file size
244
+ file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
245
+ self.logger.info(f"Index file size: {file_size_mb:.2f} MB")
246
+
247
+ index = faiss.read_index(str(shard_path))
248
+ self.index_shards.append(index)
249
+ self.shard_sizes.append(index.ntotal)
250
+
251
+ self.logger.info(f"Loaded index with {index.ntotal} vectors in {time.time() - start_time:.2f} seconds")
252
+ except Exception as e:
253
+ self.logger.error(f"Failed to load index {shard_path}: {str(e)}")
254
+
255
+ total_vectors = sum(self.shard_sizes)
256
+ self.logger.info(f"Total loaded vectors: {total_vectors} across {len(self.index_shards)} shards")
257
 
258
  def _global_index(self, shard_idx, local_idx):
259
  """Convert local index to global index"""
 
261
 
262
  def search(self, query, top_k=5):
263
  """Search with validation"""
264
+ self.logger.info(f"Searching for query: '{query}' (top_k={top_k})")
265
+ start_time = time.time()
266
+
267
+ if not query:
268
+ self.logger.warning("Empty query provided")
269
+ return pd.DataFrame()
270
+
271
+ if not self.index_shards:
272
+ self.logger.error("No index shards loaded")
273
  return pd.DataFrame()
274
 
275
  try:
276
+ self.logger.info("Encoding query")
277
  query_embedding = self.model.encode([query], convert_to_numpy=True)
278
+ self.logger.debug(f"Query encoded to shape {query_embedding.shape}")
279
  except Exception as e:
280
+ self.logger.error(f"Query encoding failed: {str(e)}")
281
  return pd.DataFrame()
282
 
283
  all_distances = []
284
  all_global_indices = []
285
 
286
  # Search with index validation
287
+ self.logger.info(f"Searching across {len(self.index_shards)} shards")
288
  for shard_idx, index in enumerate(self.index_shards):
289
  if index.ntotal == 0:
290
+ self.logger.warning(f"Skipping empty shard {shard_idx}")
291
  continue
292
 
293
  try:
294
+ shard_start = time.time()
295
  distances, indices = index.search(query_embedding, top_k)
296
+
297
+ valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
298
+ valid_indices = indices[0][valid_mask].tolist()
299
+ valid_distances = distances[0][valid_mask].tolist()
300
+
301
+ if len(valid_indices) != top_k:
302
+ self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
303
+
304
  global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
305
 
306
+ all_distances.extend(valid_distances)
307
  all_global_indices.extend(global_indices)
308
+
309
+ self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
310
  except Exception as e:
311
+ self.logger.error(f"Search failed in shard {shard_idx}: {str(e)}")
312
  continue
313
 
314
+ self.logger.info(f"Search found {len(all_global_indices)} results across all shards")
315
+
316
+ # Process results
317
+ results = self._process_results(
318
+ np.array(all_distances),
319
+ np.array(all_global_indices),
320
  top_k
321
  )
322
+
323
+ self.logger.info(f"Search completed in {time.time() - start_time:.2f} seconds with {len(results)} final results")
324
+ return results
325
 
326
  def _process_results(self, distances, global_indices, top_k):
327
  """Process raw search results into formatted DataFrame"""
328
+ process_start = time.time()
329
+
330
  # Proper numpy array emptiness checks
331
  if global_indices.size == 0 or distances.size == 0:
332
+ self.logger.warning("No search results to process")
333
  return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
334
 
335
  try:
 
 
 
336
  # Get metadata for matched indices
337
+ self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
338
+ metadata_start = time.time()
339
+ results = self.metadata_mgr.get_metadata(global_indices)
340
+ self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
341
 
342
+ # Empty results check
343
+ if len(results) == 0:
344
+ self.logger.warning("No metadata found for indices")
345
+ return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
346
+
347
  # Ensure distances match results length
348
  if len(results) != len(distances):
349
+ self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
350
+
351
+ if len(results) < len(distances):
352
+ self.logger.info("Truncating distances array to match results length")
353
+ distances = distances[:len(results)]
354
+ else:
355
+ # Should not happen but handle it anyway
356
+ self.logger.error("More results than distances - this shouldn't happen")
357
+ distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
358
 
359
  # Calculate similarity scores
360
+ self.logger.debug("Calculating similarity scores")
361
  results['similarity'] = 1 - (distances / 2)
362
 
363
+ # Log similarity statistics
364
+ if not results.empty:
365
+ self.logger.debug(f"Similarity stats: min={results['similarity'].min():.3f}, " +
366
+ f"max={results['similarity'].max():.3f}, " +
367
+ f"mean={results['similarity'].mean():.3f}")
368
+
369
  # Deduplicate and sort results
370
+ pre_dedup = len(results)
371
  results = results.drop_duplicates(subset=["title", "source"]).sort_values("similarity", ascending=False).head(top_k)
372
+ post_dedup = len(results)
373
 
374
+ if pre_dedup > post_dedup:
375
+ self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
376
+
377
+ self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
378
  return results.reset_index(drop=True)
379
 
380
  except Exception as e:
381
+ self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
382
+ return pd.DataFrame(columns=["title", "summary", "source", "similarity"])