support finetune_cli hydra and fix some minor bugs
Browse files- src/f5_tts/config/E2TTS_Base_finetune.yaml +46 -0
- src/f5_tts/config/E2TTS_Base_train.yaml +3 -0
- src/f5_tts/config/E2TTS_Small_train.yaml +3 -0
- src/f5_tts/config/F5TTS_Base_finetune.yaml +46 -0
- src/f5_tts/config/F5TTS_Base_train.yaml +3 -0
- src/f5_tts/config/F5TTS_Small_train.yaml +3 -0
- src/f5_tts/model/trainer.py +1 -1
- src/f5_tts/model/utils.py +2 -2
- src/f5_tts/train/finetune_cli.py +48 -135
- src/f5_tts/train/train.py +3 -0
src/f5_tts/config/E2TTS_Base_finetune.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/finetune_${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN # dataset name
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # "frame" or "sample"
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16 # number of workers
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 15 # max epochs
|
14 |
+
learning_rate: 7.5e-5 # learning rate
|
15 |
+
num_warmup_updates: 20000 # warmup steps
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: F5TTS_Base # model name
|
22 |
+
tokenizer: pinyin # tokenizer type
|
23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
24 |
+
arch:
|
25 |
+
dim: 1024 # model dim
|
26 |
+
depth: 22 # model depth
|
27 |
+
heads: 16 # model heads
|
28 |
+
ff_mult: 2 # feedforward expansion
|
29 |
+
text_dim: 512 # text encoder dim
|
30 |
+
conv_layers: 4 # convolution layers
|
31 |
+
mel_spec:
|
32 |
+
target_sample_rate: 24000 # target sample rate
|
33 |
+
n_mel_channels: 100 # mel channel
|
34 |
+
hop_length: 256 # hop length
|
35 |
+
win_length: 1024 # window length
|
36 |
+
n_fft: 1024 # fft length
|
37 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
38 |
+
is_local_vocoder: False # use local vocoder or not
|
39 |
+
local_vocoder_path: None # local vocoder path
|
40 |
+
|
41 |
+
ckpts:
|
42 |
+
logger: wandb # wandb | tensorboard | None
|
43 |
+
save_per_updates: 50000 # save checkpoint per steps
|
44 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
45 |
+
pretain_ckpt_path: ckpts/E2TTS_Base/model_1200000.pt
|
46 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/config/E2TTS_Base_train.yaml
CHANGED
@@ -7,6 +7,7 @@ datasets:
|
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
batch_size_type: frame # "frame" or "sample"
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
|
10 |
|
11 |
optim:
|
12 |
epochs: 15 # max epochs
|
@@ -14,6 +15,7 @@ optim:
|
|
14 |
num_warmup_updates: 20000 # warmup steps
|
15 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
16 |
max_grad_norm: 1.0 # gradient clipping
|
|
|
17 |
|
18 |
model:
|
19 |
name: E2TTS_Base # model name
|
@@ -35,6 +37,7 @@ model:
|
|
35 |
local_vocoder_path: None # path to local vocoder
|
36 |
|
37 |
ckpts:
|
|
|
38 |
save_per_updates: 50000 # save checkpoint per steps
|
39 |
last_per_steps: 5000 # save last checkpoint per steps
|
40 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
batch_size_type: frame # "frame" or "sample"
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16 # number of workers
|
11 |
|
12 |
optim:
|
13 |
epochs: 15 # max epochs
|
|
|
15 |
num_warmup_updates: 20000 # warmup steps
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb optimizer or not
|
19 |
|
20 |
model:
|
21 |
name: E2TTS_Base # model name
|
|
|
37 |
local_vocoder_path: None # path to local vocoder
|
38 |
|
39 |
ckpts:
|
40 |
+
logger: wandb # wandb | tensorboard | None
|
41 |
save_per_updates: 50000 # save checkpoint per steps
|
42 |
last_per_steps: 5000 # save last checkpoint per steps
|
43 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/config/E2TTS_Small_train.yaml
CHANGED
@@ -7,6 +7,7 @@ datasets:
|
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
batch_size_type: frame # "frame" or "sample"
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
|
10 |
|
11 |
optim:
|
12 |
epochs: 15
|
@@ -14,6 +15,7 @@ optim:
|
|
14 |
num_warmup_updates: 20000 # warmup steps
|
15 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
16 |
max_grad_norm: 1.0
|
|
|
17 |
|
18 |
model:
|
19 |
name: E2TTS_Small
|
@@ -35,6 +37,7 @@ model:
|
|
35 |
local_vocoder_path: None
|
36 |
|
37 |
ckpts:
|
|
|
38 |
save_per_updates: 50000 # save checkpoint per steps
|
39 |
last_per_steps: 5000 # save last checkpoint per steps
|
40 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
batch_size_type: frame # "frame" or "sample"
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16 # number of workers
|
11 |
|
12 |
optim:
|
13 |
epochs: 15
|
|
|
15 |
num_warmup_updates: 20000 # warmup steps
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0
|
18 |
+
bnb_optimizer: False
|
19 |
|
20 |
model:
|
21 |
name: E2TTS_Small
|
|
|
37 |
local_vocoder_path: None
|
38 |
|
39 |
ckpts:
|
40 |
+
logger: wandb # wandb | tensorboard | None
|
41 |
save_per_updates: 50000 # save checkpoint per steps
|
42 |
last_per_steps: 5000 # save last checkpoint per steps
|
43 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/config/F5TTS_Base_finetune.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/finetune_${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN # dataset name
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # "frame" or "sample"
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16 # number of workers
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 15 # max epochs
|
14 |
+
learning_rate: 7.5e-5 # learning rate
|
15 |
+
num_warmup_updates: 20000 # warmup steps
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: F5TTS_Base # model name
|
22 |
+
tokenizer: pinyin # tokenizer type
|
23 |
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
24 |
+
arch:
|
25 |
+
dim: 1024 # model dim
|
26 |
+
depth: 22 # model depth
|
27 |
+
heads: 16 # model heads
|
28 |
+
ff_mult: 2 # feedforward expansion
|
29 |
+
text_dim: 512 # text encoder dim
|
30 |
+
conv_layers: 4 # convolution layers
|
31 |
+
mel_spec:
|
32 |
+
target_sample_rate: 24000 # target sample rate
|
33 |
+
n_mel_channels: 100 # mel channel
|
34 |
+
hop_length: 256 # hop length
|
35 |
+
win_length: 1024 # window length
|
36 |
+
n_fft: 1024 # fft length
|
37 |
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
38 |
+
is_local_vocoder: False # use local vocoder or not
|
39 |
+
local_vocoder_path: None # local vocoder path
|
40 |
+
|
41 |
+
ckpts:
|
42 |
+
logger: wandb # wandb | tensorboard | None
|
43 |
+
save_per_updates: 50000 # save checkpoint per steps
|
44 |
+
last_per_steps: 5000 # save last checkpoint per steps
|
45 |
+
pretain_ckpt_path: ckpts/F5TTS_Base/model_1200000.pt
|
46 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/config/F5TTS_Base_train.yaml
CHANGED
@@ -7,6 +7,7 @@ datasets:
|
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
batch_size_type: frame # "frame" or "sample"
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
|
10 |
|
11 |
optim:
|
12 |
epochs: 15 # max epochs
|
@@ -14,6 +15,7 @@ optim:
|
|
14 |
num_warmup_updates: 20000 # warmup steps
|
15 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
16 |
max_grad_norm: 1.0 # gradient clipping
|
|
|
17 |
|
18 |
model:
|
19 |
name: F5TTS_Base # model name
|
@@ -37,6 +39,7 @@ model:
|
|
37 |
local_vocoder_path: None # local vocoder path
|
38 |
|
39 |
ckpts:
|
|
|
40 |
save_per_updates: 50000 # save checkpoint per steps
|
41 |
last_per_steps: 5000 # save last checkpoint per steps
|
42 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
batch_size_type: frame # "frame" or "sample"
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16 # number of workers
|
11 |
|
12 |
optim:
|
13 |
epochs: 15 # max epochs
|
|
|
15 |
num_warmup_updates: 20000 # warmup steps
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb optimizer or not
|
19 |
|
20 |
model:
|
21 |
name: F5TTS_Base # model name
|
|
|
39 |
local_vocoder_path: None # local vocoder path
|
40 |
|
41 |
ckpts:
|
42 |
+
logger: wandb # wandb | tensorboard | None
|
43 |
save_per_updates: 50000 # save checkpoint per steps
|
44 |
last_per_steps: 5000 # save last checkpoint per steps
|
45 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/config/F5TTS_Small_train.yaml
CHANGED
@@ -7,6 +7,7 @@ datasets:
|
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
batch_size_type: frame # "frame" or "sample"
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
|
10 |
|
11 |
optim:
|
12 |
epochs: 15
|
@@ -14,6 +15,7 @@ optim:
|
|
14 |
num_warmup_updates: 20000 # warmup steps
|
15 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
16 |
max_grad_norm: 1.0
|
|
|
17 |
|
18 |
model:
|
19 |
name: F5TTS_Small
|
@@ -37,6 +39,7 @@ model:
|
|
37 |
local_vocoder_path: None
|
38 |
|
39 |
ckpts:
|
|
|
40 |
save_per_updates: 50000 # save checkpoint per steps
|
41 |
last_per_steps: 5000 # save last checkpoint per steps
|
42 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
batch_size_type: frame # "frame" or "sample"
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16 # number of workers
|
11 |
|
12 |
optim:
|
13 |
epochs: 15
|
|
|
15 |
num_warmup_updates: 20000 # warmup steps
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0
|
18 |
+
bnb_optimizer: False
|
19 |
|
20 |
model:
|
21 |
name: F5TTS_Small
|
|
|
39 |
local_vocoder_path: None
|
40 |
|
41 |
ckpts:
|
42 |
+
logger: wandb # wandb | tensorboard | None
|
43 |
save_per_updates: 50000 # save checkpoint per steps
|
44 |
last_per_steps: 5000 # save last checkpoint per steps
|
45 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
src/f5_tts/model/trainer.py
CHANGED
@@ -91,7 +91,7 @@ class Trainer:
|
|
91 |
elif self.logger == "tensorboard":
|
92 |
from torch.utils.tensorboard import SummaryWriter
|
93 |
|
94 |
-
self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
|
95 |
|
96 |
self.model = model
|
97 |
|
|
|
91 |
elif self.logger == "tensorboard":
|
92 |
from torch.utils.tensorboard import SummaryWriter
|
93 |
|
94 |
+
self.writer = SummaryWriter(log_dir=f"{checkpoint_path}/runs/{wandb_run_name}")
|
95 |
|
96 |
self.model = model
|
97 |
|
src/f5_tts/model/utils.py
CHANGED
@@ -113,7 +113,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
113 |
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
114 |
vocab_char_map = {}
|
115 |
for i, char in enumerate(f):
|
116 |
-
vocab_char_map[char
|
117 |
vocab_size = len(vocab_char_map)
|
118 |
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
119 |
|
@@ -125,7 +125,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
125 |
with open(dataset_name, "r", encoding="utf-8") as f:
|
126 |
vocab_char_map = {}
|
127 |
for i, char in enumerate(f):
|
128 |
-
vocab_char_map[char
|
129 |
vocab_size = len(vocab_char_map)
|
130 |
|
131 |
return vocab_char_map, vocab_size
|
|
|
113 |
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
114 |
vocab_char_map = {}
|
115 |
for i, char in enumerate(f):
|
116 |
+
vocab_char_map[char.strip()] = i # ignore \n
|
117 |
vocab_size = len(vocab_char_map)
|
118 |
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
119 |
|
|
|
125 |
with open(dataset_name, "r", encoding="utf-8") as f:
|
126 |
vocab_char_map = {}
|
127 |
for i, char in enumerate(f):
|
128 |
+
vocab_char_map[char.strip()] = i
|
129 |
vocab_size = len(vocab_char_map)
|
130 |
|
131 |
return vocab_char_map, vocab_size
|
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
import argparse
|
2 |
import os
|
3 |
import shutil
|
|
|
4 |
|
5 |
from cached_path import cached_path
|
6 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
@@ -9,163 +9,76 @@ from f5_tts.model.dataset import load_dataset
|
|
9 |
from importlib.resources import files
|
10 |
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
n_fft = 1024
|
18 |
-
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
# batch_size_per_gpu = 1000 settting for gpu 8GB
|
24 |
-
# batch_size_per_gpu = 1600 settting for gpu 12GB
|
25 |
-
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
26 |
-
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
27 |
-
|
28 |
-
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
29 |
-
|
30 |
-
# change save_per_updates , last_per_steps change this value what you need ,
|
31 |
-
|
32 |
-
parser = argparse.ArgumentParser(description="Train CFM Model")
|
33 |
-
|
34 |
-
parser.add_argument(
|
35 |
-
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
|
36 |
-
)
|
37 |
-
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
38 |
-
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
39 |
-
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
|
40 |
-
parser.add_argument(
|
41 |
-
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
|
42 |
-
)
|
43 |
-
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
|
44 |
-
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
45 |
-
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
46 |
-
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
47 |
-
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
|
48 |
-
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
|
49 |
-
parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
|
50 |
-
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
51 |
-
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
52 |
-
parser.add_argument(
|
53 |
-
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
|
54 |
-
)
|
55 |
-
parser.add_argument(
|
56 |
-
"--tokenizer_path",
|
57 |
-
type=str,
|
58 |
-
default=None,
|
59 |
-
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
|
60 |
-
)
|
61 |
-
parser.add_argument(
|
62 |
-
"--log_samples",
|
63 |
-
type=bool,
|
64 |
-
default=False,
|
65 |
-
help="Log inferenced samples per ckpt save steps",
|
66 |
-
)
|
67 |
-
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
68 |
-
parser.add_argument(
|
69 |
-
"--bnb_optimizer",
|
70 |
-
type=bool,
|
71 |
-
default=False,
|
72 |
-
help="Use 8-bit Adam optimizer from bitsandbytes",
|
73 |
-
)
|
74 |
-
|
75 |
-
return parser.parse_args()
|
76 |
-
|
77 |
-
|
78 |
-
# -------------------------- Training Settings -------------------------- #
|
79 |
-
|
80 |
-
|
81 |
-
def main():
|
82 |
-
args = parse_args()
|
83 |
-
|
84 |
-
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
85 |
|
86 |
# Model parameters based on experiment name
|
87 |
-
if
|
88 |
-
wandb_resume_id = None
|
89 |
model_cls = DiT
|
90 |
-
|
91 |
-
|
92 |
-
if args.pretrain is None:
|
93 |
-
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
94 |
-
else:
|
95 |
-
ckpt_path = args.pretrain
|
96 |
-
elif args.exp_name == "E2TTS_Base":
|
97 |
-
wandb_resume_id = None
|
98 |
model_cls = UNetT
|
99 |
-
|
100 |
-
|
101 |
-
if args.pretrain is None:
|
102 |
-
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
103 |
-
else:
|
104 |
-
ckpt_path = args.pretrain
|
105 |
-
|
106 |
-
if args.finetune:
|
107 |
-
if not os.path.isdir(checkpoint_path):
|
108 |
-
os.makedirs(checkpoint_path, exist_ok=True)
|
109 |
|
110 |
-
|
111 |
-
if not os.path.isfile(file_checkpoint):
|
112 |
-
shutil.copy2(ckpt_path, file_checkpoint)
|
113 |
-
print("copy checkpoint for finetune")
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
if tokenizer == "custom":
|
118 |
-
if not args.tokenizer_path:
|
119 |
-
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
|
120 |
-
tokenizer_path = args.tokenizer_path
|
121 |
-
else:
|
122 |
-
tokenizer_path = args.dataset_name
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
mel_spec_kwargs = dict(
|
130 |
-
n_fft=n_fft,
|
131 |
-
hop_length=hop_length,
|
132 |
-
win_length=win_length,
|
133 |
-
n_mel_channels=n_mel_channels,
|
134 |
-
target_sample_rate=target_sample_rate,
|
135 |
-
mel_spec_type=mel_spec_type,
|
136 |
-
)
|
137 |
|
138 |
model = CFM(
|
139 |
-
transformer=model_cls(**
|
140 |
-
mel_spec_kwargs=
|
141 |
vocab_char_map=vocab_char_map,
|
142 |
)
|
143 |
|
144 |
trainer = Trainer(
|
145 |
model,
|
146 |
-
|
147 |
-
|
148 |
-
num_warmup_updates=
|
149 |
-
save_per_updates=
|
150 |
checkpoint_path=checkpoint_path,
|
151 |
-
batch_size=
|
152 |
-
batch_size_type=
|
153 |
-
max_samples=
|
154 |
-
grad_accumulation_steps=
|
155 |
-
max_grad_norm=
|
156 |
-
logger=
|
157 |
-
wandb_project=
|
158 |
-
wandb_run_name=
|
159 |
wandb_resume_id=wandb_resume_id,
|
160 |
-
log_samples=
|
161 |
-
last_per_steps=
|
162 |
-
bnb_optimizer=
|
|
|
|
|
|
|
163 |
)
|
164 |
|
165 |
-
train_dataset = load_dataset(
|
166 |
|
167 |
trainer.train(
|
168 |
train_dataset,
|
|
|
169 |
resumable_with_seed=666, # seed for shuffling dataset
|
170 |
)
|
171 |
|
|
|
|
|
1 |
import os
|
2 |
import shutil
|
3 |
+
import hydra
|
4 |
|
5 |
from cached_path import cached_path
|
6 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
|
|
9 |
from importlib.resources import files
|
10 |
|
11 |
|
12 |
+
@hydra.main(config_path=os.path.join("..", "configs"), config_name=None)
|
13 |
+
def main(cfg):
|
14 |
+
tokenizer = cfg.model.tokenizer
|
15 |
+
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
16 |
+
exp_name = f"finetune_{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
|
|
|
|
17 |
|
18 |
+
# set text tokenizer
|
19 |
+
if tokenizer != "custom":
|
20 |
+
tokenizer_path = cfg.datasets.name
|
21 |
+
else:
|
22 |
+
tokenizer_path = cfg.model.tokenizer_path
|
23 |
+
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
24 |
|
25 |
+
print("\nvocab : ", vocab_size)
|
26 |
+
print("\nvocoder : ", mel_spec_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# Model parameters based on experiment name
|
29 |
+
if "F5TTS" in cfg.model.name:
|
|
|
30 |
model_cls = DiT
|
31 |
+
ckpt_path = cfg.ckpts.pretain_ckpt_path or str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
32 |
+
elif "E2TTS" in cfg.model.name:
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
model_cls = UNetT
|
34 |
+
ckpt_path = cfg.ckpts.pretain_ckpt_path or str(cached_path("hf://SWivid/F5-TTS/E2TTS_Base/model_1200000.pt"))
|
35 |
+
wandb_resume_id = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
checkpoint_path = str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}"))
|
|
|
|
|
|
|
38 |
|
39 |
+
if not os.path.isdir(checkpoint_path):
|
40 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
|
43 |
+
if not os.path.isfile(file_checkpoint):
|
44 |
+
shutil.copy2(ckpt_path, file_checkpoint)
|
45 |
+
print("copy checkpoint for finetune")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
model = CFM(
|
48 |
+
transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
|
49 |
+
mel_spec_kwargs=cfg.model.mel_spec,
|
50 |
vocab_char_map=vocab_char_map,
|
51 |
)
|
52 |
|
53 |
trainer = Trainer(
|
54 |
model,
|
55 |
+
epochs=cfg.optim.epochs,
|
56 |
+
learning_rate=cfg.optim.learning_rate,
|
57 |
+
num_warmup_updates=cfg.optim.num_warmup_updates,
|
58 |
+
save_per_updates=cfg.ckpts.save_per_updates,
|
59 |
checkpoint_path=checkpoint_path,
|
60 |
+
batch_size=cfg.datasets.batch_size_per_gpu,
|
61 |
+
batch_size_type=cfg.datasets.batch_size_type,
|
62 |
+
max_samples=cfg.datasets.max_samples,
|
63 |
+
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
64 |
+
max_grad_norm=cfg.optim.max_grad_norm,
|
65 |
+
logger=cfg.ckpts.logger,
|
66 |
+
wandb_project=cfg.datasets.name,
|
67 |
+
wandb_run_name=exp_name,
|
68 |
wandb_resume_id=wandb_resume_id,
|
69 |
+
log_samples=True,
|
70 |
+
last_per_steps=cfg.ckpts.last_per_steps,
|
71 |
+
bnb_optimizer=cfg.optim.bnb_optimizer,
|
72 |
+
mel_spec_type=mel_spec_type,
|
73 |
+
is_local_vocoder=cfg.model.mel_spec.is_local_vocoder,
|
74 |
+
local_vocoder_path=cfg.model.mel_spec.local_vocoder_path,
|
75 |
)
|
76 |
|
77 |
+
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
78 |
|
79 |
trainer.train(
|
80 |
train_dataset,
|
81 |
+
num_workers=cfg.datasets.num_workers,
|
82 |
resumable_with_seed=666, # seed for shuffling dataset
|
83 |
)
|
84 |
|
src/f5_tts/train/train.py
CHANGED
@@ -48,11 +48,13 @@ def main(cfg):
|
|
48 |
max_samples=cfg.datasets.max_samples,
|
49 |
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
50 |
max_grad_norm=cfg.optim.max_grad_norm,
|
|
|
51 |
wandb_project="CFM-TTS",
|
52 |
wandb_run_name=exp_name,
|
53 |
wandb_resume_id=wandb_resume_id,
|
54 |
last_per_steps=cfg.ckpts.last_per_steps,
|
55 |
log_samples=True,
|
|
|
56 |
mel_spec_type=mel_spec_type,
|
57 |
is_local_vocoder=cfg.model.mel_spec.is_local_vocoder,
|
58 |
local_vocoder_path=cfg.model.mel_spec.local_vocoder_path,
|
@@ -61,6 +63,7 @@ def main(cfg):
|
|
61 |
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
62 |
trainer.train(
|
63 |
train_dataset,
|
|
|
64 |
resumable_with_seed=666, # seed for shuffling dataset
|
65 |
)
|
66 |
|
|
|
48 |
max_samples=cfg.datasets.max_samples,
|
49 |
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
50 |
max_grad_norm=cfg.optim.max_grad_norm,
|
51 |
+
logger=cfg.ckpts.logger,
|
52 |
wandb_project="CFM-TTS",
|
53 |
wandb_run_name=exp_name,
|
54 |
wandb_resume_id=wandb_resume_id,
|
55 |
last_per_steps=cfg.ckpts.last_per_steps,
|
56 |
log_samples=True,
|
57 |
+
bnb_optimizer=cfg.optim.bnb_optimizer,
|
58 |
mel_spec_type=mel_spec_type,
|
59 |
is_local_vocoder=cfg.model.mel_spec.is_local_vocoder,
|
60 |
local_vocoder_path=cfg.model.mel_spec.local_vocoder_path,
|
|
|
63 |
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
64 |
trainer.train(
|
65 |
train_dataset,
|
66 |
+
num_workers=cfg.datasets.num_workers,
|
67 |
resumable_with_seed=666, # seed for shuffling dataset
|
68 |
)
|
69 |
|