zkniu commited on
Commit
6523beb
·
1 Parent(s): 6b27dbe

support finetune_cli hydra and fix some minor bugs

Browse files
src/f5_tts/config/E2TTS_Base_finetune.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/finetune_${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16 # number of workers
11
+
12
+ optim:
13
+ epochs: 15 # max epochs
14
+ learning_rate: 7.5e-5 # learning rate
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Base # model name
22
+ tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 1024 # model dim
26
+ depth: 22 # model depth
27
+ heads: 16 # model heads
28
+ ff_mult: 2 # feedforward expansion
29
+ text_dim: 512 # text encoder dim
30
+ conv_layers: 4 # convolution layers
31
+ mel_spec:
32
+ target_sample_rate: 24000 # target sample rate
33
+ n_mel_channels: 100 # mel channel
34
+ hop_length: 256 # hop length
35
+ win_length: 1024 # window length
36
+ n_fft: 1024 # fft length
37
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
38
+ is_local_vocoder: False # use local vocoder or not
39
+ local_vocoder_path: None # local vocoder path
40
+
41
+ ckpts:
42
+ logger: wandb # wandb | tensorboard | None
43
+ save_per_updates: 50000 # save checkpoint per steps
44
+ last_per_steps: 5000 # save last checkpoint per steps
45
+ pretain_ckpt_path: ckpts/E2TTS_Base/model_1200000.pt
46
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
src/f5_tts/config/E2TTS_Base_train.yaml CHANGED
@@ -7,6 +7,7 @@ datasets:
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
  batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
 
10
 
11
  optim:
12
  epochs: 15 # max epochs
@@ -14,6 +15,7 @@ optim:
14
  num_warmup_updates: 20000 # warmup steps
15
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
16
  max_grad_norm: 1.0 # gradient clipping
 
17
 
18
  model:
19
  name: E2TTS_Base # model name
@@ -35,6 +37,7 @@ model:
35
  local_vocoder_path: None # path to local vocoder
36
 
37
  ckpts:
 
38
  save_per_updates: 50000 # save checkpoint per steps
39
  last_per_steps: 5000 # save last checkpoint per steps
40
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
 
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
  batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16 # number of workers
11
 
12
  optim:
13
  epochs: 15 # max epochs
 
15
  num_warmup_updates: 20000 # warmup steps
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb optimizer or not
19
 
20
  model:
21
  name: E2TTS_Base # model name
 
37
  local_vocoder_path: None # path to local vocoder
38
 
39
  ckpts:
40
+ logger: wandb # wandb | tensorboard | None
41
  save_per_updates: 50000 # save checkpoint per steps
42
  last_per_steps: 5000 # save last checkpoint per steps
43
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
src/f5_tts/config/E2TTS_Small_train.yaml CHANGED
@@ -7,6 +7,7 @@ datasets:
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
  batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
 
10
 
11
  optim:
12
  epochs: 15
@@ -14,6 +15,7 @@ optim:
14
  num_warmup_updates: 20000 # warmup steps
15
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
16
  max_grad_norm: 1.0
 
17
 
18
  model:
19
  name: E2TTS_Small
@@ -35,6 +37,7 @@ model:
35
  local_vocoder_path: None
36
 
37
  ckpts:
 
38
  save_per_updates: 50000 # save checkpoint per steps
39
  last_per_steps: 5000 # save last checkpoint per steps
40
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
 
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
  batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16 # number of workers
11
 
12
  optim:
13
  epochs: 15
 
15
  num_warmup_updates: 20000 # warmup steps
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0
18
+ bnb_optimizer: False
19
 
20
  model:
21
  name: E2TTS_Small
 
37
  local_vocoder_path: None
38
 
39
  ckpts:
40
+ logger: wandb # wandb | tensorboard | None
41
  save_per_updates: 50000 # save checkpoint per steps
42
  last_per_steps: 5000 # save last checkpoint per steps
43
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
src/f5_tts/config/F5TTS_Base_finetune.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/finetune_${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # "frame" or "sample"
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16 # number of workers
11
+
12
+ optim:
13
+ epochs: 15 # max epochs
14
+ learning_rate: 7.5e-5 # learning rate
15
+ num_warmup_updates: 20000 # warmup steps
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_Base # model name
22
+ tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
+ arch:
25
+ dim: 1024 # model dim
26
+ depth: 22 # model depth
27
+ heads: 16 # model heads
28
+ ff_mult: 2 # feedforward expansion
29
+ text_dim: 512 # text encoder dim
30
+ conv_layers: 4 # convolution layers
31
+ mel_spec:
32
+ target_sample_rate: 24000 # target sample rate
33
+ n_mel_channels: 100 # mel channel
34
+ hop_length: 256 # hop length
35
+ win_length: 1024 # window length
36
+ n_fft: 1024 # fft length
37
+ mel_spec_type: vocos # 'vocos' or 'bigvgan'
38
+ is_local_vocoder: False # use local vocoder or not
39
+ local_vocoder_path: None # local vocoder path
40
+
41
+ ckpts:
42
+ logger: wandb # wandb | tensorboard | None
43
+ save_per_updates: 50000 # save checkpoint per steps
44
+ last_per_steps: 5000 # save last checkpoint per steps
45
+ pretain_ckpt_path: ckpts/F5TTS_Base/model_1200000.pt
46
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
src/f5_tts/config/F5TTS_Base_train.yaml CHANGED
@@ -7,6 +7,7 @@ datasets:
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
  batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
 
10
 
11
  optim:
12
  epochs: 15 # max epochs
@@ -14,6 +15,7 @@ optim:
14
  num_warmup_updates: 20000 # warmup steps
15
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
16
  max_grad_norm: 1.0 # gradient clipping
 
17
 
18
  model:
19
  name: F5TTS_Base # model name
@@ -37,6 +39,7 @@ model:
37
  local_vocoder_path: None # local vocoder path
38
 
39
  ckpts:
 
40
  save_per_updates: 50000 # save checkpoint per steps
41
  last_per_steps: 5000 # save last checkpoint per steps
42
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
 
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
  batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16 # number of workers
11
 
12
  optim:
13
  epochs: 15 # max epochs
 
15
  num_warmup_updates: 20000 # warmup steps
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb optimizer or not
19
 
20
  model:
21
  name: F5TTS_Base # model name
 
39
  local_vocoder_path: None # local vocoder path
40
 
41
  ckpts:
42
+ logger: wandb # wandb | tensorboard | None
43
  save_per_updates: 50000 # save checkpoint per steps
44
  last_per_steps: 5000 # save last checkpoint per steps
45
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
src/f5_tts/config/F5TTS_Small_train.yaml CHANGED
@@ -7,6 +7,7 @@ datasets:
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
  batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
 
10
 
11
  optim:
12
  epochs: 15
@@ -14,6 +15,7 @@ optim:
14
  num_warmup_updates: 20000 # warmup steps
15
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
16
  max_grad_norm: 1.0
 
17
 
18
  model:
19
  name: F5TTS_Small
@@ -37,6 +39,7 @@ model:
37
  local_vocoder_path: None
38
 
39
  ckpts:
 
40
  save_per_updates: 50000 # save checkpoint per steps
41
  last_per_steps: 5000 # save last checkpoint per steps
42
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
 
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
  batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16 # number of workers
11
 
12
  optim:
13
  epochs: 15
 
15
  num_warmup_updates: 20000 # warmup steps
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0
18
+ bnb_optimizer: False
19
 
20
  model:
21
  name: F5TTS_Small
 
39
  local_vocoder_path: None
40
 
41
  ckpts:
42
+ logger: wandb # wandb | tensorboard | None
43
  save_per_updates: 50000 # save checkpoint per steps
44
  last_per_steps: 5000 # save last checkpoint per steps
45
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
src/f5_tts/model/trainer.py CHANGED
@@ -91,7 +91,7 @@ class Trainer:
91
  elif self.logger == "tensorboard":
92
  from torch.utils.tensorboard import SummaryWriter
93
 
94
- self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
95
 
96
  self.model = model
97
 
 
91
  elif self.logger == "tensorboard":
92
  from torch.utils.tensorboard import SummaryWriter
93
 
94
+ self.writer = SummaryWriter(log_dir=f"{checkpoint_path}/runs/{wandb_run_name}")
95
 
96
  self.model = model
97
 
src/f5_tts/model/utils.py CHANGED
@@ -113,7 +113,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
113
  with open(tokenizer_path, "r", encoding="utf-8") as f:
114
  vocab_char_map = {}
115
  for i, char in enumerate(f):
116
- vocab_char_map[char[:-1]] = i
117
  vocab_size = len(vocab_char_map)
118
  assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
119
 
@@ -125,7 +125,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
125
  with open(dataset_name, "r", encoding="utf-8") as f:
126
  vocab_char_map = {}
127
  for i, char in enumerate(f):
128
- vocab_char_map[char[:-1]] = i
129
  vocab_size = len(vocab_char_map)
130
 
131
  return vocab_char_map, vocab_size
 
113
  with open(tokenizer_path, "r", encoding="utf-8") as f:
114
  vocab_char_map = {}
115
  for i, char in enumerate(f):
116
+ vocab_char_map[char.strip()] = i # ignore \n
117
  vocab_size = len(vocab_char_map)
118
  assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
119
 
 
125
  with open(dataset_name, "r", encoding="utf-8") as f:
126
  vocab_char_map = {}
127
  for i, char in enumerate(f):
128
+ vocab_char_map[char.strip()] = i
129
  vocab_size = len(vocab_char_map)
130
 
131
  return vocab_char_map, vocab_size
src/f5_tts/train/finetune_cli.py CHANGED
@@ -1,6 +1,6 @@
1
- import argparse
2
  import os
3
  import shutil
 
4
 
5
  from cached_path import cached_path
6
  from f5_tts.model import CFM, UNetT, DiT, Trainer
@@ -9,163 +9,76 @@ from f5_tts.model.dataset import load_dataset
9
  from importlib.resources import files
10
 
11
 
12
- # -------------------------- Dataset Settings --------------------------- #
13
- target_sample_rate = 24000
14
- n_mel_channels = 100
15
- hop_length = 256
16
- win_length = 1024
17
- n_fft = 1024
18
- mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
19
 
 
 
 
 
 
 
20
 
21
- # -------------------------- Argument Parsing --------------------------- #
22
- def parse_args():
23
- # batch_size_per_gpu = 1000 settting for gpu 8GB
24
- # batch_size_per_gpu = 1600 settting for gpu 12GB
25
- # batch_size_per_gpu = 2000 settting for gpu 16GB
26
- # batch_size_per_gpu = 3200 settting for gpu 24GB
27
-
28
- # num_warmup_updates = 300 for 5000 sample about 10 hours
29
-
30
- # change save_per_updates , last_per_steps change this value what you need ,
31
-
32
- parser = argparse.ArgumentParser(description="Train CFM Model")
33
-
34
- parser.add_argument(
35
- "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
36
- )
37
- parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
38
- parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
39
- parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
40
- parser.add_argument(
41
- "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
42
- )
43
- parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
44
- parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
45
- parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
46
- parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
47
- parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
48
- parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
49
- parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
50
- parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
51
- parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
52
- parser.add_argument(
53
- "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
54
- )
55
- parser.add_argument(
56
- "--tokenizer_path",
57
- type=str,
58
- default=None,
59
- help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
60
- )
61
- parser.add_argument(
62
- "--log_samples",
63
- type=bool,
64
- default=False,
65
- help="Log inferenced samples per ckpt save steps",
66
- )
67
- parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
68
- parser.add_argument(
69
- "--bnb_optimizer",
70
- type=bool,
71
- default=False,
72
- help="Use 8-bit Adam optimizer from bitsandbytes",
73
- )
74
-
75
- return parser.parse_args()
76
-
77
-
78
- # -------------------------- Training Settings -------------------------- #
79
-
80
-
81
- def main():
82
- args = parse_args()
83
-
84
- checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
85
 
86
  # Model parameters based on experiment name
87
- if args.exp_name == "F5TTS_Base":
88
- wandb_resume_id = None
89
  model_cls = DiT
90
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
91
- if args.finetune:
92
- if args.pretrain is None:
93
- ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
94
- else:
95
- ckpt_path = args.pretrain
96
- elif args.exp_name == "E2TTS_Base":
97
- wandb_resume_id = None
98
  model_cls = UNetT
99
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
100
- if args.finetune:
101
- if args.pretrain is None:
102
- ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
103
- else:
104
- ckpt_path = args.pretrain
105
-
106
- if args.finetune:
107
- if not os.path.isdir(checkpoint_path):
108
- os.makedirs(checkpoint_path, exist_ok=True)
109
 
110
- file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
111
- if not os.path.isfile(file_checkpoint):
112
- shutil.copy2(ckpt_path, file_checkpoint)
113
- print("copy checkpoint for finetune")
114
 
115
- # Use the tokenizer and tokenizer_path provided in the command line arguments
116
- tokenizer = args.tokenizer
117
- if tokenizer == "custom":
118
- if not args.tokenizer_path:
119
- raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
120
- tokenizer_path = args.tokenizer_path
121
- else:
122
- tokenizer_path = args.dataset_name
123
 
124
- vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
125
-
126
- print("\nvocab : ", vocab_size)
127
- print("\nvocoder : ", mel_spec_type)
128
-
129
- mel_spec_kwargs = dict(
130
- n_fft=n_fft,
131
- hop_length=hop_length,
132
- win_length=win_length,
133
- n_mel_channels=n_mel_channels,
134
- target_sample_rate=target_sample_rate,
135
- mel_spec_type=mel_spec_type,
136
- )
137
 
138
  model = CFM(
139
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
140
- mel_spec_kwargs=mel_spec_kwargs,
141
  vocab_char_map=vocab_char_map,
142
  )
143
 
144
  trainer = Trainer(
145
  model,
146
- args.epochs,
147
- args.learning_rate,
148
- num_warmup_updates=args.num_warmup_updates,
149
- save_per_updates=args.save_per_updates,
150
  checkpoint_path=checkpoint_path,
151
- batch_size=args.batch_size_per_gpu,
152
- batch_size_type=args.batch_size_type,
153
- max_samples=args.max_samples,
154
- grad_accumulation_steps=args.grad_accumulation_steps,
155
- max_grad_norm=args.max_grad_norm,
156
- logger=args.logger,
157
- wandb_project=args.dataset_name,
158
- wandb_run_name=args.exp_name,
159
  wandb_resume_id=wandb_resume_id,
160
- log_samples=args.log_samples,
161
- last_per_steps=args.last_per_steps,
162
- bnb_optimizer=args.bnb_optimizer,
 
 
 
163
  )
164
 
165
- train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
166
 
167
  trainer.train(
168
  train_dataset,
 
169
  resumable_with_seed=666, # seed for shuffling dataset
170
  )
171
 
 
 
1
  import os
2
  import shutil
3
+ import hydra
4
 
5
  from cached_path import cached_path
6
  from f5_tts.model import CFM, UNetT, DiT, Trainer
 
9
  from importlib.resources import files
10
 
11
 
12
+ @hydra.main(config_path=os.path.join("..", "configs"), config_name=None)
13
+ def main(cfg):
14
+ tokenizer = cfg.model.tokenizer
15
+ mel_spec_type = cfg.model.mel_spec.mel_spec_type
16
+ exp_name = f"finetune_{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
 
 
17
 
18
+ # set text tokenizer
19
+ if tokenizer != "custom":
20
+ tokenizer_path = cfg.datasets.name
21
+ else:
22
+ tokenizer_path = cfg.model.tokenizer_path
23
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
24
 
25
+ print("\nvocab : ", vocab_size)
26
+ print("\nvocoder : ", mel_spec_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Model parameters based on experiment name
29
+ if "F5TTS" in cfg.model.name:
 
30
  model_cls = DiT
31
+ ckpt_path = cfg.ckpts.pretain_ckpt_path or str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
32
+ elif "E2TTS" in cfg.model.name:
 
 
 
 
 
 
33
  model_cls = UNetT
34
+ ckpt_path = cfg.ckpts.pretain_ckpt_path or str(cached_path("hf://SWivid/F5-TTS/E2TTS_Base/model_1200000.pt"))
35
+ wandb_resume_id = None
 
 
 
 
 
 
 
 
36
 
37
+ checkpoint_path = str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}"))
 
 
 
38
 
39
+ if not os.path.isdir(checkpoint_path):
40
+ os.makedirs(checkpoint_path, exist_ok=True)
 
 
 
 
 
 
41
 
42
+ file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
43
+ if not os.path.isfile(file_checkpoint):
44
+ shutil.copy2(ckpt_path, file_checkpoint)
45
+ print("copy checkpoint for finetune")
 
 
 
 
 
 
 
 
 
46
 
47
  model = CFM(
48
+ transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
49
+ mel_spec_kwargs=cfg.model.mel_spec,
50
  vocab_char_map=vocab_char_map,
51
  )
52
 
53
  trainer = Trainer(
54
  model,
55
+ epochs=cfg.optim.epochs,
56
+ learning_rate=cfg.optim.learning_rate,
57
+ num_warmup_updates=cfg.optim.num_warmup_updates,
58
+ save_per_updates=cfg.ckpts.save_per_updates,
59
  checkpoint_path=checkpoint_path,
60
+ batch_size=cfg.datasets.batch_size_per_gpu,
61
+ batch_size_type=cfg.datasets.batch_size_type,
62
+ max_samples=cfg.datasets.max_samples,
63
+ grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
64
+ max_grad_norm=cfg.optim.max_grad_norm,
65
+ logger=cfg.ckpts.logger,
66
+ wandb_project=cfg.datasets.name,
67
+ wandb_run_name=exp_name,
68
  wandb_resume_id=wandb_resume_id,
69
+ log_samples=True,
70
+ last_per_steps=cfg.ckpts.last_per_steps,
71
+ bnb_optimizer=cfg.optim.bnb_optimizer,
72
+ mel_spec_type=mel_spec_type,
73
+ is_local_vocoder=cfg.model.mel_spec.is_local_vocoder,
74
+ local_vocoder_path=cfg.model.mel_spec.local_vocoder_path,
75
  )
76
 
77
+ train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
78
 
79
  trainer.train(
80
  train_dataset,
81
+ num_workers=cfg.datasets.num_workers,
82
  resumable_with_seed=666, # seed for shuffling dataset
83
  )
84
 
src/f5_tts/train/train.py CHANGED
@@ -48,11 +48,13 @@ def main(cfg):
48
  max_samples=cfg.datasets.max_samples,
49
  grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
50
  max_grad_norm=cfg.optim.max_grad_norm,
 
51
  wandb_project="CFM-TTS",
52
  wandb_run_name=exp_name,
53
  wandb_resume_id=wandb_resume_id,
54
  last_per_steps=cfg.ckpts.last_per_steps,
55
  log_samples=True,
 
56
  mel_spec_type=mel_spec_type,
57
  is_local_vocoder=cfg.model.mel_spec.is_local_vocoder,
58
  local_vocoder_path=cfg.model.mel_spec.local_vocoder_path,
@@ -61,6 +63,7 @@ def main(cfg):
61
  train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
62
  trainer.train(
63
  train_dataset,
 
64
  resumable_with_seed=666, # seed for shuffling dataset
65
  )
66
 
 
48
  max_samples=cfg.datasets.max_samples,
49
  grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
50
  max_grad_norm=cfg.optim.max_grad_norm,
51
+ logger=cfg.ckpts.logger,
52
  wandb_project="CFM-TTS",
53
  wandb_run_name=exp_name,
54
  wandb_resume_id=wandb_resume_id,
55
  last_per_steps=cfg.ckpts.last_per_steps,
56
  log_samples=True,
57
+ bnb_optimizer=cfg.optim.bnb_optimizer,
58
  mel_spec_type=mel_spec_type,
59
  is_local_vocoder=cfg.model.mel_spec.is_local_vocoder,
60
  local_vocoder_path=cfg.model.mel_spec.local_vocoder_path,
 
63
  train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
64
  trainer.train(
65
  train_dataset,
66
+ num_workers=cfg.datasets.num_workers,
67
  resumable_with_seed=666, # seed for shuffling dataset
68
  )
69