jbilcke-hf HF Staff commited on
Commit
7f039e5
·
1 Parent(s): 55a4bb8

the HF space got corrupted, here's an attempt at salvaging it

Browse files
vms/ui/project/services/training.py CHANGED
@@ -810,9 +810,18 @@ class TrainingService:
810
 
811
  # Update with resume_from_checkpoint if provided
812
  if resume_from_checkpoint:
813
- config.resume_from_checkpoint = resume_from_checkpoint
814
- self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint} (will use 'latest')")
815
- config.resume_from_checkpoint = "latest"
 
 
 
 
 
 
 
 
 
816
 
817
  # Common settings for both models
818
  config.mixed_precision = DEFAULT_MIXED_PRECISION
@@ -1088,6 +1097,77 @@ class TrainingService:
1088
  except:
1089
  return False
1090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1091
  def recover_interrupted_training(self) -> Dict[str, Any]:
1092
  """Attempt to recover interrupted training
1093
 
@@ -1097,9 +1177,9 @@ class TrainingService:
1097
  status = self.get_status()
1098
  ui_updates = {}
1099
 
1100
- # Check for any checkpoints, even if status doesn't indicate training
1101
- checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1102
- has_checkpoints = len(checkpoints) > 0
1103
 
1104
  # If status indicates training but process isn't running, or if we have checkpoints
1105
  # and no active training process, try to recover
@@ -1145,15 +1225,13 @@ class TrainingService:
1145
  }
1146
  return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
1147
 
1148
- # Find the latest checkpoint if we have checkpoints
1149
  latest_checkpoint = None
1150
  checkpoint_step = 0
1151
 
1152
- if has_checkpoints:
1153
- # Find the latest checkpoint by step number
1154
- latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1155
- checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
1156
- logger.info(f"Found checkpoint at step {checkpoint_step}")
1157
 
1158
  # both options are valid, but imho it is easier to just return "latest"
1159
  # under the hood Finetrainers will convert ("latest") to (-1)
@@ -1480,17 +1558,20 @@ class TrainingService:
1480
  self.append_log(f"Error uploading to hub: {str(e)}")
1481
  return False
1482
 
1483
- def get_model_output_safetensors(self) -> Optional[str]:
1484
- """Return the path to the model safetensors
1485
 
1486
  Returns:
1487
- Path to safetensors file or None if not found
1488
  """
 
1489
 
1490
  # Check if the root level file exists (this should be the primary location)
1491
  model_output_safetensors_path = self.app.output_path / "pytorch_lora_weights.safetensors"
1492
  if model_output_safetensors_path.exists():
1493
- return str(model_output_safetensors_path)
 
 
1494
 
1495
  # Check in lora_weights directory
1496
  lora_weights_dir = self.app.output_path / "lora_weights"
@@ -1503,6 +1584,9 @@ class TrainingService:
1503
  latest_lora_checkpoint = max(lora_checkpoints, key=lambda x: int(x.name))
1504
  logger.info(f"Found latest LoRA checkpoint: {latest_lora_checkpoint}")
1505
 
 
 
 
1506
  # List contents of the latest checkpoint directory
1507
  checkpoint_contents = list(latest_lora_checkpoint.glob("*"))
1508
  logger.info(f"Contents of LoRA checkpoint {latest_lora_checkpoint.name}: {checkpoint_contents}")
@@ -1511,7 +1595,8 @@ class TrainingService:
1511
  lora_safetensors = latest_lora_checkpoint / "pytorch_lora_weights.safetensors"
1512
  if lora_safetensors.exists():
1513
  logger.info(f"Found weights in latest LoRA checkpoint: {lora_safetensors}")
1514
- return str(lora_safetensors)
 
1515
 
1516
  # Also check for other common weight file names
1517
  possible_weight_files = [
@@ -1525,24 +1610,27 @@ class TrainingService:
1525
  weight_path = latest_lora_checkpoint / weight_file
1526
  if weight_path.exists():
1527
  logger.info(f"Found weights file {weight_file} in latest LoRA checkpoint: {weight_path}")
1528
- return str(weight_path)
 
1529
 
1530
  # Check if any .safetensors files exist
1531
  safetensors_files = list(latest_lora_checkpoint.glob("*.safetensors"))
1532
  if safetensors_files:
1533
  logger.info(f"Found .safetensors files in LoRA checkpoint: {safetensors_files}")
1534
  # Return the first .safetensors file found
1535
- return str(safetensors_files[0])
 
1536
 
1537
  # Fallback: check for direct safetensors file in lora_weights root
1538
  lora_safetensors = lora_weights_dir / "pytorch_lora_weights.safetensors"
1539
  if lora_safetensors.exists():
1540
  logger.info(f"Found weights in lora_weights directory: {lora_safetensors}")
1541
- return str(lora_safetensors)
 
1542
  else:
1543
  logger.info(f"pytorch_lora_weights.safetensors not found in lora_weights directory")
1544
 
1545
- # If not found in root or lora_weights, log the issue
1546
  logger.warning(f"Model weights not found at expected location: {model_output_safetensors_path}")
1547
  logger.info(f"Checking output directory contents: {list(self.app.output_path.glob('*'))}")
1548
 
@@ -1553,6 +1641,9 @@ class TrainingService:
1553
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1554
  logger.info(f"Latest checkpoint directory: {latest_checkpoint}")
1555
 
 
 
 
1556
  # Log contents of latest checkpoint
1557
  checkpoint_contents = list(latest_checkpoint.glob("*"))
1558
  logger.info(f"Contents of latest checkpoint {latest_checkpoint.name}: {checkpoint_contents}")
@@ -1560,11 +1651,20 @@ class TrainingService:
1560
  checkpoint_weights = latest_checkpoint / "pytorch_lora_weights.safetensors"
1561
  if checkpoint_weights.exists():
1562
  logger.info(f"Found weights in latest checkpoint: {checkpoint_weights}")
1563
- return str(checkpoint_weights)
 
1564
  else:
1565
  logger.info(f"pytorch_lora_weights.safetensors not found in checkpoint directory")
1566
 
1567
- return None
 
 
 
 
 
 
 
 
1568
 
1569
  def create_training_dataset_zip(self) -> str:
1570
  """Create a ZIP file containing all training data
 
810
 
811
  # Update with resume_from_checkpoint if provided
812
  if resume_from_checkpoint:
813
+ # Validate checkpoints and find a valid one to resume from
814
+ valid_checkpoint = self.validate_and_find_valid_checkpoint()
815
+ if valid_checkpoint:
816
+ config.resume_from_checkpoint = "latest"
817
+ checkpoint_step = int(Path(valid_checkpoint).name.split("_")[-1])
818
+ self.append_log(f"Resuming from validated checkpoint at step {checkpoint_step}")
819
+ logger.info(f"Resuming from validated checkpoint: {valid_checkpoint}")
820
+ else:
821
+ error_msg = "No valid checkpoints found to resume from"
822
+ logger.error(error_msg)
823
+ self.append_log(error_msg)
824
+ return error_msg, "No valid checkpoints available"
825
 
826
  # Common settings for both models
827
  config.mixed_precision = DEFAULT_MIXED_PRECISION
 
1097
  except:
1098
  return False
1099
 
1100
+ def validate_and_find_valid_checkpoint(self) -> Optional[str]:
1101
+ """Validate checkpoint directories and find the most recent valid one
1102
+
1103
+ Returns:
1104
+ Path to valid checkpoint directory or None if no valid checkpoint found
1105
+ """
1106
+ # Find all checkpoint directories
1107
+ checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1108
+ if not checkpoints:
1109
+ logger.info("No checkpoint directories found")
1110
+ return None
1111
+
1112
+ # Sort by step number in descending order (latest first)
1113
+ sorted_checkpoints = sorted(checkpoints, key=lambda x: int(x.name.split("_")[-1]), reverse=True)
1114
+
1115
+ corrupted_checkpoints = []
1116
+
1117
+ for checkpoint_dir in sorted_checkpoints:
1118
+ step_num = int(checkpoint_dir.name.split("_")[-1])
1119
+ logger.info(f"Validating checkpoint at step {step_num}: {checkpoint_dir}")
1120
+
1121
+ # Check if the .metadata file exists
1122
+ metadata_file = checkpoint_dir / ".metadata"
1123
+ if not metadata_file.exists():
1124
+ logger.warning(f"Checkpoint {checkpoint_dir.name} is corrupted: missing .metadata file")
1125
+ corrupted_checkpoints.append(checkpoint_dir)
1126
+ continue
1127
+
1128
+ # Try to read the metadata file to ensure it's not corrupted
1129
+ try:
1130
+ with open(metadata_file, 'r') as f:
1131
+ metadata = json.load(f)
1132
+ # Basic validation - metadata should have expected structure
1133
+ if not isinstance(metadata, dict):
1134
+ raise ValueError("Invalid metadata format")
1135
+ logger.info(f"Checkpoint {checkpoint_dir.name} is valid")
1136
+
1137
+ # Clean up any corrupted checkpoints we found before this valid one
1138
+ if corrupted_checkpoints:
1139
+ self.cleanup_corrupted_checkpoints(corrupted_checkpoints)
1140
+
1141
+ return str(checkpoint_dir)
1142
+
1143
+ except (json.JSONDecodeError, IOError, ValueError) as e:
1144
+ logger.warning(f"Checkpoint {checkpoint_dir.name} is corrupted: failed to read .metadata: {e}")
1145
+ corrupted_checkpoints.append(checkpoint_dir)
1146
+ continue
1147
+
1148
+ # If we reach here, all checkpoints are corrupted
1149
+ if corrupted_checkpoints:
1150
+ logger.error("All checkpoint directories are corrupted")
1151
+ self.cleanup_corrupted_checkpoints(corrupted_checkpoints)
1152
+
1153
+ return None
1154
+
1155
+ def cleanup_corrupted_checkpoints(self, corrupted_checkpoints: List[Path]) -> None:
1156
+ """Remove corrupted checkpoint directories
1157
+
1158
+ Args:
1159
+ corrupted_checkpoints: List of corrupted checkpoint directory paths
1160
+ """
1161
+ for checkpoint_dir in corrupted_checkpoints:
1162
+ try:
1163
+ step_num = int(checkpoint_dir.name.split("_")[-1])
1164
+ logger.info(f"Removing corrupted checkpoint at step {step_num}: {checkpoint_dir}")
1165
+ shutil.rmtree(checkpoint_dir)
1166
+ self.append_log(f"Removed corrupted checkpoint: {checkpoint_dir.name}")
1167
+ except Exception as e:
1168
+ logger.error(f"Failed to remove corrupted checkpoint {checkpoint_dir}: {e}")
1169
+ self.append_log(f"Failed to remove corrupted checkpoint {checkpoint_dir.name}: {e}")
1170
+
1171
  def recover_interrupted_training(self) -> Dict[str, Any]:
1172
  """Attempt to recover interrupted training
1173
 
 
1177
  status = self.get_status()
1178
  ui_updates = {}
1179
 
1180
+ # Check for any valid checkpoints, even if status doesn't indicate training
1181
+ valid_checkpoint = self.validate_and_find_valid_checkpoint()
1182
+ has_checkpoints = valid_checkpoint is not None
1183
 
1184
  # If status indicates training but process isn't running, or if we have checkpoints
1185
  # and no active training process, try to recover
 
1225
  }
1226
  return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
1227
 
1228
+ # Use the valid checkpoint we found
1229
  latest_checkpoint = None
1230
  checkpoint_step = 0
1231
 
1232
+ if has_checkpoints and valid_checkpoint:
1233
+ checkpoint_step = int(Path(valid_checkpoint).name.split("_")[-1])
1234
+ logger.info(f"Found valid checkpoint at step {checkpoint_step}")
 
 
1235
 
1236
  # both options are valid, but imho it is easier to just return "latest"
1237
  # under the hood Finetrainers will convert ("latest") to (-1)
 
1558
  self.append_log(f"Error uploading to hub: {str(e)}")
1559
  return False
1560
 
1561
+ def get_model_output_info(self) -> Dict[str, Any]:
1562
+ """Return info about the model safetensors including path and step count
1563
 
1564
  Returns:
1565
+ Dict with 'path' (str or None) and 'steps' (int or None)
1566
  """
1567
+ result = {"path": None, "steps": None}
1568
 
1569
  # Check if the root level file exists (this should be the primary location)
1570
  model_output_safetensors_path = self.app.output_path / "pytorch_lora_weights.safetensors"
1571
  if model_output_safetensors_path.exists():
1572
+ result["path"] = str(model_output_safetensors_path)
1573
+ # For root level, we can't determine steps easily, so return None
1574
+ return result
1575
 
1576
  # Check in lora_weights directory
1577
  lora_weights_dir = self.app.output_path / "lora_weights"
 
1584
  latest_lora_checkpoint = max(lora_checkpoints, key=lambda x: int(x.name))
1585
  logger.info(f"Found latest LoRA checkpoint: {latest_lora_checkpoint}")
1586
 
1587
+ # Extract step count from directory name
1588
+ result["steps"] = int(latest_lora_checkpoint.name)
1589
+
1590
  # List contents of the latest checkpoint directory
1591
  checkpoint_contents = list(latest_lora_checkpoint.glob("*"))
1592
  logger.info(f"Contents of LoRA checkpoint {latest_lora_checkpoint.name}: {checkpoint_contents}")
 
1595
  lora_safetensors = latest_lora_checkpoint / "pytorch_lora_weights.safetensors"
1596
  if lora_safetensors.exists():
1597
  logger.info(f"Found weights in latest LoRA checkpoint: {lora_safetensors}")
1598
+ result["path"] = str(lora_safetensors)
1599
+ return result
1600
 
1601
  # Also check for other common weight file names
1602
  possible_weight_files = [
 
1610
  weight_path = latest_lora_checkpoint / weight_file
1611
  if weight_path.exists():
1612
  logger.info(f"Found weights file {weight_file} in latest LoRA checkpoint: {weight_path}")
1613
+ result["path"] = str(weight_path)
1614
+ return result
1615
 
1616
  # Check if any .safetensors files exist
1617
  safetensors_files = list(latest_lora_checkpoint.glob("*.safetensors"))
1618
  if safetensors_files:
1619
  logger.info(f"Found .safetensors files in LoRA checkpoint: {safetensors_files}")
1620
  # Return the first .safetensors file found
1621
+ result["path"] = str(safetensors_files[0])
1622
+ return result
1623
 
1624
  # Fallback: check for direct safetensors file in lora_weights root
1625
  lora_safetensors = lora_weights_dir / "pytorch_lora_weights.safetensors"
1626
  if lora_safetensors.exists():
1627
  logger.info(f"Found weights in lora_weights directory: {lora_safetensors}")
1628
+ result["path"] = str(lora_safetensors)
1629
+ return result
1630
  else:
1631
  logger.info(f"pytorch_lora_weights.safetensors not found in lora_weights directory")
1632
 
1633
+ # If not found in root or lora_weights, log the issue and check fallback
1634
  logger.warning(f"Model weights not found at expected location: {model_output_safetensors_path}")
1635
  logger.info(f"Checking output directory contents: {list(self.app.output_path.glob('*'))}")
1636
 
 
1641
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1642
  logger.info(f"Latest checkpoint directory: {latest_checkpoint}")
1643
 
1644
+ # Extract step count from checkpoint directory name
1645
+ result["steps"] = int(latest_checkpoint.name.split("_")[-1])
1646
+
1647
  # Log contents of latest checkpoint
1648
  checkpoint_contents = list(latest_checkpoint.glob("*"))
1649
  logger.info(f"Contents of latest checkpoint {latest_checkpoint.name}: {checkpoint_contents}")
 
1651
  checkpoint_weights = latest_checkpoint / "pytorch_lora_weights.safetensors"
1652
  if checkpoint_weights.exists():
1653
  logger.info(f"Found weights in latest checkpoint: {checkpoint_weights}")
1654
+ result["path"] = str(checkpoint_weights)
1655
+ return result
1656
  else:
1657
  logger.info(f"pytorch_lora_weights.safetensors not found in checkpoint directory")
1658
 
1659
+ return result
1660
+
1661
+ def get_model_output_safetensors(self) -> Optional[str]:
1662
+ """Return the path to the model safetensors
1663
+
1664
+ Returns:
1665
+ Path to safetensors file or None if not found
1666
+ """
1667
+ return self.get_model_output_info()["path"]
1668
 
1669
  def create_training_dataset_zip(self) -> str:
1670
  """Create a ZIP file containing all training data
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -25,6 +25,32 @@ class ManageTab(BaseTab):
25
  self.id = "manage_tab"
26
  self.title = "5️⃣ Storage"
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def create(self, parent=None) -> gr.TabItem:
29
  """Create the Manage tab UI components"""
30
  with gr.TabItem(self.title, id=self.id) as tab:
@@ -45,7 +71,7 @@ class ManageTab(BaseTab):
45
  gr.Markdown("📦 Training dataset download disabled for large datasets")
46
 
47
  self.components["download_model_btn"] = gr.DownloadButton(
48
- "🧠 Download weights (.safetensors)",
49
  variant="secondary",
50
  size="lg"
51
  )
 
25
  self.id = "manage_tab"
26
  self.title = "5️⃣ Storage"
27
 
28
+ def get_download_button_text(self) -> str:
29
+ """Get the dynamic text for the download button based on current model state"""
30
+ try:
31
+ model_info = self.app.training.get_model_output_info()
32
+ if model_info["path"] and model_info["steps"]:
33
+ return f"🧠 Download weights ({model_info['steps']} steps)"
34
+ elif model_info["path"]:
35
+ return "🧠 Download weights (.safetensors)"
36
+ else:
37
+ return "🧠 Download weights (not available)"
38
+ except Exception as e:
39
+ logger.warning(f"Error getting model info for button text: {e}")
40
+ return "🧠 Download weights (.safetensors)"
41
+
42
+ def update_download_button_text(self) -> gr.update:
43
+ """Update the download button text"""
44
+ return gr.update(value=self.get_download_button_text())
45
+
46
+ def download_and_update_button(self):
47
+ """Handle download and return updated button with current text"""
48
+ # Get the safetensors path for download
49
+ path = self.app.training.get_model_output_safetensors()
50
+ # For DownloadButton, we need to return the file path directly for download
51
+ # The button text will be updated on next render
52
+ return path
53
+
54
  def create(self, parent=None) -> gr.TabItem:
55
  """Create the Manage tab UI components"""
56
  with gr.TabItem(self.title, id=self.id) as tab:
 
71
  gr.Markdown("📦 Training dataset download disabled for large datasets")
72
 
73
  self.components["download_model_btn"] = gr.DownloadButton(
74
+ self.get_download_button_text(),
75
  variant="secondary",
76
  size="lg"
77
  )