File size: 4,997 Bytes
0ad7e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Manage tab for Video Model Studio UI
"""

import gradio as gr
import logging
from typing import Dict, Any, List, Optional

from .base_tab import BaseTab
from ..config import HF_API_TOKEN

logger = logging.getLogger(__name__)

class ManageTab(BaseTab):
    """Manage tab for storage management and model publication"""
    
    def __init__(self, app_state):
        super().__init__(app_state)
        self.id = "manage_tab"
        self.title = "5️⃣  Manage"
    
    def create(self, parent=None) -> gr.TabItem:
        """Create the Manage tab UI components"""
        with gr.TabItem(self.title, id=self.id) as tab:
            with gr.Column():
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("## Publishing")
                        gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")

                        with gr.Row():
                            with gr.Column():
                                self.components["repo_id"] = gr.Textbox(
                                    label="HuggingFace Model Repository",
                                    placeholder="username/model-name",
                                    info="The repository will be created if it doesn't exist"
                                )
                                self.components["make_public"] = gr.Checkbox(
                                    label="Check this to make your model public (ie. visible and downloadable by anyone)",
                                    info="You model is private by default"
                                )
                                self.components["push_model_btn"] = gr.Button(
                                    "Push my model"
                                )

                with gr.Row():
                    with gr.Column():
                        with gr.Row():
                            with gr.Column():
                                gr.Markdown("## Storage management")
                                with gr.Row():
                                    self.components["download_dataset_btn"] = gr.DownloadButton(
                                        "Download dataset",
                                        variant="secondary",
                                        size="lg"
                                    )
                                    self.components["download_model_btn"] = gr.DownloadButton(
                                        "Download model",
                                        variant="secondary",
                                        size="lg"
                                    )

                        with gr.Row():
                            self.components["global_stop_btn"] = gr.Button(
                                "Stop everything and delete my data",
                                variant="stop"
                            )
                            self.components["global_status"] = gr.Textbox(
                                label="Global Status",
                                interactive=False,
                                visible=False
                            )
        
        return tab
    
    def connect_events(self) -> None:
        """Connect event handlers to UI components"""
        # Repository ID validation
        self.components["repo_id"].change(
            fn=self.app.validate_repo,
            inputs=[self.components["repo_id"]],
            outputs=[self.components["repo_id"]]
        )
        
        # Download buttons
        self.components["download_dataset_btn"].click(
            fn=self.app.trainer.create_training_dataset_zip,
            outputs=[self.components["download_dataset_btn"]]
        )

        self.components["download_model_btn"].click(
            fn=self.app.trainer.get_model_output_safetensors,
            outputs=[self.components["download_model_btn"]]
        )
        
        # Global stop button
        self.components["global_stop_btn"].click(
            fn=self.app.handle_global_stop,
            outputs=[
                self.components["global_status"],
                self.app.tabs["split_tab"].components["video_list"],
                self.app.tabs["caption_tab"].components["training_dataset"],
                self.app.tabs["train_tab"].components["status_box"],
                self.app.tabs["train_tab"].components["log_box"],
                self.app.tabs["split_tab"].components["detect_status"],
                self.app.tabs["import_tab"].components["import_status"],
                self.app.tabs["caption_tab"].components["preview_status"]
            ]
        )
        
        # Push model button 
        # To implement model pushing functionality
        self.components["push_model_btn"].click(
            fn=lambda repo_id: self.app.upload_to_hub(repo_id),
            inputs=[self.components["repo_id"]],
            outputs=[self.components["global_status"]]
        )