Spaces:
Running
Running
Update search_utils.py
Browse files- search_utils.py +41 -200
search_utils.py
CHANGED
@@ -24,208 +24,49 @@ logger = logging.getLogger("MetadataManager")
|
|
24 |
|
25 |
class MetadataManager:
|
26 |
def __init__(self):
|
27 |
-
self.
|
28 |
-
self.
|
29 |
-
self.shard_map = {}
|
30 |
-
self.loaded_shards = {}
|
31 |
self.total_docs = 0
|
32 |
-
self.api_cache = {}
|
33 |
|
34 |
logger.info("Initializing MetadataManager")
|
35 |
-
self.
|
36 |
-
self._unzip_if_needed()
|
37 |
-
self._build_shard_map()
|
38 |
logger.info(f"Total documents indexed: {self.total_docs}")
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
"
|
43 |
-
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
44 |
-
self.shard_dir.mkdir(parents=True, exist_ok=True)
|
45 |
-
|
46 |
-
def _unzip_if_needed(self):
|
47 |
-
"""Extract the ZIP archive if no parquet files are found."""
|
48 |
-
zip_path = Path("metadata_shards.zip")
|
49 |
-
if not any(self.shard_dir.rglob("*.parquet")):
|
50 |
-
logger.info("No parquet files found, checking for zip archive")
|
51 |
-
if not zip_path.exists():
|
52 |
-
raise FileNotFoundError(f"Metadata ZIP file not found at {zip_path}")
|
53 |
-
logger.info(f"Extracting {zip_path} to {self.shard_dir}")
|
54 |
-
try:
|
55 |
-
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
56 |
-
zip_root = self._get_zip_root(zip_ref)
|
57 |
-
zip_ref.extractall(self.shard_dir)
|
58 |
-
if zip_root:
|
59 |
-
nested_dir = self.shard_dir / zip_root
|
60 |
-
if nested_dir.exists():
|
61 |
-
self._flatten_directory(nested_dir, self.shard_dir)
|
62 |
-
nested_dir.rmdir()
|
63 |
-
parquet_files = list(self.shard_dir.rglob("*.parquet"))
|
64 |
-
if not parquet_files:
|
65 |
-
raise RuntimeError("Extraction completed but no parquet files found")
|
66 |
-
logger.info(f"Found {len(parquet_files)} parquet files after extraction")
|
67 |
-
except Exception as e:
|
68 |
-
logger.error(f"Failed to extract zip file: {str(e)}")
|
69 |
-
self._clean_failed_extraction()
|
70 |
-
raise
|
71 |
-
|
72 |
-
def _get_zip_root(self, zip_ref):
|
73 |
-
"""Identify the common root directory within the ZIP file."""
|
74 |
-
try:
|
75 |
-
first_file = zip_ref.namelist()[0]
|
76 |
-
if '/' in first_file:
|
77 |
-
return first_file.split('/')[0]
|
78 |
-
return ""
|
79 |
-
except Exception as e:
|
80 |
-
logger.warning(f"Error detecting zip root: {str(e)}")
|
81 |
-
return ""
|
82 |
-
|
83 |
-
def _flatten_directory(self, src_dir, dest_dir):
|
84 |
-
"""Move files from a nested directory up to the destination."""
|
85 |
-
for item in src_dir.iterdir():
|
86 |
-
if item.is_dir():
|
87 |
-
self._flatten_directory(item, dest_dir)
|
88 |
-
item.rmdir()
|
89 |
-
else:
|
90 |
-
target = dest_dir / item.name
|
91 |
-
if target.exists():
|
92 |
-
target.unlink()
|
93 |
-
item.rename(target)
|
94 |
-
|
95 |
-
def _clean_failed_extraction(self):
|
96 |
-
"""Clean up files from a failed extraction attempt."""
|
97 |
-
logger.info("Cleaning up failed extraction")
|
98 |
-
for item in self.shard_dir.iterdir():
|
99 |
-
if item.is_dir():
|
100 |
-
shutil.rmtree(item)
|
101 |
-
else:
|
102 |
-
item.unlink()
|
103 |
-
|
104 |
-
def _build_shard_map(self):
|
105 |
-
"""Build a map from global index ranges to shard filenames."""
|
106 |
-
logger.info("Building shard map from parquet files")
|
107 |
-
parquet_files = list(self.shard_dir.glob("*.parquet"))
|
108 |
-
if not parquet_files:
|
109 |
-
raise FileNotFoundError("No parquet files found after extraction")
|
110 |
-
parquet_files = sorted(parquet_files, key=lambda x: int(x.stem.split("_")[1]))
|
111 |
-
expected_start = 0
|
112 |
-
for f in parquet_files:
|
113 |
-
try:
|
114 |
-
parts = f.stem.split("_")
|
115 |
-
if len(parts) != 3:
|
116 |
-
raise ValueError("Invalid filename format")
|
117 |
-
start = int(parts[1])
|
118 |
-
end = int(parts[2])
|
119 |
-
if start != expected_start:
|
120 |
-
raise ValueError(f"Non-contiguous shard start: expected {expected_start}, got {start}")
|
121 |
-
if end <= start:
|
122 |
-
raise ValueError(f"Invalid shard range: {start}-{end}")
|
123 |
-
self.shard_map[(start, end)] = f.name
|
124 |
-
self.total_docs = end + 1
|
125 |
-
expected_start = end + 1
|
126 |
-
logger.debug(f"Mapped shard {f.name}: indices {start}-{end}")
|
127 |
-
except Exception as e:
|
128 |
-
logger.error(f"Error processing shard {f.name}: {str(e)}")
|
129 |
-
raise RuntimeError("Invalid shard structure") from e
|
130 |
-
logger.info(f"Validated {len(self.shard_map)} continuous shards")
|
131 |
-
logger.info(f"Total document count: {self.total_docs}")
|
132 |
-
sorted_ranges = sorted(self.shard_map.keys())
|
133 |
-
for i in range(1, len(sorted_ranges)):
|
134 |
-
prev_end = sorted_ranges[i-1][1]
|
135 |
-
curr_start = sorted_ranges[i][0]
|
136 |
-
if curr_start != prev_end + 1:
|
137 |
-
logger.warning(f"Gap or overlap detected between shards: {prev_end} to {curr_start}")
|
138 |
-
|
139 |
-
def _process_shard(self, shard, local_indices):
|
140 |
-
"""Load a shard (if not already loaded) and retrieve the specified rows."""
|
141 |
try:
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
)
|
157 |
-
# Convert source to string type explicitly
|
158 |
-
self.loaded_shards[shard]['source'] = self.loaded_shards[shard]['source'].astype(str)
|
159 |
-
# Convert source strings to lists
|
160 |
-
self.loaded_shards[shard]['source'] = self.loaded_shards[shard]['source'].apply(
|
161 |
-
lambda x: x.split("; ") if isinstance(x, str) else []
|
162 |
-
)
|
163 |
-
# Handle missing summaries
|
164 |
-
self.loaded_shards[shard]['summary'] = self.loaded_shards[shard]['summary'].fillna("")
|
165 |
-
logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
|
166 |
-
|
167 |
-
except Exception as e:
|
168 |
-
logger.error(f"Failed to read parquet file {shard}: {str(e)}")
|
169 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
|
170 |
-
df = self.loaded_shards[shard]
|
171 |
-
df_len = len(df)
|
172 |
-
valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
|
173 |
-
if len(valid_local_indices) != len(local_indices):
|
174 |
-
logger.warning(f"Filtered {len(local_indices) - len(valid_local_indices)} out-of-bounds indices in shard {shard}")
|
175 |
-
if valid_local_indices:
|
176 |
-
chunk = df.iloc[valid_local_indices]
|
177 |
-
logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
|
178 |
-
return chunk
|
179 |
-
|
180 |
except Exception as e:
|
181 |
-
logger.error(f"
|
182 |
-
|
183 |
-
|
184 |
def get_metadata(self, global_indices):
|
185 |
-
"""Retrieve metadata for
|
186 |
if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
|
187 |
-
|
188 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
189 |
-
|
190 |
-
indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
|
191 |
-
logger.info(f"Retrieving metadata for {len(indices_list)} indices")
|
192 |
-
valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
|
193 |
-
invalid_count = len(indices_list) - len(valid_indices)
|
194 |
-
if invalid_count > 0:
|
195 |
-
logger.warning(f"Filtered out {invalid_count} invalid indices")
|
196 |
-
if not valid_indices:
|
197 |
-
logger.warning("No valid indices remain after filtering")
|
198 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
199 |
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
break
|
209 |
-
if not found:
|
210 |
-
logger.warning(f"Index {idx} not found in any shard range")
|
211 |
-
|
212 |
-
# Process shards concurrently
|
213 |
-
results = []
|
214 |
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
215 |
-
futures = [executor.submit(self._process_shard, shard, local_indices)
|
216 |
-
for shard, local_indices in shard_groups.items()]
|
217 |
-
for future in concurrent.futures.as_completed(futures):
|
218 |
-
df_chunk = future.result()
|
219 |
-
if not df_chunk.empty:
|
220 |
-
results.append(df_chunk)
|
221 |
-
|
222 |
-
if results:
|
223 |
-
combined = pd.concat(results).reset_index(drop=True)
|
224 |
-
logger.info(f"Combined metadata: {len(combined)} records from {len(results)} shards")
|
225 |
-
return combined
|
226 |
-
else:
|
227 |
-
logger.warning("No metadata records retrieved")
|
228 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
229 |
|
230 |
|
231 |
class SemanticSearch:
|
@@ -383,13 +224,13 @@ class SemanticSearch:
|
|
383 |
results['similarity'] = distances
|
384 |
|
385 |
# Ensure URL lists are properly formatted
|
386 |
-
results['source'] = results['source'].apply(
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
)
|
393 |
|
394 |
# Deduplicate and sort
|
395 |
required_columns = ["title", "summary", "authors", "source", "similarity"]
|
|
|
24 |
|
25 |
class MetadataManager:
|
26 |
def __init__(self):
|
27 |
+
self.metadata_path = Path("combined.parquet")
|
28 |
+
self.df = None
|
|
|
|
|
29 |
self.total_docs = 0
|
|
|
30 |
|
31 |
logger.info("Initializing MetadataManager")
|
32 |
+
self._load_metadata()
|
|
|
|
|
33 |
logger.info(f"Total documents indexed: {self.total_docs}")
|
34 |
+
|
35 |
+
def _load_metadata(self):
|
36 |
+
"""Load the combined parquet file directly"""
|
37 |
+
logger.info("Loading metadata from combined.parquet")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
try:
|
39 |
+
# Load the parquet file
|
40 |
+
self.df = pd.read_parquet(self.metadata_path)
|
41 |
+
|
42 |
+
# Clean and format the data
|
43 |
+
self.df['source'] = self.df['source'].apply(
|
44 |
+
lambda x: [
|
45 |
+
url.strip()
|
46 |
+
for url in str(x).split(';')
|
47 |
+
if url.strip()
|
48 |
+
]
|
49 |
+
)
|
50 |
+
self.total_docs = len(self.df)
|
51 |
+
|
52 |
+
logger.info(f"Successfully loaded {self.total_docs} documents")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
except Exception as e:
|
54 |
+
logger.error(f"Failed to load metadata: {str(e)}")
|
55 |
+
raise
|
56 |
+
|
57 |
def get_metadata(self, global_indices):
|
58 |
+
"""Retrieve metadata for given indices"""
|
59 |
if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
|
60 |
+
return pd.DataFrame(columns=["title", "summary", 'authors', "similarity", "source"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
try:
|
63 |
+
# Directly index the DataFrame
|
64 |
+
results = self.df.iloc[global_indices].copy()
|
65 |
+
return results.reset_index(drop=True)
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Metadata retrieval failed: {str(e)}")
|
68 |
+
return pd.DataFrame(columns=["title", "summary", "similarity", "source", 'authors'])
|
69 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
class SemanticSearch:
|
|
|
224 |
results['similarity'] = distances
|
225 |
|
226 |
# Ensure URL lists are properly formatted
|
227 |
+
# results['source'] = results['source'].apply(
|
228 |
+
# lambda x: [
|
229 |
+
# url.strip().rstrip(')') # Clean trailing parentheses and whitespace
|
230 |
+
# for url in str(x).split(';') # Split on semicolons
|
231 |
+
# if url.strip() # Remove empty strings
|
232 |
+
# ] if isinstance(x, (str, list)) else []
|
233 |
+
# )
|
234 |
|
235 |
# Deduplicate and sort
|
236 |
required_columns = ["title", "summary", "authors", "source", "similarity"]
|