Spaces:
Configuration error
Configuration error
File size: 3,290 Bytes
a846ae6 712d527 c4eee0f 712d527 0748816 712d527 3152302 0748816 a4ca14b 0748816 a4ca14b 0748816 a4ca14b 0748816 a4ca14b 0748816 a4ca14b 0748816 6fda7e5 0748816 712d527 a4ca14b 712d527 b180961 a4ca14b 28b46d3 a4ca14b 0748816 28b46d3 a4ca14b 0748816 a4ca14b a846ae6 a4ca14b aaa92f6 b180961 0748816 a4ca14b 0748816 a4ca14b 0748816 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# training script.
from importlib.resources import files
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
dataset_name = "Emilia_ZH_EN"
# -------------------------- Training Settings -------------------------- #
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
learning_rate = 7.5e-5
batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type = "frame" # "frame" or "sample"
max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm = 1.0
epochs = 11 # use linear decay, thus epochs control the slope
num_warmup_updates = 20000 # warmup steps
save_per_updates = 50000 # save checkpoint per steps
last_per_steps = 5000 # save last checkpoint per steps
# model params
if exp_name == "F5TTS_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
# ----------------------------------------------------------------------- #
def main():
if tokenizer == "custom":
tokenizer_path = tokenizer_path
else:
tokenizer_path = dataset_name
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
mel_spec_kwargs = dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=mel_spec_kwargs,
vocab_char_map=vocab_char_map,
)
trainer = Trainer(
model,
epochs,
learning_rate,
num_warmup_updates=num_warmup_updates,
save_per_updates=save_per_updates,
checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts/{exp_name}")),
batch_size=batch_size_per_gpu,
batch_size_type=batch_size_type,
max_samples=max_samples,
grad_accumulation_steps=grad_accumulation_steps,
max_grad_norm=max_grad_norm,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_steps=last_per_steps,
log_samples=True,
mel_spec_type=mel_spec_type,
)
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
trainer.train(
train_dataset,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()
|