jbilcke-hf HF Staff commited on
Commit
98352eb
Β·
1 Parent(s): 9fd1204

try to crack Finetrainers

Browse files
vms/patches/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Patches module for VideoModelStudio
3
+
4
+ This module contains monkey patches and modifications for third-party libraries
5
+ to extend their functionality for our specific use cases.
6
+ """
vms/patches/finetrainers_lora_loading.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Monkey patch for Finetrainers to support loading existing LoRA weights as training initialization.
3
+
4
+ This patch extends the SFTTrainer to accept a --pretrained_lora_path argument that allows
5
+ starting training from existing LoRA weights instead of random initialization.
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Optional, Dict, Any
11
+ from pathlib import Path
12
+
13
+ import safetensors.torch
14
+ from peft import set_peft_model_state_dict
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Global flag to track if patch has been applied
19
+ _PATCH_APPLIED = False
20
+
21
+ def _load_pretrained_lora_weights(self, lora_path: str) -> None:
22
+ """Load existing LoRA weights as training initialization
23
+
24
+ Args:
25
+ lora_path: Path to directory containing pytorch_lora_weights.safetensors
26
+ """
27
+ lora_path = Path(lora_path)
28
+
29
+ # Find the safetensors file
30
+ safetensors_file = lora_path / "pytorch_lora_weights.safetensors"
31
+ if not safetensors_file.exists():
32
+ raise FileNotFoundError(f"LoRA weights file not found: {safetensors_file}")
33
+
34
+ logger.info(f"Loading pretrained LoRA weights from: {safetensors_file}")
35
+
36
+ try:
37
+ # Load the LoRA weights
38
+ lora_state_dict = safetensors.torch.load_file(str(safetensors_file))
39
+
40
+ # Extract metadata if available
41
+ metadata = {}
42
+ try:
43
+ with open(safetensors_file, 'rb') as f:
44
+ # Try to read metadata from safetensors header
45
+ header_size = int.from_bytes(f.read(8), 'little')
46
+ header_data = f.read(header_size)
47
+ header = json.loads(header_data.decode('utf-8'))
48
+ metadata = header.get('__metadata__', {})
49
+ except Exception as e:
50
+ logger.debug(f"Could not read metadata from safetensors: {e}")
51
+
52
+ # Log metadata info if available
53
+ if metadata:
54
+ logger.info(f"LoRA metadata: rank={metadata.get('rank', 'unknown')}, "
55
+ f"alpha={metadata.get('lora_alpha', 'unknown')}")
56
+
57
+ # Apply the LoRA weights to the model
58
+ set_peft_model_state_dict(self.transformer, lora_state_dict)
59
+
60
+ logger.info(f"Successfully loaded LoRA weights from {safetensors_file}")
61
+
62
+ # Log the loaded keys for debugging
63
+ logger.debug(f"Loaded LoRA keys: {list(lora_state_dict.keys())}")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Failed to load LoRA weights from {safetensors_file}: {e}")
67
+ raise RuntimeError(f"Failed to load LoRA weights: {e}")
68
+
69
+
70
+ def patched_prepare_trainable_parameters(self) -> None:
71
+ """Patched version of _prepare_trainable_parameters that supports pretrained LoRA loading"""
72
+
73
+ # Call the original method first
74
+ original_prepare_trainable_parameters(self)
75
+
76
+ # Check if pretrained LoRA path is provided
77
+ if hasattr(self.args, 'pretrained_lora_path') and self.args.pretrained_lora_path:
78
+ logger.info(f"Pretrained LoRA path specified: {self.args.pretrained_lora_path}")
79
+
80
+ # Only load if we're doing LoRA training
81
+ if hasattr(self.args, 'training_type') and str(self.args.training_type) == 'TrainingType.LORA':
82
+ self._load_pretrained_lora_weights(self.args.pretrained_lora_path)
83
+ else:
84
+ logger.warning("pretrained_lora_path specified but training_type is not LORA")
85
+
86
+
87
+ def apply_lora_loading_patch() -> None:
88
+ """Apply the monkey patch to enable LoRA weight loading in Finetrainers"""
89
+ global _PATCH_APPLIED
90
+
91
+ if _PATCH_APPLIED:
92
+ logger.debug("Finetrainers LoRA loading patch already applied")
93
+ return
94
+
95
+ try:
96
+ from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
97
+
98
+ # Store reference to original method
99
+ global original_prepare_trainable_parameters
100
+ original_prepare_trainable_parameters = SFTTrainer._prepare_trainable_parameters
101
+
102
+ # Apply patches
103
+ SFTTrainer._prepare_trainable_parameters = patched_prepare_trainable_parameters
104
+ SFTTrainer._load_pretrained_lora_weights = _load_pretrained_lora_weights
105
+
106
+ _PATCH_APPLIED = True
107
+ logger.info("Successfully applied Finetrainers LoRA loading patch")
108
+
109
+ except ImportError as e:
110
+ logger.error(f"Failed to import Finetrainers classes for patching: {e}")
111
+ raise
112
+ except Exception as e:
113
+ logger.error(f"Failed to apply Finetrainers LoRA loading patch: {e}")
114
+ raise
115
+
116
+
117
+ def remove_lora_loading_patch() -> None:
118
+ """Remove the monkey patch (for testing purposes)"""
119
+ global _PATCH_APPLIED
120
+
121
+ if not _PATCH_APPLIED:
122
+ return
123
+
124
+ try:
125
+ from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
126
+
127
+ # Restore original method
128
+ SFTTrainer._prepare_trainable_parameters = original_prepare_trainable_parameters
129
+
130
+ # Remove added method
131
+ if hasattr(SFTTrainer, '_load_pretrained_lora_weights'):
132
+ delattr(SFTTrainer, '_load_pretrained_lora_weights')
133
+
134
+ _PATCH_APPLIED = False
135
+ logger.info("Removed Finetrainers LoRA loading patch")
136
+
137
+ except Exception as e:
138
+ logger.error(f"Failed to remove Finetrainers LoRA loading patch: {e}")
139
+
140
+
141
+ # Store reference to original method (will be set when patch is applied)
142
+ original_prepare_trainable_parameters = None
vms/ui/project/services/training.py CHANGED
@@ -54,6 +54,7 @@ from vms.utils import (
54
  prepare_finetrainers_dataset,
55
  copy_files_to_training_dir
56
  )
 
57
 
58
  logger = logging.getLogger(__name__)
59
  logger.setLevel(logging.INFO)
@@ -71,6 +72,17 @@ class TrainingService:
71
  self.file_handler = None
72
  self.setup_logging()
73
  self.ensure_valid_ui_state_file()
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  logger.info("Training service initialized")
76
 
@@ -573,6 +585,7 @@ class TrainingService:
573
  precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
574
  lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
575
  progress: Optional[gr.Progress] = None,
 
576
  ) -> Tuple[str, str]:
577
  """Start training with finetrainers"""
578
 
@@ -822,6 +835,29 @@ class TrainingService:
822
  logger.error(error_msg)
823
  self.append_log(error_msg)
824
  return error_msg, "No valid checkpoints available"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
 
826
  # Common settings for both models
827
  config.mixed_precision = DEFAULT_MIXED_PRECISION
@@ -1158,6 +1194,94 @@ class TrainingService:
1158
  logger.error(f"Failed to remove corrupted checkpoint {checkpoint_dir}: {e}")
1159
  self.append_log(f"Failed to remove corrupted checkpoint {checkpoint_dir.name}: {e}")
1160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1161
  def recover_interrupted_training(self) -> Dict[str, Any]:
1162
  """Attempt to recover interrupted training
1163
 
@@ -1493,6 +1617,13 @@ class TrainingService:
1493
  gr.Info(success_msg)
1494
  self.save_status(state='completed', message=success_msg)
1495
 
 
 
 
 
 
 
 
1496
  # Upload final model if repository was specified
1497
  session = self.load_session()
1498
  if session and session['params'].get('repo_id'):
 
54
  prepare_finetrainers_dataset,
55
  copy_files_to_training_dir
56
  )
57
+ from vms.patches.finetrainers_lora_loading import apply_lora_loading_patch
58
 
59
  logger = logging.getLogger(__name__)
60
  logger.setLevel(logging.INFO)
 
72
  self.file_handler = None
73
  self.setup_logging()
74
  self.ensure_valid_ui_state_file()
75
+
76
+ # Apply Finetrainers patches for LoRA weight loading
77
+ try:
78
+ apply_lora_loading_patch()
79
+ except Exception as e:
80
+ logger.warning(f"Failed to apply Finetrainers LoRA loading patch: {e}")
81
+
82
+ # Start background cleanup task
83
+ self._cleanup_stop_event = threading.Event()
84
+ self._cleanup_thread = threading.Thread(target=self._background_cleanup_task, daemon=True)
85
+ self._cleanup_thread.start()
86
 
87
  logger.info("Training service initialized")
88
 
 
585
  precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
586
  lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
587
  progress: Optional[gr.Progress] = None,
588
+ pretrained_lora_path: Optional[str] = None,
589
  ) -> Tuple[str, str]:
590
  """Start training with finetrainers"""
591
 
 
835
  logger.error(error_msg)
836
  self.append_log(error_msg)
837
  return error_msg, "No valid checkpoints available"
838
+
839
+ # Add pretrained LoRA path if provided (for starting fresh training with existing weights)
840
+ if pretrained_lora_path:
841
+ # Validate the LoRA path exists and contains required files
842
+ lora_path = Path(pretrained_lora_path)
843
+ lora_weights_file = lora_path / "pytorch_lora_weights.safetensors"
844
+
845
+ if not lora_path.exists():
846
+ error_msg = f"Pretrained LoRA path does not exist: {pretrained_lora_path}"
847
+ logger.error(error_msg)
848
+ self.append_log(error_msg)
849
+ return error_msg, "LoRA path not found"
850
+
851
+ if not lora_weights_file.exists():
852
+ error_msg = f"LoRA weights file not found: {lora_weights_file}"
853
+ logger.error(error_msg)
854
+ self.append_log(error_msg)
855
+ return error_msg, "LoRA weights file missing"
856
+
857
+ # Set the pretrained LoRA path for the patched Finetrainers
858
+ config.pretrained_lora_path = str(lora_path)
859
+ self.append_log(f"Starting training with pretrained LoRA weights from: {lora_path}")
860
+ logger.info(f"Using pretrained LoRA weights: {lora_path}")
861
 
862
  # Common settings for both models
863
  config.mixed_precision = DEFAULT_MIXED_PRECISION
 
1194
  logger.error(f"Failed to remove corrupted checkpoint {checkpoint_dir}: {e}")
1195
  self.append_log(f"Failed to remove corrupted checkpoint {checkpoint_dir.name}: {e}")
1196
 
1197
+ def cleanup_old_lora_weights(self, max_to_keep: int = 2) -> None:
1198
+ """Remove old LoRA weight directories, keeping only the most recent ones
1199
+
1200
+ Args:
1201
+ max_to_keep: Maximum number of LoRA weight directories to keep (default: 2)
1202
+ """
1203
+ lora_weights_path = self.app.output_path / "lora_weights"
1204
+
1205
+ if not lora_weights_path.exists():
1206
+ logger.debug("LoRA weights directory does not exist, nothing to clean up")
1207
+ return
1208
+
1209
+ # Find all LoRA weight directories (should be named with step numbers)
1210
+ lora_dirs = []
1211
+ for item in lora_weights_path.iterdir():
1212
+ if item.is_dir() and item.name.isdigit():
1213
+ lora_dirs.append(item)
1214
+
1215
+ if len(lora_dirs) <= max_to_keep:
1216
+ logger.debug(f"Found {len(lora_dirs)} LoRA weight directories, no cleanup needed (keeping {max_to_keep})")
1217
+ return
1218
+
1219
+ # Sort by step number (directory name) in descending order (newest first)
1220
+ lora_dirs_sorted = sorted(lora_dirs, key=lambda x: int(x.name), reverse=True)
1221
+
1222
+ # Keep the most recent max_to_keep directories, remove the rest
1223
+ dirs_to_keep = lora_dirs_sorted[:max_to_keep]
1224
+ dirs_to_remove = lora_dirs_sorted[max_to_keep:]
1225
+
1226
+ logger.info(f"Cleaning up old LoRA weights: keeping {len(dirs_to_keep)}, removing {len(dirs_to_remove)}")
1227
+ self.append_log(f"Cleaning up old LoRA weights: keeping latest {max_to_keep} directories")
1228
+
1229
+ for lora_dir in dirs_to_remove:
1230
+ try:
1231
+ step_num = int(lora_dir.name)
1232
+ logger.info(f"Removing old LoRA weights at step {step_num}: {lora_dir}")
1233
+ shutil.rmtree(lora_dir)
1234
+ self.append_log(f"Removed old LoRA weights: step {step_num}")
1235
+ except Exception as e:
1236
+ logger.error(f"Failed to remove old LoRA weights {lora_dir}: {e}")
1237
+ self.append_log(f"Failed to remove old LoRA weights {lora_dir.name}: {e}")
1238
+
1239
+ # Log what we kept
1240
+ kept_steps = [int(d.name) for d in dirs_to_keep]
1241
+ kept_steps.sort(reverse=True)
1242
+ logger.info(f"Kept LoRA weights for steps: {kept_steps}")
1243
+ self.append_log(f"Kept LoRA weights for steps: {kept_steps}")
1244
+
1245
+ def _background_cleanup_task(self) -> None:
1246
+ """Background task that runs every 10 minutes to clean up old LoRA weights"""
1247
+ cleanup_interval = 600 # 10 minutes in seconds
1248
+
1249
+ logger.info("Started background LoRA cleanup task (runs every 10 minutes)")
1250
+
1251
+ while not self._cleanup_stop_event.is_set():
1252
+ try:
1253
+ # Wait for 10 minutes or until stop event is set
1254
+ if self._cleanup_stop_event.wait(timeout=cleanup_interval):
1255
+ break # Stop event was set
1256
+
1257
+ # Only run cleanup if we have an output path
1258
+ if hasattr(self.app, 'output_path') and self.app.output_path:
1259
+ lora_weights_path = self.app.output_path / "lora_weights"
1260
+
1261
+ # Only cleanup if the directory exists and has content
1262
+ if lora_weights_path.exists():
1263
+ lora_dirs = [d for d in lora_weights_path.iterdir() if d.is_dir() and d.name.isdigit()]
1264
+
1265
+ if len(lora_dirs) > 2:
1266
+ logger.info(f"Background cleanup: Found {len(lora_dirs)} LoRA weight directories, cleaning up old ones")
1267
+ self.cleanup_old_lora_weights(max_to_keep=2)
1268
+ else:
1269
+ logger.debug(f"Background cleanup: Found {len(lora_dirs)} LoRA weight directories, no cleanup needed")
1270
+
1271
+ except Exception as e:
1272
+ logger.error(f"Background LoRA cleanup task error: {e}")
1273
+ # Continue running despite errors
1274
+
1275
+ logger.info("Background LoRA cleanup task stopped")
1276
+
1277
+ def stop_background_cleanup(self) -> None:
1278
+ """Stop the background cleanup task"""
1279
+ if hasattr(self, '_cleanup_stop_event'):
1280
+ self._cleanup_stop_event.set()
1281
+ if hasattr(self, '_cleanup_thread') and self._cleanup_thread.is_alive():
1282
+ self._cleanup_thread.join(timeout=5)
1283
+ logger.info("Background cleanup task stopped")
1284
+
1285
  def recover_interrupted_training(self) -> Dict[str, Any]:
1286
  """Attempt to recover interrupted training
1287
 
 
1617
  gr.Info(success_msg)
1618
  self.save_status(state='completed', message=success_msg)
1619
 
1620
+ # Clean up old LoRA weights to save disk space
1621
+ try:
1622
+ self.cleanup_old_lora_weights(max_to_keep=2)
1623
+ except Exception as e:
1624
+ logger.warning(f"Failed to cleanup old LoRA weights: {e}")
1625
+ self.append_log(f"Warning: Failed to cleanup old LoRA weights: {e}")
1626
+
1627
  # Upload final model if repository was specified
1628
  session = self.load_session()
1629
  if session and session['params'].get('repo_id'):
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -102,6 +102,18 @@ class ManageTab(BaseTab):
102
  "Push my model"
103
  )
104
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  with gr.Row():
106
  with gr.Column():
107
  gr.Markdown("## ♻️ Delete your data")
@@ -225,6 +237,12 @@ class ManageTab(BaseTab):
225
  outputs=[self.components["download_output_btn"]]
226
  )
227
 
 
 
 
 
 
 
228
  # Dataset deletion with modal
229
  self.components["delete_dataset_btn"].click(
230
  fn=lambda: Modal(visible=True),
@@ -346,6 +364,16 @@ class ManageTab(BaseTab):
346
  else:
347
  return f"Failed to upload model to {repo_id}"
348
 
 
 
 
 
 
 
 
 
 
 
349
  def delete_dataset(self):
350
  """Delete dataset files (images, videos, captions)"""
351
  status_messages = {}
 
102
  "Push my model"
103
  )
104
 
105
+ with gr.Row():
106
+ with gr.Column():
107
+ gr.Markdown("## 🧹 Maintenance")
108
+ gr.Markdown("Clean up old files to free disk space.")
109
+
110
+ with gr.Row():
111
+ self.components["cleanup_lora_btn"] = gr.Button(
112
+ "πŸ”„ Keep last 2 LoRA weights and clean up older ones",
113
+ variant="secondary",
114
+ size="lg"
115
+ )
116
+
117
  with gr.Row():
118
  with gr.Column():
119
  gr.Markdown("## ♻️ Delete your data")
 
237
  outputs=[self.components["download_output_btn"]]
238
  )
239
 
240
+ # LoRA cleanup button
241
+ self.components["cleanup_lora_btn"].click(
242
+ fn=self.cleanup_old_lora_weights,
243
+ outputs=[]
244
+ )
245
+
246
  # Dataset deletion with modal
247
  self.components["delete_dataset_btn"].click(
248
  fn=lambda: Modal(visible=True),
 
364
  else:
365
  return f"Failed to upload model to {repo_id}"
366
 
367
+ def cleanup_old_lora_weights(self):
368
+ """Clean up old LoRA weight directories, keeping only the latest 2"""
369
+ try:
370
+ self.app.training.cleanup_old_lora_weights(max_to_keep=2)
371
+ gr.Info("βœ… Successfully cleaned up old LoRA weights")
372
+ except Exception as e:
373
+ error_msg = f"❌ Failed to cleanup LoRA weights: {str(e)}"
374
+ gr.Error(error_msg)
375
+ logger.error(f"LoRA cleanup failed: {e}")
376
+
377
  def delete_dataset(self):
378
  """Delete dataset files (images, videos, captions)"""
379
  status_messages = {}
vms/ui/project/tabs/train_tab.py CHANGED
@@ -341,12 +341,20 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
341
  ## βš—οΈ Train your model on your dataset
342
  - **πŸš€ Start new training**: Begins training from scratch (clears previous checkpoints)
343
  - **πŸ›Έ Start from latest checkpoint**: Continues training from the most recent checkpoint
 
344
  """)
345
 
346
  with gr.Row():
347
  # Check for existing checkpoints to determine button text
348
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
349
  has_checkpoints = len(checkpoints) > 0
 
 
 
 
 
 
 
350
 
351
  self.components["start_btn"] = gr.Button(
352
  "πŸš€ Start new training",
@@ -361,6 +369,13 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
361
  interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
362
  )
363
 
 
 
 
 
 
 
 
364
  with gr.Row():
365
  # Just use stop and pause buttons for now to ensure compatibility
366
  self.components["stop_btn"] = gr.Button(
@@ -497,6 +512,52 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
497
  resume_from_checkpoint="latest"
498
  )
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  def connect_events(self) -> None:
501
  """Connect event handlers to UI components"""
502
  # Model type change event - Update model version dropdown choices and default parameters
@@ -701,6 +762,26 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
701
  self.components["log_box"]
702
  ]
703
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
 
705
 
706
  # Use simplified event handlers for pause/resume and stop
@@ -780,6 +861,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
780
  save_iterations, repo_id,
781
  progress=gr.Progress(),
782
  resume_from_checkpoint=None,
 
783
  ):
784
  """Handle training start with proper log parser reset and checkpoint detection"""
785
 
@@ -840,7 +922,8 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
840
  num_gpus=num_gpus,
841
  precomputation_items=precomputation_items,
842
  lr_warmup_steps=lr_warmup_steps,
843
- progress=progress
 
844
  )
845
  except Exception as e:
846
  logger.exception("Error starting training")
 
341
  ## βš—οΈ Train your model on your dataset
342
  - **πŸš€ Start new training**: Begins training from scratch (clears previous checkpoints)
343
  - **πŸ›Έ Start from latest checkpoint**: Continues training from the most recent checkpoint
344
+ - **πŸ”„ Start over using latest LoRA weights**: Start fresh training but use existing LoRA weights as initialization
345
  """)
346
 
347
  with gr.Row():
348
  # Check for existing checkpoints to determine button text
349
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
350
  has_checkpoints = len(checkpoints) > 0
351
+
352
+ # Check for existing LoRA weights
353
+ lora_weights_path = self.app.output_path / "lora_weights"
354
+ has_lora_weights = False
355
+ if lora_weights_path.exists():
356
+ lora_dirs = [d for d in lora_weights_path.iterdir() if d.is_dir()]
357
+ has_lora_weights = len(lora_dirs) > 0
358
 
359
  self.components["start_btn"] = gr.Button(
360
  "πŸš€ Start new training",
 
369
  interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
370
  )
371
 
372
+ # Add new button for starting from LoRA weights
373
+ self.components["start_from_lora_btn"] = gr.Button(
374
+ "πŸ”„ Start over using latest LoRA weights",
375
+ variant="secondary",
376
+ interactive=has_lora_weights and not ASK_USER_TO_DUPLICATE_SPACE
377
+ )
378
+
379
  with gr.Row():
380
  # Just use stop and pause buttons for now to ensure compatibility
381
  self.components["stop_btn"] = gr.Button(
 
512
  resume_from_checkpoint="latest"
513
  )
514
 
515
+ def handle_start_from_lora_training(
516
+ self, model_type, model_version, training_type,
517
+ lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
518
+ save_iterations, repo_id, progress=gr.Progress()
519
+ ):
520
+ """Handle starting training from existing LoRA weights"""
521
+ # Find the latest LoRA weights
522
+ lora_weights_path = self.app.output_path / "lora_weights"
523
+
524
+ if not lora_weights_path.exists():
525
+ return "No LoRA weights found", "Please train a model first or start a new training session"
526
+
527
+ # Find the latest LoRA checkpoint directory
528
+ lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
529
+ key=lambda x: int(x.name), reverse=True)
530
+
531
+ if not lora_dirs:
532
+ return "No LoRA weight directories found", "Please train a model first or start a new training session"
533
+
534
+ latest_lora_dir = lora_dirs[0]
535
+
536
+ # Verify the LoRA weights file exists
537
+ lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
538
+ if not lora_weights_file.exists():
539
+ return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
540
+
541
+ # Clear checkpoints to start fresh (but keep LoRA weights)
542
+ for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
543
+ if checkpoint.is_dir():
544
+ shutil.rmtree(checkpoint)
545
+
546
+ # Delete session.json to start fresh
547
+ session_file = self.app.output_path / "session.json"
548
+ if session_file.exists():
549
+ session_file.unlink()
550
+
551
+ self.app.training.append_log(f"Starting training from LoRA weights: {latest_lora_dir}")
552
+
553
+ # Start training with the LoRA weights
554
+ return self.handle_training_start(
555
+ model_type, model_version, training_type,
556
+ lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
557
+ save_iterations, repo_id, progress,
558
+ pretrained_lora_path=str(latest_lora_dir)
559
+ )
560
+
561
  def connect_events(self) -> None:
562
  """Connect event handlers to UI components"""
563
  # Model type change event - Update model version dropdown choices and default parameters
 
762
  self.components["log_box"]
763
  ]
764
  )
765
+
766
+ self.components["start_from_lora_btn"].click(
767
+ fn=self.handle_start_from_lora_training,
768
+ inputs=[
769
+ self.components["model_type"],
770
+ self.components["model_version"],
771
+ self.components["training_type"],
772
+ self.components["lora_rank"],
773
+ self.components["lora_alpha"],
774
+ self.components["train_steps"],
775
+ self.components["batch_size"],
776
+ self.components["learning_rate"],
777
+ self.components["save_iterations"],
778
+ self.app.tabs["manage_tab"].components["repo_id"]
779
+ ],
780
+ outputs=[
781
+ self.components["status_box"],
782
+ self.components["log_box"]
783
+ ]
784
+ )
785
 
786
 
787
  # Use simplified event handlers for pause/resume and stop
 
861
  save_iterations, repo_id,
862
  progress=gr.Progress(),
863
  resume_from_checkpoint=None,
864
+ pretrained_lora_path=None,
865
  ):
866
  """Handle training start with proper log parser reset and checkpoint detection"""
867
 
 
922
  num_gpus=num_gpus,
923
  precomputation_items=precomputation_items,
924
  lr_warmup_steps=lr_warmup_steps,
925
+ progress=progress,
926
+ pretrained_lora_path=pretrained_lora_path
927
  )
928
  except Exception as e:
929
  logger.exception("Error starting training")