Spaces:
Running
Running
Commit
Β·
ece1c33
1
Parent(s):
98352eb
fail to crack/hack finetrainers, so reverting..
Browse files
.claude/settings.local.json
CHANGED
@@ -1,3 +1,8 @@
|
|
1 |
{
|
2 |
-
"enableAllProjectMcpServers": false
|
|
|
|
|
|
|
|
|
|
|
3 |
}
|
|
|
1 |
{
|
2 |
+
"enableAllProjectMcpServers": false,
|
3 |
+
"permissions": {
|
4 |
+
"allow": [
|
5 |
+
"Bash(rg:*)"
|
6 |
+
]
|
7 |
+
}
|
8 |
}
|
vms/patches/finetrainers_lora_loading.py
DELETED
@@ -1,142 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Monkey patch for Finetrainers to support loading existing LoRA weights as training initialization.
|
3 |
-
|
4 |
-
This patch extends the SFTTrainer to accept a --pretrained_lora_path argument that allows
|
5 |
-
starting training from existing LoRA weights instead of random initialization.
|
6 |
-
"""
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import json
|
10 |
-
from typing import Optional, Dict, Any
|
11 |
-
from pathlib import Path
|
12 |
-
|
13 |
-
import safetensors.torch
|
14 |
-
from peft import set_peft_model_state_dict
|
15 |
-
|
16 |
-
logger = logging.getLogger(__name__)
|
17 |
-
|
18 |
-
# Global flag to track if patch has been applied
|
19 |
-
_PATCH_APPLIED = False
|
20 |
-
|
21 |
-
def _load_pretrained_lora_weights(self, lora_path: str) -> None:
|
22 |
-
"""Load existing LoRA weights as training initialization
|
23 |
-
|
24 |
-
Args:
|
25 |
-
lora_path: Path to directory containing pytorch_lora_weights.safetensors
|
26 |
-
"""
|
27 |
-
lora_path = Path(lora_path)
|
28 |
-
|
29 |
-
# Find the safetensors file
|
30 |
-
safetensors_file = lora_path / "pytorch_lora_weights.safetensors"
|
31 |
-
if not safetensors_file.exists():
|
32 |
-
raise FileNotFoundError(f"LoRA weights file not found: {safetensors_file}")
|
33 |
-
|
34 |
-
logger.info(f"Loading pretrained LoRA weights from: {safetensors_file}")
|
35 |
-
|
36 |
-
try:
|
37 |
-
# Load the LoRA weights
|
38 |
-
lora_state_dict = safetensors.torch.load_file(str(safetensors_file))
|
39 |
-
|
40 |
-
# Extract metadata if available
|
41 |
-
metadata = {}
|
42 |
-
try:
|
43 |
-
with open(safetensors_file, 'rb') as f:
|
44 |
-
# Try to read metadata from safetensors header
|
45 |
-
header_size = int.from_bytes(f.read(8), 'little')
|
46 |
-
header_data = f.read(header_size)
|
47 |
-
header = json.loads(header_data.decode('utf-8'))
|
48 |
-
metadata = header.get('__metadata__', {})
|
49 |
-
except Exception as e:
|
50 |
-
logger.debug(f"Could not read metadata from safetensors: {e}")
|
51 |
-
|
52 |
-
# Log metadata info if available
|
53 |
-
if metadata:
|
54 |
-
logger.info(f"LoRA metadata: rank={metadata.get('rank', 'unknown')}, "
|
55 |
-
f"alpha={metadata.get('lora_alpha', 'unknown')}")
|
56 |
-
|
57 |
-
# Apply the LoRA weights to the model
|
58 |
-
set_peft_model_state_dict(self.transformer, lora_state_dict)
|
59 |
-
|
60 |
-
logger.info(f"Successfully loaded LoRA weights from {safetensors_file}")
|
61 |
-
|
62 |
-
# Log the loaded keys for debugging
|
63 |
-
logger.debug(f"Loaded LoRA keys: {list(lora_state_dict.keys())}")
|
64 |
-
|
65 |
-
except Exception as e:
|
66 |
-
logger.error(f"Failed to load LoRA weights from {safetensors_file}: {e}")
|
67 |
-
raise RuntimeError(f"Failed to load LoRA weights: {e}")
|
68 |
-
|
69 |
-
|
70 |
-
def patched_prepare_trainable_parameters(self) -> None:
|
71 |
-
"""Patched version of _prepare_trainable_parameters that supports pretrained LoRA loading"""
|
72 |
-
|
73 |
-
# Call the original method first
|
74 |
-
original_prepare_trainable_parameters(self)
|
75 |
-
|
76 |
-
# Check if pretrained LoRA path is provided
|
77 |
-
if hasattr(self.args, 'pretrained_lora_path') and self.args.pretrained_lora_path:
|
78 |
-
logger.info(f"Pretrained LoRA path specified: {self.args.pretrained_lora_path}")
|
79 |
-
|
80 |
-
# Only load if we're doing LoRA training
|
81 |
-
if hasattr(self.args, 'training_type') and str(self.args.training_type) == 'TrainingType.LORA':
|
82 |
-
self._load_pretrained_lora_weights(self.args.pretrained_lora_path)
|
83 |
-
else:
|
84 |
-
logger.warning("pretrained_lora_path specified but training_type is not LORA")
|
85 |
-
|
86 |
-
|
87 |
-
def apply_lora_loading_patch() -> None:
|
88 |
-
"""Apply the monkey patch to enable LoRA weight loading in Finetrainers"""
|
89 |
-
global _PATCH_APPLIED
|
90 |
-
|
91 |
-
if _PATCH_APPLIED:
|
92 |
-
logger.debug("Finetrainers LoRA loading patch already applied")
|
93 |
-
return
|
94 |
-
|
95 |
-
try:
|
96 |
-
from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
|
97 |
-
|
98 |
-
# Store reference to original method
|
99 |
-
global original_prepare_trainable_parameters
|
100 |
-
original_prepare_trainable_parameters = SFTTrainer._prepare_trainable_parameters
|
101 |
-
|
102 |
-
# Apply patches
|
103 |
-
SFTTrainer._prepare_trainable_parameters = patched_prepare_trainable_parameters
|
104 |
-
SFTTrainer._load_pretrained_lora_weights = _load_pretrained_lora_weights
|
105 |
-
|
106 |
-
_PATCH_APPLIED = True
|
107 |
-
logger.info("Successfully applied Finetrainers LoRA loading patch")
|
108 |
-
|
109 |
-
except ImportError as e:
|
110 |
-
logger.error(f"Failed to import Finetrainers classes for patching: {e}")
|
111 |
-
raise
|
112 |
-
except Exception as e:
|
113 |
-
logger.error(f"Failed to apply Finetrainers LoRA loading patch: {e}")
|
114 |
-
raise
|
115 |
-
|
116 |
-
|
117 |
-
def remove_lora_loading_patch() -> None:
|
118 |
-
"""Remove the monkey patch (for testing purposes)"""
|
119 |
-
global _PATCH_APPLIED
|
120 |
-
|
121 |
-
if not _PATCH_APPLIED:
|
122 |
-
return
|
123 |
-
|
124 |
-
try:
|
125 |
-
from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
|
126 |
-
|
127 |
-
# Restore original method
|
128 |
-
SFTTrainer._prepare_trainable_parameters = original_prepare_trainable_parameters
|
129 |
-
|
130 |
-
# Remove added method
|
131 |
-
if hasattr(SFTTrainer, '_load_pretrained_lora_weights'):
|
132 |
-
delattr(SFTTrainer, '_load_pretrained_lora_weights')
|
133 |
-
|
134 |
-
_PATCH_APPLIED = False
|
135 |
-
logger.info("Removed Finetrainers LoRA loading patch")
|
136 |
-
|
137 |
-
except Exception as e:
|
138 |
-
logger.error(f"Failed to remove Finetrainers LoRA loading patch: {e}")
|
139 |
-
|
140 |
-
|
141 |
-
# Store reference to original method (will be set when patch is applied)
|
142 |
-
original_prepare_trainable_parameters = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vms/ui/project/services/training.py
CHANGED
@@ -54,7 +54,6 @@ from vms.utils import (
|
|
54 |
prepare_finetrainers_dataset,
|
55 |
copy_files_to_training_dir
|
56 |
)
|
57 |
-
from vms.patches.finetrainers_lora_loading import apply_lora_loading_patch
|
58 |
|
59 |
logger = logging.getLogger(__name__)
|
60 |
logger.setLevel(logging.INFO)
|
@@ -73,11 +72,6 @@ class TrainingService:
|
|
73 |
self.setup_logging()
|
74 |
self.ensure_valid_ui_state_file()
|
75 |
|
76 |
-
# Apply Finetrainers patches for LoRA weight loading
|
77 |
-
try:
|
78 |
-
apply_lora_loading_patch()
|
79 |
-
except Exception as e:
|
80 |
-
logger.warning(f"Failed to apply Finetrainers LoRA loading patch: {e}")
|
81 |
|
82 |
# Start background cleanup task
|
83 |
self._cleanup_stop_event = threading.Event()
|
@@ -585,7 +579,6 @@ class TrainingService:
|
|
585 |
precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
|
586 |
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
|
587 |
progress: Optional[gr.Progress] = None,
|
588 |
-
pretrained_lora_path: Optional[str] = None,
|
589 |
) -> Tuple[str, str]:
|
590 |
"""Start training with finetrainers"""
|
591 |
|
@@ -836,29 +829,6 @@ class TrainingService:
|
|
836 |
self.append_log(error_msg)
|
837 |
return error_msg, "No valid checkpoints available"
|
838 |
|
839 |
-
# Add pretrained LoRA path if provided (for starting fresh training with existing weights)
|
840 |
-
if pretrained_lora_path:
|
841 |
-
# Validate the LoRA path exists and contains required files
|
842 |
-
lora_path = Path(pretrained_lora_path)
|
843 |
-
lora_weights_file = lora_path / "pytorch_lora_weights.safetensors"
|
844 |
-
|
845 |
-
if not lora_path.exists():
|
846 |
-
error_msg = f"Pretrained LoRA path does not exist: {pretrained_lora_path}"
|
847 |
-
logger.error(error_msg)
|
848 |
-
self.append_log(error_msg)
|
849 |
-
return error_msg, "LoRA path not found"
|
850 |
-
|
851 |
-
if not lora_weights_file.exists():
|
852 |
-
error_msg = f"LoRA weights file not found: {lora_weights_file}"
|
853 |
-
logger.error(error_msg)
|
854 |
-
self.append_log(error_msg)
|
855 |
-
return error_msg, "LoRA weights file missing"
|
856 |
-
|
857 |
-
# Set the pretrained LoRA path for the patched Finetrainers
|
858 |
-
config.pretrained_lora_path = str(lora_path)
|
859 |
-
self.append_log(f"Starting training with pretrained LoRA weights from: {lora_path}")
|
860 |
-
logger.info(f"Using pretrained LoRA weights: {lora_path}")
|
861 |
-
|
862 |
# Common settings for both models
|
863 |
config.mixed_precision = DEFAULT_MIXED_PRECISION
|
864 |
config.seed = DEFAULT_SEED
|
@@ -1823,4 +1793,47 @@ class TrainingService:
|
|
1823 |
return temp_zip_path
|
1824 |
except Exception as e:
|
1825 |
print(f"Failed to create output zip: {str(e)}")
|
1826 |
-
raise gr.Error(f"Failed to create output zip: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
prepare_finetrainers_dataset,
|
55 |
copy_files_to_training_dir
|
56 |
)
|
|
|
57 |
|
58 |
logger = logging.getLogger(__name__)
|
59 |
logger.setLevel(logging.INFO)
|
|
|
72 |
self.setup_logging()
|
73 |
self.ensure_valid_ui_state_file()
|
74 |
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# Start background cleanup task
|
77 |
self._cleanup_stop_event = threading.Event()
|
|
|
579 |
precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
|
580 |
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
|
581 |
progress: Optional[gr.Progress] = None,
|
|
|
582 |
) -> Tuple[str, str]:
|
583 |
"""Start training with finetrainers"""
|
584 |
|
|
|
829 |
self.append_log(error_msg)
|
830 |
return error_msg, "No valid checkpoints available"
|
831 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
832 |
# Common settings for both models
|
833 |
config.mixed_precision = DEFAULT_MIXED_PRECISION
|
834 |
config.seed = DEFAULT_SEED
|
|
|
1793 |
return temp_zip_path
|
1794 |
except Exception as e:
|
1795 |
print(f"Failed to create output zip: {str(e)}")
|
1796 |
+
raise gr.Error(f"Failed to create output zip: {str(e)}")
|
1797 |
+
|
1798 |
+
def create_checkpoint_zip(self) -> Optional[str]:
|
1799 |
+
"""Create a ZIP file containing the latest finetrainers checkpoint
|
1800 |
+
|
1801 |
+
Returns:
|
1802 |
+
Path to created ZIP file or None if no checkpoint found
|
1803 |
+
"""
|
1804 |
+
# Find all checkpoint directories
|
1805 |
+
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1806 |
+
if not checkpoints:
|
1807 |
+
logger.info("No checkpoint directories found")
|
1808 |
+
raise gr.Error("No checkpoint directories found")
|
1809 |
+
|
1810 |
+
# Get the latest checkpoint by step number
|
1811 |
+
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1812 |
+
step_num = int(latest_checkpoint.name.split("_")[-1])
|
1813 |
+
|
1814 |
+
# Create temporary zip file
|
1815 |
+
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip:
|
1816 |
+
temp_zip_path = str(temp_zip.name)
|
1817 |
+
print(f"Creating zip file for checkpoint {latest_checkpoint.name}..")
|
1818 |
+
try:
|
1819 |
+
make_archive(latest_checkpoint, temp_zip_path)
|
1820 |
+
print(f"Checkpoint zip file created for step {step_num}!")
|
1821 |
+
return temp_zip_path
|
1822 |
+
except Exception as e:
|
1823 |
+
print(f"Failed to create checkpoint zip: {str(e)}")
|
1824 |
+
raise gr.Error(f"Failed to create checkpoint zip: {str(e)}")
|
1825 |
+
|
1826 |
+
def get_checkpoint_button_text(self) -> str:
|
1827 |
+
"""Get the dynamic text for the download checkpoint button based on available checkpoints"""
|
1828 |
+
try:
|
1829 |
+
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
1830 |
+
if not checkpoints:
|
1831 |
+
return "π₯ Download checkpoints (not available)"
|
1832 |
+
|
1833 |
+
# Get the latest checkpoint by step number
|
1834 |
+
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1835 |
+
step_num = int(latest_checkpoint.name.split("_")[-1])
|
1836 |
+
return f"π₯ Download checkpoints (step {step_num})"
|
1837 |
+
except Exception as e:
|
1838 |
+
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
1839 |
+
return "π₯ Download checkpoints (not available)"
|
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -39,6 +39,14 @@ class ManageTab(BaseTab):
|
|
39 |
logger.warning(f"Error getting model info for button text: {e}")
|
40 |
return "π§ Download weights (.safetensors)"
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def update_download_button_text(self) -> gr.update:
|
43 |
"""Update the download button text"""
|
44 |
return gr.update(value=self.get_download_button_text())
|
@@ -76,6 +84,12 @@ class ManageTab(BaseTab):
|
|
76 |
size="lg"
|
77 |
)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
self.components["download_output_btn"] = gr.DownloadButton(
|
80 |
"π Download output directory (.zip)",
|
81 |
variant="secondary",
|
@@ -232,6 +246,11 @@ class ManageTab(BaseTab):
|
|
232 |
outputs=[self.components["download_model_btn"]]
|
233 |
)
|
234 |
|
|
|
|
|
|
|
|
|
|
|
235 |
self.components["download_output_btn"].click(
|
236 |
fn=self.app.training.create_output_directory_zip,
|
237 |
outputs=[self.components["download_output_btn"]]
|
|
|
39 |
logger.warning(f"Error getting model info for button text: {e}")
|
40 |
return "π§ Download weights (.safetensors)"
|
41 |
|
42 |
+
def get_checkpoint_button_text(self) -> str:
|
43 |
+
"""Get the dynamic text for the download checkpoint button"""
|
44 |
+
try:
|
45 |
+
return self.app.training.get_checkpoint_button_text()
|
46 |
+
except Exception as e:
|
47 |
+
logger.warning(f"Error getting checkpoint button text: {e}")
|
48 |
+
return "π₯ Download checkpoints (not available)"
|
49 |
+
|
50 |
def update_download_button_text(self) -> gr.update:
|
51 |
"""Update the download button text"""
|
52 |
return gr.update(value=self.get_download_button_text())
|
|
|
84 |
size="lg"
|
85 |
)
|
86 |
|
87 |
+
self.components["download_checkpoint_btn"] = gr.DownloadButton(
|
88 |
+
self.get_checkpoint_button_text(),
|
89 |
+
variant="secondary",
|
90 |
+
size="lg"
|
91 |
+
)
|
92 |
+
|
93 |
self.components["download_output_btn"] = gr.DownloadButton(
|
94 |
"π Download output directory (.zip)",
|
95 |
variant="secondary",
|
|
|
246 |
outputs=[self.components["download_model_btn"]]
|
247 |
)
|
248 |
|
249 |
+
self.components["download_checkpoint_btn"].click(
|
250 |
+
fn=self.app.training.create_checkpoint_zip,
|
251 |
+
outputs=[self.components["download_checkpoint_btn"]]
|
252 |
+
)
|
253 |
+
|
254 |
self.components["download_output_btn"].click(
|
255 |
fn=self.app.training.create_output_directory_zip,
|
256 |
outputs=[self.components["download_output_btn"]]
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -369,11 +369,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
369 |
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
370 |
)
|
371 |
|
372 |
-
#
|
373 |
self.components["start_from_lora_btn"] = gr.Button(
|
374 |
"π Start over using latest LoRA weights",
|
375 |
variant="secondary",
|
376 |
-
interactive=has_lora_weights and not ASK_USER_TO_DUPLICATE_SPACE
|
|
|
377 |
)
|
378 |
|
379 |
with gr.Row():
|
|
|
369 |
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
370 |
)
|
371 |
|
372 |
+
# Starting from LoRA weights is DISABLED for now
|
373 |
self.components["start_from_lora_btn"] = gr.Button(
|
374 |
"π Start over using latest LoRA weights",
|
375 |
variant="secondary",
|
376 |
+
interactive=has_lora_weights and not ASK_USER_TO_DUPLICATE_SPACE,
|
377 |
+
visible=False,
|
378 |
)
|
379 |
|
380 |
with gr.Row():
|