Spaces:
Configuration error
Configuration error
finish train dependencies
Browse files- README.md +2 -64
- src/f5_tts/model/dataset.py +9 -7
- src/f5_tts/train/README.md +68 -0
- src/f5_tts/train/datasets/prepare_csv_wavs.py +8 -6
- src/f5_tts/train/datasets/prepare_emilia.py +14 -11
- src/f5_tts/train/datasets/prepare_wenetspeech4tts.py +14 -8
- src/f5_tts/train/finetune_cli.py +6 -5
- src/f5_tts/train/finetune_gradio.py +167 -25
- src/f5_tts/train/train.py +5 -1
README.md
CHANGED
@@ -65,70 +65,6 @@ pre-commit run --all-files
|
|
65 |
|
66 |
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
67 |
|
68 |
-
## Prepare Dataset
|
69 |
-
|
70 |
-
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `f5_tts/model/dataset.py`.
|
71 |
-
|
72 |
-
```bash
|
73 |
-
# switch to the main directory
|
74 |
-
cd f5_tts
|
75 |
-
|
76 |
-
# prepare custom dataset up to your need
|
77 |
-
# download corresponding dataset first, and fill in the path in scripts
|
78 |
-
|
79 |
-
# Prepare the Emilia dataset
|
80 |
-
python scripts/prepare_emilia.py
|
81 |
-
|
82 |
-
# Prepare the Wenetspeech4TTS dataset
|
83 |
-
python scripts/prepare_wenetspeech4tts.py
|
84 |
-
|
85 |
-
# https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029
|
86 |
-
python scripts/prepare_csv_wavs.py
|
87 |
-
```
|
88 |
-
|
89 |
-
## Training & Finetuning
|
90 |
-
|
91 |
-
Once your datasets are prepared, you can start the training process.
|
92 |
-
|
93 |
-
```bash
|
94 |
-
# switch to the main directory
|
95 |
-
cd f5_tts
|
96 |
-
|
97 |
-
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
98 |
-
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
99 |
-
accelerate config
|
100 |
-
accelerate launch train.py
|
101 |
-
```
|
102 |
-
An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
|
103 |
-
|
104 |
-
Gradio UI finetuning with `f5_tts/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
105 |
-
|
106 |
-
### Wandb Logging
|
107 |
-
|
108 |
-
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
|
109 |
-
|
110 |
-
To turn on wandb logging, you can either:
|
111 |
-
|
112 |
-
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
113 |
-
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
|
114 |
-
|
115 |
-
On Mac & Linux:
|
116 |
-
|
117 |
-
```
|
118 |
-
export WANDB_API_KEY=<YOUR WANDB API KEY>
|
119 |
-
```
|
120 |
-
|
121 |
-
On Windows:
|
122 |
-
|
123 |
-
```
|
124 |
-
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
125 |
-
```
|
126 |
-
Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
|
127 |
-
|
128 |
-
```
|
129 |
-
export WANDB_MODE=offline
|
130 |
-
```
|
131 |
-
|
132 |
## Inference
|
133 |
|
134 |
```python
|
@@ -215,6 +151,8 @@ To test speech editing capabilities, use the following command.
|
|
215 |
python f5_tts/speech_edit.py
|
216 |
```
|
217 |
|
|
|
|
|
218 |
## [Evaluation](src/f5_tts/eval/README.md)
|
219 |
|
220 |
## Acknowledgements
|
|
|
65 |
|
66 |
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
## Inference
|
69 |
|
70 |
```python
|
|
|
151 |
python f5_tts/speech_edit.py
|
152 |
```
|
153 |
|
154 |
+
## [Training](src/f5_tts/train/README.md)
|
155 |
+
|
156 |
## [Evaluation](src/f5_tts/eval/README.md)
|
157 |
|
158 |
## Acknowledgements
|
src/f5_tts/model/dataset.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
import json
|
2 |
import random
|
|
|
3 |
from tqdm import tqdm
|
4 |
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
-
from torch.utils.data import Dataset, Sampler
|
8 |
import torchaudio
|
|
|
|
|
9 |
from datasets import load_from_disk
|
10 |
from datasets import Dataset as Dataset_
|
11 |
-
from torch import nn
|
12 |
|
13 |
from f5_tts.model.modules import MelSpec
|
14 |
from f5_tts.model.utils import default
|
@@ -221,16 +222,17 @@ def load_dataset(
|
|
221 |
print("Loading dataset ...")
|
222 |
|
223 |
if dataset_type == "CustomDataset":
|
|
|
224 |
if audio_type == "raw":
|
225 |
try:
|
226 |
-
train_dataset = load_from_disk(f"
|
227 |
except: # noqa: E722
|
228 |
-
train_dataset = Dataset_.from_file(f"
|
229 |
preprocessed_mel = False
|
230 |
elif audio_type == "mel":
|
231 |
-
train_dataset = Dataset_.from_file(f"
|
232 |
preprocessed_mel = True
|
233 |
-
with open(f"
|
234 |
data_dict = json.load(f)
|
235 |
durations = data_dict["duration"]
|
236 |
train_dataset = CustomDataset(
|
@@ -261,7 +263,7 @@ def load_dataset(
|
|
261 |
)
|
262 |
pre, post = dataset_name.split("_")
|
263 |
train_dataset = HFDataset(
|
264 |
-
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="
|
265 |
)
|
266 |
|
267 |
return train_dataset
|
|
|
1 |
import json
|
2 |
import random
|
3 |
+
from importlib.resources import files
|
4 |
from tqdm import tqdm
|
5 |
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
|
|
8 |
import torchaudio
|
9 |
+
from torch import nn
|
10 |
+
from torch.utils.data import Dataset, Sampler
|
11 |
from datasets import load_from_disk
|
12 |
from datasets import Dataset as Dataset_
|
|
|
13 |
|
14 |
from f5_tts.model.modules import MelSpec
|
15 |
from f5_tts.model.utils import default
|
|
|
222 |
print("Loading dataset ...")
|
223 |
|
224 |
if dataset_type == "CustomDataset":
|
225 |
+
rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
|
226 |
if audio_type == "raw":
|
227 |
try:
|
228 |
+
train_dataset = load_from_disk(f"{rel_data_path}/raw")
|
229 |
except: # noqa: E722
|
230 |
+
train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
|
231 |
preprocessed_mel = False
|
232 |
elif audio_type == "mel":
|
233 |
+
train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
|
234 |
preprocessed_mel = True
|
235 |
+
with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
|
236 |
data_dict = json.load(f)
|
237 |
durations = data_dict["duration"]
|
238 |
train_dataset = CustomDataset(
|
|
|
263 |
)
|
264 |
pre, post = dataset_name.split("_")
|
265 |
train_dataset = HFDataset(
|
266 |
+
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
|
267 |
)
|
268 |
|
269 |
return train_dataset
|
src/f5_tts/train/README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Prepare Dataset
|
3 |
+
|
4 |
+
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
|
5 |
+
|
6 |
+
### 1. Datasets used for pretrained models
|
7 |
+
Download corresponding dataset first, and fill in the path in scripts.
|
8 |
+
|
9 |
+
```bash
|
10 |
+
# Prepare the Emilia dataset
|
11 |
+
python src/f5_tts/train/datasets/prepare_emilia.py
|
12 |
+
|
13 |
+
# Prepare the Wenetspeech4TTS dataset
|
14 |
+
python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
|
15 |
+
```
|
16 |
+
|
17 |
+
### 2. Create custom dataset with metadata.csv
|
18 |
+
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
|
19 |
+
|
20 |
+
```bash
|
21 |
+
python src/f5_tts/train/datasets/prepare_csv_wavs.py
|
22 |
+
```
|
23 |
+
|
24 |
+
## Training & Finetuning
|
25 |
+
|
26 |
+
Once your datasets are prepared, you can start the training process.
|
27 |
+
|
28 |
+
### 1. Training script used for pretrained model
|
29 |
+
|
30 |
+
```bash
|
31 |
+
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
32 |
+
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
33 |
+
accelerate config
|
34 |
+
accelerate launch src/f5_tts/train/train.py
|
35 |
+
```
|
36 |
+
|
37 |
+
### 2. Finetuning practice
|
38 |
+
Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
|
39 |
+
|
40 |
+
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
41 |
+
|
42 |
+
### 3. Wandb Logging
|
43 |
+
|
44 |
+
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
45 |
+
|
46 |
+
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
|
47 |
+
|
48 |
+
To turn on wandb logging, you can either:
|
49 |
+
|
50 |
+
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
51 |
+
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
|
52 |
+
|
53 |
+
On Mac & Linux:
|
54 |
+
|
55 |
+
```
|
56 |
+
export WANDB_API_KEY=<YOUR WANDB API KEY>
|
57 |
+
```
|
58 |
+
|
59 |
+
On Windows:
|
60 |
+
|
61 |
+
```
|
62 |
+
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
63 |
+
```
|
64 |
+
Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
|
65 |
+
|
66 |
+
```
|
67 |
+
export WANDB_MODE=offline
|
68 |
+
```
|
src/f5_tts/train/datasets/prepare_csv_wavs.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
-
import sys
|
2 |
import os
|
|
|
3 |
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
-
|
|
|
7 |
import json
|
8 |
import shutil
|
9 |
-
import
|
|
|
10 |
|
11 |
-
import csv
|
12 |
import torchaudio
|
13 |
from tqdm import tqdm
|
14 |
from datasets.arrow_writer import ArrowWriter
|
@@ -17,7 +18,8 @@ from f5_tts.model.utils import (
|
|
17 |
convert_char_to_pinyin,
|
18 |
)
|
19 |
|
20 |
-
|
|
|
21 |
|
22 |
|
23 |
def is_csv_wavs_format(input_dataset_dir):
|
@@ -80,7 +82,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
|
80 |
print(f"\nSaving to {out_dir} ...")
|
81 |
|
82 |
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
83 |
-
# dataset.save_to_disk(f"
|
84 |
raw_arrow_path = out_dir / "raw.arrow"
|
85 |
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
86 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
|
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
+
import argparse
|
7 |
+
import csv
|
8 |
import json
|
9 |
import shutil
|
10 |
+
from importlib.resources import files
|
11 |
+
from pathlib import Path
|
12 |
|
|
|
13 |
import torchaudio
|
14 |
from tqdm import tqdm
|
15 |
from datasets.arrow_writer import ArrowWriter
|
|
|
18 |
convert_char_to_pinyin,
|
19 |
)
|
20 |
|
21 |
+
|
22 |
+
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
|
23 |
|
24 |
|
25 |
def is_csv_wavs_format(input_dataset_dir):
|
|
|
82 |
print(f"\nSaving to {out_dir} ...")
|
83 |
|
84 |
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
85 |
+
# dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
|
86 |
raw_arrow_path = out_dir / "raw.arrow"
|
87 |
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
88 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
src/f5_tts/train/datasets/prepare_emilia.py
CHANGED
@@ -4,15 +4,16 @@
|
|
4 |
# generate audio text map for Emilia ZH & EN
|
5 |
# evaluate for vocab size
|
6 |
|
7 |
-
import sys
|
8 |
import os
|
|
|
9 |
|
10 |
sys.path.append(os.getcwd())
|
11 |
|
12 |
-
from pathlib import Path
|
13 |
import json
|
14 |
-
from tqdm import tqdm
|
15 |
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
|
|
|
16 |
|
17 |
from datasets.arrow_writer import ArrowWriter
|
18 |
|
@@ -173,24 +174,25 @@ def main():
|
|
173 |
executor.shutdown()
|
174 |
|
175 |
# save preprocessed dataset to disk
|
176 |
-
if not os.path.exists(f"
|
177 |
-
os.makedirs(f"
|
178 |
-
print(f"\nSaving to
|
|
|
179 |
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
180 |
-
# dataset.save_to_disk(f"
|
181 |
-
with ArrowWriter(path=f"
|
182 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
183 |
writer.write(line)
|
184 |
|
185 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
186 |
-
with open(f"
|
187 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
188 |
|
189 |
# vocab map, i.e. tokenizer
|
190 |
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
191 |
# if tokenizer == "pinyin":
|
192 |
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
193 |
-
with open(f"
|
194 |
for vocab in sorted(text_vocab_set):
|
195 |
f.write(vocab + "\n")
|
196 |
|
@@ -212,7 +214,8 @@ if __name__ == "__main__":
|
|
212 |
langs = ["ZH", "EN"]
|
213 |
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
|
214 |
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
|
215 |
-
|
|
|
216 |
|
217 |
main()
|
218 |
|
|
|
4 |
# generate audio text map for Emilia ZH & EN
|
5 |
# evaluate for vocab size
|
6 |
|
|
|
7 |
import os
|
8 |
+
import sys
|
9 |
|
10 |
sys.path.append(os.getcwd())
|
11 |
|
|
|
12 |
import json
|
|
|
13 |
from concurrent.futures import ProcessPoolExecutor
|
14 |
+
from importlib.resources import files
|
15 |
+
from pathlib import Path
|
16 |
+
from tqdm import tqdm
|
17 |
|
18 |
from datasets.arrow_writer import ArrowWriter
|
19 |
|
|
|
174 |
executor.shutdown()
|
175 |
|
176 |
# save preprocessed dataset to disk
|
177 |
+
if not os.path.exists(f"{save_dir}"):
|
178 |
+
os.makedirs(f"{save_dir}")
|
179 |
+
print(f"\nSaving to {save_dir} ...")
|
180 |
+
|
181 |
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
182 |
+
# dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB")
|
183 |
+
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
184 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
185 |
writer.write(line)
|
186 |
|
187 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
188 |
+
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
189 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
190 |
|
191 |
# vocab map, i.e. tokenizer
|
192 |
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
193 |
# if tokenizer == "pinyin":
|
194 |
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
195 |
+
with open(f"{save_dir}/vocab.txt", "w") as f:
|
196 |
for vocab in sorted(text_vocab_set):
|
197 |
f.write(vocab + "\n")
|
198 |
|
|
|
214 |
langs = ["ZH", "EN"]
|
215 |
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
|
216 |
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
|
217 |
+
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
218 |
+
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
219 |
|
220 |
main()
|
221 |
|
src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
# generate audio text map for WenetSpeech4TTS
|
2 |
# evaluate for vocab size
|
3 |
|
4 |
-
import sys
|
5 |
import os
|
|
|
6 |
|
7 |
sys.path.append(os.getcwd())
|
8 |
|
9 |
import json
|
10 |
-
from tqdm import tqdm
|
11 |
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
|
12 |
|
13 |
import torchaudio
|
14 |
from datasets import Dataset
|
@@ -66,11 +67,11 @@ def main():
|
|
66 |
if not os.path.exists("data"):
|
67 |
os.makedirs("data")
|
68 |
|
69 |
-
print(f"\nSaving to
|
70 |
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
71 |
-
dataset.save_to_disk(f"
|
72 |
|
73 |
-
with open(f"
|
74 |
json.dump(
|
75 |
{"duration": duration_list}, f, ensure_ascii=False
|
76 |
) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
@@ -84,7 +85,7 @@ def main():
|
|
84 |
if tokenizer == "pinyin":
|
85 |
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
86 |
|
87 |
-
with open(f"
|
88 |
for vocab in sorted(text_vocab_set):
|
89 |
f.write(vocab + "\n")
|
90 |
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
@@ -98,13 +99,18 @@ if __name__ == "__main__":
|
|
98 |
polyphone = True
|
99 |
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
100 |
|
101 |
-
dataset_name =
|
|
|
|
|
|
|
|
|
102 |
dataset_paths = [
|
103 |
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
104 |
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
105 |
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
106 |
][-dataset_choice:]
|
107 |
-
|
|
|
108 |
|
109 |
main()
|
110 |
|
|
|
1 |
# generate audio text map for WenetSpeech4TTS
|
2 |
# evaluate for vocab size
|
3 |
|
|
|
4 |
import os
|
5 |
+
import sys
|
6 |
|
7 |
sys.path.append(os.getcwd())
|
8 |
|
9 |
import json
|
|
|
10 |
from concurrent.futures import ProcessPoolExecutor
|
11 |
+
from importlib.resources import files
|
12 |
+
from tqdm import tqdm
|
13 |
|
14 |
import torchaudio
|
15 |
from datasets import Dataset
|
|
|
67 |
if not os.path.exists("data"):
|
68 |
os.makedirs("data")
|
69 |
|
70 |
+
print(f"\nSaving to {save_dir} ...")
|
71 |
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
72 |
+
dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format
|
73 |
|
74 |
+
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
75 |
json.dump(
|
76 |
{"duration": duration_list}, f, ensure_ascii=False
|
77 |
) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
|
|
85 |
if tokenizer == "pinyin":
|
86 |
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
87 |
|
88 |
+
with open(f"{save_dir}/vocab.txt", "w") as f:
|
89 |
for vocab in sorted(text_vocab_set):
|
90 |
f.write(vocab + "\n")
|
91 |
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
|
|
99 |
polyphone = True
|
100 |
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
101 |
|
102 |
+
dataset_name = (
|
103 |
+
["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
|
104 |
+
+ "_"
|
105 |
+
+ tokenizer
|
106 |
+
)
|
107 |
dataset_paths = [
|
108 |
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
109 |
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
110 |
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
111 |
][-dataset_choice:]
|
112 |
+
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
113 |
+
print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n")
|
114 |
|
115 |
main()
|
116 |
|
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -7,6 +7,7 @@ from f5_tts.model import CFM, UNetT, DiT, Trainer
|
|
7 |
from f5_tts.model.utils import get_tokenizer
|
8 |
from f5_tts.model.dataset import load_dataset
|
9 |
|
|
|
10 |
# -------------------------- Dataset Settings --------------------------- #
|
11 |
target_sample_rate = 24000
|
12 |
n_mel_channels = 100
|
@@ -20,9 +21,9 @@ def parse_args():
|
|
20 |
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
21 |
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
22 |
|
23 |
-
# num_warmup_updates
|
24 |
|
25 |
-
# change save_per_updates , last_per_steps what you need
|
26 |
|
27 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
28 |
|
@@ -39,9 +40,9 @@ def parse_args():
|
|
39 |
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
40 |
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
41 |
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
42 |
-
parser.add_argument("--num_warmup_updates", type=int, default=
|
43 |
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
|
44 |
-
parser.add_argument("--last_per_steps", type=int, default=
|
45 |
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
46 |
parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
|
47 |
parser.add_argument(
|
@@ -126,7 +127,7 @@ def main():
|
|
126 |
max_samples=args.max_samples,
|
127 |
grad_accumulation_steps=args.grad_accumulation_steps,
|
128 |
max_grad_norm=args.max_grad_norm,
|
129 |
-
wandb_project=
|
130 |
wandb_run_name=args.exp_name,
|
131 |
wandb_resume_id=wandb_resume_id,
|
132 |
last_per_steps=args.last_per_steps,
|
|
|
7 |
from f5_tts.model.utils import get_tokenizer
|
8 |
from f5_tts.model.dataset import load_dataset
|
9 |
|
10 |
+
|
11 |
# -------------------------- Dataset Settings --------------------------- #
|
12 |
target_sample_rate = 24000
|
13 |
n_mel_channels = 100
|
|
|
21 |
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
22 |
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
23 |
|
24 |
+
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
25 |
|
26 |
+
# change save_per_updates , last_per_steps change this value what you need ,
|
27 |
|
28 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
29 |
|
|
|
40 |
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
41 |
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
42 |
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
43 |
+
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
|
44 |
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
|
45 |
+
parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
|
46 |
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
47 |
parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
|
48 |
parser.add_argument(
|
|
|
127 |
max_samples=args.max_samples,
|
128 |
grad_accumulation_steps=args.grad_accumulation_steps,
|
129 |
max_grad_norm=args.max_grad_norm,
|
130 |
+
wandb_project=args.dataset_name,
|
131 |
wandb_run_name=args.exp_name,
|
132 |
wandb_resume_id=wandb_resume_id,
|
133 |
last_per_steps=args.last_per_steps,
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -251,6 +251,7 @@ def start_training(
|
|
251 |
file_checkpoint_train="",
|
252 |
tokenizer_type="pinyin",
|
253 |
tokenizer_file="",
|
|
|
254 |
):
|
255 |
global training_process, tts_api
|
256 |
|
@@ -282,9 +283,24 @@ def start_training(
|
|
282 |
yield "start train", gr.update(interactive=False), gr.update(interactive=False)
|
283 |
|
284 |
# Command to run the training script with the specified arguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
cmd = (
|
287 |
-
f"accelerate launch finetune-cli.py --exp_name {exp_name} "
|
288 |
f"--learning_rate {learning_rate} "
|
289 |
f"--batch_size_per_gpu {batch_size_per_gpu} "
|
290 |
f"--batch_size_type {batch_size_type} "
|
@@ -305,7 +321,8 @@ def start_training(
|
|
305 |
|
306 |
if tokenizer_file != "":
|
307 |
cmd += f" --tokenizer_path {tokenizer_file}"
|
308 |
-
|
|
|
309 |
|
310 |
print(cmd)
|
311 |
|
@@ -466,7 +483,7 @@ def format_seconds_to_hms(seconds):
|
|
466 |
return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
|
467 |
|
468 |
|
469 |
-
def create_metadata(name_project, progress=gr.Progress()):
|
470 |
path_project = os.path.join(path_data, name_project)
|
471 |
path_project_wavs = os.path.join(path_project, "wavs")
|
472 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
@@ -475,7 +492,7 @@ def create_metadata(name_project, progress=gr.Progress()):
|
|
475 |
file_vocab = os.path.join(path_project, "vocab.txt")
|
476 |
|
477 |
if not os.path.isfile(file_metadata):
|
478 |
-
return "The file was not found in " + file_metadata
|
479 |
|
480 |
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
481 |
data = f.read()
|
@@ -488,6 +505,7 @@ def create_metadata(name_project, progress=gr.Progress()):
|
|
488 |
lenght = 0
|
489 |
result = []
|
490 |
error_files = []
|
|
|
491 |
for line in progress.tqdm(data.split("\n"), total=count):
|
492 |
sp_line = line.split("|")
|
493 |
if len(sp_line) != 2:
|
@@ -497,29 +515,38 @@ def create_metadata(name_project, progress=gr.Progress()):
|
|
497 |
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
|
498 |
|
499 |
if not os.path.isfile(file_audio):
|
500 |
-
error_files.append(file_audio)
|
501 |
continue
|
502 |
|
503 |
-
|
504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
continue
|
506 |
if len(text) < 4:
|
|
|
507 |
continue
|
508 |
|
509 |
text = clear_text(text)
|
510 |
text = convert_char_to_pinyin([text], polyphone=True)[0]
|
511 |
|
512 |
audio_path_list.append(file_audio)
|
513 |
-
duration_list.append(
|
514 |
text_list.append(text)
|
515 |
|
516 |
-
result.append({"audio_path": file_audio, "text": text, "duration":
|
|
|
|
|
517 |
|
518 |
-
lenght +=
|
519 |
|
520 |
if duration_list == []:
|
521 |
-
|
522 |
-
return f"Error: No audio files found in the specified path : \n{error_files_text}"
|
523 |
|
524 |
min_second = round(min(duration_list), 2)
|
525 |
max_second = round(max(duration_list), 2)
|
@@ -531,17 +558,35 @@ def create_metadata(name_project, progress=gr.Progress()):
|
|
531 |
with open(file_duration, "w") as f:
|
532 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
533 |
|
534 |
-
|
535 |
-
if not
|
536 |
-
|
537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
|
539 |
if error_files != []:
|
540 |
-
error_text = "
|
541 |
else:
|
542 |
error_text = ""
|
543 |
|
544 |
-
return
|
|
|
|
|
|
|
545 |
|
546 |
|
547 |
def check_user(value):
|
@@ -579,10 +624,21 @@ def calculate_train(
|
|
579 |
samples = len(duration_list)
|
580 |
hours = sum(duration_list) / 3600
|
581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
if torch.cuda.is_available():
|
583 |
-
|
584 |
-
total_memory =
|
|
|
|
|
|
|
|
|
585 |
elif torch.backends.mps.is_available():
|
|
|
586 |
total_memory = psutil.virtual_memory().available / (1024**3)
|
587 |
|
588 |
if batch_size_type == "frame":
|
@@ -619,7 +675,7 @@ def calculate_train(
|
|
619 |
wanted_max_updates = 1000000
|
620 |
|
621 |
# train params
|
622 |
-
gpus =
|
623 |
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
|
624 |
grad_accum = 1
|
625 |
|
@@ -816,6 +872,73 @@ def get_checkpoints_project(project_name, is_gradio=True):
|
|
816 |
return files_checkpoints, selelect_checkpoint
|
817 |
|
818 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
819 |
with gr.Blocks() as app:
|
820 |
gr.Markdown(
|
821 |
"""
|
@@ -904,10 +1027,13 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
904 |
|
905 |
```"""
|
906 |
)
|
907 |
-
|
908 |
bt_prepare = bt_create = gr.Button("prepare")
|
909 |
txt_info_prepare = gr.Text(label="info", value="")
|
910 |
-
|
|
|
|
|
|
|
911 |
|
912 |
random_sample_prepare = gr.Button("random sample")
|
913 |
|
@@ -928,7 +1054,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
928 |
with gr.Row():
|
929 |
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
|
930 |
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
|
931 |
-
file_checkpoint_train = gr.Textbox(label="
|
932 |
|
933 |
with gr.Row():
|
934 |
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
@@ -951,6 +1077,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
951 |
last_per_steps = gr.Number(label="Last per Steps", value=50)
|
952 |
|
953 |
with gr.Row():
|
|
|
954 |
start_button = gr.Button("Start Training")
|
955 |
stop_button = gr.Button("Stop Training", interactive=False)
|
956 |
|
@@ -974,6 +1101,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
974 |
file_checkpoint_train,
|
975 |
tokenizer_type,
|
976 |
tokenizer_file,
|
|
|
977 |
],
|
978 |
outputs=[txt_info_train, start_button, stop_button],
|
979 |
)
|
@@ -1019,7 +1147,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
1019 |
outputs=[txt_info_reduse],
|
1020 |
)
|
1021 |
|
1022 |
-
with gr.TabItem("vocab check
|
1023 |
check_button = gr.Button("check vocab")
|
1024 |
txt_info_check = gr.Text(label="info", value="")
|
1025 |
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
|
@@ -1060,6 +1188,20 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
|
1060 |
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1061 |
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1062 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1063 |
|
1064 |
@click.command()
|
1065 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
|
|
251 |
file_checkpoint_train="",
|
252 |
tokenizer_type="pinyin",
|
253 |
tokenizer_file="",
|
254 |
+
mixed_precision="fp16",
|
255 |
):
|
256 |
global training_process, tts_api
|
257 |
|
|
|
283 |
yield "start train", gr.update(interactive=False), gr.update(interactive=False)
|
284 |
|
285 |
# Command to run the training script with the specified arguments
|
286 |
+
|
287 |
+
if tokenizer_file == "":
|
288 |
+
if dataset_name.endswith("_pinyin"):
|
289 |
+
tokenizer_type = "pinyin"
|
290 |
+
elif dataset_name.endswith("_char"):
|
291 |
+
tokenizer_type = "char"
|
292 |
+
else:
|
293 |
+
tokenizer_file = "custom"
|
294 |
+
|
295 |
dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
|
296 |
+
|
297 |
+
if mixed_precision != "none":
|
298 |
+
fp16 = f"--mixed_precision={mixed_precision}"
|
299 |
+
else:
|
300 |
+
fp16 = ""
|
301 |
+
|
302 |
cmd = (
|
303 |
+
f"accelerate launch {fp16} finetune-cli.py --exp_name {exp_name} "
|
304 |
f"--learning_rate {learning_rate} "
|
305 |
f"--batch_size_per_gpu {batch_size_per_gpu} "
|
306 |
f"--batch_size_type {batch_size_type} "
|
|
|
321 |
|
322 |
if tokenizer_file != "":
|
323 |
cmd += f" --tokenizer_path {tokenizer_file}"
|
324 |
+
|
325 |
+
cmd += f" --tokenizer {tokenizer_type} "
|
326 |
|
327 |
print(cmd)
|
328 |
|
|
|
483 |
return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
|
484 |
|
485 |
|
486 |
+
def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
487 |
path_project = os.path.join(path_data, name_project)
|
488 |
path_project_wavs = os.path.join(path_project, "wavs")
|
489 |
file_metadata = os.path.join(path_project, "metadata.csv")
|
|
|
492 |
file_vocab = os.path.join(path_project, "vocab.txt")
|
493 |
|
494 |
if not os.path.isfile(file_metadata):
|
495 |
+
return "The file was not found in " + file_metadata, ""
|
496 |
|
497 |
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
498 |
data = f.read()
|
|
|
505 |
lenght = 0
|
506 |
result = []
|
507 |
error_files = []
|
508 |
+
text_vocab_set = set()
|
509 |
for line in progress.tqdm(data.split("\n"), total=count):
|
510 |
sp_line = line.split("|")
|
511 |
if len(sp_line) != 2:
|
|
|
515 |
file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
|
516 |
|
517 |
if not os.path.isfile(file_audio):
|
518 |
+
error_files.append([file_audio, "error path"])
|
519 |
continue
|
520 |
|
521 |
+
try:
|
522 |
+
duration = get_audio_duration(file_audio)
|
523 |
+
except Exception as e:
|
524 |
+
error_files.append([file_audio, "duration"])
|
525 |
+
print(f"Error processing {file_audio}: {e}")
|
526 |
+
continue
|
527 |
+
|
528 |
+
if duration < 1 and duration > 25:
|
529 |
+
error_files.append([file_audio, "duration < 1 and > 25 "])
|
530 |
continue
|
531 |
if len(text) < 4:
|
532 |
+
error_files.append([file_audio, "very small text len 3"])
|
533 |
continue
|
534 |
|
535 |
text = clear_text(text)
|
536 |
text = convert_char_to_pinyin([text], polyphone=True)[0]
|
537 |
|
538 |
audio_path_list.append(file_audio)
|
539 |
+
duration_list.append(duration)
|
540 |
text_list.append(text)
|
541 |
|
542 |
+
result.append({"audio_path": file_audio, "text": text, "duration": duration})
|
543 |
+
if ch_tokenizer:
|
544 |
+
text_vocab_set.update(list(text))
|
545 |
|
546 |
+
lenght += duration
|
547 |
|
548 |
if duration_list == []:
|
549 |
+
return f"Error: No audio files found in the specified path : {path_project_wavs}", ""
|
|
|
550 |
|
551 |
min_second = round(min(duration_list), 2)
|
552 |
max_second = round(max(duration_list), 2)
|
|
|
558 |
with open(file_duration, "w") as f:
|
559 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
560 |
|
561 |
+
new_vocal = ""
|
562 |
+
if not ch_tokenizer:
|
563 |
+
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
564 |
+
if not os.path.isfile(file_vocab_finetune):
|
565 |
+
return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
|
566 |
+
shutil.copy2(file_vocab_finetune, file_vocab)
|
567 |
+
|
568 |
+
with open(file_vocab, "r", encoding="utf-8-sig") as f:
|
569 |
+
vocab_char_map = {}
|
570 |
+
for i, char in enumerate(f):
|
571 |
+
vocab_char_map[char[:-1]] = i
|
572 |
+
vocab_size = len(vocab_char_map)
|
573 |
+
|
574 |
+
else:
|
575 |
+
with open(file_vocab, "w", encoding="utf-8-sig") as f:
|
576 |
+
for vocab in sorted(text_vocab_set):
|
577 |
+
f.write(vocab + "\n")
|
578 |
+
new_vocal += vocab + "\n"
|
579 |
+
vocab_size = len(text_vocab_set)
|
580 |
|
581 |
if error_files != []:
|
582 |
+
error_text = "\n".join([" = ".join(item) for item in error_files])
|
583 |
else:
|
584 |
error_text = ""
|
585 |
|
586 |
+
return (
|
587 |
+
f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\nvocab : {vocab_size}\n{error_text}",
|
588 |
+
new_vocal,
|
589 |
+
)
|
590 |
|
591 |
|
592 |
def check_user(value):
|
|
|
624 |
samples = len(duration_list)
|
625 |
hours = sum(duration_list) / 3600
|
626 |
|
627 |
+
# if torch.cuda.is_available():
|
628 |
+
# gpu_properties = torch.cuda.get_device_properties(0)
|
629 |
+
# total_memory = gpu_properties.total_memory / (1024**3)
|
630 |
+
# elif torch.backends.mps.is_available():
|
631 |
+
# total_memory = psutil.virtual_memory().available / (1024**3)
|
632 |
+
|
633 |
if torch.cuda.is_available():
|
634 |
+
gpu_count = torch.cuda.device_count()
|
635 |
+
total_memory = 0
|
636 |
+
for i in range(gpu_count):
|
637 |
+
gpu_properties = torch.cuda.get_device_properties(i)
|
638 |
+
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
639 |
+
|
640 |
elif torch.backends.mps.is_available():
|
641 |
+
gpu_count = 1
|
642 |
total_memory = psutil.virtual_memory().available / (1024**3)
|
643 |
|
644 |
if batch_size_type == "frame":
|
|
|
675 |
wanted_max_updates = 1000000
|
676 |
|
677 |
# train params
|
678 |
+
gpus = gpu_count
|
679 |
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
|
680 |
grad_accum = 1
|
681 |
|
|
|
872 |
return files_checkpoints, selelect_checkpoint
|
873 |
|
874 |
|
875 |
+
def get_gpu_stats():
|
876 |
+
gpu_stats = ""
|
877 |
+
|
878 |
+
if torch.cuda.is_available():
|
879 |
+
gpu_count = torch.cuda.device_count()
|
880 |
+
for i in range(gpu_count):
|
881 |
+
gpu_name = torch.cuda.get_device_name(i)
|
882 |
+
gpu_properties = torch.cuda.get_device_properties(i)
|
883 |
+
total_memory = gpu_properties.total_memory / (1024**3) # in GB
|
884 |
+
allocated_memory = torch.cuda.memory_allocated(i) / (1024**2) # in MB
|
885 |
+
reserved_memory = torch.cuda.memory_reserved(i) / (1024**2) # in MB
|
886 |
+
|
887 |
+
gpu_stats += (
|
888 |
+
f"GPU {i} Name: {gpu_name}\n"
|
889 |
+
f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
|
890 |
+
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
891 |
+
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
892 |
+
)
|
893 |
+
|
894 |
+
elif torch.backends.mps.is_available():
|
895 |
+
gpu_count = 1
|
896 |
+
gpu_stats += "MPS GPU\n"
|
897 |
+
total_memory = psutil.virtual_memory().total / (
|
898 |
+
1024**3
|
899 |
+
) # Total system memory (MPS doesn't have its own memory)
|
900 |
+
allocated_memory = 0
|
901 |
+
reserved_memory = 0
|
902 |
+
|
903 |
+
gpu_stats += (
|
904 |
+
f"Total system memory: {total_memory:.2f} GB\n"
|
905 |
+
f"Allocated GPU memory (MPS): {allocated_memory:.2f} MB\n"
|
906 |
+
f"Reserved GPU memory (MPS): {reserved_memory:.2f} MB\n"
|
907 |
+
)
|
908 |
+
|
909 |
+
else:
|
910 |
+
gpu_stats = "No GPU available"
|
911 |
+
|
912 |
+
return gpu_stats
|
913 |
+
|
914 |
+
|
915 |
+
def get_cpu_stats():
|
916 |
+
cpu_usage = psutil.cpu_percent(interval=1)
|
917 |
+
memory_info = psutil.virtual_memory()
|
918 |
+
memory_used = memory_info.used / (1024**2)
|
919 |
+
memory_total = memory_info.total / (1024**2)
|
920 |
+
memory_percent = memory_info.percent
|
921 |
+
|
922 |
+
pid = os.getpid()
|
923 |
+
process = psutil.Process(pid)
|
924 |
+
nice_value = process.nice()
|
925 |
+
|
926 |
+
cpu_stats = (
|
927 |
+
f"CPU Usage: {cpu_usage:.2f}%\n"
|
928 |
+
f"System Memory: {memory_used:.2f} MB used / {memory_total:.2f} MB total ({memory_percent}% used)\n"
|
929 |
+
f"Process Priority (Nice value): {nice_value}"
|
930 |
+
)
|
931 |
+
|
932 |
+
return cpu_stats
|
933 |
+
|
934 |
+
|
935 |
+
def get_combined_stats():
|
936 |
+
gpu_stats = get_gpu_stats()
|
937 |
+
cpu_stats = get_cpu_stats()
|
938 |
+
combined_stats = f"### GPU Stats\n{gpu_stats}\n\n### CPU Stats\n{cpu_stats}"
|
939 |
+
return combined_stats
|
940 |
+
|
941 |
+
|
942 |
with gr.Blocks() as app:
|
943 |
gr.Markdown(
|
944 |
"""
|
|
|
1027 |
|
1028 |
```"""
|
1029 |
)
|
1030 |
+
ch_tokenizern = gr.Checkbox(label="create vocabulary from dataset", value=False)
|
1031 |
bt_prepare = bt_create = gr.Button("prepare")
|
1032 |
txt_info_prepare = gr.Text(label="info", value="")
|
1033 |
+
txt_vocab_prepare = gr.Text(label="vocab", value="")
|
1034 |
+
bt_prepare.click(
|
1035 |
+
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
|
1036 |
+
)
|
1037 |
|
1038 |
random_sample_prepare = gr.Button("random sample")
|
1039 |
|
|
|
1054 |
with gr.Row():
|
1055 |
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
|
1056 |
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
|
1057 |
+
file_checkpoint_train = gr.Textbox(label="Pretrain Model", value="")
|
1058 |
|
1059 |
with gr.Row():
|
1060 |
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
|
|
1077 |
last_per_steps = gr.Number(label="Last per Steps", value=50)
|
1078 |
|
1079 |
with gr.Row():
|
1080 |
+
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
|
1081 |
start_button = gr.Button("Start Training")
|
1082 |
stop_button = gr.Button("Stop Training", interactive=False)
|
1083 |
|
|
|
1101 |
file_checkpoint_train,
|
1102 |
tokenizer_type,
|
1103 |
tokenizer_file,
|
1104 |
+
mixed_precision,
|
1105 |
],
|
1106 |
outputs=[txt_info_train, start_button, stop_button],
|
1107 |
)
|
|
|
1147 |
outputs=[txt_info_reduse],
|
1148 |
)
|
1149 |
|
1150 |
+
with gr.TabItem("vocab check"):
|
1151 |
check_button = gr.Button("check vocab")
|
1152 |
txt_info_check = gr.Text(label="info", value="")
|
1153 |
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
|
|
|
1188 |
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1189 |
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1190 |
|
1191 |
+
with gr.TabItem("system info"):
|
1192 |
+
output_box = gr.Textbox(label="GPU and CPU Information", lines=20)
|
1193 |
+
|
1194 |
+
def update_stats():
|
1195 |
+
return get_combined_stats()
|
1196 |
+
|
1197 |
+
update_button = gr.Button("Update Stats")
|
1198 |
+
update_button.click(fn=update_stats, outputs=output_box)
|
1199 |
+
|
1200 |
+
def auto_update():
|
1201 |
+
yield gr.update(value=update_stats())
|
1202 |
+
|
1203 |
+
gr.update(fn=auto_update, inputs=[], outputs=output_box)
|
1204 |
+
|
1205 |
|
1206 |
@click.command()
|
1207 |
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
src/f5_tts/train/train.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
2 |
from f5_tts.model.utils import get_tokenizer
|
3 |
from f5_tts.model.dataset import load_dataset
|
@@ -69,7 +73,7 @@ def main():
|
|
69 |
learning_rate,
|
70 |
num_warmup_updates=num_warmup_updates,
|
71 |
save_per_updates=save_per_updates,
|
72 |
-
checkpoint_path=f"ckpts/{exp_name}",
|
73 |
batch_size=batch_size_per_gpu,
|
74 |
batch_size_type=batch_size_type,
|
75 |
max_samples=max_samples,
|
|
|
1 |
+
# training script.
|
2 |
+
|
3 |
+
from importlib.resources import files
|
4 |
+
|
5 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
6 |
from f5_tts.model.utils import get_tokenizer
|
7 |
from f5_tts.model.dataset import load_dataset
|
|
|
73 |
learning_rate,
|
74 |
num_warmup_updates=num_warmup_updates,
|
75 |
save_per_updates=save_per_updates,
|
76 |
+
checkpoint_path=str(files("f5_tts").joinpath(f"../../ckpts/{exp_name}")),
|
77 |
batch_size=batch_size_per_gpu,
|
78 |
batch_size_type=batch_size_type,
|
79 |
max_samples=max_samples,
|