jbilcke-hf HF Staff commited on
Commit
ece1c33
Β·
1 Parent(s): 98352eb

fail to crack/hack finetrainers, so reverting..

Browse files
.claude/settings.local.json CHANGED
@@ -1,3 +1,8 @@
1
  {
2
- "enableAllProjectMcpServers": false
 
 
 
 
 
3
  }
 
1
  {
2
+ "enableAllProjectMcpServers": false,
3
+ "permissions": {
4
+ "allow": [
5
+ "Bash(rg:*)"
6
+ ]
7
+ }
8
  }
vms/patches/finetrainers_lora_loading.py DELETED
@@ -1,142 +0,0 @@
1
- """
2
- Monkey patch for Finetrainers to support loading existing LoRA weights as training initialization.
3
-
4
- This patch extends the SFTTrainer to accept a --pretrained_lora_path argument that allows
5
- starting training from existing LoRA weights instead of random initialization.
6
- """
7
-
8
- import logging
9
- import json
10
- from typing import Optional, Dict, Any
11
- from pathlib import Path
12
-
13
- import safetensors.torch
14
- from peft import set_peft_model_state_dict
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
- # Global flag to track if patch has been applied
19
- _PATCH_APPLIED = False
20
-
21
- def _load_pretrained_lora_weights(self, lora_path: str) -> None:
22
- """Load existing LoRA weights as training initialization
23
-
24
- Args:
25
- lora_path: Path to directory containing pytorch_lora_weights.safetensors
26
- """
27
- lora_path = Path(lora_path)
28
-
29
- # Find the safetensors file
30
- safetensors_file = lora_path / "pytorch_lora_weights.safetensors"
31
- if not safetensors_file.exists():
32
- raise FileNotFoundError(f"LoRA weights file not found: {safetensors_file}")
33
-
34
- logger.info(f"Loading pretrained LoRA weights from: {safetensors_file}")
35
-
36
- try:
37
- # Load the LoRA weights
38
- lora_state_dict = safetensors.torch.load_file(str(safetensors_file))
39
-
40
- # Extract metadata if available
41
- metadata = {}
42
- try:
43
- with open(safetensors_file, 'rb') as f:
44
- # Try to read metadata from safetensors header
45
- header_size = int.from_bytes(f.read(8), 'little')
46
- header_data = f.read(header_size)
47
- header = json.loads(header_data.decode('utf-8'))
48
- metadata = header.get('__metadata__', {})
49
- except Exception as e:
50
- logger.debug(f"Could not read metadata from safetensors: {e}")
51
-
52
- # Log metadata info if available
53
- if metadata:
54
- logger.info(f"LoRA metadata: rank={metadata.get('rank', 'unknown')}, "
55
- f"alpha={metadata.get('lora_alpha', 'unknown')}")
56
-
57
- # Apply the LoRA weights to the model
58
- set_peft_model_state_dict(self.transformer, lora_state_dict)
59
-
60
- logger.info(f"Successfully loaded LoRA weights from {safetensors_file}")
61
-
62
- # Log the loaded keys for debugging
63
- logger.debug(f"Loaded LoRA keys: {list(lora_state_dict.keys())}")
64
-
65
- except Exception as e:
66
- logger.error(f"Failed to load LoRA weights from {safetensors_file}: {e}")
67
- raise RuntimeError(f"Failed to load LoRA weights: {e}")
68
-
69
-
70
- def patched_prepare_trainable_parameters(self) -> None:
71
- """Patched version of _prepare_trainable_parameters that supports pretrained LoRA loading"""
72
-
73
- # Call the original method first
74
- original_prepare_trainable_parameters(self)
75
-
76
- # Check if pretrained LoRA path is provided
77
- if hasattr(self.args, 'pretrained_lora_path') and self.args.pretrained_lora_path:
78
- logger.info(f"Pretrained LoRA path specified: {self.args.pretrained_lora_path}")
79
-
80
- # Only load if we're doing LoRA training
81
- if hasattr(self.args, 'training_type') and str(self.args.training_type) == 'TrainingType.LORA':
82
- self._load_pretrained_lora_weights(self.args.pretrained_lora_path)
83
- else:
84
- logger.warning("pretrained_lora_path specified but training_type is not LORA")
85
-
86
-
87
- def apply_lora_loading_patch() -> None:
88
- """Apply the monkey patch to enable LoRA weight loading in Finetrainers"""
89
- global _PATCH_APPLIED
90
-
91
- if _PATCH_APPLIED:
92
- logger.debug("Finetrainers LoRA loading patch already applied")
93
- return
94
-
95
- try:
96
- from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
97
-
98
- # Store reference to original method
99
- global original_prepare_trainable_parameters
100
- original_prepare_trainable_parameters = SFTTrainer._prepare_trainable_parameters
101
-
102
- # Apply patches
103
- SFTTrainer._prepare_trainable_parameters = patched_prepare_trainable_parameters
104
- SFTTrainer._load_pretrained_lora_weights = _load_pretrained_lora_weights
105
-
106
- _PATCH_APPLIED = True
107
- logger.info("Successfully applied Finetrainers LoRA loading patch")
108
-
109
- except ImportError as e:
110
- logger.error(f"Failed to import Finetrainers classes for patching: {e}")
111
- raise
112
- except Exception as e:
113
- logger.error(f"Failed to apply Finetrainers LoRA loading patch: {e}")
114
- raise
115
-
116
-
117
- def remove_lora_loading_patch() -> None:
118
- """Remove the monkey patch (for testing purposes)"""
119
- global _PATCH_APPLIED
120
-
121
- if not _PATCH_APPLIED:
122
- return
123
-
124
- try:
125
- from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
126
-
127
- # Restore original method
128
- SFTTrainer._prepare_trainable_parameters = original_prepare_trainable_parameters
129
-
130
- # Remove added method
131
- if hasattr(SFTTrainer, '_load_pretrained_lora_weights'):
132
- delattr(SFTTrainer, '_load_pretrained_lora_weights')
133
-
134
- _PATCH_APPLIED = False
135
- logger.info("Removed Finetrainers LoRA loading patch")
136
-
137
- except Exception as e:
138
- logger.error(f"Failed to remove Finetrainers LoRA loading patch: {e}")
139
-
140
-
141
- # Store reference to original method (will be set when patch is applied)
142
- original_prepare_trainable_parameters = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vms/ui/project/services/training.py CHANGED
@@ -54,7 +54,6 @@ from vms.utils import (
54
  prepare_finetrainers_dataset,
55
  copy_files_to_training_dir
56
  )
57
- from vms.patches.finetrainers_lora_loading import apply_lora_loading_patch
58
 
59
  logger = logging.getLogger(__name__)
60
  logger.setLevel(logging.INFO)
@@ -73,11 +72,6 @@ class TrainingService:
73
  self.setup_logging()
74
  self.ensure_valid_ui_state_file()
75
 
76
- # Apply Finetrainers patches for LoRA weight loading
77
- try:
78
- apply_lora_loading_patch()
79
- except Exception as e:
80
- logger.warning(f"Failed to apply Finetrainers LoRA loading patch: {e}")
81
 
82
  # Start background cleanup task
83
  self._cleanup_stop_event = threading.Event()
@@ -585,7 +579,6 @@ class TrainingService:
585
  precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
586
  lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
587
  progress: Optional[gr.Progress] = None,
588
- pretrained_lora_path: Optional[str] = None,
589
  ) -> Tuple[str, str]:
590
  """Start training with finetrainers"""
591
 
@@ -836,29 +829,6 @@ class TrainingService:
836
  self.append_log(error_msg)
837
  return error_msg, "No valid checkpoints available"
838
 
839
- # Add pretrained LoRA path if provided (for starting fresh training with existing weights)
840
- if pretrained_lora_path:
841
- # Validate the LoRA path exists and contains required files
842
- lora_path = Path(pretrained_lora_path)
843
- lora_weights_file = lora_path / "pytorch_lora_weights.safetensors"
844
-
845
- if not lora_path.exists():
846
- error_msg = f"Pretrained LoRA path does not exist: {pretrained_lora_path}"
847
- logger.error(error_msg)
848
- self.append_log(error_msg)
849
- return error_msg, "LoRA path not found"
850
-
851
- if not lora_weights_file.exists():
852
- error_msg = f"LoRA weights file not found: {lora_weights_file}"
853
- logger.error(error_msg)
854
- self.append_log(error_msg)
855
- return error_msg, "LoRA weights file missing"
856
-
857
- # Set the pretrained LoRA path for the patched Finetrainers
858
- config.pretrained_lora_path = str(lora_path)
859
- self.append_log(f"Starting training with pretrained LoRA weights from: {lora_path}")
860
- logger.info(f"Using pretrained LoRA weights: {lora_path}")
861
-
862
  # Common settings for both models
863
  config.mixed_precision = DEFAULT_MIXED_PRECISION
864
  config.seed = DEFAULT_SEED
@@ -1823,4 +1793,47 @@ class TrainingService:
1823
  return temp_zip_path
1824
  except Exception as e:
1825
  print(f"Failed to create output zip: {str(e)}")
1826
- raise gr.Error(f"Failed to create output zip: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  prepare_finetrainers_dataset,
55
  copy_files_to_training_dir
56
  )
 
57
 
58
  logger = logging.getLogger(__name__)
59
  logger.setLevel(logging.INFO)
 
72
  self.setup_logging()
73
  self.ensure_valid_ui_state_file()
74
 
 
 
 
 
 
75
 
76
  # Start background cleanup task
77
  self._cleanup_stop_event = threading.Event()
 
579
  precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
580
  lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
581
  progress: Optional[gr.Progress] = None,
 
582
  ) -> Tuple[str, str]:
583
  """Start training with finetrainers"""
584
 
 
829
  self.append_log(error_msg)
830
  return error_msg, "No valid checkpoints available"
831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
  # Common settings for both models
833
  config.mixed_precision = DEFAULT_MIXED_PRECISION
834
  config.seed = DEFAULT_SEED
 
1793
  return temp_zip_path
1794
  except Exception as e:
1795
  print(f"Failed to create output zip: {str(e)}")
1796
+ raise gr.Error(f"Failed to create output zip: {str(e)}")
1797
+
1798
+ def create_checkpoint_zip(self) -> Optional[str]:
1799
+ """Create a ZIP file containing the latest finetrainers checkpoint
1800
+
1801
+ Returns:
1802
+ Path to created ZIP file or None if no checkpoint found
1803
+ """
1804
+ # Find all checkpoint directories
1805
+ checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1806
+ if not checkpoints:
1807
+ logger.info("No checkpoint directories found")
1808
+ raise gr.Error("No checkpoint directories found")
1809
+
1810
+ # Get the latest checkpoint by step number
1811
+ latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1812
+ step_num = int(latest_checkpoint.name.split("_")[-1])
1813
+
1814
+ # Create temporary zip file
1815
+ with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip:
1816
+ temp_zip_path = str(temp_zip.name)
1817
+ print(f"Creating zip file for checkpoint {latest_checkpoint.name}..")
1818
+ try:
1819
+ make_archive(latest_checkpoint, temp_zip_path)
1820
+ print(f"Checkpoint zip file created for step {step_num}!")
1821
+ return temp_zip_path
1822
+ except Exception as e:
1823
+ print(f"Failed to create checkpoint zip: {str(e)}")
1824
+ raise gr.Error(f"Failed to create checkpoint zip: {str(e)}")
1825
+
1826
+ def get_checkpoint_button_text(self) -> str:
1827
+ """Get the dynamic text for the download checkpoint button based on available checkpoints"""
1828
+ try:
1829
+ checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
1830
+ if not checkpoints:
1831
+ return "πŸ“₯ Download checkpoints (not available)"
1832
+
1833
+ # Get the latest checkpoint by step number
1834
+ latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1835
+ step_num = int(latest_checkpoint.name.split("_")[-1])
1836
+ return f"πŸ“₯ Download checkpoints (step {step_num})"
1837
+ except Exception as e:
1838
+ logger.warning(f"Error getting checkpoint info for button text: {e}")
1839
+ return "πŸ“₯ Download checkpoints (not available)"
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -39,6 +39,14 @@ class ManageTab(BaseTab):
39
  logger.warning(f"Error getting model info for button text: {e}")
40
  return "🧠 Download weights (.safetensors)"
41
 
 
 
 
 
 
 
 
 
42
  def update_download_button_text(self) -> gr.update:
43
  """Update the download button text"""
44
  return gr.update(value=self.get_download_button_text())
@@ -76,6 +84,12 @@ class ManageTab(BaseTab):
76
  size="lg"
77
  )
78
 
 
 
 
 
 
 
79
  self.components["download_output_btn"] = gr.DownloadButton(
80
  "πŸ“ Download output directory (.zip)",
81
  variant="secondary",
@@ -232,6 +246,11 @@ class ManageTab(BaseTab):
232
  outputs=[self.components["download_model_btn"]]
233
  )
234
 
 
 
 
 
 
235
  self.components["download_output_btn"].click(
236
  fn=self.app.training.create_output_directory_zip,
237
  outputs=[self.components["download_output_btn"]]
 
39
  logger.warning(f"Error getting model info for button text: {e}")
40
  return "🧠 Download weights (.safetensors)"
41
 
42
+ def get_checkpoint_button_text(self) -> str:
43
+ """Get the dynamic text for the download checkpoint button"""
44
+ try:
45
+ return self.app.training.get_checkpoint_button_text()
46
+ except Exception as e:
47
+ logger.warning(f"Error getting checkpoint button text: {e}")
48
+ return "πŸ“₯ Download checkpoints (not available)"
49
+
50
  def update_download_button_text(self) -> gr.update:
51
  """Update the download button text"""
52
  return gr.update(value=self.get_download_button_text())
 
84
  size="lg"
85
  )
86
 
87
+ self.components["download_checkpoint_btn"] = gr.DownloadButton(
88
+ self.get_checkpoint_button_text(),
89
+ variant="secondary",
90
+ size="lg"
91
+ )
92
+
93
  self.components["download_output_btn"] = gr.DownloadButton(
94
  "πŸ“ Download output directory (.zip)",
95
  variant="secondary",
 
246
  outputs=[self.components["download_model_btn"]]
247
  )
248
 
249
+ self.components["download_checkpoint_btn"].click(
250
+ fn=self.app.training.create_checkpoint_zip,
251
+ outputs=[self.components["download_checkpoint_btn"]]
252
+ )
253
+
254
  self.components["download_output_btn"].click(
255
  fn=self.app.training.create_output_directory_zip,
256
  outputs=[self.components["download_output_btn"]]
vms/ui/project/tabs/train_tab.py CHANGED
@@ -369,11 +369,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
369
  interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
370
  )
371
 
372
- # Add new button for starting from LoRA weights
373
  self.components["start_from_lora_btn"] = gr.Button(
374
  "πŸ”„ Start over using latest LoRA weights",
375
  variant="secondary",
376
- interactive=has_lora_weights and not ASK_USER_TO_DUPLICATE_SPACE
 
377
  )
378
 
379
  with gr.Row():
 
369
  interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
370
  )
371
 
372
+ # Starting from LoRA weights is DISABLED for now
373
  self.components["start_from_lora_btn"] = gr.Button(
374
  "πŸ”„ Start over using latest LoRA weights",
375
  variant="secondary",
376
+ interactive=has_lora_weights and not ASK_USER_TO_DUPLICATE_SPACE,
377
+ visible=False,
378
  )
379
 
380
  with gr.Row():