Spaces:
Running
Running
Commit
·
7f039e5
1
Parent(s):
55a4bb8
the HF space got corrupted, here's an attempt at salvaging it
Browse files- vms/ui/project/services/training.py +123 -23
- vms/ui/project/tabs/manage_tab.py +27 -1
vms/ui/project/services/training.py
CHANGED
@@ -810,9 +810,18 @@ class TrainingService:
|
|
810 |
|
811 |
# Update with resume_from_checkpoint if provided
|
812 |
if resume_from_checkpoint:
|
813 |
-
|
814 |
-
self.
|
815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
816 |
|
817 |
# Common settings for both models
|
818 |
config.mixed_precision = DEFAULT_MIXED_PRECISION
|
@@ -1088,6 +1097,77 @@ class TrainingService:
|
|
1088 |
except:
|
1089 |
return False
|
1090 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1091 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
1092 |
"""Attempt to recover interrupted training
|
1093 |
|
@@ -1097,9 +1177,9 @@ class TrainingService:
|
|
1097 |
status = self.get_status()
|
1098 |
ui_updates = {}
|
1099 |
|
1100 |
-
# Check for any checkpoints, even if status doesn't indicate training
|
1101 |
-
|
1102 |
-
has_checkpoints =
|
1103 |
|
1104 |
# If status indicates training but process isn't running, or if we have checkpoints
|
1105 |
# and no active training process, try to recover
|
@@ -1145,15 +1225,13 @@ class TrainingService:
|
|
1145 |
}
|
1146 |
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
1147 |
|
1148 |
-
#
|
1149 |
latest_checkpoint = None
|
1150 |
checkpoint_step = 0
|
1151 |
|
1152 |
-
if has_checkpoints:
|
1153 |
-
|
1154 |
-
|
1155 |
-
checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
|
1156 |
-
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
1157 |
|
1158 |
# both options are valid, but imho it is easier to just return "latest"
|
1159 |
# under the hood Finetrainers will convert ("latest") to (-1)
|
@@ -1480,17 +1558,20 @@ class TrainingService:
|
|
1480 |
self.append_log(f"Error uploading to hub: {str(e)}")
|
1481 |
return False
|
1482 |
|
1483 |
-
def
|
1484 |
-
"""Return
|
1485 |
|
1486 |
Returns:
|
1487 |
-
|
1488 |
"""
|
|
|
1489 |
|
1490 |
# Check if the root level file exists (this should be the primary location)
|
1491 |
model_output_safetensors_path = self.app.output_path / "pytorch_lora_weights.safetensors"
|
1492 |
if model_output_safetensors_path.exists():
|
1493 |
-
|
|
|
|
|
1494 |
|
1495 |
# Check in lora_weights directory
|
1496 |
lora_weights_dir = self.app.output_path / "lora_weights"
|
@@ -1503,6 +1584,9 @@ class TrainingService:
|
|
1503 |
latest_lora_checkpoint = max(lora_checkpoints, key=lambda x: int(x.name))
|
1504 |
logger.info(f"Found latest LoRA checkpoint: {latest_lora_checkpoint}")
|
1505 |
|
|
|
|
|
|
|
1506 |
# List contents of the latest checkpoint directory
|
1507 |
checkpoint_contents = list(latest_lora_checkpoint.glob("*"))
|
1508 |
logger.info(f"Contents of LoRA checkpoint {latest_lora_checkpoint.name}: {checkpoint_contents}")
|
@@ -1511,7 +1595,8 @@ class TrainingService:
|
|
1511 |
lora_safetensors = latest_lora_checkpoint / "pytorch_lora_weights.safetensors"
|
1512 |
if lora_safetensors.exists():
|
1513 |
logger.info(f"Found weights in latest LoRA checkpoint: {lora_safetensors}")
|
1514 |
-
|
|
|
1515 |
|
1516 |
# Also check for other common weight file names
|
1517 |
possible_weight_files = [
|
@@ -1525,24 +1610,27 @@ class TrainingService:
|
|
1525 |
weight_path = latest_lora_checkpoint / weight_file
|
1526 |
if weight_path.exists():
|
1527 |
logger.info(f"Found weights file {weight_file} in latest LoRA checkpoint: {weight_path}")
|
1528 |
-
|
|
|
1529 |
|
1530 |
# Check if any .safetensors files exist
|
1531 |
safetensors_files = list(latest_lora_checkpoint.glob("*.safetensors"))
|
1532 |
if safetensors_files:
|
1533 |
logger.info(f"Found .safetensors files in LoRA checkpoint: {safetensors_files}")
|
1534 |
# Return the first .safetensors file found
|
1535 |
-
|
|
|
1536 |
|
1537 |
# Fallback: check for direct safetensors file in lora_weights root
|
1538 |
lora_safetensors = lora_weights_dir / "pytorch_lora_weights.safetensors"
|
1539 |
if lora_safetensors.exists():
|
1540 |
logger.info(f"Found weights in lora_weights directory: {lora_safetensors}")
|
1541 |
-
|
|
|
1542 |
else:
|
1543 |
logger.info(f"pytorch_lora_weights.safetensors not found in lora_weights directory")
|
1544 |
|
1545 |
-
# If not found in root or lora_weights, log the issue
|
1546 |
logger.warning(f"Model weights not found at expected location: {model_output_safetensors_path}")
|
1547 |
logger.info(f"Checking output directory contents: {list(self.app.output_path.glob('*'))}")
|
1548 |
|
@@ -1553,6 +1641,9 @@ class TrainingService:
|
|
1553 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1554 |
logger.info(f"Latest checkpoint directory: {latest_checkpoint}")
|
1555 |
|
|
|
|
|
|
|
1556 |
# Log contents of latest checkpoint
|
1557 |
checkpoint_contents = list(latest_checkpoint.glob("*"))
|
1558 |
logger.info(f"Contents of latest checkpoint {latest_checkpoint.name}: {checkpoint_contents}")
|
@@ -1560,11 +1651,20 @@ class TrainingService:
|
|
1560 |
checkpoint_weights = latest_checkpoint / "pytorch_lora_weights.safetensors"
|
1561 |
if checkpoint_weights.exists():
|
1562 |
logger.info(f"Found weights in latest checkpoint: {checkpoint_weights}")
|
1563 |
-
|
|
|
1564 |
else:
|
1565 |
logger.info(f"pytorch_lora_weights.safetensors not found in checkpoint directory")
|
1566 |
|
1567 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1568 |
|
1569 |
def create_training_dataset_zip(self) -> str:
|
1570 |
"""Create a ZIP file containing all training data
|
|
|
810 |
|
811 |
# Update with resume_from_checkpoint if provided
|
812 |
if resume_from_checkpoint:
|
813 |
+
# Validate checkpoints and find a valid one to resume from
|
814 |
+
valid_checkpoint = self.validate_and_find_valid_checkpoint()
|
815 |
+
if valid_checkpoint:
|
816 |
+
config.resume_from_checkpoint = "latest"
|
817 |
+
checkpoint_step = int(Path(valid_checkpoint).name.split("_")[-1])
|
818 |
+
self.append_log(f"Resuming from validated checkpoint at step {checkpoint_step}")
|
819 |
+
logger.info(f"Resuming from validated checkpoint: {valid_checkpoint}")
|
820 |
+
else:
|
821 |
+
error_msg = "No valid checkpoints found to resume from"
|
822 |
+
logger.error(error_msg)
|
823 |
+
self.append_log(error_msg)
|
824 |
+
return error_msg, "No valid checkpoints available"
|
825 |
|
826 |
# Common settings for both models
|
827 |
config.mixed_precision = DEFAULT_MIXED_PRECISION
|
|
|
1097 |
except:
|
1098 |
return False
|
1099 |
|
1100 |
+
def validate_and_find_valid_checkpoint(self) -> Optional[str]:
|
1101 |
+
"""Validate checkpoint directories and find the most recent valid one
|
1102 |
+
|
1103 |
+
Returns:
|
1104 |
+
Path to valid checkpoint directory or None if no valid checkpoint found
|
1105 |
+
"""
|
1106 |
+
# Find all checkpoint directories
|
1107 |
+
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1108 |
+
if not checkpoints:
|
1109 |
+
logger.info("No checkpoint directories found")
|
1110 |
+
return None
|
1111 |
+
|
1112 |
+
# Sort by step number in descending order (latest first)
|
1113 |
+
sorted_checkpoints = sorted(checkpoints, key=lambda x: int(x.name.split("_")[-1]), reverse=True)
|
1114 |
+
|
1115 |
+
corrupted_checkpoints = []
|
1116 |
+
|
1117 |
+
for checkpoint_dir in sorted_checkpoints:
|
1118 |
+
step_num = int(checkpoint_dir.name.split("_")[-1])
|
1119 |
+
logger.info(f"Validating checkpoint at step {step_num}: {checkpoint_dir}")
|
1120 |
+
|
1121 |
+
# Check if the .metadata file exists
|
1122 |
+
metadata_file = checkpoint_dir / ".metadata"
|
1123 |
+
if not metadata_file.exists():
|
1124 |
+
logger.warning(f"Checkpoint {checkpoint_dir.name} is corrupted: missing .metadata file")
|
1125 |
+
corrupted_checkpoints.append(checkpoint_dir)
|
1126 |
+
continue
|
1127 |
+
|
1128 |
+
# Try to read the metadata file to ensure it's not corrupted
|
1129 |
+
try:
|
1130 |
+
with open(metadata_file, 'r') as f:
|
1131 |
+
metadata = json.load(f)
|
1132 |
+
# Basic validation - metadata should have expected structure
|
1133 |
+
if not isinstance(metadata, dict):
|
1134 |
+
raise ValueError("Invalid metadata format")
|
1135 |
+
logger.info(f"Checkpoint {checkpoint_dir.name} is valid")
|
1136 |
+
|
1137 |
+
# Clean up any corrupted checkpoints we found before this valid one
|
1138 |
+
if corrupted_checkpoints:
|
1139 |
+
self.cleanup_corrupted_checkpoints(corrupted_checkpoints)
|
1140 |
+
|
1141 |
+
return str(checkpoint_dir)
|
1142 |
+
|
1143 |
+
except (json.JSONDecodeError, IOError, ValueError) as e:
|
1144 |
+
logger.warning(f"Checkpoint {checkpoint_dir.name} is corrupted: failed to read .metadata: {e}")
|
1145 |
+
corrupted_checkpoints.append(checkpoint_dir)
|
1146 |
+
continue
|
1147 |
+
|
1148 |
+
# If we reach here, all checkpoints are corrupted
|
1149 |
+
if corrupted_checkpoints:
|
1150 |
+
logger.error("All checkpoint directories are corrupted")
|
1151 |
+
self.cleanup_corrupted_checkpoints(corrupted_checkpoints)
|
1152 |
+
|
1153 |
+
return None
|
1154 |
+
|
1155 |
+
def cleanup_corrupted_checkpoints(self, corrupted_checkpoints: List[Path]) -> None:
|
1156 |
+
"""Remove corrupted checkpoint directories
|
1157 |
+
|
1158 |
+
Args:
|
1159 |
+
corrupted_checkpoints: List of corrupted checkpoint directory paths
|
1160 |
+
"""
|
1161 |
+
for checkpoint_dir in corrupted_checkpoints:
|
1162 |
+
try:
|
1163 |
+
step_num = int(checkpoint_dir.name.split("_")[-1])
|
1164 |
+
logger.info(f"Removing corrupted checkpoint at step {step_num}: {checkpoint_dir}")
|
1165 |
+
shutil.rmtree(checkpoint_dir)
|
1166 |
+
self.append_log(f"Removed corrupted checkpoint: {checkpoint_dir.name}")
|
1167 |
+
except Exception as e:
|
1168 |
+
logger.error(f"Failed to remove corrupted checkpoint {checkpoint_dir}: {e}")
|
1169 |
+
self.append_log(f"Failed to remove corrupted checkpoint {checkpoint_dir.name}: {e}")
|
1170 |
+
|
1171 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
1172 |
"""Attempt to recover interrupted training
|
1173 |
|
|
|
1177 |
status = self.get_status()
|
1178 |
ui_updates = {}
|
1179 |
|
1180 |
+
# Check for any valid checkpoints, even if status doesn't indicate training
|
1181 |
+
valid_checkpoint = self.validate_and_find_valid_checkpoint()
|
1182 |
+
has_checkpoints = valid_checkpoint is not None
|
1183 |
|
1184 |
# If status indicates training but process isn't running, or if we have checkpoints
|
1185 |
# and no active training process, try to recover
|
|
|
1225 |
}
|
1226 |
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
1227 |
|
1228 |
+
# Use the valid checkpoint we found
|
1229 |
latest_checkpoint = None
|
1230 |
checkpoint_step = 0
|
1231 |
|
1232 |
+
if has_checkpoints and valid_checkpoint:
|
1233 |
+
checkpoint_step = int(Path(valid_checkpoint).name.split("_")[-1])
|
1234 |
+
logger.info(f"Found valid checkpoint at step {checkpoint_step}")
|
|
|
|
|
1235 |
|
1236 |
# both options are valid, but imho it is easier to just return "latest"
|
1237 |
# under the hood Finetrainers will convert ("latest") to (-1)
|
|
|
1558 |
self.append_log(f"Error uploading to hub: {str(e)}")
|
1559 |
return False
|
1560 |
|
1561 |
+
def get_model_output_info(self) -> Dict[str, Any]:
|
1562 |
+
"""Return info about the model safetensors including path and step count
|
1563 |
|
1564 |
Returns:
|
1565 |
+
Dict with 'path' (str or None) and 'steps' (int or None)
|
1566 |
"""
|
1567 |
+
result = {"path": None, "steps": None}
|
1568 |
|
1569 |
# Check if the root level file exists (this should be the primary location)
|
1570 |
model_output_safetensors_path = self.app.output_path / "pytorch_lora_weights.safetensors"
|
1571 |
if model_output_safetensors_path.exists():
|
1572 |
+
result["path"] = str(model_output_safetensors_path)
|
1573 |
+
# For root level, we can't determine steps easily, so return None
|
1574 |
+
return result
|
1575 |
|
1576 |
# Check in lora_weights directory
|
1577 |
lora_weights_dir = self.app.output_path / "lora_weights"
|
|
|
1584 |
latest_lora_checkpoint = max(lora_checkpoints, key=lambda x: int(x.name))
|
1585 |
logger.info(f"Found latest LoRA checkpoint: {latest_lora_checkpoint}")
|
1586 |
|
1587 |
+
# Extract step count from directory name
|
1588 |
+
result["steps"] = int(latest_lora_checkpoint.name)
|
1589 |
+
|
1590 |
# List contents of the latest checkpoint directory
|
1591 |
checkpoint_contents = list(latest_lora_checkpoint.glob("*"))
|
1592 |
logger.info(f"Contents of LoRA checkpoint {latest_lora_checkpoint.name}: {checkpoint_contents}")
|
|
|
1595 |
lora_safetensors = latest_lora_checkpoint / "pytorch_lora_weights.safetensors"
|
1596 |
if lora_safetensors.exists():
|
1597 |
logger.info(f"Found weights in latest LoRA checkpoint: {lora_safetensors}")
|
1598 |
+
result["path"] = str(lora_safetensors)
|
1599 |
+
return result
|
1600 |
|
1601 |
# Also check for other common weight file names
|
1602 |
possible_weight_files = [
|
|
|
1610 |
weight_path = latest_lora_checkpoint / weight_file
|
1611 |
if weight_path.exists():
|
1612 |
logger.info(f"Found weights file {weight_file} in latest LoRA checkpoint: {weight_path}")
|
1613 |
+
result["path"] = str(weight_path)
|
1614 |
+
return result
|
1615 |
|
1616 |
# Check if any .safetensors files exist
|
1617 |
safetensors_files = list(latest_lora_checkpoint.glob("*.safetensors"))
|
1618 |
if safetensors_files:
|
1619 |
logger.info(f"Found .safetensors files in LoRA checkpoint: {safetensors_files}")
|
1620 |
# Return the first .safetensors file found
|
1621 |
+
result["path"] = str(safetensors_files[0])
|
1622 |
+
return result
|
1623 |
|
1624 |
# Fallback: check for direct safetensors file in lora_weights root
|
1625 |
lora_safetensors = lora_weights_dir / "pytorch_lora_weights.safetensors"
|
1626 |
if lora_safetensors.exists():
|
1627 |
logger.info(f"Found weights in lora_weights directory: {lora_safetensors}")
|
1628 |
+
result["path"] = str(lora_safetensors)
|
1629 |
+
return result
|
1630 |
else:
|
1631 |
logger.info(f"pytorch_lora_weights.safetensors not found in lora_weights directory")
|
1632 |
|
1633 |
+
# If not found in root or lora_weights, log the issue and check fallback
|
1634 |
logger.warning(f"Model weights not found at expected location: {model_output_safetensors_path}")
|
1635 |
logger.info(f"Checking output directory contents: {list(self.app.output_path.glob('*'))}")
|
1636 |
|
|
|
1641 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1642 |
logger.info(f"Latest checkpoint directory: {latest_checkpoint}")
|
1643 |
|
1644 |
+
# Extract step count from checkpoint directory name
|
1645 |
+
result["steps"] = int(latest_checkpoint.name.split("_")[-1])
|
1646 |
+
|
1647 |
# Log contents of latest checkpoint
|
1648 |
checkpoint_contents = list(latest_checkpoint.glob("*"))
|
1649 |
logger.info(f"Contents of latest checkpoint {latest_checkpoint.name}: {checkpoint_contents}")
|
|
|
1651 |
checkpoint_weights = latest_checkpoint / "pytorch_lora_weights.safetensors"
|
1652 |
if checkpoint_weights.exists():
|
1653 |
logger.info(f"Found weights in latest checkpoint: {checkpoint_weights}")
|
1654 |
+
result["path"] = str(checkpoint_weights)
|
1655 |
+
return result
|
1656 |
else:
|
1657 |
logger.info(f"pytorch_lora_weights.safetensors not found in checkpoint directory")
|
1658 |
|
1659 |
+
return result
|
1660 |
+
|
1661 |
+
def get_model_output_safetensors(self) -> Optional[str]:
|
1662 |
+
"""Return the path to the model safetensors
|
1663 |
+
|
1664 |
+
Returns:
|
1665 |
+
Path to safetensors file or None if not found
|
1666 |
+
"""
|
1667 |
+
return self.get_model_output_info()["path"]
|
1668 |
|
1669 |
def create_training_dataset_zip(self) -> str:
|
1670 |
"""Create a ZIP file containing all training data
|
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -25,6 +25,32 @@ class ManageTab(BaseTab):
|
|
25 |
self.id = "manage_tab"
|
26 |
self.title = "5️⃣ Storage"
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def create(self, parent=None) -> gr.TabItem:
|
29 |
"""Create the Manage tab UI components"""
|
30 |
with gr.TabItem(self.title, id=self.id) as tab:
|
@@ -45,7 +71,7 @@ class ManageTab(BaseTab):
|
|
45 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
46 |
|
47 |
self.components["download_model_btn"] = gr.DownloadButton(
|
48 |
-
|
49 |
variant="secondary",
|
50 |
size="lg"
|
51 |
)
|
|
|
25 |
self.id = "manage_tab"
|
26 |
self.title = "5️⃣ Storage"
|
27 |
|
28 |
+
def get_download_button_text(self) -> str:
|
29 |
+
"""Get the dynamic text for the download button based on current model state"""
|
30 |
+
try:
|
31 |
+
model_info = self.app.training.get_model_output_info()
|
32 |
+
if model_info["path"] and model_info["steps"]:
|
33 |
+
return f"🧠 Download weights ({model_info['steps']} steps)"
|
34 |
+
elif model_info["path"]:
|
35 |
+
return "🧠 Download weights (.safetensors)"
|
36 |
+
else:
|
37 |
+
return "🧠 Download weights (not available)"
|
38 |
+
except Exception as e:
|
39 |
+
logger.warning(f"Error getting model info for button text: {e}")
|
40 |
+
return "🧠 Download weights (.safetensors)"
|
41 |
+
|
42 |
+
def update_download_button_text(self) -> gr.update:
|
43 |
+
"""Update the download button text"""
|
44 |
+
return gr.update(value=self.get_download_button_text())
|
45 |
+
|
46 |
+
def download_and_update_button(self):
|
47 |
+
"""Handle download and return updated button with current text"""
|
48 |
+
# Get the safetensors path for download
|
49 |
+
path = self.app.training.get_model_output_safetensors()
|
50 |
+
# For DownloadButton, we need to return the file path directly for download
|
51 |
+
# The button text will be updated on next render
|
52 |
+
return path
|
53 |
+
|
54 |
def create(self, parent=None) -> gr.TabItem:
|
55 |
"""Create the Manage tab UI components"""
|
56 |
with gr.TabItem(self.title, id=self.id) as tab:
|
|
|
71 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
72 |
|
73 |
self.components["download_model_btn"] = gr.DownloadButton(
|
74 |
+
self.get_download_button_text(),
|
75 |
variant="secondary",
|
76 |
size="lg"
|
77 |
)
|