File size: 3,290 Bytes
a846ae6
 
 
 
712d527
c4eee0f
712d527
0748816
 
 
 
 
 
712d527
 
3152302
0748816
a4ca14b
 
0748816
 
 
 
 
 
 
 
 
 
 
 
a4ca14b
0748816
 
 
 
 
 
 
 
 
 
a4ca14b
0748816
 
 
a4ca14b
0748816
 
 
 
a4ca14b
0748816
6fda7e5
 
 
 
 
0748816
 
712d527
a4ca14b
712d527
 
 
b180961
a4ca14b
 
28b46d3
a4ca14b
 
 
0748816
 
 
28b46d3
a4ca14b
0748816
a4ca14b
 
a846ae6
a4ca14b
 
 
 
 
 
 
 
 
aaa92f6
b180961
0748816
 
 
a4ca14b
 
 
 
0748816
 
a4ca14b
0748816
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# training script.

from importlib.resources import files

from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer

# -------------------------- Dataset Settings --------------------------- #

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos"  # 'vocos' or 'bigvgan'

tokenizer = "pinyin"  # 'pinyin', 'char', or 'custom'
tokenizer_path = None  # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
dataset_name = "Emilia_ZH_EN"

# -------------------------- Training Settings -------------------------- #

exp_name = "F5TTS_Base"  # F5TTS_Base | E2TTS_Base

learning_rate = 7.5e-5

batch_size_per_gpu = 38400  # 8 GPUs, 8 * 38400 = 307200
batch_size_type = "frame"  # "frame" or "sample"
max_samples = 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
grad_accumulation_steps = 1  # note: updates = steps / grad_accumulation_steps
max_grad_norm = 1.0

epochs = 11  # use linear decay, thus epochs control the slope
num_warmup_updates = 20000  # warmup steps
save_per_updates = 50000  # save checkpoint per steps
last_per_steps = 5000  # save last checkpoint per steps

# model params
if exp_name == "F5TTS_Base":
    wandb_resume_id = None
    model_cls = DiT
    model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
    wandb_resume_id = None
    model_cls = UNetT
    model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)


# ----------------------------------------------------------------------- #


def main():
    if tokenizer == "custom":
        tokenizer_path = tokenizer_path
    else:
        tokenizer_path = dataset_name
    vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)

    mel_spec_kwargs = dict(
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        n_mel_channels=n_mel_channels,
        target_sample_rate=target_sample_rate,
        mel_spec_type=mel_spec_type,
    )

    model = CFM(
        transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
        mel_spec_kwargs=mel_spec_kwargs,
        vocab_char_map=vocab_char_map,
    )

    trainer = Trainer(
        model,
        epochs,
        learning_rate,
        num_warmup_updates=num_warmup_updates,
        save_per_updates=save_per_updates,
        checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts/{exp_name}")),
        batch_size=batch_size_per_gpu,
        batch_size_type=batch_size_type,
        max_samples=max_samples,
        grad_accumulation_steps=grad_accumulation_steps,
        max_grad_norm=max_grad_norm,
        wandb_project="CFM-TTS",
        wandb_run_name=exp_name,
        wandb_resume_id=wandb_resume_id,
        last_per_steps=last_per_steps,
        log_samples=True,
        mel_spec_type=mel_spec_type,
    )

    train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
    trainer.train(
        train_dataset,
        resumable_with_seed=666,  # seed for shuffling dataset
    )


if __name__ == "__main__":
    main()