Spaces:
Running
Running
| import gradio as gr | |
| from modules.Enhancer.ResembleEnhance import unload_enhancer | |
| from modules.webui import webui_config | |
| from modules.webui.webui_utils import get_speaker_names | |
| from .ft_ui_utils import get_datasets_listfile, run_speaker_ft | |
| from .ProcessMonitor import ProcessMonitor | |
| from modules.speaker import speaker_mgr | |
| from modules.models import unload_chat_tts | |
| class SpeakerFt: | |
| def __init__(self): | |
| self.process_monitor = ProcessMonitor() | |
| self.status_str = "idle" | |
| def unload_main_thread_models(self): | |
| unload_chat_tts() | |
| unload_enhancer() | |
| def run( | |
| self, | |
| batch_size: int, | |
| epochs: int, | |
| lr: str, | |
| train_text: bool, | |
| data_path: str, | |
| select_speaker: str = "", | |
| ): | |
| if self.process_monitor.process: | |
| return | |
| self.unload_main_thread_models() | |
| spk_path = None | |
| if select_speaker != "" and select_speaker != "none": | |
| select_speaker = select_speaker.split(" : ")[1].strip() | |
| spk = speaker_mgr.get_speaker(select_speaker) | |
| if spk is None: | |
| return ["Speaker not found"] | |
| spk_filename = speaker_mgr.get_speaker_filename(spk.id) | |
| spk_path = f"./data/speakers/{spk_filename}" | |
| command = ["python3", "-m", "modules.finetune.train_speaker"] | |
| command += [ | |
| f"--batch_size={batch_size}", | |
| f"--epochs={epochs}", | |
| f"--data_path={data_path}", | |
| ] | |
| if train_text: | |
| command.append("--train_text") | |
| if spk_path: | |
| command.append(f"--init_speaker={spk_path}") | |
| self.status("Training process starting") | |
| self.process_monitor.start_process(command) | |
| self.status("Training started") | |
| def status(self, text: str): | |
| self.status_str = text | |
| def flush(self): | |
| stdout, stderr = self.process_monitor.get_output() | |
| return f"{self.status_str}\n{stdout}\n{stderr}" | |
| def clear(self): | |
| self.process_monitor.stdout = "" | |
| self.process_monitor.stderr = "" | |
| self.status("Logs cleared") | |
| def stop(self): | |
| self.process_monitor.stop_process() | |
| self.status("Training stopped") | |
| def create_speaker_ft_tab(demo: gr.Blocks): | |
| spk_ft = SpeakerFt() | |
| speakers, speaker_names = get_speaker_names() | |
| speaker_names = ["none"] + speaker_names | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| gr.Markdown("🎛️hparams") | |
| dataset_input = gr.Dropdown( | |
| label="Dataset", choices=get_datasets_listfile() | |
| ) | |
| lr_input = gr.Textbox(label="Learning Rate", value="1e-2") | |
| epochs_input = gr.Slider( | |
| label="Epochs", value=10, minimum=1, maximum=100, step=1 | |
| ) | |
| batch_size_input = gr.Slider( | |
| label="Batch Size", value=4, minimum=1, maximum=64, step=1 | |
| ) | |
| train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True) | |
| init_spk_dropdown = gr.Dropdown( | |
| label="Initial Speaker", | |
| choices=speaker_names, | |
| value="none", | |
| ) | |
| with gr.Group(): | |
| start_train_btn = gr.Button("Start Training") | |
| stop_train_btn = gr.Button("Stop Training") | |
| clear_train_btn = gr.Button("Clear logs") | |
| with gr.Column(scale=5): | |
| with gr.Group(): | |
| # log | |
| gr.Markdown("📜logs") | |
| log_output = gr.Textbox( | |
| show_label=False, label="Log", value="", lines=20, interactive=True | |
| ) | |
| start_train_btn.click( | |
| spk_ft.run, | |
| inputs=[ | |
| batch_size_input, | |
| epochs_input, | |
| lr_input, | |
| train_text_checkbox, | |
| dataset_input, | |
| init_spk_dropdown, | |
| ], | |
| outputs=[], | |
| ) | |
| stop_train_btn.click(spk_ft.stop) | |
| clear_train_btn.click(spk_ft.clear) | |
| if webui_config.experimental: | |
| demo.load(spk_ft.flush, every=1, outputs=[log_output]) | |