Testys commited on
Commit
74dd725
·
verified ·
1 Parent(s): 7f7218c

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +41 -200
search_utils.py CHANGED
@@ -24,208 +24,49 @@ logger = logging.getLogger("MetadataManager")
24
 
25
  class MetadataManager:
26
  def __init__(self):
27
- self.cache_dir = Path("unzipped_cache")
28
- self.shard_dir = self.cache_dir / "metadata_shards"
29
- self.shard_map = {}
30
- self.loaded_shards = {}
31
  self.total_docs = 0
32
- self.api_cache = {}
33
 
34
  logger.info("Initializing MetadataManager")
35
- self._ensure_directories()
36
- self._unzip_if_needed()
37
- self._build_shard_map()
38
  logger.info(f"Total documents indexed: {self.total_docs}")
39
- logger.info(f"Total shards found: {len(self.shard_map)}")
40
-
41
- def _ensure_directories(self):
42
- """Create necessary directories if they don't exist."""
43
- self.cache_dir.mkdir(parents=True, exist_ok=True)
44
- self.shard_dir.mkdir(parents=True, exist_ok=True)
45
-
46
- def _unzip_if_needed(self):
47
- """Extract the ZIP archive if no parquet files are found."""
48
- zip_path = Path("metadata_shards.zip")
49
- if not any(self.shard_dir.rglob("*.parquet")):
50
- logger.info("No parquet files found, checking for zip archive")
51
- if not zip_path.exists():
52
- raise FileNotFoundError(f"Metadata ZIP file not found at {zip_path}")
53
- logger.info(f"Extracting {zip_path} to {self.shard_dir}")
54
- try:
55
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
56
- zip_root = self._get_zip_root(zip_ref)
57
- zip_ref.extractall(self.shard_dir)
58
- if zip_root:
59
- nested_dir = self.shard_dir / zip_root
60
- if nested_dir.exists():
61
- self._flatten_directory(nested_dir, self.shard_dir)
62
- nested_dir.rmdir()
63
- parquet_files = list(self.shard_dir.rglob("*.parquet"))
64
- if not parquet_files:
65
- raise RuntimeError("Extraction completed but no parquet files found")
66
- logger.info(f"Found {len(parquet_files)} parquet files after extraction")
67
- except Exception as e:
68
- logger.error(f"Failed to extract zip file: {str(e)}")
69
- self._clean_failed_extraction()
70
- raise
71
-
72
- def _get_zip_root(self, zip_ref):
73
- """Identify the common root directory within the ZIP file."""
74
- try:
75
- first_file = zip_ref.namelist()[0]
76
- if '/' in first_file:
77
- return first_file.split('/')[0]
78
- return ""
79
- except Exception as e:
80
- logger.warning(f"Error detecting zip root: {str(e)}")
81
- return ""
82
-
83
- def _flatten_directory(self, src_dir, dest_dir):
84
- """Move files from a nested directory up to the destination."""
85
- for item in src_dir.iterdir():
86
- if item.is_dir():
87
- self._flatten_directory(item, dest_dir)
88
- item.rmdir()
89
- else:
90
- target = dest_dir / item.name
91
- if target.exists():
92
- target.unlink()
93
- item.rename(target)
94
-
95
- def _clean_failed_extraction(self):
96
- """Clean up files from a failed extraction attempt."""
97
- logger.info("Cleaning up failed extraction")
98
- for item in self.shard_dir.iterdir():
99
- if item.is_dir():
100
- shutil.rmtree(item)
101
- else:
102
- item.unlink()
103
-
104
- def _build_shard_map(self):
105
- """Build a map from global index ranges to shard filenames."""
106
- logger.info("Building shard map from parquet files")
107
- parquet_files = list(self.shard_dir.glob("*.parquet"))
108
- if not parquet_files:
109
- raise FileNotFoundError("No parquet files found after extraction")
110
- parquet_files = sorted(parquet_files, key=lambda x: int(x.stem.split("_")[1]))
111
- expected_start = 0
112
- for f in parquet_files:
113
- try:
114
- parts = f.stem.split("_")
115
- if len(parts) != 3:
116
- raise ValueError("Invalid filename format")
117
- start = int(parts[1])
118
- end = int(parts[2])
119
- if start != expected_start:
120
- raise ValueError(f"Non-contiguous shard start: expected {expected_start}, got {start}")
121
- if end <= start:
122
- raise ValueError(f"Invalid shard range: {start}-{end}")
123
- self.shard_map[(start, end)] = f.name
124
- self.total_docs = end + 1
125
- expected_start = end + 1
126
- logger.debug(f"Mapped shard {f.name}: indices {start}-{end}")
127
- except Exception as e:
128
- logger.error(f"Error processing shard {f.name}: {str(e)}")
129
- raise RuntimeError("Invalid shard structure") from e
130
- logger.info(f"Validated {len(self.shard_map)} continuous shards")
131
- logger.info(f"Total document count: {self.total_docs}")
132
- sorted_ranges = sorted(self.shard_map.keys())
133
- for i in range(1, len(sorted_ranges)):
134
- prev_end = sorted_ranges[i-1][1]
135
- curr_start = sorted_ranges[i][0]
136
- if curr_start != prev_end + 1:
137
- logger.warning(f"Gap or overlap detected between shards: {prev_end} to {curr_start}")
138
-
139
- def _process_shard(self, shard, local_indices):
140
- """Load a shard (if not already loaded) and retrieve the specified rows."""
141
  try:
142
- if shard not in self.loaded_shards:
143
- shard_path = self.shard_dir / shard
144
- if not shard_path.exists():
145
- logger.error(f"Shard file not found: {shard_path}")
146
- return pd.DataFrame(columns=["title", "summary", "similarity", "authors", "source"])
147
-
148
- file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
149
- logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
150
-
151
- try:
152
- # Load with explicit dtype for source column
153
- self.loaded_shards[shard] = pd.read_parquet(
154
- shard_path,
155
- columns=["title", "summary", "source", "authors"]
156
- )
157
- # Convert source to string type explicitly
158
- self.loaded_shards[shard]['source'] = self.loaded_shards[shard]['source'].astype(str)
159
- # Convert source strings to lists
160
- self.loaded_shards[shard]['source'] = self.loaded_shards[shard]['source'].apply(
161
- lambda x: x.split("; ") if isinstance(x, str) else []
162
- )
163
- # Handle missing summaries
164
- self.loaded_shards[shard]['summary'] = self.loaded_shards[shard]['summary'].fillna("")
165
- logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
166
-
167
- except Exception as e:
168
- logger.error(f"Failed to read parquet file {shard}: {str(e)}")
169
- return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
170
- df = self.loaded_shards[shard]
171
- df_len = len(df)
172
- valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
173
- if len(valid_local_indices) != len(local_indices):
174
- logger.warning(f"Filtered {len(local_indices) - len(valid_local_indices)} out-of-bounds indices in shard {shard}")
175
- if valid_local_indices:
176
- chunk = df.iloc[valid_local_indices]
177
- logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
178
- return chunk
179
-
180
  except Exception as e:
181
- logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
182
- return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
183
-
184
  def get_metadata(self, global_indices):
185
- """Retrieve metadata for a batch of global indices using parallel shard processing."""
186
  if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
187
- logger.warning("Empty indices array passed to get_metadata")
188
- return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
189
-
190
- indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
191
- logger.info(f"Retrieving metadata for {len(indices_list)} indices")
192
- valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
193
- invalid_count = len(indices_list) - len(valid_indices)
194
- if invalid_count > 0:
195
- logger.warning(f"Filtered out {invalid_count} invalid indices")
196
- if not valid_indices:
197
- logger.warning("No valid indices remain after filtering")
198
- return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
199
 
200
- # Group indices by shard
201
- shard_groups = {}
202
- for idx in valid_indices:
203
- found = False
204
- for (start, end), shard in self.shard_map.items():
205
- if start <= idx <= end:
206
- shard_groups.setdefault(shard, []).append(idx - start)
207
- found = True
208
- break
209
- if not found:
210
- logger.warning(f"Index {idx} not found in any shard range")
211
-
212
- # Process shards concurrently
213
- results = []
214
- with concurrent.futures.ThreadPoolExecutor() as executor:
215
- futures = [executor.submit(self._process_shard, shard, local_indices)
216
- for shard, local_indices in shard_groups.items()]
217
- for future in concurrent.futures.as_completed(futures):
218
- df_chunk = future.result()
219
- if not df_chunk.empty:
220
- results.append(df_chunk)
221
-
222
- if results:
223
- combined = pd.concat(results).reset_index(drop=True)
224
- logger.info(f"Combined metadata: {len(combined)} records from {len(results)} shards")
225
- return combined
226
- else:
227
- logger.warning("No metadata records retrieved")
228
- return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
229
 
230
 
231
  class SemanticSearch:
@@ -383,13 +224,13 @@ class SemanticSearch:
383
  results['similarity'] = distances
384
 
385
  # Ensure URL lists are properly formatted
386
- results['source'] = results['source'].apply(
387
- lambda x: [
388
- url.strip().rstrip(')') # Clean trailing parentheses and whitespace
389
- for url in str(x).split(';') # Split on semicolons
390
- if url.strip() # Remove empty strings
391
- ] if isinstance(x, (str, list)) else []
392
- )
393
 
394
  # Deduplicate and sort
395
  required_columns = ["title", "summary", "authors", "source", "similarity"]
 
24
 
25
  class MetadataManager:
26
  def __init__(self):
27
+ self.metadata_path = Path("combined.parquet")
28
+ self.df = None
 
 
29
  self.total_docs = 0
 
30
 
31
  logger.info("Initializing MetadataManager")
32
+ self._load_metadata()
 
 
33
  logger.info(f"Total documents indexed: {self.total_docs}")
34
+
35
+ def _load_metadata(self):
36
+ """Load the combined parquet file directly"""
37
+ logger.info("Loading metadata from combined.parquet")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
+ # Load the parquet file
40
+ self.df = pd.read_parquet(self.metadata_path)
41
+
42
+ # Clean and format the data
43
+ self.df['source'] = self.df['source'].apply(
44
+ lambda x: [
45
+ url.strip()
46
+ for url in str(x).split(';')
47
+ if url.strip()
48
+ ]
49
+ )
50
+ self.total_docs = len(self.df)
51
+
52
+ logger.info(f"Successfully loaded {self.total_docs} documents")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
+ logger.error(f"Failed to load metadata: {str(e)}")
55
+ raise
56
+
57
  def get_metadata(self, global_indices):
58
+ """Retrieve metadata for given indices"""
59
  if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
60
+ return pd.DataFrame(columns=["title", "summary", 'authors', "similarity", "source"])
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ try:
63
+ # Directly index the DataFrame
64
+ results = self.df.iloc[global_indices].copy()
65
+ return results.reset_index(drop=True)
66
+ except Exception as e:
67
+ logger.error(f"Metadata retrieval failed: {str(e)}")
68
+ return pd.DataFrame(columns=["title", "summary", "similarity", "source", 'authors'])
69
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  class SemanticSearch:
 
224
  results['similarity'] = distances
225
 
226
  # Ensure URL lists are properly formatted
227
+ # results['source'] = results['source'].apply(
228
+ # lambda x: [
229
+ # url.strip().rstrip(')') # Clean trailing parentheses and whitespace
230
+ # for url in str(x).split(';') # Split on semicolons
231
+ # if url.strip() # Remove empty strings
232
+ # ] if isinstance(x, (str, list)) else []
233
+ # )
234
 
235
  # Deduplicate and sort
236
  required_columns = ["title", "summary", "authors", "source", "similarity"]