Spaces:
Runtime error
Runtime error
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- finetune_gradio.py +150 -7
finetune_gradio.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from transformers import pipeline
|
| 5 |
import gradio as gr
|
| 6 |
import torch
|
|
|
|
| 7 |
import click
|
| 8 |
import torchaudio
|
| 9 |
from glob import glob
|
|
@@ -20,11 +23,16 @@ import psutil
|
|
| 20 |
import platform
|
| 21 |
import subprocess
|
| 22 |
from datasets.arrow_writer import ArrowWriter
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
training_process = None
|
| 26 |
system = platform.system()
|
| 27 |
python_executable = sys.executable or "python"
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
path_data = "data"
|
| 30 |
|
|
@@ -240,7 +248,12 @@ def start_training(
|
|
| 240 |
last_per_steps=800,
|
| 241 |
finetune=True,
|
| 242 |
):
|
| 243 |
-
global training_process
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
| 246 |
|
|
@@ -288,7 +301,7 @@ def start_training(
|
|
| 288 |
training_process = subprocess.Popen(cmd, shell=True)
|
| 289 |
|
| 290 |
time.sleep(5)
|
| 291 |
-
yield "
|
| 292 |
|
| 293 |
# Wait for the training process to finish
|
| 294 |
training_process.wait()
|
|
@@ -519,6 +532,17 @@ def calculate_train(
|
|
| 519 |
path_project = os.path.join(path_data, name_project)
|
| 520 |
file_duraction = os.path.join(path_project, "duration.json")
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
with open(file_duraction, "r") as file:
|
| 523 |
data = json.load(file)
|
| 524 |
|
|
@@ -549,8 +573,8 @@ def calculate_train(
|
|
| 549 |
else:
|
| 550 |
max_samples = 64
|
| 551 |
|
| 552 |
-
num_warmup_updates = int(samples * 0.
|
| 553 |
-
save_per_updates = int(samples * 0.
|
| 554 |
last_per_steps = int(save_per_updates * 5)
|
| 555 |
|
| 556 |
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
|
@@ -559,7 +583,7 @@ def calculate_train(
|
|
| 559 |
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
| 560 |
|
| 561 |
if finetune:
|
| 562 |
-
learning_rate = 1e-
|
| 563 |
else:
|
| 564 |
learning_rate = 7.5e-5
|
| 565 |
|
|
@@ -611,6 +635,7 @@ def vocab_check(project_name):
|
|
| 611 |
sp = item.split("|")
|
| 612 |
if len(sp) != 2:
|
| 613 |
continue
|
|
|
|
| 614 |
text = sp[1].lower().strip()
|
| 615 |
|
| 616 |
for t in text:
|
|
@@ -625,6 +650,80 @@ def vocab_check(project_name):
|
|
| 625 |
return info
|
| 626 |
|
| 627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
with gr.Blocks() as app:
|
| 629 |
with gr.Row():
|
| 630 |
project_name = gr.Textbox(label="project name", value="my_speak")
|
|
@@ -661,6 +760,18 @@ with gr.Blocks() as app:
|
|
| 661 |
)
|
| 662 |
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
| 663 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
with gr.TabItem("prepare Data"):
|
| 665 |
gr.Markdown(
|
| 666 |
"""```plaintext
|
|
@@ -687,6 +798,16 @@ with gr.Blocks() as app:
|
|
| 687 |
txt_info_prepare = gr.Text(label="info", value="")
|
| 688 |
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
|
| 689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
with gr.TabItem("train Data"):
|
| 691 |
with gr.Row():
|
| 692 |
bt_calculate = bt_create = gr.Button("Auto Settings")
|
|
@@ -696,11 +817,11 @@ with gr.Blocks() as app:
|
|
| 696 |
|
| 697 |
with gr.Row():
|
| 698 |
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
| 699 |
-
learning_rate = gr.Number(label="Learning Rate", value=1e-
|
| 700 |
|
| 701 |
with gr.Row():
|
| 702 |
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
|
| 703 |
-
max_samples = gr.Number(label="Max Samples", value=
|
| 704 |
|
| 705 |
with gr.Row():
|
| 706 |
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
|
@@ -778,6 +899,28 @@ with gr.Blocks() as app:
|
|
| 778 |
txt_info_check = gr.Text(label="info", value="")
|
| 779 |
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
|
| 780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
|
| 782 |
@click.command()
|
| 783 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
|
| 4 |
+
import tempfile
|
| 5 |
+
import random
|
| 6 |
from transformers import pipeline
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
| 9 |
+
import gc
|
| 10 |
import click
|
| 11 |
import torchaudio
|
| 12 |
from glob import glob
|
|
|
|
| 23 |
import platform
|
| 24 |
import subprocess
|
| 25 |
from datasets.arrow_writer import ArrowWriter
|
| 26 |
+
from datasets import Dataset as Dataset_
|
| 27 |
+
from api import F5TTS
|
| 28 |
|
| 29 |
|
| 30 |
training_process = None
|
| 31 |
system = platform.system()
|
| 32 |
python_executable = sys.executable or "python"
|
| 33 |
+
tts_api = None
|
| 34 |
+
last_checkpoint = ""
|
| 35 |
+
last_device = ""
|
| 36 |
|
| 37 |
path_data = "data"
|
| 38 |
|
|
|
|
| 248 |
last_per_steps=800,
|
| 249 |
finetune=True,
|
| 250 |
):
|
| 251 |
+
global training_process, tts_api
|
| 252 |
+
|
| 253 |
+
if tts_api is not None:
|
| 254 |
+
del tts_api
|
| 255 |
+
gc.collect()
|
| 256 |
+
torch.cuda.empty_cache()
|
| 257 |
|
| 258 |
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
| 259 |
|
|
|
|
| 301 |
training_process = subprocess.Popen(cmd, shell=True)
|
| 302 |
|
| 303 |
time.sleep(5)
|
| 304 |
+
yield "train start", gr.update(interactive=False), gr.update(interactive=True)
|
| 305 |
|
| 306 |
# Wait for the training process to finish
|
| 307 |
training_process.wait()
|
|
|
|
| 532 |
path_project = os.path.join(path_data, name_project)
|
| 533 |
file_duraction = os.path.join(path_project, "duration.json")
|
| 534 |
|
| 535 |
+
if not os.path.isfile(file_duraction):
|
| 536 |
+
return (
|
| 537 |
+
1000,
|
| 538 |
+
max_samples,
|
| 539 |
+
num_warmup_updates,
|
| 540 |
+
save_per_updates,
|
| 541 |
+
last_per_steps,
|
| 542 |
+
"project not found !",
|
| 543 |
+
learning_rate,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
with open(file_duraction, "r") as file:
|
| 547 |
data = json.load(file)
|
| 548 |
|
|
|
|
| 573 |
else:
|
| 574 |
max_samples = 64
|
| 575 |
|
| 576 |
+
num_warmup_updates = int(samples * 0.05)
|
| 577 |
+
save_per_updates = int(samples * 0.10)
|
| 578 |
last_per_steps = int(save_per_updates * 5)
|
| 579 |
|
| 580 |
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
|
|
|
| 583 |
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
| 584 |
|
| 585 |
if finetune:
|
| 586 |
+
learning_rate = 1e-5
|
| 587 |
else:
|
| 588 |
learning_rate = 7.5e-5
|
| 589 |
|
|
|
|
| 635 |
sp = item.split("|")
|
| 636 |
if len(sp) != 2:
|
| 637 |
continue
|
| 638 |
+
|
| 639 |
text = sp[1].lower().strip()
|
| 640 |
|
| 641 |
for t in text:
|
|
|
|
| 650 |
return info
|
| 651 |
|
| 652 |
|
| 653 |
+
def get_random_sample_prepare(project_name):
|
| 654 |
+
name_project = project_name + "_pinyin"
|
| 655 |
+
path_project = os.path.join(path_data, name_project)
|
| 656 |
+
file_arrow = os.path.join(path_project, "raw.arrow")
|
| 657 |
+
if not os.path.isfile(file_arrow):
|
| 658 |
+
return "", None
|
| 659 |
+
dataset = Dataset_.from_file(file_arrow)
|
| 660 |
+
random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
|
| 661 |
+
text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
|
| 662 |
+
audio_path = random_sample["audio_path"][0]
|
| 663 |
+
return text, audio_path
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def get_random_sample_transcribe(project_name):
|
| 667 |
+
name_project = project_name + "_pinyin"
|
| 668 |
+
path_project = os.path.join(path_data, name_project)
|
| 669 |
+
file_metadata = os.path.join(path_project, "metadata.csv")
|
| 670 |
+
if not os.path.isfile(file_metadata):
|
| 671 |
+
return "", None
|
| 672 |
+
|
| 673 |
+
data = ""
|
| 674 |
+
with open(file_metadata, "r", encoding="utf-8") as f:
|
| 675 |
+
data = f.read()
|
| 676 |
+
|
| 677 |
+
list_data = []
|
| 678 |
+
for item in data.split("\n"):
|
| 679 |
+
sp = item.split("|")
|
| 680 |
+
if len(sp) != 2:
|
| 681 |
+
continue
|
| 682 |
+
list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
|
| 683 |
+
|
| 684 |
+
if list_data == []:
|
| 685 |
+
return "", None
|
| 686 |
+
|
| 687 |
+
random_item = random.choice(list_data)
|
| 688 |
+
|
| 689 |
+
return random_item[1], random_item[0]
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def get_random_sample_infer(project_name):
|
| 693 |
+
text, audio = get_random_sample_transcribe(project_name)
|
| 694 |
+
return (
|
| 695 |
+
text,
|
| 696 |
+
text,
|
| 697 |
+
audio,
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def infer(project_name, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
| 702 |
+
global last_checkpoint, last_device, tts_api
|
| 703 |
+
|
| 704 |
+
if not os.path.isfile(file_checkpoint):
|
| 705 |
+
return None
|
| 706 |
+
|
| 707 |
+
if training_process is not None:
|
| 708 |
+
device_test = "cpu"
|
| 709 |
+
else:
|
| 710 |
+
device_test = None
|
| 711 |
+
|
| 712 |
+
if last_checkpoint != file_checkpoint or last_device != device_test:
|
| 713 |
+
if last_checkpoint != file_checkpoint:
|
| 714 |
+
last_checkpoint = file_checkpoint
|
| 715 |
+
if last_device != device_test:
|
| 716 |
+
last_device = device_test
|
| 717 |
+
|
| 718 |
+
tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
|
| 719 |
+
|
| 720 |
+
print("update", device_test, file_checkpoint)
|
| 721 |
+
|
| 722 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 723 |
+
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
| 724 |
+
return f.name
|
| 725 |
+
|
| 726 |
+
|
| 727 |
with gr.Blocks() as app:
|
| 728 |
with gr.Row():
|
| 729 |
project_name = gr.Textbox(label="project name", value="my_speak")
|
|
|
|
| 760 |
)
|
| 761 |
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
| 762 |
|
| 763 |
+
random_sample_transcribe = gr.Button("random sample")
|
| 764 |
+
|
| 765 |
+
with gr.Row():
|
| 766 |
+
random_text_transcribe = gr.Text(label="Text")
|
| 767 |
+
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
|
| 768 |
+
|
| 769 |
+
random_sample_transcribe.click(
|
| 770 |
+
fn=get_random_sample_transcribe,
|
| 771 |
+
inputs=[project_name],
|
| 772 |
+
outputs=[random_text_transcribe, random_audio_transcribe],
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
with gr.TabItem("prepare Data"):
|
| 776 |
gr.Markdown(
|
| 777 |
"""```plaintext
|
|
|
|
| 798 |
txt_info_prepare = gr.Text(label="info", value="")
|
| 799 |
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
|
| 800 |
|
| 801 |
+
random_sample_prepare = gr.Button("random sample")
|
| 802 |
+
|
| 803 |
+
with gr.Row():
|
| 804 |
+
random_text_prepare = gr.Text(label="Pinyin")
|
| 805 |
+
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
| 806 |
+
|
| 807 |
+
random_sample_prepare.click(
|
| 808 |
+
fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
with gr.TabItem("train Data"):
|
| 812 |
with gr.Row():
|
| 813 |
bt_calculate = bt_create = gr.Button("Auto Settings")
|
|
|
|
| 817 |
|
| 818 |
with gr.Row():
|
| 819 |
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
| 820 |
+
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
| 821 |
|
| 822 |
with gr.Row():
|
| 823 |
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
|
| 824 |
+
max_samples = gr.Number(label="Max Samples", value=64)
|
| 825 |
|
| 826 |
with gr.Row():
|
| 827 |
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
|
|
|
| 899 |
txt_info_check = gr.Text(label="info", value="")
|
| 900 |
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
|
| 901 |
|
| 902 |
+
with gr.TabItem("test model"):
|
| 903 |
+
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
| 904 |
+
nfe_step = gr.Number(label="n_step", value=32)
|
| 905 |
+
file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
|
| 906 |
+
|
| 907 |
+
random_sample_infer = gr.Button("random sample")
|
| 908 |
+
|
| 909 |
+
ref_text = gr.Textbox(label="ref text")
|
| 910 |
+
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
| 911 |
+
gen_text = gr.Textbox(label="gen text")
|
| 912 |
+
random_sample_infer.click(
|
| 913 |
+
fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
|
| 914 |
+
)
|
| 915 |
+
check_button_infer = gr.Button("infer")
|
| 916 |
+
gen_audio = gr.Audio(label="audio gen", type="filepath")
|
| 917 |
+
|
| 918 |
+
check_button_infer.click(
|
| 919 |
+
fn=infer,
|
| 920 |
+
inputs=[project_name, file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
|
| 921 |
+
outputs=[gen_audio],
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
|
| 925 |
@click.command()
|
| 926 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|