Spaces:
Running
Running
Commit
·
41a8716
1
Parent(s):
2264c6e
Revert to commit a9df757
Browse files
vms/ui/app_ui.py
CHANGED
|
@@ -403,7 +403,6 @@ class AppUI:
|
|
| 403 |
]
|
| 404 |
)
|
| 405 |
|
| 406 |
-
|
| 407 |
# Button update timer for button components (every 1 second)
|
| 408 |
button_timer = gr.Timer(value=1)
|
| 409 |
button_outputs = [
|
|
|
|
| 403 |
]
|
| 404 |
)
|
| 405 |
|
|
|
|
| 406 |
# Button update timer for button components (every 1 second)
|
| 407 |
button_timer = gr.Timer(value=1)
|
| 408 |
button_outputs = [
|
vms/ui/models/tabs/training_tab.py
CHANGED
|
@@ -88,8 +88,9 @@ class TrainingTab(BaseTab):
|
|
| 88 |
gr.Markdown(model.model_display_name or "Unknown")
|
| 89 |
|
| 90 |
with gr.Column(scale=2, min_width=20):
|
| 91 |
-
progress_text = f"Step {model.current_step}/{model.total_steps}
|
| 92 |
gr.Markdown(progress_text)
|
|
|
|
| 93 |
|
| 94 |
with gr.Column(scale=2, min_width=20):
|
| 95 |
with gr.Row():
|
|
|
|
| 88 |
gr.Markdown(model.model_display_name or "Unknown")
|
| 89 |
|
| 90 |
with gr.Column(scale=2, min_width=20):
|
| 91 |
+
progress_text = f"Step {model.current_step}/{model.total_steps}"
|
| 92 |
gr.Markdown(progress_text)
|
| 93 |
+
gr.Progress(value=model.training_progress/100)
|
| 94 |
|
| 95 |
with gr.Column(scale=2, min_width=20):
|
| 96 |
with gr.Row():
|
vms/ui/project/services/training.py
CHANGED
|
@@ -1823,9 +1823,12 @@ class TrainingService:
|
|
| 1823 |
try:
|
| 1824 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 1825 |
if not checkpoints:
|
| 1826 |
-
return "
|
| 1827 |
|
| 1828 |
-
|
|
|
|
|
|
|
|
|
|
| 1829 |
except Exception as e:
|
| 1830 |
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
| 1831 |
-
return "
|
|
|
|
| 1823 |
try:
|
| 1824 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 1825 |
if not checkpoints:
|
| 1826 |
+
return "📥 Download checkpoints (not available)"
|
| 1827 |
|
| 1828 |
+
# Get the latest checkpoint by step number
|
| 1829 |
+
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
| 1830 |
+
step_num = int(latest_checkpoint.name.split("_")[-1])
|
| 1831 |
+
return f"📥 Download checkpoints (step {step_num})"
|
| 1832 |
except Exception as e:
|
| 1833 |
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
| 1834 |
+
return "📥 Download checkpoints (not available)"
|
vms/ui/project/tabs/manage_tab.py
CHANGED
|
@@ -25,6 +25,50 @@ class ManageTab(BaseTab):
|
|
| 25 |
self.id = "manage_tab"
|
| 26 |
self.title = "5️⃣ Storage"
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def create(self, parent=None) -> gr.TabItem:
|
| 30 |
"""Create the Manage tab UI components"""
|
|
@@ -46,19 +90,19 @@ class ManageTab(BaseTab):
|
|
| 46 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
| 47 |
|
| 48 |
self.components["download_model_btn"] = gr.DownloadButton(
|
| 49 |
-
|
| 50 |
variant="secondary",
|
| 51 |
size="lg"
|
| 52 |
)
|
| 53 |
|
| 54 |
self.components["download_checkpoint_btn"] = gr.DownloadButton(
|
| 55 |
-
|
| 56 |
variant="secondary",
|
| 57 |
size="lg"
|
| 58 |
)
|
| 59 |
|
| 60 |
self.components["download_output_btn"] = gr.DownloadButton(
|
| 61 |
-
"📁 Download output
|
| 62 |
variant="secondary",
|
| 63 |
size="lg",
|
| 64 |
visible=False
|
|
|
|
| 25 |
self.id = "manage_tab"
|
| 26 |
self.title = "5️⃣ Storage"
|
| 27 |
|
| 28 |
+
def get_download_button_text(self) -> str:
|
| 29 |
+
"""Get the dynamic text for the download button based on current model state"""
|
| 30 |
+
try:
|
| 31 |
+
model_info = self.app.training.get_model_output_info()
|
| 32 |
+
if model_info["path"] and model_info["steps"]:
|
| 33 |
+
return f"🧠 Download weights ({model_info['steps']} steps)"
|
| 34 |
+
elif model_info["path"]:
|
| 35 |
+
return "🧠 Download weights (.safetensors)"
|
| 36 |
+
else:
|
| 37 |
+
return "🧠 Download weights (not available)"
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.warning(f"Error getting model info for button text: {e}")
|
| 40 |
+
return "🧠 Download weights (.safetensors)"
|
| 41 |
+
|
| 42 |
+
def get_checkpoint_button_text(self) -> str:
|
| 43 |
+
"""Get the dynamic text for the download checkpoint button"""
|
| 44 |
+
try:
|
| 45 |
+
return self.app.training.get_checkpoint_button_text()
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.warning(f"Error getting checkpoint button text: {e}")
|
| 48 |
+
return "📥 Download checkpoints (not available)"
|
| 49 |
+
|
| 50 |
+
def update_download_button_text(self) -> gr.update:
|
| 51 |
+
"""Update the download button text"""
|
| 52 |
+
return gr.update(value=self.get_download_button_text())
|
| 53 |
+
|
| 54 |
+
def update_checkpoint_button_text(self) -> gr.update:
|
| 55 |
+
"""Update the checkpoint button text"""
|
| 56 |
+
return gr.update(value=self.get_checkpoint_button_text())
|
| 57 |
+
|
| 58 |
+
def update_both_download_buttons(self) -> Tuple[gr.update, gr.update]:
|
| 59 |
+
"""Update both download button texts"""
|
| 60 |
+
return (
|
| 61 |
+
gr.update(value=self.get_download_button_text()),
|
| 62 |
+
gr.update(value=self.get_checkpoint_button_text())
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def download_and_update_button(self):
|
| 66 |
+
"""Handle download and return updated button with current text"""
|
| 67 |
+
# Get the safetensors path for download
|
| 68 |
+
path = self.app.training.get_model_output_safetensors()
|
| 69 |
+
# For DownloadButton, we need to return the file path directly for download
|
| 70 |
+
# The button text will be updated on next render
|
| 71 |
+
return path
|
| 72 |
|
| 73 |
def create(self, parent=None) -> gr.TabItem:
|
| 74 |
"""Create the Manage tab UI components"""
|
|
|
|
| 90 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
| 91 |
|
| 92 |
self.components["download_model_btn"] = gr.DownloadButton(
|
| 93 |
+
self.get_download_button_text(),
|
| 94 |
variant="secondary",
|
| 95 |
size="lg"
|
| 96 |
)
|
| 97 |
|
| 98 |
self.components["download_checkpoint_btn"] = gr.DownloadButton(
|
| 99 |
+
self.get_checkpoint_button_text(),
|
| 100 |
variant="secondary",
|
| 101 |
size="lg"
|
| 102 |
)
|
| 103 |
|
| 104 |
self.components["download_output_btn"] = gr.DownloadButton(
|
| 105 |
+
"📁 Download output directory (.zip)",
|
| 106 |
variant="secondary",
|
| 107 |
size="lg",
|
| 108 |
visible=False
|
vms/ui/project/tabs/train_tab.py
CHANGED
|
@@ -494,7 +494,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 494 |
save_iterations, repo_id, progress
|
| 495 |
)
|
| 496 |
|
| 497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
def handle_resume_training(
|
| 500 |
self, model_type, model_version, training_type,
|
|
@@ -506,7 +511,10 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 506 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 507 |
|
| 508 |
if not checkpoints:
|
| 509 |
-
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
| 512 |
|
|
@@ -518,7 +526,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 518 |
resume_from_checkpoint="latest"
|
| 519 |
)
|
| 520 |
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
def handle_start_from_lora_training(
|
| 524 |
self, model_type, model_version, training_type,
|
|
@@ -529,22 +542,26 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 529 |
# Find the latest LoRA weights
|
| 530 |
lora_weights_path = self.app.output_path / "lora_weights"
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
if not lora_weights_path.exists():
|
| 533 |
-
return "No LoRA weights found", "Please train a model first or start a new training session"
|
| 534 |
|
| 535 |
# Find the latest LoRA checkpoint directory
|
| 536 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
| 537 |
key=lambda x: int(x.name), reverse=True)
|
| 538 |
|
| 539 |
if not lora_dirs:
|
| 540 |
-
return "No LoRA weight directories found", "Please train a model first or start a new training session"
|
| 541 |
|
| 542 |
latest_lora_dir = lora_dirs[0]
|
| 543 |
|
| 544 |
# Verify the LoRA weights file exists
|
| 545 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
| 546 |
if not lora_weights_file.exists():
|
| 547 |
-
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
|
| 548 |
|
| 549 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
| 550 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
|
@@ -565,7 +582,11 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 565 |
save_iterations, repo_id, progress,
|
| 566 |
)
|
| 567 |
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
def connect_events(self) -> None:
|
| 571 |
"""Connect event handlers to UI components"""
|
|
@@ -748,7 +769,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 748 |
],
|
| 749 |
outputs=[
|
| 750 |
self.components["status_box"],
|
| 751 |
-
self.components["log_box"]
|
|
|
|
|
|
|
| 752 |
]
|
| 753 |
)
|
| 754 |
|
|
@@ -768,7 +791,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 768 |
],
|
| 769 |
outputs=[
|
| 770 |
self.components["status_box"],
|
| 771 |
-
self.components["log_box"]
|
|
|
|
|
|
|
| 772 |
]
|
| 773 |
)
|
| 774 |
|
|
@@ -788,7 +813,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 788 |
],
|
| 789 |
outputs=[
|
| 790 |
self.components["status_box"],
|
| 791 |
-
self.components["log_box"]
|
|
|
|
|
|
|
| 792 |
]
|
| 793 |
)
|
| 794 |
|
|
@@ -804,7 +831,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 804 |
self.components["current_task_box"],
|
| 805 |
self.components["start_btn"],
|
| 806 |
self.components["stop_btn"],
|
| 807 |
-
third_btn
|
|
|
|
|
|
|
| 808 |
]
|
| 809 |
)
|
| 810 |
|
|
@@ -816,7 +845,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 816 |
self.components["current_task_box"],
|
| 817 |
self.components["start_btn"],
|
| 818 |
self.components["stop_btn"],
|
| 819 |
-
third_btn
|
|
|
|
|
|
|
| 820 |
]
|
| 821 |
)
|
| 822 |
|
|
@@ -1209,7 +1240,12 @@ Full finetune mode trains all parameters of the model, requiring more VRAM but p
|
|
| 1209 |
variant="stop"
|
| 1210 |
)
|
| 1211 |
|
| 1212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1213 |
|
| 1214 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
| 1215 |
"""Update UI components based on training state"""
|
|
|
|
| 494 |
save_iterations, repo_id, progress
|
| 495 |
)
|
| 496 |
|
| 497 |
+
# Update download button texts
|
| 498 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 499 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 500 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 501 |
+
|
| 502 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
| 503 |
|
| 504 |
def handle_resume_training(
|
| 505 |
self, model_type, model_version, training_type,
|
|
|
|
| 511 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 512 |
|
| 513 |
if not checkpoints:
|
| 514 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 515 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 516 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 517 |
+
return "No checkpoints found to resume from", "Please start a new training session instead", download_btn_text, checkpoint_btn_text
|
| 518 |
|
| 519 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
| 520 |
|
|
|
|
| 526 |
resume_from_checkpoint="latest"
|
| 527 |
)
|
| 528 |
|
| 529 |
+
# Update download button texts
|
| 530 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 531 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 532 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 533 |
+
|
| 534 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
| 535 |
|
| 536 |
def handle_start_from_lora_training(
|
| 537 |
self, model_type, model_version, training_type,
|
|
|
|
| 542 |
# Find the latest LoRA weights
|
| 543 |
lora_weights_path = self.app.output_path / "lora_weights"
|
| 544 |
|
| 545 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 546 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 547 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 548 |
+
|
| 549 |
if not lora_weights_path.exists():
|
| 550 |
+
return "No LoRA weights found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
|
| 551 |
|
| 552 |
# Find the latest LoRA checkpoint directory
|
| 553 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
| 554 |
key=lambda x: int(x.name), reverse=True)
|
| 555 |
|
| 556 |
if not lora_dirs:
|
| 557 |
+
return "No LoRA weight directories found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
|
| 558 |
|
| 559 |
latest_lora_dir = lora_dirs[0]
|
| 560 |
|
| 561 |
# Verify the LoRA weights file exists
|
| 562 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
| 563 |
if not lora_weights_file.exists():
|
| 564 |
+
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory", download_btn_text, checkpoint_btn_text
|
| 565 |
|
| 566 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
| 567 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
|
|
|
| 582 |
save_iterations, repo_id, progress,
|
| 583 |
)
|
| 584 |
|
| 585 |
+
# Update download button texts
|
| 586 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 587 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 588 |
+
|
| 589 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
| 590 |
|
| 591 |
def connect_events(self) -> None:
|
| 592 |
"""Connect event handlers to UI components"""
|
|
|
|
| 769 |
],
|
| 770 |
outputs=[
|
| 771 |
self.components["status_box"],
|
| 772 |
+
self.components["log_box"],
|
| 773 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 774 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 775 |
]
|
| 776 |
)
|
| 777 |
|
|
|
|
| 791 |
],
|
| 792 |
outputs=[
|
| 793 |
self.components["status_box"],
|
| 794 |
+
self.components["log_box"],
|
| 795 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 796 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 797 |
]
|
| 798 |
)
|
| 799 |
|
|
|
|
| 813 |
],
|
| 814 |
outputs=[
|
| 815 |
self.components["status_box"],
|
| 816 |
+
self.components["log_box"],
|
| 817 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 818 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 819 |
]
|
| 820 |
)
|
| 821 |
|
|
|
|
| 831 |
self.components["current_task_box"],
|
| 832 |
self.components["start_btn"],
|
| 833 |
self.components["stop_btn"],
|
| 834 |
+
third_btn,
|
| 835 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 836 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 837 |
]
|
| 838 |
)
|
| 839 |
|
|
|
|
| 845 |
self.components["current_task_box"],
|
| 846 |
self.components["start_btn"],
|
| 847 |
self.components["stop_btn"],
|
| 848 |
+
third_btn,
|
| 849 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 850 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 851 |
]
|
| 852 |
)
|
| 853 |
|
|
|
|
| 1240 |
variant="stop"
|
| 1241 |
)
|
| 1242 |
|
| 1243 |
+
# Update download button texts
|
| 1244 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 1245 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 1246 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 1247 |
+
|
| 1248 |
+
return start_btn, resume_btn, stop_btn, delete_checkpoints_btn, download_btn_text, checkpoint_btn_text
|
| 1249 |
|
| 1250 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
| 1251 |
"""Update UI components based on training state"""
|