Spaces:
Running
Running
Commit
Β·
98352eb
1
Parent(s):
9fd1204
try to crack Finetrainers
Browse files
vms/patches/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Patches module for VideoModelStudio
|
3 |
+
|
4 |
+
This module contains monkey patches and modifications for third-party libraries
|
5 |
+
to extend their functionality for our specific use cases.
|
6 |
+
"""
|
vms/patches/finetrainers_lora_loading.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Monkey patch for Finetrainers to support loading existing LoRA weights as training initialization.
|
3 |
+
|
4 |
+
This patch extends the SFTTrainer to accept a --pretrained_lora_path argument that allows
|
5 |
+
starting training from existing LoRA weights instead of random initialization.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import json
|
10 |
+
from typing import Optional, Dict, Any
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import safetensors.torch
|
14 |
+
from peft import set_peft_model_state_dict
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
# Global flag to track if patch has been applied
|
19 |
+
_PATCH_APPLIED = False
|
20 |
+
|
21 |
+
def _load_pretrained_lora_weights(self, lora_path: str) -> None:
|
22 |
+
"""Load existing LoRA weights as training initialization
|
23 |
+
|
24 |
+
Args:
|
25 |
+
lora_path: Path to directory containing pytorch_lora_weights.safetensors
|
26 |
+
"""
|
27 |
+
lora_path = Path(lora_path)
|
28 |
+
|
29 |
+
# Find the safetensors file
|
30 |
+
safetensors_file = lora_path / "pytorch_lora_weights.safetensors"
|
31 |
+
if not safetensors_file.exists():
|
32 |
+
raise FileNotFoundError(f"LoRA weights file not found: {safetensors_file}")
|
33 |
+
|
34 |
+
logger.info(f"Loading pretrained LoRA weights from: {safetensors_file}")
|
35 |
+
|
36 |
+
try:
|
37 |
+
# Load the LoRA weights
|
38 |
+
lora_state_dict = safetensors.torch.load_file(str(safetensors_file))
|
39 |
+
|
40 |
+
# Extract metadata if available
|
41 |
+
metadata = {}
|
42 |
+
try:
|
43 |
+
with open(safetensors_file, 'rb') as f:
|
44 |
+
# Try to read metadata from safetensors header
|
45 |
+
header_size = int.from_bytes(f.read(8), 'little')
|
46 |
+
header_data = f.read(header_size)
|
47 |
+
header = json.loads(header_data.decode('utf-8'))
|
48 |
+
metadata = header.get('__metadata__', {})
|
49 |
+
except Exception as e:
|
50 |
+
logger.debug(f"Could not read metadata from safetensors: {e}")
|
51 |
+
|
52 |
+
# Log metadata info if available
|
53 |
+
if metadata:
|
54 |
+
logger.info(f"LoRA metadata: rank={metadata.get('rank', 'unknown')}, "
|
55 |
+
f"alpha={metadata.get('lora_alpha', 'unknown')}")
|
56 |
+
|
57 |
+
# Apply the LoRA weights to the model
|
58 |
+
set_peft_model_state_dict(self.transformer, lora_state_dict)
|
59 |
+
|
60 |
+
logger.info(f"Successfully loaded LoRA weights from {safetensors_file}")
|
61 |
+
|
62 |
+
# Log the loaded keys for debugging
|
63 |
+
logger.debug(f"Loaded LoRA keys: {list(lora_state_dict.keys())}")
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
logger.error(f"Failed to load LoRA weights from {safetensors_file}: {e}")
|
67 |
+
raise RuntimeError(f"Failed to load LoRA weights: {e}")
|
68 |
+
|
69 |
+
|
70 |
+
def patched_prepare_trainable_parameters(self) -> None:
|
71 |
+
"""Patched version of _prepare_trainable_parameters that supports pretrained LoRA loading"""
|
72 |
+
|
73 |
+
# Call the original method first
|
74 |
+
original_prepare_trainable_parameters(self)
|
75 |
+
|
76 |
+
# Check if pretrained LoRA path is provided
|
77 |
+
if hasattr(self.args, 'pretrained_lora_path') and self.args.pretrained_lora_path:
|
78 |
+
logger.info(f"Pretrained LoRA path specified: {self.args.pretrained_lora_path}")
|
79 |
+
|
80 |
+
# Only load if we're doing LoRA training
|
81 |
+
if hasattr(self.args, 'training_type') and str(self.args.training_type) == 'TrainingType.LORA':
|
82 |
+
self._load_pretrained_lora_weights(self.args.pretrained_lora_path)
|
83 |
+
else:
|
84 |
+
logger.warning("pretrained_lora_path specified but training_type is not LORA")
|
85 |
+
|
86 |
+
|
87 |
+
def apply_lora_loading_patch() -> None:
|
88 |
+
"""Apply the monkey patch to enable LoRA weight loading in Finetrainers"""
|
89 |
+
global _PATCH_APPLIED
|
90 |
+
|
91 |
+
if _PATCH_APPLIED:
|
92 |
+
logger.debug("Finetrainers LoRA loading patch already applied")
|
93 |
+
return
|
94 |
+
|
95 |
+
try:
|
96 |
+
from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
|
97 |
+
|
98 |
+
# Store reference to original method
|
99 |
+
global original_prepare_trainable_parameters
|
100 |
+
original_prepare_trainable_parameters = SFTTrainer._prepare_trainable_parameters
|
101 |
+
|
102 |
+
# Apply patches
|
103 |
+
SFTTrainer._prepare_trainable_parameters = patched_prepare_trainable_parameters
|
104 |
+
SFTTrainer._load_pretrained_lora_weights = _load_pretrained_lora_weights
|
105 |
+
|
106 |
+
_PATCH_APPLIED = True
|
107 |
+
logger.info("Successfully applied Finetrainers LoRA loading patch")
|
108 |
+
|
109 |
+
except ImportError as e:
|
110 |
+
logger.error(f"Failed to import Finetrainers classes for patching: {e}")
|
111 |
+
raise
|
112 |
+
except Exception as e:
|
113 |
+
logger.error(f"Failed to apply Finetrainers LoRA loading patch: {e}")
|
114 |
+
raise
|
115 |
+
|
116 |
+
|
117 |
+
def remove_lora_loading_patch() -> None:
|
118 |
+
"""Remove the monkey patch (for testing purposes)"""
|
119 |
+
global _PATCH_APPLIED
|
120 |
+
|
121 |
+
if not _PATCH_APPLIED:
|
122 |
+
return
|
123 |
+
|
124 |
+
try:
|
125 |
+
from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
|
126 |
+
|
127 |
+
# Restore original method
|
128 |
+
SFTTrainer._prepare_trainable_parameters = original_prepare_trainable_parameters
|
129 |
+
|
130 |
+
# Remove added method
|
131 |
+
if hasattr(SFTTrainer, '_load_pretrained_lora_weights'):
|
132 |
+
delattr(SFTTrainer, '_load_pretrained_lora_weights')
|
133 |
+
|
134 |
+
_PATCH_APPLIED = False
|
135 |
+
logger.info("Removed Finetrainers LoRA loading patch")
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
logger.error(f"Failed to remove Finetrainers LoRA loading patch: {e}")
|
139 |
+
|
140 |
+
|
141 |
+
# Store reference to original method (will be set when patch is applied)
|
142 |
+
original_prepare_trainable_parameters = None
|
vms/ui/project/services/training.py
CHANGED
@@ -54,6 +54,7 @@ from vms.utils import (
|
|
54 |
prepare_finetrainers_dataset,
|
55 |
copy_files_to_training_dir
|
56 |
)
|
|
|
57 |
|
58 |
logger = logging.getLogger(__name__)
|
59 |
logger.setLevel(logging.INFO)
|
@@ -71,6 +72,17 @@ class TrainingService:
|
|
71 |
self.file_handler = None
|
72 |
self.setup_logging()
|
73 |
self.ensure_valid_ui_state_file()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
logger.info("Training service initialized")
|
76 |
|
@@ -573,6 +585,7 @@ class TrainingService:
|
|
573 |
precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
|
574 |
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
|
575 |
progress: Optional[gr.Progress] = None,
|
|
|
576 |
) -> Tuple[str, str]:
|
577 |
"""Start training with finetrainers"""
|
578 |
|
@@ -822,6 +835,29 @@ class TrainingService:
|
|
822 |
logger.error(error_msg)
|
823 |
self.append_log(error_msg)
|
824 |
return error_msg, "No valid checkpoints available"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
825 |
|
826 |
# Common settings for both models
|
827 |
config.mixed_precision = DEFAULT_MIXED_PRECISION
|
@@ -1158,6 +1194,94 @@ class TrainingService:
|
|
1158 |
logger.error(f"Failed to remove corrupted checkpoint {checkpoint_dir}: {e}")
|
1159 |
self.append_log(f"Failed to remove corrupted checkpoint {checkpoint_dir.name}: {e}")
|
1160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1161 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
1162 |
"""Attempt to recover interrupted training
|
1163 |
|
@@ -1493,6 +1617,13 @@ class TrainingService:
|
|
1493 |
gr.Info(success_msg)
|
1494 |
self.save_status(state='completed', message=success_msg)
|
1495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1496 |
# Upload final model if repository was specified
|
1497 |
session = self.load_session()
|
1498 |
if session and session['params'].get('repo_id'):
|
|
|
54 |
prepare_finetrainers_dataset,
|
55 |
copy_files_to_training_dir
|
56 |
)
|
57 |
+
from vms.patches.finetrainers_lora_loading import apply_lora_loading_patch
|
58 |
|
59 |
logger = logging.getLogger(__name__)
|
60 |
logger.setLevel(logging.INFO)
|
|
|
72 |
self.file_handler = None
|
73 |
self.setup_logging()
|
74 |
self.ensure_valid_ui_state_file()
|
75 |
+
|
76 |
+
# Apply Finetrainers patches for LoRA weight loading
|
77 |
+
try:
|
78 |
+
apply_lora_loading_patch()
|
79 |
+
except Exception as e:
|
80 |
+
logger.warning(f"Failed to apply Finetrainers LoRA loading patch: {e}")
|
81 |
+
|
82 |
+
# Start background cleanup task
|
83 |
+
self._cleanup_stop_event = threading.Event()
|
84 |
+
self._cleanup_thread = threading.Thread(target=self._background_cleanup_task, daemon=True)
|
85 |
+
self._cleanup_thread.start()
|
86 |
|
87 |
logger.info("Training service initialized")
|
88 |
|
|
|
585 |
precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
|
586 |
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
|
587 |
progress: Optional[gr.Progress] = None,
|
588 |
+
pretrained_lora_path: Optional[str] = None,
|
589 |
) -> Tuple[str, str]:
|
590 |
"""Start training with finetrainers"""
|
591 |
|
|
|
835 |
logger.error(error_msg)
|
836 |
self.append_log(error_msg)
|
837 |
return error_msg, "No valid checkpoints available"
|
838 |
+
|
839 |
+
# Add pretrained LoRA path if provided (for starting fresh training with existing weights)
|
840 |
+
if pretrained_lora_path:
|
841 |
+
# Validate the LoRA path exists and contains required files
|
842 |
+
lora_path = Path(pretrained_lora_path)
|
843 |
+
lora_weights_file = lora_path / "pytorch_lora_weights.safetensors"
|
844 |
+
|
845 |
+
if not lora_path.exists():
|
846 |
+
error_msg = f"Pretrained LoRA path does not exist: {pretrained_lora_path}"
|
847 |
+
logger.error(error_msg)
|
848 |
+
self.append_log(error_msg)
|
849 |
+
return error_msg, "LoRA path not found"
|
850 |
+
|
851 |
+
if not lora_weights_file.exists():
|
852 |
+
error_msg = f"LoRA weights file not found: {lora_weights_file}"
|
853 |
+
logger.error(error_msg)
|
854 |
+
self.append_log(error_msg)
|
855 |
+
return error_msg, "LoRA weights file missing"
|
856 |
+
|
857 |
+
# Set the pretrained LoRA path for the patched Finetrainers
|
858 |
+
config.pretrained_lora_path = str(lora_path)
|
859 |
+
self.append_log(f"Starting training with pretrained LoRA weights from: {lora_path}")
|
860 |
+
logger.info(f"Using pretrained LoRA weights: {lora_path}")
|
861 |
|
862 |
# Common settings for both models
|
863 |
config.mixed_precision = DEFAULT_MIXED_PRECISION
|
|
|
1194 |
logger.error(f"Failed to remove corrupted checkpoint {checkpoint_dir}: {e}")
|
1195 |
self.append_log(f"Failed to remove corrupted checkpoint {checkpoint_dir.name}: {e}")
|
1196 |
|
1197 |
+
def cleanup_old_lora_weights(self, max_to_keep: int = 2) -> None:
|
1198 |
+
"""Remove old LoRA weight directories, keeping only the most recent ones
|
1199 |
+
|
1200 |
+
Args:
|
1201 |
+
max_to_keep: Maximum number of LoRA weight directories to keep (default: 2)
|
1202 |
+
"""
|
1203 |
+
lora_weights_path = self.app.output_path / "lora_weights"
|
1204 |
+
|
1205 |
+
if not lora_weights_path.exists():
|
1206 |
+
logger.debug("LoRA weights directory does not exist, nothing to clean up")
|
1207 |
+
return
|
1208 |
+
|
1209 |
+
# Find all LoRA weight directories (should be named with step numbers)
|
1210 |
+
lora_dirs = []
|
1211 |
+
for item in lora_weights_path.iterdir():
|
1212 |
+
if item.is_dir() and item.name.isdigit():
|
1213 |
+
lora_dirs.append(item)
|
1214 |
+
|
1215 |
+
if len(lora_dirs) <= max_to_keep:
|
1216 |
+
logger.debug(f"Found {len(lora_dirs)} LoRA weight directories, no cleanup needed (keeping {max_to_keep})")
|
1217 |
+
return
|
1218 |
+
|
1219 |
+
# Sort by step number (directory name) in descending order (newest first)
|
1220 |
+
lora_dirs_sorted = sorted(lora_dirs, key=lambda x: int(x.name), reverse=True)
|
1221 |
+
|
1222 |
+
# Keep the most recent max_to_keep directories, remove the rest
|
1223 |
+
dirs_to_keep = lora_dirs_sorted[:max_to_keep]
|
1224 |
+
dirs_to_remove = lora_dirs_sorted[max_to_keep:]
|
1225 |
+
|
1226 |
+
logger.info(f"Cleaning up old LoRA weights: keeping {len(dirs_to_keep)}, removing {len(dirs_to_remove)}")
|
1227 |
+
self.append_log(f"Cleaning up old LoRA weights: keeping latest {max_to_keep} directories")
|
1228 |
+
|
1229 |
+
for lora_dir in dirs_to_remove:
|
1230 |
+
try:
|
1231 |
+
step_num = int(lora_dir.name)
|
1232 |
+
logger.info(f"Removing old LoRA weights at step {step_num}: {lora_dir}")
|
1233 |
+
shutil.rmtree(lora_dir)
|
1234 |
+
self.append_log(f"Removed old LoRA weights: step {step_num}")
|
1235 |
+
except Exception as e:
|
1236 |
+
logger.error(f"Failed to remove old LoRA weights {lora_dir}: {e}")
|
1237 |
+
self.append_log(f"Failed to remove old LoRA weights {lora_dir.name}: {e}")
|
1238 |
+
|
1239 |
+
# Log what we kept
|
1240 |
+
kept_steps = [int(d.name) for d in dirs_to_keep]
|
1241 |
+
kept_steps.sort(reverse=True)
|
1242 |
+
logger.info(f"Kept LoRA weights for steps: {kept_steps}")
|
1243 |
+
self.append_log(f"Kept LoRA weights for steps: {kept_steps}")
|
1244 |
+
|
1245 |
+
def _background_cleanup_task(self) -> None:
|
1246 |
+
"""Background task that runs every 10 minutes to clean up old LoRA weights"""
|
1247 |
+
cleanup_interval = 600 # 10 minutes in seconds
|
1248 |
+
|
1249 |
+
logger.info("Started background LoRA cleanup task (runs every 10 minutes)")
|
1250 |
+
|
1251 |
+
while not self._cleanup_stop_event.is_set():
|
1252 |
+
try:
|
1253 |
+
# Wait for 10 minutes or until stop event is set
|
1254 |
+
if self._cleanup_stop_event.wait(timeout=cleanup_interval):
|
1255 |
+
break # Stop event was set
|
1256 |
+
|
1257 |
+
# Only run cleanup if we have an output path
|
1258 |
+
if hasattr(self.app, 'output_path') and self.app.output_path:
|
1259 |
+
lora_weights_path = self.app.output_path / "lora_weights"
|
1260 |
+
|
1261 |
+
# Only cleanup if the directory exists and has content
|
1262 |
+
if lora_weights_path.exists():
|
1263 |
+
lora_dirs = [d for d in lora_weights_path.iterdir() if d.is_dir() and d.name.isdigit()]
|
1264 |
+
|
1265 |
+
if len(lora_dirs) > 2:
|
1266 |
+
logger.info(f"Background cleanup: Found {len(lora_dirs)} LoRA weight directories, cleaning up old ones")
|
1267 |
+
self.cleanup_old_lora_weights(max_to_keep=2)
|
1268 |
+
else:
|
1269 |
+
logger.debug(f"Background cleanup: Found {len(lora_dirs)} LoRA weight directories, no cleanup needed")
|
1270 |
+
|
1271 |
+
except Exception as e:
|
1272 |
+
logger.error(f"Background LoRA cleanup task error: {e}")
|
1273 |
+
# Continue running despite errors
|
1274 |
+
|
1275 |
+
logger.info("Background LoRA cleanup task stopped")
|
1276 |
+
|
1277 |
+
def stop_background_cleanup(self) -> None:
|
1278 |
+
"""Stop the background cleanup task"""
|
1279 |
+
if hasattr(self, '_cleanup_stop_event'):
|
1280 |
+
self._cleanup_stop_event.set()
|
1281 |
+
if hasattr(self, '_cleanup_thread') and self._cleanup_thread.is_alive():
|
1282 |
+
self._cleanup_thread.join(timeout=5)
|
1283 |
+
logger.info("Background cleanup task stopped")
|
1284 |
+
|
1285 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
1286 |
"""Attempt to recover interrupted training
|
1287 |
|
|
|
1617 |
gr.Info(success_msg)
|
1618 |
self.save_status(state='completed', message=success_msg)
|
1619 |
|
1620 |
+
# Clean up old LoRA weights to save disk space
|
1621 |
+
try:
|
1622 |
+
self.cleanup_old_lora_weights(max_to_keep=2)
|
1623 |
+
except Exception as e:
|
1624 |
+
logger.warning(f"Failed to cleanup old LoRA weights: {e}")
|
1625 |
+
self.append_log(f"Warning: Failed to cleanup old LoRA weights: {e}")
|
1626 |
+
|
1627 |
# Upload final model if repository was specified
|
1628 |
session = self.load_session()
|
1629 |
if session and session['params'].get('repo_id'):
|
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -102,6 +102,18 @@ class ManageTab(BaseTab):
|
|
102 |
"Push my model"
|
103 |
)
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
with gr.Row():
|
106 |
with gr.Column():
|
107 |
gr.Markdown("## β»οΈ Delete your data")
|
@@ -225,6 +237,12 @@ class ManageTab(BaseTab):
|
|
225 |
outputs=[self.components["download_output_btn"]]
|
226 |
)
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
# Dataset deletion with modal
|
229 |
self.components["delete_dataset_btn"].click(
|
230 |
fn=lambda: Modal(visible=True),
|
@@ -346,6 +364,16 @@ class ManageTab(BaseTab):
|
|
346 |
else:
|
347 |
return f"Failed to upload model to {repo_id}"
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
def delete_dataset(self):
|
350 |
"""Delete dataset files (images, videos, captions)"""
|
351 |
status_messages = {}
|
|
|
102 |
"Push my model"
|
103 |
)
|
104 |
|
105 |
+
with gr.Row():
|
106 |
+
with gr.Column():
|
107 |
+
gr.Markdown("## π§Ή Maintenance")
|
108 |
+
gr.Markdown("Clean up old files to free disk space.")
|
109 |
+
|
110 |
+
with gr.Row():
|
111 |
+
self.components["cleanup_lora_btn"] = gr.Button(
|
112 |
+
"π Keep last 2 LoRA weights and clean up older ones",
|
113 |
+
variant="secondary",
|
114 |
+
size="lg"
|
115 |
+
)
|
116 |
+
|
117 |
with gr.Row():
|
118 |
with gr.Column():
|
119 |
gr.Markdown("## β»οΈ Delete your data")
|
|
|
237 |
outputs=[self.components["download_output_btn"]]
|
238 |
)
|
239 |
|
240 |
+
# LoRA cleanup button
|
241 |
+
self.components["cleanup_lora_btn"].click(
|
242 |
+
fn=self.cleanup_old_lora_weights,
|
243 |
+
outputs=[]
|
244 |
+
)
|
245 |
+
|
246 |
# Dataset deletion with modal
|
247 |
self.components["delete_dataset_btn"].click(
|
248 |
fn=lambda: Modal(visible=True),
|
|
|
364 |
else:
|
365 |
return f"Failed to upload model to {repo_id}"
|
366 |
|
367 |
+
def cleanup_old_lora_weights(self):
|
368 |
+
"""Clean up old LoRA weight directories, keeping only the latest 2"""
|
369 |
+
try:
|
370 |
+
self.app.training.cleanup_old_lora_weights(max_to_keep=2)
|
371 |
+
gr.Info("β
Successfully cleaned up old LoRA weights")
|
372 |
+
except Exception as e:
|
373 |
+
error_msg = f"β Failed to cleanup LoRA weights: {str(e)}"
|
374 |
+
gr.Error(error_msg)
|
375 |
+
logger.error(f"LoRA cleanup failed: {e}")
|
376 |
+
|
377 |
def delete_dataset(self):
|
378 |
"""Delete dataset files (images, videos, captions)"""
|
379 |
status_messages = {}
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -341,12 +341,20 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
341 |
## βοΈ Train your model on your dataset
|
342 |
- **π Start new training**: Begins training from scratch (clears previous checkpoints)
|
343 |
- **πΈ Start from latest checkpoint**: Continues training from the most recent checkpoint
|
|
|
344 |
""")
|
345 |
|
346 |
with gr.Row():
|
347 |
# Check for existing checkpoints to determine button text
|
348 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
349 |
has_checkpoints = len(checkpoints) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
self.components["start_btn"] = gr.Button(
|
352 |
"π Start new training",
|
@@ -361,6 +369,13 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
361 |
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
362 |
)
|
363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
with gr.Row():
|
365 |
# Just use stop and pause buttons for now to ensure compatibility
|
366 |
self.components["stop_btn"] = gr.Button(
|
@@ -497,6 +512,52 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
497 |
resume_from_checkpoint="latest"
|
498 |
)
|
499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
def connect_events(self) -> None:
|
501 |
"""Connect event handlers to UI components"""
|
502 |
# Model type change event - Update model version dropdown choices and default parameters
|
@@ -701,6 +762,26 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
701 |
self.components["log_box"]
|
702 |
]
|
703 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
704 |
|
705 |
|
706 |
# Use simplified event handlers for pause/resume and stop
|
@@ -780,6 +861,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
780 |
save_iterations, repo_id,
|
781 |
progress=gr.Progress(),
|
782 |
resume_from_checkpoint=None,
|
|
|
783 |
):
|
784 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
785 |
|
@@ -840,7 +922,8 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
840 |
num_gpus=num_gpus,
|
841 |
precomputation_items=precomputation_items,
|
842 |
lr_warmup_steps=lr_warmup_steps,
|
843 |
-
progress=progress
|
|
|
844 |
)
|
845 |
except Exception as e:
|
846 |
logger.exception("Error starting training")
|
|
|
341 |
## βοΈ Train your model on your dataset
|
342 |
- **π Start new training**: Begins training from scratch (clears previous checkpoints)
|
343 |
- **πΈ Start from latest checkpoint**: Continues training from the most recent checkpoint
|
344 |
+
- **π Start over using latest LoRA weights**: Start fresh training but use existing LoRA weights as initialization
|
345 |
""")
|
346 |
|
347 |
with gr.Row():
|
348 |
# Check for existing checkpoints to determine button text
|
349 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
350 |
has_checkpoints = len(checkpoints) > 0
|
351 |
+
|
352 |
+
# Check for existing LoRA weights
|
353 |
+
lora_weights_path = self.app.output_path / "lora_weights"
|
354 |
+
has_lora_weights = False
|
355 |
+
if lora_weights_path.exists():
|
356 |
+
lora_dirs = [d for d in lora_weights_path.iterdir() if d.is_dir()]
|
357 |
+
has_lora_weights = len(lora_dirs) > 0
|
358 |
|
359 |
self.components["start_btn"] = gr.Button(
|
360 |
"π Start new training",
|
|
|
369 |
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
370 |
)
|
371 |
|
372 |
+
# Add new button for starting from LoRA weights
|
373 |
+
self.components["start_from_lora_btn"] = gr.Button(
|
374 |
+
"π Start over using latest LoRA weights",
|
375 |
+
variant="secondary",
|
376 |
+
interactive=has_lora_weights and not ASK_USER_TO_DUPLICATE_SPACE
|
377 |
+
)
|
378 |
+
|
379 |
with gr.Row():
|
380 |
# Just use stop and pause buttons for now to ensure compatibility
|
381 |
self.components["stop_btn"] = gr.Button(
|
|
|
512 |
resume_from_checkpoint="latest"
|
513 |
)
|
514 |
|
515 |
+
def handle_start_from_lora_training(
|
516 |
+
self, model_type, model_version, training_type,
|
517 |
+
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
518 |
+
save_iterations, repo_id, progress=gr.Progress()
|
519 |
+
):
|
520 |
+
"""Handle starting training from existing LoRA weights"""
|
521 |
+
# Find the latest LoRA weights
|
522 |
+
lora_weights_path = self.app.output_path / "lora_weights"
|
523 |
+
|
524 |
+
if not lora_weights_path.exists():
|
525 |
+
return "No LoRA weights found", "Please train a model first or start a new training session"
|
526 |
+
|
527 |
+
# Find the latest LoRA checkpoint directory
|
528 |
+
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
529 |
+
key=lambda x: int(x.name), reverse=True)
|
530 |
+
|
531 |
+
if not lora_dirs:
|
532 |
+
return "No LoRA weight directories found", "Please train a model first or start a new training session"
|
533 |
+
|
534 |
+
latest_lora_dir = lora_dirs[0]
|
535 |
+
|
536 |
+
# Verify the LoRA weights file exists
|
537 |
+
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
538 |
+
if not lora_weights_file.exists():
|
539 |
+
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
|
540 |
+
|
541 |
+
# Clear checkpoints to start fresh (but keep LoRA weights)
|
542 |
+
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
543 |
+
if checkpoint.is_dir():
|
544 |
+
shutil.rmtree(checkpoint)
|
545 |
+
|
546 |
+
# Delete session.json to start fresh
|
547 |
+
session_file = self.app.output_path / "session.json"
|
548 |
+
if session_file.exists():
|
549 |
+
session_file.unlink()
|
550 |
+
|
551 |
+
self.app.training.append_log(f"Starting training from LoRA weights: {latest_lora_dir}")
|
552 |
+
|
553 |
+
# Start training with the LoRA weights
|
554 |
+
return self.handle_training_start(
|
555 |
+
model_type, model_version, training_type,
|
556 |
+
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
557 |
+
save_iterations, repo_id, progress,
|
558 |
+
pretrained_lora_path=str(latest_lora_dir)
|
559 |
+
)
|
560 |
+
|
561 |
def connect_events(self) -> None:
|
562 |
"""Connect event handlers to UI components"""
|
563 |
# Model type change event - Update model version dropdown choices and default parameters
|
|
|
762 |
self.components["log_box"]
|
763 |
]
|
764 |
)
|
765 |
+
|
766 |
+
self.components["start_from_lora_btn"].click(
|
767 |
+
fn=self.handle_start_from_lora_training,
|
768 |
+
inputs=[
|
769 |
+
self.components["model_type"],
|
770 |
+
self.components["model_version"],
|
771 |
+
self.components["training_type"],
|
772 |
+
self.components["lora_rank"],
|
773 |
+
self.components["lora_alpha"],
|
774 |
+
self.components["train_steps"],
|
775 |
+
self.components["batch_size"],
|
776 |
+
self.components["learning_rate"],
|
777 |
+
self.components["save_iterations"],
|
778 |
+
self.app.tabs["manage_tab"].components["repo_id"]
|
779 |
+
],
|
780 |
+
outputs=[
|
781 |
+
self.components["status_box"],
|
782 |
+
self.components["log_box"]
|
783 |
+
]
|
784 |
+
)
|
785 |
|
786 |
|
787 |
# Use simplified event handlers for pause/resume and stop
|
|
|
861 |
save_iterations, repo_id,
|
862 |
progress=gr.Progress(),
|
863 |
resume_from_checkpoint=None,
|
864 |
+
pretrained_lora_path=None,
|
865 |
):
|
866 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
867 |
|
|
|
922 |
num_gpus=num_gpus,
|
923 |
precomputation_items=precomputation_items,
|
924 |
lr_warmup_steps=lr_warmup_steps,
|
925 |
+
progress=progress,
|
926 |
+
pretrained_lora_path=pretrained_lora_path
|
927 |
)
|
928 |
except Exception as e:
|
929 |
logger.exception("Error starting training")
|