Testys commited on
Commit
d286a45
·
1 Parent(s): ce1eaaf

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +74 -30
search_utils.py CHANGED
@@ -21,60 +21,104 @@ logger = logging.getLogger("MetadataManager")
21
 
22
  class MetadataManager:
23
  def __init__(self):
24
- self.shard_dir = Path("metadata_shards")
 
25
  self.shard_map = {}
26
  self.loaded_shards = {}
27
  self.total_docs = 0
28
 
29
  logger.info("Initializing MetadataManager")
30
- self._ensure_unzipped()
 
31
  self._build_shard_map()
32
  logger.info(f"Total documents indexed: {self.total_docs}")
33
  logger.info(f"Total shards found: {len(self.shard_map)}")
34
 
35
- def _ensure_unzipped(self):
36
- """Handle ZIP extraction without Streamlit elements"""
37
- logger.info(f"Checking for shard directory: {self.shard_dir}")
38
- if not self.shard_dir.exists():
39
- zip_path = Path("metadata_shards.zip")
40
- logger.info(f"Shard directory not found, looking for zip file: {zip_path}")
41
- if zip_path.exists():
42
- logger.info(f"Extracting from zip file: {zip_path}")
43
- start_time = time.time()
 
 
 
 
 
 
 
 
 
44
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
 
 
 
 
 
45
  zip_ref.extractall(self.shard_dir)
46
- logger.info(f"Extraction completed in {time.time() - start_time:.2f} seconds")
47
- else:
48
- error_msg = "Metadata ZIP file not found"
49
- logger.error(error_msg)
50
- raise FileNotFoundError(error_msg)
51
- else:
52
- logger.info("Shard directory exists, skipping extraction")
 
 
 
 
 
 
53
 
 
 
 
54
  def _build_shard_map(self):
55
- """Create index range to shard mapping"""
56
  logger.info("Building shard map from parquet files")
57
- self.total_docs = 0
58
- shard_files = list(self.shard_dir.glob("*.parquet"))
59
- logger.info(f"Found {len(shard_files)} parquet files")
60
 
61
- if not shard_files:
62
- logger.warning("No parquet files found in shard directory")
63
 
64
- for f in sorted(shard_files):
 
 
 
 
 
 
65
  try:
66
  parts = f.stem.split("_")
67
- if len(parts) < 3:
68
- logger.warning(f"Skipping file with invalid name format: {f}")
69
- continue
70
 
71
  start = int(parts[1])
72
  end = int(parts[2])
 
 
 
 
 
 
 
 
 
73
  self.shard_map[(start, end)] = f.name
74
- self.total_docs = max(self.total_docs, end + 1)
 
 
75
  logger.debug(f"Mapped shard {f.name}: indices {start}-{end}")
 
76
  except Exception as e:
77
- logger.error(f"Error parsing shard filename {f}: {str(e)}")
 
 
 
 
78
 
79
  # Log shard statistics
80
  logger.info(f"Shard map built with {len(self.shard_map)} shards")
 
21
 
22
  class MetadataManager:
23
  def __init__(self):
24
+ self.cache_dir = Path("unzipped_cache")
25
+ self.shard_dir = self.cache_dir / "metadata_shards"
26
  self.shard_map = {}
27
  self.loaded_shards = {}
28
  self.total_docs = 0
29
 
30
  logger.info("Initializing MetadataManager")
31
+ self._ensure_directories()
32
+ self._unzip_if_needed()
33
  self._build_shard_map()
34
  logger.info(f"Total documents indexed: {self.total_docs}")
35
  logger.info(f"Total shards found: {len(self.shard_map)}")
36
 
37
+ def _ensure_directories(self):
38
+ """Create necessary directories if they don't exist"""
39
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
40
+ self.shard_dir.mkdir(parents=True, exist_ok=True)
41
+
42
+ def _unzip_if_needed(self):
43
+ """Handle ZIP extraction with validation and retries"""
44
+ zip_path = Path("metadata_shards.zip")
45
+
46
+ # Check if we need to unzip
47
+ if not any(self.shard_dir.glob("*.parquet")):
48
+ logger.info("No parquet files found, checking for zip archive")
49
+
50
+ if not zip_path.exists():
51
+ raise FileNotFoundError(f"Metadata ZIP file not found at {zip_path}")
52
+
53
+ logger.info(f"Extracting {zip_path} to {self.shard_dir}")
54
+ try:
55
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
56
+ # Validate zip contents before extraction
57
+ zip_files = zip_ref.namelist()
58
+ if not any(fname.endswith('.parquet') for fname in zip_files):
59
+ raise ValueError("ZIP file contains no parquet files")
60
+
61
  zip_ref.extractall(self.shard_dir)
62
+ logger.info(f"Extracted {len(zip_files)} files")
63
+
64
+ # Verify extraction succeeded
65
+ if not any(self.shard_dir.glob("*.parquet")):
66
+ raise RuntimeError("Extraction completed but no parquet files found")
67
+
68
+ except Exception as e:
69
+ logger.error(f"Failed to extract zip file: {str(e)}")
70
+ # Clean up partial extraction
71
+ if any(self.shard_dir.iterdir()):
72
+ for f in self.shard_dir.glob("*"):
73
+ f.unlink()
74
+ raise
75
 
76
+ else:
77
+ logger.info("Parquet files already exist in cache directory")
78
+
79
  def _build_shard_map(self):
80
+ """Create validated index range to shard mapping"""
81
  logger.info("Building shard map from parquet files")
82
+ parquet_files = list(self.shard_dir.glob("*.parquet"))
 
 
83
 
84
+ if not parquet_files:
85
+ raise FileNotFoundError("No parquet files found after extraction")
86
 
87
+ # Sort files by numerical order
88
+ parquet_files = sorted(parquet_files, key=lambda x: int(x.stem.split("_")[1]))
89
+
90
+ # Track expected next index
91
+ expected_start = 0
92
+
93
+ for f in parquet_files:
94
  try:
95
  parts = f.stem.split("_")
96
+ if len(parts) != 3:
97
+ raise ValueError("Invalid filename format")
 
98
 
99
  start = int(parts[1])
100
  end = int(parts[2])
101
+
102
+ # Validate continuity
103
+ if start != expected_start:
104
+ raise ValueError(f"Non-contiguous shard start: expected {expected_start}, got {start}")
105
+
106
+ # Validate range
107
+ if end <= start:
108
+ raise ValueError(f"Invalid shard range: {start}-{end}")
109
+
110
  self.shard_map[(start, end)] = f.name
111
+ self.total_docs = end + 1
112
+ expected_start = end + 1
113
+
114
  logger.debug(f"Mapped shard {f.name}: indices {start}-{end}")
115
+
116
  except Exception as e:
117
+ logger.error(f"Error processing shard {f.name}: {str(e)}")
118
+ raise RuntimeError("Invalid shard structure") from e
119
+
120
+ logger.info(f"Validated {len(self.shard_map)} continuous shards")
121
+ logger.info(f"Total document count: {self.total_docs}")
122
 
123
  # Log shard statistics
124
  logger.info(f"Shard map built with {len(self.shard_map)} shards")