SWivid commited on
Commit
a846ae6
·
1 Parent(s): ba4b04b

finish train dependencies

Browse files
README.md CHANGED
@@ -65,70 +65,6 @@ pre-commit run --all-files
65
 
66
  Note: Some model components have linting exceptions for E722 to accommodate tensor notation
67
 
68
- ## Prepare Dataset
69
-
70
- Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `f5_tts/model/dataset.py`.
71
-
72
- ```bash
73
- # switch to the main directory
74
- cd f5_tts
75
-
76
- # prepare custom dataset up to your need
77
- # download corresponding dataset first, and fill in the path in scripts
78
-
79
- # Prepare the Emilia dataset
80
- python scripts/prepare_emilia.py
81
-
82
- # Prepare the Wenetspeech4TTS dataset
83
- python scripts/prepare_wenetspeech4tts.py
84
-
85
- # https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029
86
- python scripts/prepare_csv_wavs.py
87
- ```
88
-
89
- ## Training & Finetuning
90
-
91
- Once your datasets are prepared, you can start the training process.
92
-
93
- ```bash
94
- # switch to the main directory
95
- cd f5_tts
96
-
97
- # setup accelerate config, e.g. use multi-gpu ddp, fp16
98
- # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
99
- accelerate config
100
- accelerate launch train.py
101
- ```
102
- An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
103
-
104
- Gradio UI finetuning with `f5_tts/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
105
-
106
- ### Wandb Logging
107
-
108
- By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
109
-
110
- To turn on wandb logging, you can either:
111
-
112
- 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
113
- 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
114
-
115
- On Mac & Linux:
116
-
117
- ```
118
- export WANDB_API_KEY=<YOUR WANDB API KEY>
119
- ```
120
-
121
- On Windows:
122
-
123
- ```
124
- set WANDB_API_KEY=<YOUR WANDB API KEY>
125
- ```
126
- Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
127
-
128
- ```
129
- export WANDB_MODE=offline
130
- ```
131
-
132
  ## Inference
133
 
134
  ```python
@@ -215,6 +151,8 @@ To test speech editing capabilities, use the following command.
215
  python f5_tts/speech_edit.py
216
  ```
217
 
 
 
218
  ## [Evaluation](src/f5_tts/eval/README.md)
219
 
220
  ## Acknowledgements
 
65
 
66
  Note: Some model components have linting exceptions for E722 to accommodate tensor notation
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ## Inference
69
 
70
  ```python
 
151
  python f5_tts/speech_edit.py
152
  ```
153
 
154
+ ## [Training](src/f5_tts/train/README.md)
155
+
156
  ## [Evaluation](src/f5_tts/eval/README.md)
157
 
158
  ## Acknowledgements
src/f5_tts/model/dataset.py CHANGED
@@ -1,14 +1,15 @@
1
  import json
2
  import random
 
3
  from tqdm import tqdm
4
 
5
  import torch
6
  import torch.nn.functional as F
7
- from torch.utils.data import Dataset, Sampler
8
  import torchaudio
 
 
9
  from datasets import load_from_disk
10
  from datasets import Dataset as Dataset_
11
- from torch import nn
12
 
13
  from f5_tts.model.modules import MelSpec
14
  from f5_tts.model.utils import default
@@ -221,16 +222,17 @@ def load_dataset(
221
  print("Loading dataset ...")
222
 
223
  if dataset_type == "CustomDataset":
 
224
  if audio_type == "raw":
225
  try:
226
- train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
227
  except: # noqa: E722
228
- train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
229
  preprocessed_mel = False
230
  elif audio_type == "mel":
231
- train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
232
  preprocessed_mel = True
233
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f:
234
  data_dict = json.load(f)
235
  durations = data_dict["duration"]
236
  train_dataset = CustomDataset(
@@ -261,7 +263,7 @@ def load_dataset(
261
  )
262
  pre, post = dataset_name.split("_")
263
  train_dataset = HFDataset(
264
- load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),
265
  )
266
 
267
  return train_dataset
 
1
  import json
2
  import random
3
+ from importlib.resources import files
4
  from tqdm import tqdm
5
 
6
  import torch
7
  import torch.nn.functional as F
 
8
  import torchaudio
9
+ from torch import nn
10
+ from torch.utils.data import Dataset, Sampler
11
  from datasets import load_from_disk
12
  from datasets import Dataset as Dataset_
 
13
 
14
  from f5_tts.model.modules import MelSpec
15
  from f5_tts.model.utils import default
 
222
  print("Loading dataset ...")
223
 
224
  if dataset_type == "CustomDataset":
225
+ rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
226
  if audio_type == "raw":
227
  try:
228
+ train_dataset = load_from_disk(f"{rel_data_path}/raw")
229
  except: # noqa: E722
230
+ train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
231
  preprocessed_mel = False
232
  elif audio_type == "mel":
233
+ train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
234
  preprocessed_mel = True
235
+ with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
236
  data_dict = json.load(f)
237
  durations = data_dict["duration"]
238
  train_dataset = CustomDataset(
 
263
  )
264
  pre, post = dataset_name.split("_")
265
  train_dataset = HFDataset(
266
+ load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
267
  )
268
 
269
  return train_dataset
src/f5_tts/train/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Prepare Dataset
3
+
4
+ Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
5
+
6
+ ### 1. Datasets used for pretrained models
7
+ Download corresponding dataset first, and fill in the path in scripts.
8
+
9
+ ```bash
10
+ # Prepare the Emilia dataset
11
+ python src/f5_tts/train/datasets/prepare_emilia.py
12
+
13
+ # Prepare the Wenetspeech4TTS dataset
14
+ python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
15
+ ```
16
+
17
+ ### 2. Create custom dataset with metadata.csv
18
+ Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
19
+
20
+ ```bash
21
+ python src/f5_tts/train/datasets/prepare_csv_wavs.py
22
+ ```
23
+
24
+ ## Training & Finetuning
25
+
26
+ Once your datasets are prepared, you can start the training process.
27
+
28
+ ### 1. Training script used for pretrained model
29
+
30
+ ```bash
31
+ # setup accelerate config, e.g. use multi-gpu ddp, fp16
32
+ # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
33
+ accelerate config
34
+ accelerate launch src/f5_tts/train/train.py
35
+ ```
36
+
37
+ ### 2. Finetuning practice
38
+ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
39
+
40
+ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
41
+
42
+ ### 3. Wandb Logging
43
+
44
+ The `wandb/` dir will be created under path you run training/finetuning scripts.
45
+
46
+ By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
47
+
48
+ To turn on wandb logging, you can either:
49
+
50
+ 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
51
+ 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
52
+
53
+ On Mac & Linux:
54
+
55
+ ```
56
+ export WANDB_API_KEY=<YOUR WANDB API KEY>
57
+ ```
58
+
59
+ On Windows:
60
+
61
+ ```
62
+ set WANDB_API_KEY=<YOUR WANDB API KEY>
63
+ ```
64
+ Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
65
+
66
+ ```
67
+ export WANDB_MODE=offline
68
+ ```
src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -1,14 +1,15 @@
1
- import sys
2
  import os
 
3
 
4
  sys.path.append(os.getcwd())
5
 
6
- from pathlib import Path
 
7
  import json
8
  import shutil
9
- import argparse
 
10
 
11
- import csv
12
  import torchaudio
13
  from tqdm import tqdm
14
  from datasets.arrow_writer import ArrowWriter
@@ -17,7 +18,8 @@ from f5_tts.model.utils import (
17
  convert_char_to_pinyin,
18
  )
19
 
20
- PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
 
21
 
22
 
23
  def is_csv_wavs_format(input_dataset_dir):
@@ -80,7 +82,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
80
  print(f"\nSaving to {out_dir} ...")
81
 
82
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
83
- # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
84
  raw_arrow_path = out_dir / "raw.arrow"
85
  with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
86
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
 
 
1
  import os
2
+ import sys
3
 
4
  sys.path.append(os.getcwd())
5
 
6
+ import argparse
7
+ import csv
8
  import json
9
  import shutil
10
+ from importlib.resources import files
11
+ from pathlib import Path
12
 
 
13
  import torchaudio
14
  from tqdm import tqdm
15
  from datasets.arrow_writer import ArrowWriter
 
18
  convert_char_to_pinyin,
19
  )
20
 
21
+
22
+ PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
23
 
24
 
25
  def is_csv_wavs_format(input_dataset_dir):
 
82
  print(f"\nSaving to {out_dir} ...")
83
 
84
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
85
+ # dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
86
  raw_arrow_path = out_dir / "raw.arrow"
87
  with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
88
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
src/f5_tts/train/datasets/prepare_emilia.py CHANGED
@@ -4,15 +4,16 @@
4
  # generate audio text map for Emilia ZH & EN
5
  # evaluate for vocab size
6
 
7
- import sys
8
  import os
 
9
 
10
  sys.path.append(os.getcwd())
11
 
12
- from pathlib import Path
13
  import json
14
- from tqdm import tqdm
15
  from concurrent.futures import ProcessPoolExecutor
 
 
 
16
 
17
  from datasets.arrow_writer import ArrowWriter
18
 
@@ -173,24 +174,25 @@ def main():
173
  executor.shutdown()
174
 
175
  # save preprocessed dataset to disk
176
- if not os.path.exists(f"data/{dataset_name}"):
177
- os.makedirs(f"data/{dataset_name}")
178
- print(f"\nSaving to data/{dataset_name} ...")
 
179
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
180
- # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
181
- with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
182
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
183
  writer.write(line)
184
 
185
  # dup a json separately saving duration in case for DynamicBatchSampler ease
186
- with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
187
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
188
 
189
  # vocab map, i.e. tokenizer
190
  # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
191
  # if tokenizer == "pinyin":
192
  # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
193
- with open(f"data/{dataset_name}/vocab.txt", "w") as f:
194
  for vocab in sorted(text_vocab_set):
195
  f.write(vocab + "\n")
196
 
@@ -212,7 +214,8 @@ if __name__ == "__main__":
212
  langs = ["ZH", "EN"]
213
  dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
214
  dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
215
- print(f"\nPrepare for {dataset_name}\n")
 
216
 
217
  main()
218
 
 
4
  # generate audio text map for Emilia ZH & EN
5
  # evaluate for vocab size
6
 
 
7
  import os
8
+ import sys
9
 
10
  sys.path.append(os.getcwd())
11
 
 
12
  import json
 
13
  from concurrent.futures import ProcessPoolExecutor
14
+ from importlib.resources import files
15
+ from pathlib import Path
16
+ from tqdm import tqdm
17
 
18
  from datasets.arrow_writer import ArrowWriter
19
 
 
174
  executor.shutdown()
175
 
176
  # save preprocessed dataset to disk
177
+ if not os.path.exists(f"{save_dir}"):
178
+ os.makedirs(f"{save_dir}")
179
+ print(f"\nSaving to {save_dir} ...")
180
+
181
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
182
+ # dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB")
183
+ with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
184
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
185
  writer.write(line)
186
 
187
  # dup a json separately saving duration in case for DynamicBatchSampler ease
188
+ with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
189
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
190
 
191
  # vocab map, i.e. tokenizer
192
  # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
193
  # if tokenizer == "pinyin":
194
  # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
195
+ with open(f"{save_dir}/vocab.txt", "w") as f:
196
  for vocab in sorted(text_vocab_set):
197
  f.write(vocab + "\n")
198
 
 
214
  langs = ["ZH", "EN"]
215
  dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
216
  dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
217
+ save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
218
+ print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
219
 
220
  main()
221
 
src/f5_tts/train/datasets/prepare_wenetspeech4tts.py CHANGED
@@ -1,14 +1,15 @@
1
  # generate audio text map for WenetSpeech4TTS
2
  # evaluate for vocab size
3
 
4
- import sys
5
  import os
 
6
 
7
  sys.path.append(os.getcwd())
8
 
9
  import json
10
- from tqdm import tqdm
11
  from concurrent.futures import ProcessPoolExecutor
 
 
12
 
13
  import torchaudio
14
  from datasets import Dataset
@@ -66,11 +67,11 @@ def main():
66
  if not os.path.exists("data"):
67
  os.makedirs("data")
68
 
69
- print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
70
  dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
71
- dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
72
 
73
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
74
  json.dump(
75
  {"duration": duration_list}, f, ensure_ascii=False
76
  ) # dup a json separately saving duration in case for DynamicBatchSampler ease
@@ -84,7 +85,7 @@ def main():
84
  if tokenizer == "pinyin":
85
  text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
86
 
87
- with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
88
  for vocab in sorted(text_vocab_set):
89
  f.write(vocab + "\n")
90
  print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
@@ -98,13 +99,18 @@ if __name__ == "__main__":
98
  polyphone = True
99
  dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
100
 
101
- dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
 
 
 
 
102
  dataset_paths = [
103
  "<SOME_PATH>/WenetSpeech4TTS/Basic",
104
  "<SOME_PATH>/WenetSpeech4TTS/Standard",
105
  "<SOME_PATH>/WenetSpeech4TTS/Premium",
106
  ][-dataset_choice:]
107
- print(f"\nChoose Dataset: {dataset_name}\n")
 
108
 
109
  main()
110
 
 
1
  # generate audio text map for WenetSpeech4TTS
2
  # evaluate for vocab size
3
 
 
4
  import os
5
+ import sys
6
 
7
  sys.path.append(os.getcwd())
8
 
9
  import json
 
10
  from concurrent.futures import ProcessPoolExecutor
11
+ from importlib.resources import files
12
+ from tqdm import tqdm
13
 
14
  import torchaudio
15
  from datasets import Dataset
 
67
  if not os.path.exists("data"):
68
  os.makedirs("data")
69
 
70
+ print(f"\nSaving to {save_dir} ...")
71
  dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
72
+ dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format
73
 
74
+ with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
75
  json.dump(
76
  {"duration": duration_list}, f, ensure_ascii=False
77
  ) # dup a json separately saving duration in case for DynamicBatchSampler ease
 
85
  if tokenizer == "pinyin":
86
  text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
87
 
88
+ with open(f"{save_dir}/vocab.txt", "w") as f:
89
  for vocab in sorted(text_vocab_set):
90
  f.write(vocab + "\n")
91
  print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
 
99
  polyphone = True
100
  dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
101
 
102
+ dataset_name = (
103
+ ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
104
+ + "_"
105
+ + tokenizer
106
+ )
107
  dataset_paths = [
108
  "<SOME_PATH>/WenetSpeech4TTS/Basic",
109
  "<SOME_PATH>/WenetSpeech4TTS/Standard",
110
  "<SOME_PATH>/WenetSpeech4TTS/Premium",
111
  ][-dataset_choice:]
112
+ save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
113
+ print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n")
114
 
115
  main()
116
 
src/f5_tts/train/finetune_cli.py CHANGED
@@ -7,6 +7,7 @@ from f5_tts.model import CFM, UNetT, DiT, Trainer
7
  from f5_tts.model.utils import get_tokenizer
8
  from f5_tts.model.dataset import load_dataset
9
 
 
10
  # -------------------------- Dataset Settings --------------------------- #
11
  target_sample_rate = 24000
12
  n_mel_channels = 100
@@ -20,9 +21,9 @@ def parse_args():
20
  # batch_size_per_gpu = 2000 settting for gpu 16GB
21
  # batch_size_per_gpu = 3200 settting for gpu 24GB
22
 
23
- # num_warmup_updates 10000 sample = 500
24
 
25
- # change save_per_updates , last_per_steps what you need ,
26
 
27
  parser = argparse.ArgumentParser(description="Train CFM Model")
28
 
@@ -39,9 +40,9 @@ def parse_args():
39
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
40
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
41
  parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
42
- parser.add_argument("--num_warmup_updates", type=int, default=500, help="Warmup steps")
43
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
44
- parser.add_argument("--last_per_steps", type=int, default=20000, help="Save last checkpoint every X steps")
45
  parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
46
  parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
47
  parser.add_argument(
@@ -126,7 +127,7 @@ def main():
126
  max_samples=args.max_samples,
127
  grad_accumulation_steps=args.grad_accumulation_steps,
128
  max_grad_norm=args.max_grad_norm,
129
- wandb_project="CFM-TTS",
130
  wandb_run_name=args.exp_name,
131
  wandb_resume_id=wandb_resume_id,
132
  last_per_steps=args.last_per_steps,
 
7
  from f5_tts.model.utils import get_tokenizer
8
  from f5_tts.model.dataset import load_dataset
9
 
10
+
11
  # -------------------------- Dataset Settings --------------------------- #
12
  target_sample_rate = 24000
13
  n_mel_channels = 100
 
21
  # batch_size_per_gpu = 2000 settting for gpu 16GB
22
  # batch_size_per_gpu = 3200 settting for gpu 24GB
23
 
24
+ # num_warmup_updates = 300 for 5000 sample about 10 hours
25
 
26
+ # change save_per_updates , last_per_steps change this value what you need ,
27
 
28
  parser = argparse.ArgumentParser(description="Train CFM Model")
29
 
 
40
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
41
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
42
  parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
43
+ parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
44
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
45
+ parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
46
  parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
47
  parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
48
  parser.add_argument(
 
127
  max_samples=args.max_samples,
128
  grad_accumulation_steps=args.grad_accumulation_steps,
129
  max_grad_norm=args.max_grad_norm,
130
+ wandb_project=args.dataset_name,
131
  wandb_run_name=args.exp_name,
132
  wandb_resume_id=wandb_resume_id,
133
  last_per_steps=args.last_per_steps,
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -251,6 +251,7 @@ def start_training(
251
  file_checkpoint_train="",
252
  tokenizer_type="pinyin",
253
  tokenizer_file="",
 
254
  ):
255
  global training_process, tts_api
256
 
@@ -282,9 +283,24 @@ def start_training(
282
  yield "start train", gr.update(interactive=False), gr.update(interactive=False)
283
 
284
  # Command to run the training script with the specified arguments
 
 
 
 
 
 
 
 
 
285
  dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
 
 
 
 
 
 
286
  cmd = (
287
- f"accelerate launch finetune-cli.py --exp_name {exp_name} "
288
  f"--learning_rate {learning_rate} "
289
  f"--batch_size_per_gpu {batch_size_per_gpu} "
290
  f"--batch_size_type {batch_size_type} "
@@ -305,7 +321,8 @@ def start_training(
305
 
306
  if tokenizer_file != "":
307
  cmd += f" --tokenizer_path {tokenizer_file}"
308
- cmd += f" --tokenizer {tokenizer_type} "
 
309
 
310
  print(cmd)
311
 
@@ -466,7 +483,7 @@ def format_seconds_to_hms(seconds):
466
  return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
467
 
468
 
469
- def create_metadata(name_project, progress=gr.Progress()):
470
  path_project = os.path.join(path_data, name_project)
471
  path_project_wavs = os.path.join(path_project, "wavs")
472
  file_metadata = os.path.join(path_project, "metadata.csv")
@@ -475,7 +492,7 @@ def create_metadata(name_project, progress=gr.Progress()):
475
  file_vocab = os.path.join(path_project, "vocab.txt")
476
 
477
  if not os.path.isfile(file_metadata):
478
- return "The file was not found in " + file_metadata
479
 
480
  with open(file_metadata, "r", encoding="utf-8-sig") as f:
481
  data = f.read()
@@ -488,6 +505,7 @@ def create_metadata(name_project, progress=gr.Progress()):
488
  lenght = 0
489
  result = []
490
  error_files = []
 
491
  for line in progress.tqdm(data.split("\n"), total=count):
492
  sp_line = line.split("|")
493
  if len(sp_line) != 2:
@@ -497,29 +515,38 @@ def create_metadata(name_project, progress=gr.Progress()):
497
  file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
498
 
499
  if not os.path.isfile(file_audio):
500
- error_files.append(file_audio)
501
  continue
502
 
503
- duraction = get_audio_duration(file_audio)
504
- if duraction < 2 and duraction > 15:
 
 
 
 
 
 
 
505
  continue
506
  if len(text) < 4:
 
507
  continue
508
 
509
  text = clear_text(text)
510
  text = convert_char_to_pinyin([text], polyphone=True)[0]
511
 
512
  audio_path_list.append(file_audio)
513
- duration_list.append(duraction)
514
  text_list.append(text)
515
 
516
- result.append({"audio_path": file_audio, "text": text, "duration": duraction})
 
 
517
 
518
- lenght += duraction
519
 
520
  if duration_list == []:
521
- error_files_text = "\n".join(error_files)
522
- return f"Error: No audio files found in the specified path : \n{error_files_text}"
523
 
524
  min_second = round(min(duration_list), 2)
525
  max_second = round(max(duration_list), 2)
@@ -531,17 +558,35 @@ def create_metadata(name_project, progress=gr.Progress()):
531
  with open(file_duration, "w") as f:
532
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
533
 
534
- file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
535
- if not os.path.isfile(file_vocab_finetune):
536
- return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
537
- shutil.copy2(file_vocab_finetune, file_vocab)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
  if error_files != []:
540
- error_text = "error files\n" + "\n".join(error_files)
541
  else:
542
  error_text = ""
543
 
544
- return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
 
 
 
545
 
546
 
547
  def check_user(value):
@@ -579,10 +624,21 @@ def calculate_train(
579
  samples = len(duration_list)
580
  hours = sum(duration_list) / 3600
581
 
 
 
 
 
 
 
582
  if torch.cuda.is_available():
583
- gpu_properties = torch.cuda.get_device_properties(0)
584
- total_memory = gpu_properties.total_memory / (1024**3)
 
 
 
 
585
  elif torch.backends.mps.is_available():
 
586
  total_memory = psutil.virtual_memory().available / (1024**3)
587
 
588
  if batch_size_type == "frame":
@@ -619,7 +675,7 @@ def calculate_train(
619
  wanted_max_updates = 1000000
620
 
621
  # train params
622
- gpus = 1
623
  frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
624
  grad_accum = 1
625
 
@@ -816,6 +872,73 @@ def get_checkpoints_project(project_name, is_gradio=True):
816
  return files_checkpoints, selelect_checkpoint
817
 
818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
819
  with gr.Blocks() as app:
820
  gr.Markdown(
821
  """
@@ -904,10 +1027,13 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
904
 
905
  ```"""
906
  )
907
-
908
  bt_prepare = bt_create = gr.Button("prepare")
909
  txt_info_prepare = gr.Text(label="info", value="")
910
- bt_prepare.click(fn=create_metadata, inputs=[cm_project], outputs=[txt_info_prepare])
 
 
 
911
 
912
  random_sample_prepare = gr.Button("random sample")
913
 
@@ -928,7 +1054,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
928
  with gr.Row():
929
  ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
930
  tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
931
- file_checkpoint_train = gr.Textbox(label="Checkpoint", value="")
932
 
933
  with gr.Row():
934
  exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
@@ -951,6 +1077,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
951
  last_per_steps = gr.Number(label="Last per Steps", value=50)
952
 
953
  with gr.Row():
 
954
  start_button = gr.Button("Start Training")
955
  stop_button = gr.Button("Stop Training", interactive=False)
956
 
@@ -974,6 +1101,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
974
  file_checkpoint_train,
975
  tokenizer_type,
976
  tokenizer_file,
 
977
  ],
978
  outputs=[txt_info_train, start_button, stop_button],
979
  )
@@ -1019,7 +1147,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
1019
  outputs=[txt_info_reduse],
1020
  )
1021
 
1022
- with gr.TabItem("vocab check experiment"):
1023
  check_button = gr.Button("check vocab")
1024
  txt_info_check = gr.Text(label="info", value="")
1025
  check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
@@ -1060,6 +1188,20 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
1060
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1061
  cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
 
1064
  @click.command()
1065
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
 
251
  file_checkpoint_train="",
252
  tokenizer_type="pinyin",
253
  tokenizer_file="",
254
+ mixed_precision="fp16",
255
  ):
256
  global training_process, tts_api
257
 
 
283
  yield "start train", gr.update(interactive=False), gr.update(interactive=False)
284
 
285
  # Command to run the training script with the specified arguments
286
+
287
+ if tokenizer_file == "":
288
+ if dataset_name.endswith("_pinyin"):
289
+ tokenizer_type = "pinyin"
290
+ elif dataset_name.endswith("_char"):
291
+ tokenizer_type = "char"
292
+ else:
293
+ tokenizer_file = "custom"
294
+
295
  dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
296
+
297
+ if mixed_precision != "none":
298
+ fp16 = f"--mixed_precision={mixed_precision}"
299
+ else:
300
+ fp16 = ""
301
+
302
  cmd = (
303
+ f"accelerate launch {fp16} finetune-cli.py --exp_name {exp_name} "
304
  f"--learning_rate {learning_rate} "
305
  f"--batch_size_per_gpu {batch_size_per_gpu} "
306
  f"--batch_size_type {batch_size_type} "
 
321
 
322
  if tokenizer_file != "":
323
  cmd += f" --tokenizer_path {tokenizer_file}"
324
+
325
+ cmd += f" --tokenizer {tokenizer_type} "
326
 
327
  print(cmd)
328
 
 
483
  return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
484
 
485
 
486
+ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
487
  path_project = os.path.join(path_data, name_project)
488
  path_project_wavs = os.path.join(path_project, "wavs")
489
  file_metadata = os.path.join(path_project, "metadata.csv")
 
492
  file_vocab = os.path.join(path_project, "vocab.txt")
493
 
494
  if not os.path.isfile(file_metadata):
495
+ return "The file was not found in " + file_metadata, ""
496
 
497
  with open(file_metadata, "r", encoding="utf-8-sig") as f:
498
  data = f.read()
 
505
  lenght = 0
506
  result = []
507
  error_files = []
508
+ text_vocab_set = set()
509
  for line in progress.tqdm(data.split("\n"), total=count):
510
  sp_line = line.split("|")
511
  if len(sp_line) != 2:
 
515
  file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
516
 
517
  if not os.path.isfile(file_audio):
518
+ error_files.append([file_audio, "error path"])
519
  continue
520
 
521
+ try:
522
+ duration = get_audio_duration(file_audio)
523
+ except Exception as e:
524
+ error_files.append([file_audio, "duration"])
525
+ print(f"Error processing {file_audio}: {e}")
526
+ continue
527
+
528
+ if duration < 1 and duration > 25:
529
+ error_files.append([file_audio, "duration < 1 and > 25 "])
530
  continue
531
  if len(text) < 4:
532
+ error_files.append([file_audio, "very small text len 3"])
533
  continue
534
 
535
  text = clear_text(text)
536
  text = convert_char_to_pinyin([text], polyphone=True)[0]
537
 
538
  audio_path_list.append(file_audio)
539
+ duration_list.append(duration)
540
  text_list.append(text)
541
 
542
+ result.append({"audio_path": file_audio, "text": text, "duration": duration})
543
+ if ch_tokenizer:
544
+ text_vocab_set.update(list(text))
545
 
546
+ lenght += duration
547
 
548
  if duration_list == []:
549
+ return f"Error: No audio files found in the specified path : {path_project_wavs}", ""
 
550
 
551
  min_second = round(min(duration_list), 2)
552
  max_second = round(max(duration_list), 2)
 
558
  with open(file_duration, "w") as f:
559
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
560
 
561
+ new_vocal = ""
562
+ if not ch_tokenizer:
563
+ file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
564
+ if not os.path.isfile(file_vocab_finetune):
565
+ return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
566
+ shutil.copy2(file_vocab_finetune, file_vocab)
567
+
568
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
569
+ vocab_char_map = {}
570
+ for i, char in enumerate(f):
571
+ vocab_char_map[char[:-1]] = i
572
+ vocab_size = len(vocab_char_map)
573
+
574
+ else:
575
+ with open(file_vocab, "w", encoding="utf-8-sig") as f:
576
+ for vocab in sorted(text_vocab_set):
577
+ f.write(vocab + "\n")
578
+ new_vocal += vocab + "\n"
579
+ vocab_size = len(text_vocab_set)
580
 
581
  if error_files != []:
582
+ error_text = "\n".join([" = ".join(item) for item in error_files])
583
  else:
584
  error_text = ""
585
 
586
+ return (
587
+ f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\nvocab : {vocab_size}\n{error_text}",
588
+ new_vocal,
589
+ )
590
 
591
 
592
  def check_user(value):
 
624
  samples = len(duration_list)
625
  hours = sum(duration_list) / 3600
626
 
627
+ # if torch.cuda.is_available():
628
+ # gpu_properties = torch.cuda.get_device_properties(0)
629
+ # total_memory = gpu_properties.total_memory / (1024**3)
630
+ # elif torch.backends.mps.is_available():
631
+ # total_memory = psutil.virtual_memory().available / (1024**3)
632
+
633
  if torch.cuda.is_available():
634
+ gpu_count = torch.cuda.device_count()
635
+ total_memory = 0
636
+ for i in range(gpu_count):
637
+ gpu_properties = torch.cuda.get_device_properties(i)
638
+ total_memory += gpu_properties.total_memory / (1024**3) # in GB
639
+
640
  elif torch.backends.mps.is_available():
641
+ gpu_count = 1
642
  total_memory = psutil.virtual_memory().available / (1024**3)
643
 
644
  if batch_size_type == "frame":
 
675
  wanted_max_updates = 1000000
676
 
677
  # train params
678
+ gpus = gpu_count
679
  frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
680
  grad_accum = 1
681
 
 
872
  return files_checkpoints, selelect_checkpoint
873
 
874
 
875
+ def get_gpu_stats():
876
+ gpu_stats = ""
877
+
878
+ if torch.cuda.is_available():
879
+ gpu_count = torch.cuda.device_count()
880
+ for i in range(gpu_count):
881
+ gpu_name = torch.cuda.get_device_name(i)
882
+ gpu_properties = torch.cuda.get_device_properties(i)
883
+ total_memory = gpu_properties.total_memory / (1024**3) # in GB
884
+ allocated_memory = torch.cuda.memory_allocated(i) / (1024**2) # in MB
885
+ reserved_memory = torch.cuda.memory_reserved(i) / (1024**2) # in MB
886
+
887
+ gpu_stats += (
888
+ f"GPU {i} Name: {gpu_name}\n"
889
+ f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
890
+ f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
891
+ f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
892
+ )
893
+
894
+ elif torch.backends.mps.is_available():
895
+ gpu_count = 1
896
+ gpu_stats += "MPS GPU\n"
897
+ total_memory = psutil.virtual_memory().total / (
898
+ 1024**3
899
+ ) # Total system memory (MPS doesn't have its own memory)
900
+ allocated_memory = 0
901
+ reserved_memory = 0
902
+
903
+ gpu_stats += (
904
+ f"Total system memory: {total_memory:.2f} GB\n"
905
+ f"Allocated GPU memory (MPS): {allocated_memory:.2f} MB\n"
906
+ f"Reserved GPU memory (MPS): {reserved_memory:.2f} MB\n"
907
+ )
908
+
909
+ else:
910
+ gpu_stats = "No GPU available"
911
+
912
+ return gpu_stats
913
+
914
+
915
+ def get_cpu_stats():
916
+ cpu_usage = psutil.cpu_percent(interval=1)
917
+ memory_info = psutil.virtual_memory()
918
+ memory_used = memory_info.used / (1024**2)
919
+ memory_total = memory_info.total / (1024**2)
920
+ memory_percent = memory_info.percent
921
+
922
+ pid = os.getpid()
923
+ process = psutil.Process(pid)
924
+ nice_value = process.nice()
925
+
926
+ cpu_stats = (
927
+ f"CPU Usage: {cpu_usage:.2f}%\n"
928
+ f"System Memory: {memory_used:.2f} MB used / {memory_total:.2f} MB total ({memory_percent}% used)\n"
929
+ f"Process Priority (Nice value): {nice_value}"
930
+ )
931
+
932
+ return cpu_stats
933
+
934
+
935
+ def get_combined_stats():
936
+ gpu_stats = get_gpu_stats()
937
+ cpu_stats = get_cpu_stats()
938
+ combined_stats = f"### GPU Stats\n{gpu_stats}\n\n### CPU Stats\n{cpu_stats}"
939
+ return combined_stats
940
+
941
+
942
  with gr.Blocks() as app:
943
  gr.Markdown(
944
  """
 
1027
 
1028
  ```"""
1029
  )
1030
+ ch_tokenizern = gr.Checkbox(label="create vocabulary from dataset", value=False)
1031
  bt_prepare = bt_create = gr.Button("prepare")
1032
  txt_info_prepare = gr.Text(label="info", value="")
1033
+ txt_vocab_prepare = gr.Text(label="vocab", value="")
1034
+ bt_prepare.click(
1035
+ fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
1036
+ )
1037
 
1038
  random_sample_prepare = gr.Button("random sample")
1039
 
 
1054
  with gr.Row():
1055
  ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
1056
  tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
1057
+ file_checkpoint_train = gr.Textbox(label="Pretrain Model", value="")
1058
 
1059
  with gr.Row():
1060
  exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
 
1077
  last_per_steps = gr.Number(label="Last per Steps", value=50)
1078
 
1079
  with gr.Row():
1080
+ mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
1081
  start_button = gr.Button("Start Training")
1082
  stop_button = gr.Button("Stop Training", interactive=False)
1083
 
 
1101
  file_checkpoint_train,
1102
  tokenizer_type,
1103
  tokenizer_file,
1104
+ mixed_precision,
1105
  ],
1106
  outputs=[txt_info_train, start_button, stop_button],
1107
  )
 
1147
  outputs=[txt_info_reduse],
1148
  )
1149
 
1150
+ with gr.TabItem("vocab check"):
1151
  check_button = gr.Button("check vocab")
1152
  txt_info_check = gr.Text(label="info", value="")
1153
  check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
 
1188
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1189
  cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1190
 
1191
+ with gr.TabItem("system info"):
1192
+ output_box = gr.Textbox(label="GPU and CPU Information", lines=20)
1193
+
1194
+ def update_stats():
1195
+ return get_combined_stats()
1196
+
1197
+ update_button = gr.Button("Update Stats")
1198
+ update_button.click(fn=update_stats, outputs=output_box)
1199
+
1200
+ def auto_update():
1201
+ yield gr.update(value=update_stats())
1202
+
1203
+ gr.update(fn=auto_update, inputs=[], outputs=output_box)
1204
+
1205
 
1206
  @click.command()
1207
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
src/f5_tts/train/train.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  from f5_tts.model import CFM, UNetT, DiT, Trainer
2
  from f5_tts.model.utils import get_tokenizer
3
  from f5_tts.model.dataset import load_dataset
@@ -69,7 +73,7 @@ def main():
69
  learning_rate,
70
  num_warmup_updates=num_warmup_updates,
71
  save_per_updates=save_per_updates,
72
- checkpoint_path=f"ckpts/{exp_name}",
73
  batch_size=batch_size_per_gpu,
74
  batch_size_type=batch_size_type,
75
  max_samples=max_samples,
 
1
+ # training script.
2
+
3
+ from importlib.resources import files
4
+
5
  from f5_tts.model import CFM, UNetT, DiT, Trainer
6
  from f5_tts.model.utils import get_tokenizer
7
  from f5_tts.model.dataset import load_dataset
 
73
  learning_rate,
74
  num_warmup_updates=num_warmup_updates,
75
  save_per_updates=save_per_updates,
76
+ checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts/{exp_name}")),
77
  batch_size=batch_size_per_gpu,
78
  batch_size_type=batch_size_type,
79
  max_samples=max_samples,