Spaces:
Configuration error
Configuration error
update finetune-cli -gradio
Browse files- README.md +1 -0
- pyproject.toml +1 -0
- src/f5_tts/eval/eval_infer_batch.py +0 -3
- src/f5_tts/eval/eval_infer_batch.sh +6 -6
- src/f5_tts/infer/infer_gradio.py +172 -1
- src/f5_tts/train/datasets/prepare_csv_wavs.py +1 -1
- src/f5_tts/train/finetune_cli.py +25 -9
- src/f5_tts/train/finetune_gradio.py +182 -44
README.md
CHANGED
@@ -183,6 +183,7 @@ Currently supported features:
|
|
183 |
- Chunk inference
|
184 |
- Podcast Generation
|
185 |
- Multiple Speech-Type Generation
|
|
|
186 |
|
187 |
You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
|
188 |
|
|
|
183 |
- Chunk inference
|
184 |
- Podcast Generation
|
185 |
- Multiple Speech-Type Generation
|
186 |
+
- Voice Chat powered by Qwen2.5-3B-Instruct
|
187 |
|
188 |
You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
|
189 |
|
pyproject.toml
CHANGED
@@ -35,6 +35,7 @@ dependencies = [
|
|
35 |
"torchdiffeq",
|
36 |
"tqdm>=4.65.0",
|
37 |
"transformers",
|
|
|
38 |
"vocos",
|
39 |
"wandb",
|
40 |
"x_transformers>=1.31.14",
|
|
|
35 |
"torchdiffeq",
|
36 |
"tqdm>=4.65.0",
|
37 |
"transformers",
|
38 |
+
"transformers_stream_generator",
|
39 |
"vocos",
|
40 |
"wandb",
|
41 |
"x_transformers>=1.31.14",
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -4,7 +4,6 @@ import os
|
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
import time
|
7 |
-
import random
|
8 |
from tqdm import tqdm
|
9 |
import argparse
|
10 |
from importlib.resources import files
|
@@ -97,8 +96,6 @@ def main():
|
|
97 |
metainfo = get_seedtts_testset_metainfo(metalst)
|
98 |
|
99 |
# path to save genereted wavs
|
100 |
-
if seed is None:
|
101 |
-
seed = random.randint(-10000, 10000)
|
102 |
output_dir = (
|
103 |
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
104 |
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
|
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
import time
|
|
|
7 |
from tqdm import tqdm
|
8 |
import argparse
|
9 |
from importlib.resources import files
|
|
|
96 |
metainfo = get_seedtts_testset_metainfo(metalst)
|
97 |
|
98 |
# path to save genereted wavs
|
|
|
|
|
99 |
output_dir = (
|
100 |
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
101 |
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
src/f5_tts/eval/eval_infer_batch.sh
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# e.g. F5-TTS, 16 NFE
|
4 |
-
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
5 |
-
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
6 |
-
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
7 |
|
8 |
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
-
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
-
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
-
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
12 |
|
13 |
# etc.
|
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# e.g. F5-TTS, 16 NFE
|
4 |
+
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
5 |
+
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
6 |
+
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
7 |
|
8 |
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
+
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
+
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
+
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
12 |
|
13 |
# etc.
|
src/f5_tts/infer/infer_gradio.py
CHANGED
@@ -11,6 +11,7 @@ import soundfile as sf
|
|
11 |
import torchaudio
|
12 |
from cached_path import cached_path
|
13 |
from pydub import AudioSegment
|
|
|
14 |
|
15 |
try:
|
16 |
import spaces
|
@@ -51,6 +52,33 @@ E2TTS_ema_model = load_model(
|
|
51 |
UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
52 |
)
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
@gpu_decorator
|
56 |
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
|
@@ -490,6 +518,146 @@ with gr.Blocks() as app_emotional:
|
|
490 |
outputs=generate_emotional_btn,
|
491 |
)
|
492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
with gr.Blocks() as app:
|
494 |
gr.Markdown(
|
495 |
"""
|
@@ -507,7 +675,10 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
507 |
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
|
508 |
"""
|
509 |
)
|
510 |
-
gr.TabbedInterface(
|
|
|
|
|
|
|
511 |
|
512 |
|
513 |
@click.command()
|
|
|
11 |
import torchaudio
|
12 |
from cached_path import cached_path
|
13 |
from pydub import AudioSegment
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
15 |
|
16 |
try:
|
17 |
import spaces
|
|
|
52 |
UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
53 |
)
|
54 |
|
55 |
+
# Initialize Qwen model and tokenizer
|
56 |
+
model_name = "Qwen/Qwen2.5-3B-Instruct"
|
57 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
59 |
+
|
60 |
+
|
61 |
+
def generate_response(messages):
|
62 |
+
"""Generate response using Qwen"""
|
63 |
+
text = tokenizer.apply_chat_template(
|
64 |
+
messages,
|
65 |
+
tokenize=False,
|
66 |
+
add_generation_prompt=True,
|
67 |
+
)
|
68 |
+
|
69 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
70 |
+
generated_ids = model.generate(
|
71 |
+
**model_inputs,
|
72 |
+
max_new_tokens=512,
|
73 |
+
temperature=0.7,
|
74 |
+
top_p=0.95,
|
75 |
+
)
|
76 |
+
|
77 |
+
generated_ids = [
|
78 |
+
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
79 |
+
]
|
80 |
+
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
81 |
+
|
82 |
|
83 |
@gpu_decorator
|
84 |
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
|
|
|
518 |
outputs=generate_emotional_btn,
|
519 |
)
|
520 |
|
521 |
+
|
522 |
+
with gr.Blocks() as app_chat:
|
523 |
+
gr.Markdown(
|
524 |
+
"""
|
525 |
+
# Voice Chat
|
526 |
+
Have a conversation with an AI using your reference voice!
|
527 |
+
1. Upload a reference audio clip and optionally its transcript.
|
528 |
+
2. Record your message through your microphone.
|
529 |
+
3. The AI will respond using the reference voice.
|
530 |
+
"""
|
531 |
+
)
|
532 |
+
|
533 |
+
with gr.Row():
|
534 |
+
with gr.Column():
|
535 |
+
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
|
536 |
+
|
537 |
+
with gr.Column():
|
538 |
+
with gr.Accordion("Advanced Settings", open=False):
|
539 |
+
model_choice_chat = gr.Radio(
|
540 |
+
choices=["F5-TTS", "E2-TTS"],
|
541 |
+
label="TTS Model",
|
542 |
+
value="F5-TTS",
|
543 |
+
)
|
544 |
+
remove_silence_chat = gr.Checkbox(
|
545 |
+
label="Remove Silences",
|
546 |
+
value=True,
|
547 |
+
)
|
548 |
+
ref_text_chat = gr.Textbox(
|
549 |
+
label="Reference Text",
|
550 |
+
info="Optional: Leave blank to auto-transcribe",
|
551 |
+
lines=2,
|
552 |
+
)
|
553 |
+
system_prompt_chat = gr.Textbox(
|
554 |
+
label="System Prompt",
|
555 |
+
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
556 |
+
lines=2,
|
557 |
+
)
|
558 |
+
|
559 |
+
chatbot_interface = gr.Chatbot(label="Conversation")
|
560 |
+
|
561 |
+
with gr.Row():
|
562 |
+
with gr.Column():
|
563 |
+
audio_output_chat = gr.Audio(autoplay=True)
|
564 |
+
with gr.Column():
|
565 |
+
audio_input_chat = gr.Microphone(
|
566 |
+
label="Or speak your message",
|
567 |
+
type="filepath",
|
568 |
+
)
|
569 |
+
|
570 |
+
clear_btn_chat = gr.Button("Clear Conversation")
|
571 |
+
|
572 |
+
conversation_state = gr.State(
|
573 |
+
value=[
|
574 |
+
{
|
575 |
+
"role": "system",
|
576 |
+
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
577 |
+
}
|
578 |
+
]
|
579 |
+
)
|
580 |
+
|
581 |
+
def process_audio_input(audio_path, history, conv_state):
|
582 |
+
"""Handle audio input from user"""
|
583 |
+
if not audio_path:
|
584 |
+
return history, conv_state, ""
|
585 |
+
|
586 |
+
text = ""
|
587 |
+
text = preprocess_ref_audio_text(audio_path, text)[1]
|
588 |
+
|
589 |
+
if not text.strip():
|
590 |
+
return history, conv_state, ""
|
591 |
+
|
592 |
+
conv_state.append({"role": "user", "content": text})
|
593 |
+
history.append((text, None))
|
594 |
+
|
595 |
+
response = generate_response(conv_state)
|
596 |
+
|
597 |
+
conv_state.append({"role": "assistant", "content": response})
|
598 |
+
history[-1] = (text, response)
|
599 |
+
|
600 |
+
return history, conv_state, ""
|
601 |
+
|
602 |
+
def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
|
603 |
+
"""Generate TTS audio for AI response"""
|
604 |
+
if not history or not ref_audio:
|
605 |
+
return None
|
606 |
+
|
607 |
+
last_user_message, last_ai_response = history[-1]
|
608 |
+
if not last_ai_response:
|
609 |
+
return None
|
610 |
+
|
611 |
+
audio_result, _ = infer(
|
612 |
+
ref_audio,
|
613 |
+
ref_text,
|
614 |
+
last_ai_response,
|
615 |
+
model,
|
616 |
+
remove_silence,
|
617 |
+
cross_fade_duration=0.15,
|
618 |
+
speed=1.0,
|
619 |
+
)
|
620 |
+
return audio_result
|
621 |
+
|
622 |
+
def clear_conversation():
|
623 |
+
"""Reset the conversation"""
|
624 |
+
return [], [
|
625 |
+
{
|
626 |
+
"role": "system",
|
627 |
+
"content": "You are a friendly person, and may impersonate whoever they address you as. Stay in character. Keep your responses concise since they will be spoken out loud.",
|
628 |
+
}
|
629 |
+
]
|
630 |
+
|
631 |
+
def update_system_prompt(new_prompt):
|
632 |
+
"""Update the system prompt and reset the conversation"""
|
633 |
+
new_conv_state = [{"role": "system", "content": new_prompt}]
|
634 |
+
return [], new_conv_state
|
635 |
+
|
636 |
+
# Handle audio input
|
637 |
+
audio_input_chat.stop_recording(
|
638 |
+
process_audio_input,
|
639 |
+
inputs=[audio_input_chat, chatbot_interface, conversation_state],
|
640 |
+
outputs=[chatbot_interface, conversation_state],
|
641 |
+
).then(
|
642 |
+
generate_audio_response,
|
643 |
+
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
|
644 |
+
outputs=audio_output_chat,
|
645 |
+
)
|
646 |
+
|
647 |
+
# Handle clear button
|
648 |
+
clear_btn_chat.click(
|
649 |
+
clear_conversation,
|
650 |
+
outputs=[chatbot_interface, conversation_state],
|
651 |
+
)
|
652 |
+
|
653 |
+
# Handle system prompt change and reset conversation
|
654 |
+
system_prompt_chat.change(
|
655 |
+
update_system_prompt,
|
656 |
+
inputs=system_prompt_chat,
|
657 |
+
outputs=[chatbot_interface, conversation_state],
|
658 |
+
)
|
659 |
+
|
660 |
+
|
661 |
with gr.Blocks() as app:
|
662 |
gr.Markdown(
|
663 |
"""
|
|
|
675 |
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
|
676 |
"""
|
677 |
)
|
678 |
+
gr.TabbedInterface(
|
679 |
+
[app_tts, app_podcast, app_emotional, app_chat, app_credits],
|
680 |
+
["TTS", "Podcast", "Multi-Style", "Voice-Chat", "Credits"],
|
681 |
+
)
|
682 |
|
683 |
|
684 |
@click.command()
|
src/f5_tts/train/datasets/prepare_csv_wavs.py
CHANGED
@@ -60,7 +60,7 @@ def read_audio_text_pairs(csv_file_path):
|
|
60 |
audio_text_pairs = []
|
61 |
|
62 |
parent = Path(csv_file_path).parent
|
63 |
-
with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
|
64 |
reader = csv.reader(csvfile, delimiter="|")
|
65 |
next(reader) # Skip the header row
|
66 |
for row in reader:
|
|
|
60 |
audio_text_pairs = []
|
61 |
|
62 |
parent = Path(csv_file_path).parent
|
63 |
+
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
|
64 |
reader = csv.reader(csvfile, delimiter="|")
|
65 |
next(reader) # Skip the header row
|
66 |
for row in reader:
|
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -15,26 +15,35 @@ hop_length = 256
|
|
15 |
|
16 |
# -------------------------- Argument Parsing --------------------------- #
|
17 |
def parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
19 |
|
20 |
parser.add_argument(
|
21 |
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
|
22 |
)
|
23 |
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
24 |
-
parser.add_argument("--learning_rate", type=float, default=1e-
|
25 |
-
parser.add_argument("--batch_size_per_gpu", type=int, default=
|
26 |
parser.add_argument(
|
27 |
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
|
28 |
)
|
29 |
-
parser.add_argument("--max_samples", type=int, default=
|
30 |
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
31 |
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
32 |
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
33 |
-
parser.add_argument("--num_warmup_updates", type=int, default=
|
34 |
-
parser.add_argument("--save_per_updates", type=int, default=
|
35 |
-
parser.add_argument("--last_per_steps", type=int, default=
|
36 |
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
37 |
-
|
38 |
parser.add_argument(
|
39 |
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
|
40 |
)
|
@@ -60,13 +69,19 @@ def main():
|
|
60 |
model_cls = DiT
|
61 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
62 |
if args.finetune:
|
63 |
-
|
|
|
|
|
|
|
64 |
elif args.exp_name == "E2TTS_Base":
|
65 |
wandb_resume_id = None
|
66 |
model_cls = UNetT
|
67 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
68 |
if args.finetune:
|
69 |
-
|
|
|
|
|
|
|
70 |
|
71 |
if args.finetune:
|
72 |
path_ckpt = os.path.join("ckpts", args.dataset_name)
|
@@ -118,6 +133,7 @@ def main():
|
|
118 |
)
|
119 |
|
120 |
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
|
|
121 |
trainer.train(
|
122 |
train_dataset,
|
123 |
resumable_with_seed=666, # seed for shuffling dataset
|
|
|
15 |
|
16 |
# -------------------------- Argument Parsing --------------------------- #
|
17 |
def parse_args():
|
18 |
+
# batch_size_per_gpu = 1000 settting for gpu 8GB
|
19 |
+
# batch_size_per_gpu = 1600 settting for gpu 12GB
|
20 |
+
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
21 |
+
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
22 |
+
|
23 |
+
# num_warmup_updates 10000 sample = 500
|
24 |
+
|
25 |
+
# change save_per_updates , last_per_steps what you need ,
|
26 |
+
|
27 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
28 |
|
29 |
parser.add_argument(
|
30 |
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
|
31 |
)
|
32 |
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
33 |
+
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
34 |
+
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
|
35 |
parser.add_argument(
|
36 |
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
|
37 |
)
|
38 |
+
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
|
39 |
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
40 |
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
41 |
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
42 |
+
parser.add_argument("--num_warmup_updates", type=int, default=500, help="Warmup steps")
|
43 |
+
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
|
44 |
+
parser.add_argument("--last_per_steps", type=int, default=20000, help="Save last checkpoint every X steps")
|
45 |
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
46 |
+
parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
|
47 |
parser.add_argument(
|
48 |
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
|
49 |
)
|
|
|
69 |
model_cls = DiT
|
70 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
71 |
if args.finetune:
|
72 |
+
if args.pretrain is None:
|
73 |
+
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
74 |
+
else:
|
75 |
+
ckpt_path = args.pretrain
|
76 |
elif args.exp_name == "E2TTS_Base":
|
77 |
wandb_resume_id = None
|
78 |
model_cls = UNetT
|
79 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
80 |
if args.finetune:
|
81 |
+
if args.pretrain is None:
|
82 |
+
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
83 |
+
else:
|
84 |
+
ckpt_path = args.pretrain
|
85 |
|
86 |
if args.finetune:
|
87 |
path_ckpt = os.path.join("ckpts", args.dataset_name)
|
|
|
133 |
)
|
134 |
|
135 |
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
136 |
+
|
137 |
trainer.train(
|
138 |
train_dataset,
|
139 |
resumable_with_seed=666, # seed for shuffling dataset
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -20,6 +20,7 @@ import torch
|
|
20 |
import torchaudio
|
21 |
from datasets import Dataset as Dataset_
|
22 |
from datasets.arrow_writer import ArrowWriter
|
|
|
23 |
from scipy.io import wavfile
|
24 |
from transformers import pipeline
|
25 |
|
@@ -247,6 +248,9 @@ def start_training(
|
|
247 |
save_per_updates=400,
|
248 |
last_per_steps=800,
|
249 |
finetune=True,
|
|
|
|
|
|
|
250 |
):
|
251 |
global training_process, tts_api
|
252 |
|
@@ -256,7 +260,7 @@ def start_training(
|
|
256 |
torch.cuda.empty_cache()
|
257 |
tts_api = None
|
258 |
|
259 |
-
path_project = os.path.join(path_data, dataset_name
|
260 |
|
261 |
if not os.path.isdir(path_project):
|
262 |
yield (
|
@@ -278,6 +282,7 @@ def start_training(
|
|
278 |
yield "start train", gr.update(interactive=False), gr.update(interactive=False)
|
279 |
|
280 |
# Command to run the training script with the specified arguments
|
|
|
281 |
cmd = (
|
282 |
f"accelerate launch finetune-cli.py --exp_name {exp_name} "
|
283 |
f"--learning_rate {learning_rate} "
|
@@ -295,6 +300,13 @@ def start_training(
|
|
295 |
if finetune:
|
296 |
cmd += f" --finetune {finetune}"
|
297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
print(cmd)
|
299 |
|
300 |
try:
|
@@ -331,10 +343,28 @@ def stop_training():
|
|
331 |
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
|
332 |
|
333 |
|
334 |
-
def
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
os.makedirs(os.path.join(path_data, name), exist_ok=True)
|
337 |
os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
|
|
|
|
|
338 |
|
339 |
|
340 |
def transcribe(file_audio, language="english"):
|
@@ -359,14 +389,14 @@ def transcribe(file_audio, language="english"):
|
|
359 |
|
360 |
|
361 |
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
|
362 |
-
name_project += "_pinyin"
|
363 |
path_project = os.path.join(path_data, name_project)
|
364 |
path_dataset = os.path.join(path_project, "dataset")
|
365 |
path_project_wavs = os.path.join(path_project, "wavs")
|
366 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
367 |
|
368 |
-
if
|
369 |
-
|
|
|
370 |
|
371 |
if os.path.isdir(path_project_wavs):
|
372 |
shutil.rmtree(path_project_wavs)
|
@@ -418,7 +448,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
|
|
418 |
except: # noqa: E722
|
419 |
error_num += 1
|
420 |
|
421 |
-
with open(file_metadata, "w", encoding="utf-8") as f:
|
422 |
f.write(data)
|
423 |
|
424 |
if error_num != []:
|
@@ -437,7 +467,6 @@ def format_seconds_to_hms(seconds):
|
|
437 |
|
438 |
|
439 |
def create_metadata(name_project, progress=gr.Progress()):
|
440 |
-
name_project += "_pinyin"
|
441 |
path_project = os.path.join(path_data, name_project)
|
442 |
path_project_wavs = os.path.join(path_project, "wavs")
|
443 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
@@ -448,7 +477,7 @@ def create_metadata(name_project, progress=gr.Progress()):
|
|
448 |
if not os.path.isfile(file_metadata):
|
449 |
return "The file was not found in " + file_metadata
|
450 |
|
451 |
-
with open(file_metadata, "r", encoding="utf-8") as f:
|
452 |
data = f.read()
|
453 |
|
454 |
audio_path_list = []
|
@@ -499,7 +528,7 @@ def create_metadata(name_project, progress=gr.Progress()):
|
|
499 |
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
500 |
writer.write(line)
|
501 |
|
502 |
-
with open(file_duration, "w"
|
503 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
504 |
|
505 |
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
@@ -529,7 +558,6 @@ def calculate_train(
|
|
529 |
last_per_steps,
|
530 |
finetune,
|
531 |
):
|
532 |
-
name_project += "_pinyin"
|
533 |
path_project = os.path.join(path_data, name_project)
|
534 |
file_duraction = os.path.join(path_project, "duration.json")
|
535 |
|
@@ -548,8 +576,8 @@ def calculate_train(
|
|
548 |
data = json.load(file)
|
549 |
|
550 |
duration_list = data["duration"]
|
551 |
-
|
552 |
samples = len(duration_list)
|
|
|
553 |
|
554 |
if torch.cuda.is_available():
|
555 |
gpu_properties = torch.cuda.get_device_properties(0)
|
@@ -583,34 +611,67 @@ def calculate_train(
|
|
583 |
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
584 |
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
if finetune:
|
587 |
learning_rate = 1e-5
|
588 |
else:
|
589 |
learning_rate = 7.5e-5
|
590 |
|
591 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
|
593 |
|
594 |
-
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) ->
|
595 |
try:
|
596 |
checkpoint = torch.load(checkpoint_path)
|
597 |
print("Original Checkpoint Keys:", checkpoint.keys())
|
598 |
|
599 |
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
|
|
|
|
600 |
|
601 |
-
if
|
|
|
|
|
|
|
|
|
602 |
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
|
603 |
torch.save(new_checkpoint, new_checkpoint_path)
|
604 |
-
|
605 |
-
|
606 |
-
return "No 'ema_model_state_dict' found in the checkpoint."
|
607 |
|
608 |
except Exception as e:
|
609 |
return f"An error occurred: {e}"
|
610 |
|
611 |
|
612 |
def vocab_check(project_name):
|
613 |
-
name_project = project_name
|
614 |
path_project = os.path.join(path_data, name_project)
|
615 |
|
616 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
@@ -619,15 +680,15 @@ def vocab_check(project_name):
|
|
619 |
if not os.path.isfile(file_vocab):
|
620 |
return f"the file {file_vocab} not found !"
|
621 |
|
622 |
-
with open(file_vocab, "r", encoding="utf-8") as f:
|
623 |
data = f.read()
|
624 |
-
|
625 |
-
|
626 |
|
627 |
if not os.path.isfile(file_metadata):
|
628 |
return f"the file {file_metadata} not found !"
|
629 |
|
630 |
-
with open(file_metadata, "r", encoding="utf-8") as f:
|
631 |
data = f.read()
|
632 |
|
633 |
miss_symbols = []
|
@@ -652,7 +713,7 @@ def vocab_check(project_name):
|
|
652 |
|
653 |
|
654 |
def get_random_sample_prepare(project_name):
|
655 |
-
name_project = project_name
|
656 |
path_project = os.path.join(path_data, name_project)
|
657 |
file_arrow = os.path.join(path_project, "raw.arrow")
|
658 |
if not os.path.isfile(file_arrow):
|
@@ -665,14 +726,14 @@ def get_random_sample_prepare(project_name):
|
|
665 |
|
666 |
|
667 |
def get_random_sample_transcribe(project_name):
|
668 |
-
name_project = project_name
|
669 |
path_project = os.path.join(path_data, name_project)
|
670 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
671 |
if not os.path.isfile(file_metadata):
|
672 |
return "", None
|
673 |
|
674 |
data = ""
|
675 |
-
with open(file_metadata, "r", encoding="utf-8") as f:
|
676 |
data = f.read()
|
677 |
|
678 |
list_data = []
|
@@ -703,13 +764,14 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
|
703 |
global last_checkpoint, last_device, tts_api
|
704 |
|
705 |
if not os.path.isfile(file_checkpoint):
|
706 |
-
return None
|
707 |
|
708 |
if training_process is not None:
|
709 |
device_test = "cpu"
|
710 |
else:
|
711 |
device_test = None
|
712 |
|
|
|
713 |
if last_checkpoint != file_checkpoint or last_device != device_test:
|
714 |
if last_checkpoint != file_checkpoint:
|
715 |
last_checkpoint = file_checkpoint
|
@@ -722,19 +784,67 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
|
722 |
|
723 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
724 |
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
725 |
-
return f.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
|
727 |
|
728 |
with gr.Blocks() as app:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
729 |
with gr.Row():
|
|
|
|
|
730 |
project_name = gr.Textbox(label="project name", value="my_speak")
|
731 |
bt_create = gr.Button("create new project")
|
732 |
|
733 |
-
|
|
|
|
|
734 |
|
735 |
with gr.Tabs():
|
736 |
with gr.TabItem("transcribe Data"):
|
737 |
-
ch_manual = gr.Checkbox(label="
|
738 |
|
739 |
mark_info_transcribe = gr.Markdown(
|
740 |
"""```plaintext
|
@@ -756,7 +866,7 @@ with gr.Blocks() as app:
|
|
756 |
txt_info_transcribe = gr.Text(label="info", value="")
|
757 |
bt_transcribe.click(
|
758 |
fn=transcribe_all,
|
759 |
-
inputs=[
|
760 |
outputs=[txt_info_transcribe],
|
761 |
)
|
762 |
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
@@ -769,7 +879,7 @@ with gr.Blocks() as app:
|
|
769 |
|
770 |
random_sample_transcribe.click(
|
771 |
fn=get_random_sample_transcribe,
|
772 |
-
inputs=[
|
773 |
outputs=[random_text_transcribe, random_audio_transcribe],
|
774 |
)
|
775 |
|
@@ -797,7 +907,7 @@ with gr.Blocks() as app:
|
|
797 |
|
798 |
bt_prepare = bt_create = gr.Button("prepare")
|
799 |
txt_info_prepare = gr.Text(label="info", value="")
|
800 |
-
bt_prepare.click(fn=create_metadata, inputs=[
|
801 |
|
802 |
random_sample_prepare = gr.Button("random sample")
|
803 |
|
@@ -806,16 +916,20 @@ with gr.Blocks() as app:
|
|
806 |
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
807 |
|
808 |
random_sample_prepare.click(
|
809 |
-
fn=get_random_sample_prepare, inputs=[
|
810 |
)
|
811 |
|
812 |
with gr.TabItem("train Data"):
|
813 |
with gr.Row():
|
814 |
bt_calculate = bt_create = gr.Button("Auto Settings")
|
815 |
-
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
|
816 |
lb_samples = gr.Label(label="samples")
|
817 |
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
818 |
|
|
|
|
|
|
|
|
|
|
|
819 |
with gr.Row():
|
820 |
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
821 |
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
@@ -844,7 +958,7 @@ with gr.Blocks() as app:
|
|
844 |
start_button.click(
|
845 |
fn=start_training,
|
846 |
inputs=[
|
847 |
-
|
848 |
exp_name,
|
849 |
learning_rate,
|
850 |
batch_size_per_gpu,
|
@@ -857,14 +971,18 @@ with gr.Blocks() as app:
|
|
857 |
save_per_updates,
|
858 |
last_per_steps,
|
859 |
ch_finetune,
|
|
|
|
|
|
|
860 |
],
|
861 |
outputs=[txt_info_train, start_button, stop_button],
|
862 |
)
|
863 |
stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
|
|
|
864 |
bt_calculate.click(
|
865 |
fn=calculate_train,
|
866 |
inputs=[
|
867 |
-
|
868 |
batch_size_type,
|
869 |
max_samples,
|
870 |
learning_rate,
|
@@ -881,29 +999,42 @@ with gr.Blocks() as app:
|
|
881 |
last_per_steps,
|
882 |
lb_samples,
|
883 |
learning_rate,
|
|
|
884 |
],
|
885 |
)
|
886 |
|
|
|
|
|
|
|
|
|
887 |
with gr.TabItem("reduse checkpoint"):
|
888 |
txt_path_checkpoint = gr.Text(label="path checkpoint :")
|
889 |
txt_path_checkpoint_small = gr.Text(label="path output :")
|
|
|
890 |
txt_info_reduse = gr.Text(label="info", value="")
|
891 |
reduse_button = gr.Button("reduse")
|
892 |
reduse_button.click(
|
893 |
fn=extract_and_save_ema_model,
|
894 |
-
inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
|
895 |
outputs=[txt_info_reduse],
|
896 |
)
|
897 |
|
898 |
with gr.TabItem("vocab check experiment"):
|
899 |
check_button = gr.Button("check vocab")
|
900 |
txt_info_check = gr.Text(label="info", value="")
|
901 |
-
check_button.click(fn=vocab_check, inputs=[
|
902 |
|
903 |
with gr.TabItem("test model"):
|
904 |
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
|
|
|
|
905 |
nfe_step = gr.Number(label="n_step", value=32)
|
906 |
-
|
|
|
|
|
|
|
|
|
|
|
907 |
|
908 |
random_sample_infer = gr.Button("random sample")
|
909 |
|
@@ -911,17 +1042,24 @@ with gr.Blocks() as app:
|
|
911 |
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
912 |
gen_text = gr.Textbox(label="gen text")
|
913 |
random_sample_infer.click(
|
914 |
-
fn=get_random_sample_infer, inputs=[
|
915 |
)
|
916 |
-
|
|
|
|
|
|
|
|
|
917 |
gen_audio = gr.Audio(label="audio gen", type="filepath")
|
918 |
|
919 |
check_button_infer.click(
|
920 |
fn=infer,
|
921 |
-
inputs=[
|
922 |
-
outputs=[gen_audio],
|
923 |
)
|
924 |
|
|
|
|
|
|
|
925 |
|
926 |
@click.command()
|
927 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
|
|
20 |
import torchaudio
|
21 |
from datasets import Dataset as Dataset_
|
22 |
from datasets.arrow_writer import ArrowWriter
|
23 |
+
from safetensors.torch import save_file
|
24 |
from scipy.io import wavfile
|
25 |
from transformers import pipeline
|
26 |
|
|
|
248 |
save_per_updates=400,
|
249 |
last_per_steps=800,
|
250 |
finetune=True,
|
251 |
+
file_checkpoint_train="",
|
252 |
+
tokenizer_type="pinyin",
|
253 |
+
tokenizer_file="",
|
254 |
):
|
255 |
global training_process, tts_api
|
256 |
|
|
|
260 |
torch.cuda.empty_cache()
|
261 |
tts_api = None
|
262 |
|
263 |
+
path_project = os.path.join(path_data, dataset_name)
|
264 |
|
265 |
if not os.path.isdir(path_project):
|
266 |
yield (
|
|
|
282 |
yield "start train", gr.update(interactive=False), gr.update(interactive=False)
|
283 |
|
284 |
# Command to run the training script with the specified arguments
|
285 |
+
dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
|
286 |
cmd = (
|
287 |
f"accelerate launch finetune-cli.py --exp_name {exp_name} "
|
288 |
f"--learning_rate {learning_rate} "
|
|
|
300 |
if finetune:
|
301 |
cmd += f" --finetune {finetune}"
|
302 |
|
303 |
+
if file_checkpoint_train != "":
|
304 |
+
cmd += f" --file_checkpoint_train {file_checkpoint_train}"
|
305 |
+
|
306 |
+
if tokenizer_file != "":
|
307 |
+
cmd += f" --tokenizer_path {tokenizer_file}"
|
308 |
+
cmd += f" --tokenizer {tokenizer_type} "
|
309 |
+
|
310 |
print(cmd)
|
311 |
|
312 |
try:
|
|
|
343 |
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
|
344 |
|
345 |
|
346 |
+
def get_list_projects():
|
347 |
+
project_list = []
|
348 |
+
for folder in os.listdir("data"):
|
349 |
+
path_folder = os.path.join("data", folder)
|
350 |
+
if not os.path.isdir(path_folder):
|
351 |
+
continue
|
352 |
+
folder = folder.lower()
|
353 |
+
if folder == "emilia_zh_en_pinyin":
|
354 |
+
continue
|
355 |
+
project_list.append(folder)
|
356 |
+
|
357 |
+
projects_selelect = None if not project_list else project_list[-1]
|
358 |
+
|
359 |
+
return project_list, projects_selelect
|
360 |
+
|
361 |
+
|
362 |
+
def create_data_project(name, tokenizer_type):
|
363 |
+
name += "_" + tokenizer_type
|
364 |
os.makedirs(os.path.join(path_data, name), exist_ok=True)
|
365 |
os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
|
366 |
+
project_list, projects_selelect = get_list_projects()
|
367 |
+
return gr.update(choices=project_list, value=name)
|
368 |
|
369 |
|
370 |
def transcribe(file_audio, language="english"):
|
|
|
389 |
|
390 |
|
391 |
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
|
|
|
392 |
path_project = os.path.join(path_data, name_project)
|
393 |
path_dataset = os.path.join(path_project, "dataset")
|
394 |
path_project_wavs = os.path.join(path_project, "wavs")
|
395 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
396 |
|
397 |
+
if not user:
|
398 |
+
if audio_files is None:
|
399 |
+
return "You need to load an audio file."
|
400 |
|
401 |
if os.path.isdir(path_project_wavs):
|
402 |
shutil.rmtree(path_project_wavs)
|
|
|
448 |
except: # noqa: E722
|
449 |
error_num += 1
|
450 |
|
451 |
+
with open(file_metadata, "w", encoding="utf-8-sig") as f:
|
452 |
f.write(data)
|
453 |
|
454 |
if error_num != []:
|
|
|
467 |
|
468 |
|
469 |
def create_metadata(name_project, progress=gr.Progress()):
|
|
|
470 |
path_project = os.path.join(path_data, name_project)
|
471 |
path_project_wavs = os.path.join(path_project, "wavs")
|
472 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
|
|
477 |
if not os.path.isfile(file_metadata):
|
478 |
return "The file was not found in " + file_metadata
|
479 |
|
480 |
+
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
481 |
data = f.read()
|
482 |
|
483 |
audio_path_list = []
|
|
|
528 |
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
529 |
writer.write(line)
|
530 |
|
531 |
+
with open(file_duration, "w") as f:
|
532 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
533 |
|
534 |
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
|
|
558 |
last_per_steps,
|
559 |
finetune,
|
560 |
):
|
|
|
561 |
path_project = os.path.join(path_data, name_project)
|
562 |
file_duraction = os.path.join(path_project, "duration.json")
|
563 |
|
|
|
576 |
data = json.load(file)
|
577 |
|
578 |
duration_list = data["duration"]
|
|
|
579 |
samples = len(duration_list)
|
580 |
+
hours = sum(duration_list) / 3600
|
581 |
|
582 |
if torch.cuda.is_available():
|
583 |
gpu_properties = torch.cuda.get_device_properties(0)
|
|
|
611 |
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
612 |
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
613 |
|
614 |
+
total_hours = hours
|
615 |
+
mel_hop_length = 256
|
616 |
+
mel_sampling_rate = 24000
|
617 |
+
|
618 |
+
# target
|
619 |
+
wanted_max_updates = 1000000
|
620 |
+
|
621 |
+
# train params
|
622 |
+
gpus = 1
|
623 |
+
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
|
624 |
+
grad_accum = 1
|
625 |
+
|
626 |
+
# intermediate
|
627 |
+
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
628 |
+
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
629 |
+
updates_per_epoch = total_hours / mini_batch_hours
|
630 |
+
# steps_per_epoch = updates_per_epoch * grad_accum
|
631 |
+
epochs = wanted_max_updates / updates_per_epoch
|
632 |
+
|
633 |
if finetune:
|
634 |
learning_rate = 1e-5
|
635 |
else:
|
636 |
learning_rate = 7.5e-5
|
637 |
|
638 |
+
return (
|
639 |
+
batch_size_per_gpu,
|
640 |
+
max_samples,
|
641 |
+
num_warmup_updates,
|
642 |
+
save_per_updates,
|
643 |
+
last_per_steps,
|
644 |
+
samples,
|
645 |
+
learning_rate,
|
646 |
+
int(epochs),
|
647 |
+
)
|
648 |
|
649 |
|
650 |
+
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
|
651 |
try:
|
652 |
checkpoint = torch.load(checkpoint_path)
|
653 |
print("Original Checkpoint Keys:", checkpoint.keys())
|
654 |
|
655 |
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
656 |
+
if ema_model_state_dict is None:
|
657 |
+
return "No 'ema_model_state_dict' found in the checkpoint."
|
658 |
|
659 |
+
if safetensors:
|
660 |
+
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
|
661 |
+
save_file(ema_model_state_dict, new_checkpoint_path)
|
662 |
+
else:
|
663 |
+
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
|
664 |
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
|
665 |
torch.save(new_checkpoint, new_checkpoint_path)
|
666 |
+
|
667 |
+
return f"New checkpoint saved at: {new_checkpoint_path}"
|
|
|
668 |
|
669 |
except Exception as e:
|
670 |
return f"An error occurred: {e}"
|
671 |
|
672 |
|
673 |
def vocab_check(project_name):
|
674 |
+
name_project = project_name
|
675 |
path_project = os.path.join(path_data, name_project)
|
676 |
|
677 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
|
|
680 |
if not os.path.isfile(file_vocab):
|
681 |
return f"the file {file_vocab} not found !"
|
682 |
|
683 |
+
with open(file_vocab, "r", encoding="utf-8-sig") as f:
|
684 |
data = f.read()
|
685 |
+
vocab = data.split("\n")
|
686 |
+
vocab = set(vocab)
|
687 |
|
688 |
if not os.path.isfile(file_metadata):
|
689 |
return f"the file {file_metadata} not found !"
|
690 |
|
691 |
+
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
692 |
data = f.read()
|
693 |
|
694 |
miss_symbols = []
|
|
|
713 |
|
714 |
|
715 |
def get_random_sample_prepare(project_name):
|
716 |
+
name_project = project_name
|
717 |
path_project = os.path.join(path_data, name_project)
|
718 |
file_arrow = os.path.join(path_project, "raw.arrow")
|
719 |
if not os.path.isfile(file_arrow):
|
|
|
726 |
|
727 |
|
728 |
def get_random_sample_transcribe(project_name):
|
729 |
+
name_project = project_name
|
730 |
path_project = os.path.join(path_data, name_project)
|
731 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
732 |
if not os.path.isfile(file_metadata):
|
733 |
return "", None
|
734 |
|
735 |
data = ""
|
736 |
+
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
737 |
data = f.read()
|
738 |
|
739 |
list_data = []
|
|
|
764 |
global last_checkpoint, last_device, tts_api
|
765 |
|
766 |
if not os.path.isfile(file_checkpoint):
|
767 |
+
return None, "checkpoint not found!"
|
768 |
|
769 |
if training_process is not None:
|
770 |
device_test = "cpu"
|
771 |
else:
|
772 |
device_test = None
|
773 |
|
774 |
+
device_test = "cpu"
|
775 |
if last_checkpoint != file_checkpoint or last_device != device_test:
|
776 |
if last_checkpoint != file_checkpoint:
|
777 |
last_checkpoint = file_checkpoint
|
|
|
784 |
|
785 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
786 |
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
787 |
+
return f.name, tts_api.device
|
788 |
+
|
789 |
+
|
790 |
+
def check_finetune(finetune):
|
791 |
+
return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune)
|
792 |
+
|
793 |
+
|
794 |
+
def get_checkpoints_project(project_name, is_gradio=True):
|
795 |
+
if project_name is None:
|
796 |
+
return [], ""
|
797 |
+
project_name = project_name.replace("_pinyin", "").replace("_char", "")
|
798 |
+
path_project_ckpts = os.path.join("ckpts", project_name)
|
799 |
+
|
800 |
+
if os.path.isdir(path_project_ckpts):
|
801 |
+
files_checkpoints = glob(os.path.join(path_project_ckpts, "*.pt"))
|
802 |
+
files_checkpoints = sorted(
|
803 |
+
files_checkpoints,
|
804 |
+
key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
|
805 |
+
if os.path.basename(x) != "model_last.pt"
|
806 |
+
else float("inf"),
|
807 |
+
)
|
808 |
+
else:
|
809 |
+
files_checkpoints = []
|
810 |
+
|
811 |
+
selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0]
|
812 |
+
|
813 |
+
if is_gradio:
|
814 |
+
return gr.update(choices=files_checkpoints, value=selelect_checkpoint)
|
815 |
+
|
816 |
+
return files_checkpoints, selelect_checkpoint
|
817 |
|
818 |
|
819 |
with gr.Blocks() as app:
|
820 |
+
gr.Markdown(
|
821 |
+
"""
|
822 |
+
# E2/F5 TTS AUTOMATIC FINETUNE
|
823 |
+
|
824 |
+
This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
|
825 |
+
|
826 |
+
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
827 |
+
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
828 |
+
|
829 |
+
The checkpoints support English and Chinese.
|
830 |
+
|
831 |
+
for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
|
832 |
+
"""
|
833 |
+
)
|
834 |
+
|
835 |
with gr.Row():
|
836 |
+
projects, projects_selelect = get_list_projects()
|
837 |
+
tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char"], value="pinyin")
|
838 |
project_name = gr.Textbox(label="project name", value="my_speak")
|
839 |
bt_create = gr.Button("create new project")
|
840 |
|
841 |
+
cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True)
|
842 |
+
|
843 |
+
bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
|
844 |
|
845 |
with gr.Tabs():
|
846 |
with gr.TabItem("transcribe Data"):
|
847 |
+
ch_manual = gr.Checkbox(label="audio from path", value=False)
|
848 |
|
849 |
mark_info_transcribe = gr.Markdown(
|
850 |
"""```plaintext
|
|
|
866 |
txt_info_transcribe = gr.Text(label="info", value="")
|
867 |
bt_transcribe.click(
|
868 |
fn=transcribe_all,
|
869 |
+
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
|
870 |
outputs=[txt_info_transcribe],
|
871 |
)
|
872 |
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
|
|
879 |
|
880 |
random_sample_transcribe.click(
|
881 |
fn=get_random_sample_transcribe,
|
882 |
+
inputs=[cm_project],
|
883 |
outputs=[random_text_transcribe, random_audio_transcribe],
|
884 |
)
|
885 |
|
|
|
907 |
|
908 |
bt_prepare = bt_create = gr.Button("prepare")
|
909 |
txt_info_prepare = gr.Text(label="info", value="")
|
910 |
+
bt_prepare.click(fn=create_metadata, inputs=[cm_project], outputs=[txt_info_prepare])
|
911 |
|
912 |
random_sample_prepare = gr.Button("random sample")
|
913 |
|
|
|
916 |
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
917 |
|
918 |
random_sample_prepare.click(
|
919 |
+
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
920 |
)
|
921 |
|
922 |
with gr.TabItem("train Data"):
|
923 |
with gr.Row():
|
924 |
bt_calculate = bt_create = gr.Button("Auto Settings")
|
|
|
925 |
lb_samples = gr.Label(label="samples")
|
926 |
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
927 |
|
928 |
+
with gr.Row():
|
929 |
+
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
|
930 |
+
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
|
931 |
+
file_checkpoint_train = gr.Textbox(label="Checkpoint", value="")
|
932 |
+
|
933 |
with gr.Row():
|
934 |
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
935 |
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
|
|
958 |
start_button.click(
|
959 |
fn=start_training,
|
960 |
inputs=[
|
961 |
+
cm_project,
|
962 |
exp_name,
|
963 |
learning_rate,
|
964 |
batch_size_per_gpu,
|
|
|
971 |
save_per_updates,
|
972 |
last_per_steps,
|
973 |
ch_finetune,
|
974 |
+
file_checkpoint_train,
|
975 |
+
tokenizer_type,
|
976 |
+
tokenizer_file,
|
977 |
],
|
978 |
outputs=[txt_info_train, start_button, stop_button],
|
979 |
)
|
980 |
stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
|
981 |
+
|
982 |
bt_calculate.click(
|
983 |
fn=calculate_train,
|
984 |
inputs=[
|
985 |
+
cm_project,
|
986 |
batch_size_type,
|
987 |
max_samples,
|
988 |
learning_rate,
|
|
|
999 |
last_per_steps,
|
1000 |
lb_samples,
|
1001 |
learning_rate,
|
1002 |
+
epochs,
|
1003 |
],
|
1004 |
)
|
1005 |
|
1006 |
+
ch_finetune.change(
|
1007 |
+
check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
|
1008 |
+
)
|
1009 |
+
|
1010 |
with gr.TabItem("reduse checkpoint"):
|
1011 |
txt_path_checkpoint = gr.Text(label="path checkpoint :")
|
1012 |
txt_path_checkpoint_small = gr.Text(label="path output :")
|
1013 |
+
ch_safetensors = gr.Checkbox(label="safetensors", value="")
|
1014 |
txt_info_reduse = gr.Text(label="info", value="")
|
1015 |
reduse_button = gr.Button("reduse")
|
1016 |
reduse_button.click(
|
1017 |
fn=extract_and_save_ema_model,
|
1018 |
+
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
|
1019 |
outputs=[txt_info_reduse],
|
1020 |
)
|
1021 |
|
1022 |
with gr.TabItem("vocab check experiment"):
|
1023 |
check_button = gr.Button("check vocab")
|
1024 |
txt_info_check = gr.Text(label="info", value="")
|
1025 |
+
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
|
1026 |
|
1027 |
with gr.TabItem("test model"):
|
1028 |
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
1029 |
+
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
1030 |
+
|
1031 |
nfe_step = gr.Number(label="n_step", value=32)
|
1032 |
+
|
1033 |
+
with gr.Row():
|
1034 |
+
cm_checkpoint = gr.Dropdown(
|
1035 |
+
choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
|
1036 |
+
)
|
1037 |
+
bt_checkpoint_refresh = gr.Button("refresh")
|
1038 |
|
1039 |
random_sample_infer = gr.Button("random sample")
|
1040 |
|
|
|
1042 |
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
1043 |
gen_text = gr.Textbox(label="gen text")
|
1044 |
random_sample_infer.click(
|
1045 |
+
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
1046 |
)
|
1047 |
+
|
1048 |
+
with gr.Row():
|
1049 |
+
txt_info_gpu = gr.Textbox("", label="device")
|
1050 |
+
check_button_infer = gr.Button("infer")
|
1051 |
+
|
1052 |
gen_audio = gr.Audio(label="audio gen", type="filepath")
|
1053 |
|
1054 |
check_button_infer.click(
|
1055 |
fn=infer,
|
1056 |
+
inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step],
|
1057 |
+
outputs=[gen_audio, txt_info_gpu],
|
1058 |
)
|
1059 |
|
1060 |
+
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1061 |
+
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1062 |
+
|
1063 |
|
1064 |
@click.command()
|
1065 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|