1.0.0 F5-TTS v1 base model with better training and inference performance
Browse files- .github/workflows/publish-pypi.yaml +66 -0
- README.md +3 -2
- ckpts/README.md +5 -3
- pyproject.toml +1 -2
- src/f5_tts/api.py +50 -59
- src/f5_tts/configs/{E2TTS_Base_train.yaml → E2TTS_Base.yaml} +11 -7
- src/f5_tts/configs/{E2TTS_Small_train.yaml → E2TTS_Small.yaml} +11 -7
- src/f5_tts/configs/{F5TTS_Base_train.yaml → F5TTS_Base.yaml} +11 -7
- src/f5_tts/configs/{F5TTS_Small_train.yaml → F5TTS_Small.yaml} +11 -7
- src/f5_tts/configs/F5TTS_v1_Base.yaml +53 -0
- src/f5_tts/eval/eval_infer_batch.py +22 -27
- src/f5_tts/eval/eval_infer_batch.sh +11 -6
- src/f5_tts/eval/eval_librispeech_test_clean.py +21 -27
- src/f5_tts/eval/eval_seedtts_testset.py +21 -27
- src/f5_tts/eval/eval_utmos.py +14 -16
- src/f5_tts/eval/utils_eval.py +11 -6
- src/f5_tts/infer/README.md +20 -85
- src/f5_tts/infer/SHARED.md +19 -9
- src/f5_tts/infer/infer_cli.py +26 -31
- src/f5_tts/infer/infer_gradio.py +36 -11
- src/f5_tts/infer/speech_edit.py +25 -26
- src/f5_tts/infer/utils_infer.py +6 -6
- src/f5_tts/model/backbones/README.md +2 -2
- src/f5_tts/model/backbones/dit.py +63 -8
- src/f5_tts/model/backbones/mmdit.py +52 -9
- src/f5_tts/model/backbones/unett.py +36 -5
- src/f5_tts/model/cfm.py +3 -2
- src/f5_tts/model/dataset.py +5 -2
- src/f5_tts/model/modules.py +115 -42
- src/f5_tts/model/trainer.py +29 -18
- src/f5_tts/model/utils.py +4 -3
- src/f5_tts/scripts/count_max_epoch.py +1 -1
- src/f5_tts/socket_client.py +61 -0
- src/f5_tts/socket_server.py +19 -9
- src/f5_tts/train/README.md +5 -5
- src/f5_tts/train/finetune_cli.py +47 -15
- src/f5_tts/train/finetune_gradio.py +128 -148
- src/f5_tts/train/train.py +11 -11
.github/workflows/publish-pypi.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow uses actions that are not certified by GitHub.
|
2 |
+
# They are provided by a third-party and are governed by
|
3 |
+
# separate terms of service, privacy policy, and support
|
4 |
+
# documentation.
|
5 |
+
|
6 |
+
# GitHub recommends pinning actions to a commit SHA.
|
7 |
+
# To get a newer version, you will need to update the SHA.
|
8 |
+
# You can also reference a tag or branch, but the action may change without warning.
|
9 |
+
|
10 |
+
name: Upload Python Package
|
11 |
+
|
12 |
+
on:
|
13 |
+
release:
|
14 |
+
types: [published]
|
15 |
+
|
16 |
+
permissions:
|
17 |
+
contents: read
|
18 |
+
|
19 |
+
jobs:
|
20 |
+
release-build:
|
21 |
+
runs-on: ubuntu-latest
|
22 |
+
|
23 |
+
steps:
|
24 |
+
- uses: actions/checkout@v4
|
25 |
+
|
26 |
+
- uses: actions/setup-python@v5
|
27 |
+
with:
|
28 |
+
python-version: "3.x"
|
29 |
+
|
30 |
+
- name: Build release distributions
|
31 |
+
run: |
|
32 |
+
# NOTE: put your own distribution build steps here.
|
33 |
+
python -m pip install build
|
34 |
+
python -m build
|
35 |
+
|
36 |
+
- name: Upload distributions
|
37 |
+
uses: actions/upload-artifact@v4
|
38 |
+
with:
|
39 |
+
name: release-dists
|
40 |
+
path: dist/
|
41 |
+
|
42 |
+
pypi-publish:
|
43 |
+
runs-on: ubuntu-latest
|
44 |
+
|
45 |
+
needs:
|
46 |
+
- release-build
|
47 |
+
|
48 |
+
permissions:
|
49 |
+
# IMPORTANT: this permission is mandatory for trusted publishing
|
50 |
+
id-token: write
|
51 |
+
|
52 |
+
# Dedicated environments with protections for publishing are strongly recommended.
|
53 |
+
environment:
|
54 |
+
name: pypi
|
55 |
+
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
|
56 |
+
# url: https://pypi.org/p/YOURPROJECT
|
57 |
+
|
58 |
+
steps:
|
59 |
+
- name: Retrieve release distributions
|
60 |
+
uses: actions/download-artifact@v4
|
61 |
+
with:
|
62 |
+
name: release-dists
|
63 |
+
path: dist/
|
64 |
+
|
65 |
+
- name: Publish release distributions to PyPI
|
66 |
+
uses: pypa/gh-action-pypi-publish@6f7e8d9c0b1a2c3d4e5f6a7b8c9d0e1f2a3b4c5d
|
README.md
CHANGED
@@ -18,6 +18,7 @@
|
|
18 |
### Thanks to all the contributors !
|
19 |
|
20 |
## News
|
|
|
21 |
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
|
22 |
|
23 |
## Installation
|
@@ -37,7 +38,7 @@ conda activate f5-tts
|
|
37 |
|
38 |
> ```bash
|
39 |
> # Install pytorch with your CUDA version, e.g.
|
40 |
-
> pip install torch==2.
|
41 |
> ```
|
42 |
|
43 |
</details>
|
@@ -159,7 +160,7 @@ volumes:
|
|
159 |
# Run with flags
|
160 |
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
161 |
f5-tts_infer-cli \
|
162 |
-
--model "F5-
|
163 |
--ref_audio "ref_audio.wav" \
|
164 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
165 |
--gen_text "Some text you want TTS model generate for you."
|
|
|
18 |
### Thanks to all the contributors !
|
19 |
|
20 |
## News
|
21 |
+
- **2025/03/12**: F5-TTS v1 base model with better training and inference performance.
|
22 |
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
|
23 |
|
24 |
## Installation
|
|
|
38 |
|
39 |
> ```bash
|
40 |
> # Install pytorch with your CUDA version, e.g.
|
41 |
+
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
42 |
> ```
|
43 |
|
44 |
</details>
|
|
|
160 |
# Run with flags
|
161 |
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
162 |
f5-tts_infer-cli \
|
163 |
+
--model "F5-TTS_v1" \
|
164 |
--ref_audio "ref_audio.wav" \
|
165 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
166 |
--gen_text "Some text you want TTS model generate for you."
|
ckpts/README.md
CHANGED
@@ -3,8 +3,10 @@ Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
|
|
3 |
|
4 |
```
|
5 |
ckpts/
|
6 |
-
|
7 |
-
|
8 |
F5TTS_Base/
|
9 |
-
model_1200000.
|
|
|
|
|
10 |
```
|
|
|
3 |
|
4 |
```
|
5 |
ckpts/
|
6 |
+
F5TTS_v1_Base/
|
7 |
+
model_1250000.safetensors
|
8 |
F5TTS_Base/
|
9 |
+
model_1200000.safetensors
|
10 |
+
E2TTS_Base/
|
11 |
+
model_1200000.safetensors
|
12 |
```
|
pyproject.toml
CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
-
version = "0.
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
@@ -25,7 +25,6 @@ dependencies = [
|
|
25 |
"jieba",
|
26 |
"librosa",
|
27 |
"matplotlib",
|
28 |
-
"nltk",
|
29 |
"numpy<=1.26.4",
|
30 |
"pydub",
|
31 |
"pypinyin",
|
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
+
version = "1.0.0"
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
|
|
25 |
"jieba",
|
26 |
"librosa",
|
27 |
"matplotlib",
|
|
|
28 |
"numpy<=1.26.4",
|
29 |
"pydub",
|
30 |
"pypinyin",
|
src/f5_tts/api.py
CHANGED
@@ -5,43 +5,43 @@ from importlib.resources import files
|
|
5 |
import soundfile as sf
|
6 |
import tqdm
|
7 |
from cached_path import cached_path
|
|
|
8 |
|
9 |
from f5_tts.infer.utils_infer import (
|
10 |
-
hop_length,
|
11 |
-
infer_process,
|
12 |
load_model,
|
13 |
load_vocoder,
|
|
|
14 |
preprocess_ref_audio_text,
|
|
|
15 |
remove_silence_for_generated_wav,
|
16 |
save_spectrogram,
|
17 |
-
transcribe,
|
18 |
-
target_sample_rate,
|
19 |
)
|
20 |
-
from f5_tts.model import DiT, UNetT
|
21 |
from f5_tts.model.utils import seed_everything
|
22 |
|
23 |
|
24 |
class F5TTS:
|
25 |
def __init__(
|
26 |
self,
|
27 |
-
|
28 |
ckpt_file="",
|
29 |
vocab_file="",
|
30 |
ode_method="euler",
|
31 |
use_ema=True,
|
32 |
-
|
33 |
-
local_path=None,
|
34 |
device=None,
|
35 |
hf_cache_dir=None,
|
36 |
):
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
self.
|
42 |
-
self.
|
43 |
-
|
44 |
-
|
|
|
|
|
45 |
if device is not None:
|
46 |
self.device = device
|
47 |
else:
|
@@ -58,39 +58,31 @@ class F5TTS:
|
|
58 |
)
|
59 |
|
60 |
# Load models
|
61 |
-
self.
|
62 |
-
|
63 |
-
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
|
64 |
)
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
|
79 |
-
)
|
80 |
-
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
81 |
-
model_cls = DiT
|
82 |
-
elif model_type == "E2-TTS":
|
83 |
-
if not ckpt_file:
|
84 |
-
ckpt_file = str(
|
85 |
-
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
86 |
-
)
|
87 |
-
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
88 |
-
model_cls = UNetT
|
89 |
else:
|
90 |
-
raise ValueError(f"Unknown model type: {
|
91 |
|
|
|
|
|
|
|
|
|
92 |
self.ema_model = load_model(
|
93 |
-
model_cls,
|
94 |
)
|
95 |
|
96 |
def transcribe(self, ref_audio, language=None):
|
@@ -102,8 +94,8 @@ class F5TTS:
|
|
102 |
if remove_silence:
|
103 |
remove_silence_for_generated_wav(file_wave)
|
104 |
|
105 |
-
def export_spectrogram(self,
|
106 |
-
save_spectrogram(
|
107 |
|
108 |
def infer(
|
109 |
self,
|
@@ -121,17 +113,16 @@ class F5TTS:
|
|
121 |
fix_duration=None,
|
122 |
remove_silence=False,
|
123 |
file_wave=None,
|
124 |
-
|
125 |
-
seed
|
126 |
):
|
127 |
-
if seed
|
128 |
-
seed = random.randint(0, sys.maxsize)
|
129 |
-
seed_everything(seed)
|
130 |
-
self.seed = seed
|
131 |
|
132 |
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
133 |
|
134 |
-
wav, sr,
|
135 |
ref_file,
|
136 |
ref_text,
|
137 |
gen_text,
|
@@ -153,22 +144,22 @@ class F5TTS:
|
|
153 |
if file_wave is not None:
|
154 |
self.export_wav(wav, file_wave, remove_silence)
|
155 |
|
156 |
-
if
|
157 |
-
self.export_spectrogram(
|
158 |
|
159 |
-
return wav, sr,
|
160 |
|
161 |
|
162 |
if __name__ == "__main__":
|
163 |
f5tts = F5TTS()
|
164 |
|
165 |
-
wav, sr,
|
166 |
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
167 |
ref_text="some call me nature, others call me mother nature.",
|
168 |
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
169 |
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
170 |
-
|
171 |
-
seed
|
172 |
)
|
173 |
|
174 |
print("seed :", f5tts.seed)
|
|
|
5 |
import soundfile as sf
|
6 |
import tqdm
|
7 |
from cached_path import cached_path
|
8 |
+
from omegaconf import OmegaConf
|
9 |
|
10 |
from f5_tts.infer.utils_infer import (
|
|
|
|
|
11 |
load_model,
|
12 |
load_vocoder,
|
13 |
+
transcribe,
|
14 |
preprocess_ref_audio_text,
|
15 |
+
infer_process,
|
16 |
remove_silence_for_generated_wav,
|
17 |
save_spectrogram,
|
|
|
|
|
18 |
)
|
19 |
+
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
20 |
from f5_tts.model.utils import seed_everything
|
21 |
|
22 |
|
23 |
class F5TTS:
|
24 |
def __init__(
|
25 |
self,
|
26 |
+
model="F5TTS_v1_Base",
|
27 |
ckpt_file="",
|
28 |
vocab_file="",
|
29 |
ode_method="euler",
|
30 |
use_ema=True,
|
31 |
+
vocoder_local_path=None,
|
|
|
32 |
device=None,
|
33 |
hf_cache_dir=None,
|
34 |
):
|
35 |
+
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
36 |
+
model_cls = globals()[model_cfg.model.backbone]
|
37 |
+
model_arc = model_cfg.model.arch
|
38 |
+
|
39 |
+
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
40 |
+
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
41 |
+
|
42 |
+
self.ode_method = ode_method
|
43 |
+
self.use_ema = use_ema
|
44 |
+
|
45 |
if device is not None:
|
46 |
self.device = device
|
47 |
else:
|
|
|
58 |
)
|
59 |
|
60 |
# Load models
|
61 |
+
self.vocoder = load_vocoder(
|
62 |
+
self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
|
|
|
63 |
)
|
64 |
|
65 |
+
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
66 |
+
|
67 |
+
# override for previous models
|
68 |
+
if model == "F5TTS_Base":
|
69 |
+
if self.mel_spec_type == "vocos":
|
70 |
+
ckpt_step = 1200000
|
71 |
+
elif self.mel_spec_type == "bigvgan":
|
72 |
+
model = "F5TTS_Base_bigvgan"
|
73 |
+
ckpt_type = "pt"
|
74 |
+
elif model == "E2TTS_Base":
|
75 |
+
repo_name = "E2-TTS"
|
76 |
+
ckpt_step = 1200000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
else:
|
78 |
+
raise ValueError(f"Unknown model type: {model}")
|
79 |
|
80 |
+
if not ckpt_file:
|
81 |
+
ckpt_file = str(
|
82 |
+
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
|
83 |
+
)
|
84 |
self.ema_model = load_model(
|
85 |
+
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
|
86 |
)
|
87 |
|
88 |
def transcribe(self, ref_audio, language=None):
|
|
|
94 |
if remove_silence:
|
95 |
remove_silence_for_generated_wav(file_wave)
|
96 |
|
97 |
+
def export_spectrogram(self, spec, file_spec):
|
98 |
+
save_spectrogram(spec, file_spec)
|
99 |
|
100 |
def infer(
|
101 |
self,
|
|
|
113 |
fix_duration=None,
|
114 |
remove_silence=False,
|
115 |
file_wave=None,
|
116 |
+
file_spec=None,
|
117 |
+
seed=None,
|
118 |
):
|
119 |
+
if seed is None:
|
120 |
+
self.seed = random.randint(0, sys.maxsize)
|
121 |
+
seed_everything(self.seed)
|
|
|
122 |
|
123 |
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
124 |
|
125 |
+
wav, sr, spec = infer_process(
|
126 |
ref_file,
|
127 |
ref_text,
|
128 |
gen_text,
|
|
|
144 |
if file_wave is not None:
|
145 |
self.export_wav(wav, file_wave, remove_silence)
|
146 |
|
147 |
+
if file_spec is not None:
|
148 |
+
self.export_spectrogram(spec, file_spec)
|
149 |
|
150 |
+
return wav, sr, spec
|
151 |
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
f5tts = F5TTS()
|
155 |
|
156 |
+
wav, sr, spec = f5tts.infer(
|
157 |
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
158 |
ref_text="some call me nature, others call me mother nature.",
|
159 |
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
160 |
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
161 |
+
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
162 |
+
seed=None,
|
163 |
)
|
164 |
|
165 |
print("seed :", f5tts.seed)
|
src/f5_tts/configs/{E2TTS_Base_train.yaml → E2TTS_Base.yaml}
RENAMED
@@ -1,16 +1,16 @@
|
|
1 |
hydra:
|
2 |
run:
|
3 |
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
-
|
5 |
datasets:
|
6 |
name: Emilia_ZH_EN # dataset name
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
-
batch_size_type: frame #
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
num_workers: 16
|
11 |
|
12 |
optim:
|
13 |
-
epochs:
|
14 |
learning_rate: 7.5e-5
|
15 |
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
@@ -20,25 +20,29 @@ optim:
|
|
20 |
model:
|
21 |
name: E2TTS_Base
|
22 |
tokenizer: pinyin
|
23 |
-
tokenizer_path:
|
|
|
24 |
arch:
|
25 |
dim: 1024
|
26 |
depth: 24
|
27 |
heads: 16
|
28 |
ff_mult: 4
|
|
|
|
|
29 |
mel_spec:
|
30 |
target_sample_rate: 24000
|
31 |
n_mel_channels: 100
|
32 |
hop_length: 256
|
33 |
win_length: 1024
|
34 |
n_fft: 1024
|
35 |
-
mel_spec_type: vocos #
|
36 |
vocoder:
|
37 |
is_local: False # use local offline ckpt or not
|
38 |
-
local_path:
|
39 |
|
40 |
ckpts:
|
41 |
-
logger: wandb # wandb | tensorboard |
|
|
|
42 |
save_per_updates: 50000 # save checkpoint per updates
|
43 |
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
44 |
last_per_updates: 5000 # save last checkpoint per updates
|
|
|
1 |
hydra:
|
2 |
run:
|
3 |
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
datasets:
|
6 |
name: Emilia_ZH_EN # dataset name
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
num_workers: 16
|
11 |
|
12 |
optim:
|
13 |
+
epochs: 11
|
14 |
learning_rate: 7.5e-5
|
15 |
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
|
20 |
model:
|
21 |
name: E2TTS_Base
|
22 |
tokenizer: pinyin
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: UNetT
|
25 |
arch:
|
26 |
dim: 1024
|
27 |
depth: 24
|
28 |
heads: 16
|
29 |
ff_mult: 4
|
30 |
+
text_mask_padding: False
|
31 |
+
pe_attn_head: 1
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
34 |
n_mel_channels: 100
|
35 |
hop_length: 256
|
36 |
win_length: 1024
|
37 |
n_fft: 1024
|
38 |
+
mel_spec_type: vocos # vocos | bigvgan
|
39 |
vocoder:
|
40 |
is_local: False # use local offline ckpt or not
|
41 |
+
local_path: null # local vocoder path
|
42 |
|
43 |
ckpts:
|
44 |
+
logger: wandb # wandb | tensorboard | null
|
45 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
46 |
save_per_updates: 50000 # save checkpoint per updates
|
47 |
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
48 |
last_per_updates: 5000 # save last checkpoint per updates
|
src/f5_tts/configs/{E2TTS_Small_train.yaml → E2TTS_Small.yaml}
RENAMED
@@ -1,16 +1,16 @@
|
|
1 |
hydra:
|
2 |
run:
|
3 |
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
-
|
5 |
datasets:
|
6 |
name: Emilia_ZH_EN
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
-
batch_size_type: frame #
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
num_workers: 16
|
11 |
|
12 |
optim:
|
13 |
-
epochs:
|
14 |
learning_rate: 7.5e-5
|
15 |
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
@@ -20,25 +20,29 @@ optim:
|
|
20 |
model:
|
21 |
name: E2TTS_Small
|
22 |
tokenizer: pinyin
|
23 |
-
tokenizer_path:
|
|
|
24 |
arch:
|
25 |
dim: 768
|
26 |
depth: 20
|
27 |
heads: 12
|
28 |
ff_mult: 4
|
|
|
|
|
29 |
mel_spec:
|
30 |
target_sample_rate: 24000
|
31 |
n_mel_channels: 100
|
32 |
hop_length: 256
|
33 |
win_length: 1024
|
34 |
n_fft: 1024
|
35 |
-
mel_spec_type: vocos #
|
36 |
vocoder:
|
37 |
is_local: False # use local offline ckpt or not
|
38 |
-
local_path:
|
39 |
|
40 |
ckpts:
|
41 |
-
logger: wandb # wandb | tensorboard |
|
|
|
42 |
save_per_updates: 50000 # save checkpoint per updates
|
43 |
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
44 |
last_per_updates: 5000 # save last checkpoint per updates
|
|
|
1 |
hydra:
|
2 |
run:
|
3 |
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
datasets:
|
6 |
name: Emilia_ZH_EN
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
num_workers: 16
|
11 |
|
12 |
optim:
|
13 |
+
epochs: 11
|
14 |
learning_rate: 7.5e-5
|
15 |
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
|
20 |
model:
|
21 |
name: E2TTS_Small
|
22 |
tokenizer: pinyin
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: UNetT
|
25 |
arch:
|
26 |
dim: 768
|
27 |
depth: 20
|
28 |
heads: 12
|
29 |
ff_mult: 4
|
30 |
+
text_mask_padding: False
|
31 |
+
pe_attn_head: 1
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
34 |
n_mel_channels: 100
|
35 |
hop_length: 256
|
36 |
win_length: 1024
|
37 |
n_fft: 1024
|
38 |
+
mel_spec_type: vocos # vocos | bigvgan
|
39 |
vocoder:
|
40 |
is_local: False # use local offline ckpt or not
|
41 |
+
local_path: null # local vocoder path
|
42 |
|
43 |
ckpts:
|
44 |
+
logger: wandb # wandb | tensorboard | null
|
45 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
46 |
save_per_updates: 50000 # save checkpoint per updates
|
47 |
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
48 |
last_per_updates: 5000 # save last checkpoint per updates
|
src/f5_tts/configs/{F5TTS_Base_train.yaml → F5TTS_Base.yaml}
RENAMED
@@ -1,16 +1,16 @@
|
|
1 |
hydra:
|
2 |
run:
|
3 |
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
-
|
5 |
datasets:
|
6 |
name: Emilia_ZH_EN # dataset name
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
-
batch_size_type: frame #
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
num_workers: 16
|
11 |
|
12 |
optim:
|
13 |
-
epochs:
|
14 |
learning_rate: 7.5e-5
|
15 |
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
@@ -20,14 +20,17 @@ optim:
|
|
20 |
model:
|
21 |
name: F5TTS_Base # model name
|
22 |
tokenizer: pinyin # tokenizer type
|
23 |
-
tokenizer_path:
|
|
|
24 |
arch:
|
25 |
dim: 1024
|
26 |
depth: 22
|
27 |
heads: 16
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
|
|
30 |
conv_layers: 4
|
|
|
31 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
@@ -35,13 +38,14 @@ model:
|
|
35 |
hop_length: 256
|
36 |
win_length: 1024
|
37 |
n_fft: 1024
|
38 |
-
mel_spec_type: vocos #
|
39 |
vocoder:
|
40 |
is_local: False # use local offline ckpt or not
|
41 |
-
local_path:
|
42 |
|
43 |
ckpts:
|
44 |
-
logger: wandb # wandb | tensorboard |
|
|
|
45 |
save_per_updates: 50000 # save checkpoint per updates
|
46 |
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
47 |
last_per_updates: 5000 # save last checkpoint per updates
|
|
|
1 |
hydra:
|
2 |
run:
|
3 |
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
datasets:
|
6 |
name: Emilia_ZH_EN # dataset name
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
num_workers: 16
|
11 |
|
12 |
optim:
|
13 |
+
epochs: 11
|
14 |
learning_rate: 7.5e-5
|
15 |
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
|
20 |
model:
|
21 |
name: F5TTS_Base # model name
|
22 |
tokenizer: pinyin # tokenizer type
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: DiT
|
25 |
arch:
|
26 |
dim: 1024
|
27 |
depth: 22
|
28 |
heads: 16
|
29 |
ff_mult: 2
|
30 |
text_dim: 512
|
31 |
+
text_mask_padding: False
|
32 |
conv_layers: 4
|
33 |
+
pe_attn_head: 1
|
34 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
35 |
mel_spec:
|
36 |
target_sample_rate: 24000
|
|
|
38 |
hop_length: 256
|
39 |
win_length: 1024
|
40 |
n_fft: 1024
|
41 |
+
mel_spec_type: vocos # vocos | bigvgan
|
42 |
vocoder:
|
43 |
is_local: False # use local offline ckpt or not
|
44 |
+
local_path: null # local vocoder path
|
45 |
|
46 |
ckpts:
|
47 |
+
logger: wandb # wandb | tensorboard | null
|
48 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
49 |
save_per_updates: 50000 # save checkpoint per updates
|
50 |
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
51 |
last_per_updates: 5000 # save last checkpoint per updates
|
src/f5_tts/configs/{F5TTS_Small_train.yaml → F5TTS_Small.yaml}
RENAMED
@@ -1,16 +1,16 @@
|
|
1 |
hydra:
|
2 |
run:
|
3 |
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
-
|
5 |
datasets:
|
6 |
name: Emilia_ZH_EN
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
-
batch_size_type: frame #
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
num_workers: 16
|
11 |
|
12 |
optim:
|
13 |
-
epochs:
|
14 |
learning_rate: 7.5e-5
|
15 |
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
@@ -20,14 +20,17 @@ optim:
|
|
20 |
model:
|
21 |
name: F5TTS_Small
|
22 |
tokenizer: pinyin
|
23 |
-
tokenizer_path:
|
|
|
24 |
arch:
|
25 |
dim: 768
|
26 |
depth: 18
|
27 |
heads: 12
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
|
|
30 |
conv_layers: 4
|
|
|
31 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
@@ -35,13 +38,14 @@ model:
|
|
35 |
hop_length: 256
|
36 |
win_length: 1024
|
37 |
n_fft: 1024
|
38 |
-
mel_spec_type: vocos #
|
39 |
vocoder:
|
40 |
is_local: False # use local offline ckpt or not
|
41 |
-
local_path:
|
42 |
|
43 |
ckpts:
|
44 |
-
logger: wandb # wandb | tensorboard |
|
|
|
45 |
save_per_updates: 50000 # save checkpoint per updates
|
46 |
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
47 |
last_per_updates: 5000 # save last checkpoint per updates
|
|
|
1 |
hydra:
|
2 |
run:
|
3 |
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
datasets:
|
6 |
name: Emilia_ZH_EN
|
7 |
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
num_workers: 16
|
11 |
|
12 |
optim:
|
13 |
+
epochs: 11
|
14 |
learning_rate: 7.5e-5
|
15 |
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
|
20 |
model:
|
21 |
name: F5TTS_Small
|
22 |
tokenizer: pinyin
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: DiT
|
25 |
arch:
|
26 |
dim: 768
|
27 |
depth: 18
|
28 |
heads: 12
|
29 |
ff_mult: 2
|
30 |
text_dim: 512
|
31 |
+
text_mask_padding: False
|
32 |
conv_layers: 4
|
33 |
+
pe_attn_head: 1
|
34 |
checkpoint_activations: False # recompute activations and save memory for extra compute
|
35 |
mel_spec:
|
36 |
target_sample_rate: 24000
|
|
|
38 |
hop_length: 256
|
39 |
win_length: 1024
|
40 |
n_fft: 1024
|
41 |
+
mel_spec_type: vocos # vocos | bigvgan
|
42 |
vocoder:
|
43 |
is_local: False # use local offline ckpt or not
|
44 |
+
local_path: null # local vocoder path
|
45 |
|
46 |
ckpts:
|
47 |
+
logger: wandb # wandb | tensorboard | null
|
48 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
49 |
save_per_updates: 50000 # save checkpoint per updates
|
50 |
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
51 |
last_per_updates: 5000 # save last checkpoint per updates
|
src/f5_tts/configs/F5TTS_v1_Base.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
4 |
+
|
5 |
+
datasets:
|
6 |
+
name: Emilia_ZH_EN # dataset name
|
7 |
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
8 |
+
batch_size_type: frame # frame | sample
|
9 |
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
10 |
+
num_workers: 16
|
11 |
+
|
12 |
+
optim:
|
13 |
+
epochs: 11
|
14 |
+
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
+
max_grad_norm: 1.0 # gradient clipping
|
18 |
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
19 |
+
|
20 |
+
model:
|
21 |
+
name: F5TTS_v1_Base # model name
|
22 |
+
tokenizer: pinyin # tokenizer type
|
23 |
+
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
24 |
+
backbone: DiT
|
25 |
+
arch:
|
26 |
+
dim: 1024
|
27 |
+
depth: 22
|
28 |
+
heads: 16
|
29 |
+
ff_mult: 2
|
30 |
+
text_dim: 512
|
31 |
+
text_mask_padding: True
|
32 |
+
qk_norm: null # null | rms_norm
|
33 |
+
conv_layers: 4
|
34 |
+
pe_attn_head: null
|
35 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
36 |
+
mel_spec:
|
37 |
+
target_sample_rate: 24000
|
38 |
+
n_mel_channels: 100
|
39 |
+
hop_length: 256
|
40 |
+
win_length: 1024
|
41 |
+
n_fft: 1024
|
42 |
+
mel_spec_type: vocos # vocos | bigvgan
|
43 |
+
vocoder:
|
44 |
+
is_local: False # use local offline ckpt or not
|
45 |
+
local_path: null # local vocoder path
|
46 |
+
|
47 |
+
ckpts:
|
48 |
+
logger: wandb # wandb | tensorboard | null
|
49 |
+
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
50 |
+
save_per_updates: 50000 # save checkpoint per updates
|
51 |
+
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
52 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
53 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -10,6 +10,7 @@ from importlib.resources import files
|
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
from accelerate import Accelerator
|
|
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
from f5_tts.eval.utils_eval import (
|
@@ -18,36 +19,26 @@ from f5_tts.eval.utils_eval import (
|
|
18 |
get_seedtts_testset_metainfo,
|
19 |
)
|
20 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
21 |
-
from f5_tts.model import CFM, DiT, UNetT
|
22 |
from f5_tts.model.utils import get_tokenizer
|
23 |
|
24 |
accelerator = Accelerator()
|
25 |
device = f"cuda:{accelerator.process_index}"
|
26 |
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
target_sample_rate = 24000
|
31 |
-
n_mel_channels = 100
|
32 |
-
hop_length = 256
|
33 |
-
win_length = 1024
|
34 |
-
n_fft = 1024
|
35 |
target_rms = 0.1
|
36 |
|
|
|
37 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
38 |
|
39 |
|
40 |
def main():
|
41 |
-
# ---------------------- infer setting ---------------------- #
|
42 |
-
|
43 |
parser = argparse.ArgumentParser(description="batch inference")
|
44 |
|
45 |
parser.add_argument("-s", "--seed", default=None, type=int)
|
46 |
-
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
47 |
parser.add_argument("-n", "--expname", required=True)
|
48 |
-
parser.add_argument("-c", "--ckptstep", default=
|
49 |
-
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
50 |
-
parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
|
51 |
|
52 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
53 |
parser.add_argument("-o", "--odemethod", default="euler")
|
@@ -58,12 +49,8 @@ def main():
|
|
58 |
args = parser.parse_args()
|
59 |
|
60 |
seed = args.seed
|
61 |
-
dataset_name = args.dataset
|
62 |
exp_name = args.expname
|
63 |
ckpt_step = args.ckptstep
|
64 |
-
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
65 |
-
mel_spec_type = args.mel_spec_type
|
66 |
-
tokenizer = args.tokenizer
|
67 |
|
68 |
nfe_step = args.nfestep
|
69 |
ode_method = args.odemethod
|
@@ -77,13 +64,19 @@ def main():
|
|
77 |
use_truth_duration = False
|
78 |
no_ref_audio = False
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
if testset == "ls_pc_test_clean":
|
89 |
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
@@ -111,8 +104,6 @@ def main():
|
|
111 |
|
112 |
# -------------------------------------------------#
|
113 |
|
114 |
-
use_ema = True
|
115 |
-
|
116 |
prompts_all = get_inference_prompt(
|
117 |
metainfo,
|
118 |
speed=speed,
|
@@ -139,7 +130,7 @@ def main():
|
|
139 |
|
140 |
# Model
|
141 |
model = CFM(
|
142 |
-
transformer=model_cls(**
|
143 |
mel_spec_kwargs=dict(
|
144 |
n_fft=n_fft,
|
145 |
hop_length=hop_length,
|
@@ -154,6 +145,10 @@ def main():
|
|
154 |
vocab_char_map=vocab_char_map,
|
155 |
).to(device)
|
156 |
|
|
|
|
|
|
|
|
|
157 |
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
158 |
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
159 |
|
|
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
from accelerate import Accelerator
|
13 |
+
from omegaconf import OmegaConf
|
14 |
from tqdm import tqdm
|
15 |
|
16 |
from f5_tts.eval.utils_eval import (
|
|
|
19 |
get_seedtts_testset_metainfo,
|
20 |
)
|
21 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
22 |
+
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
23 |
from f5_tts.model.utils import get_tokenizer
|
24 |
|
25 |
accelerator = Accelerator()
|
26 |
device = f"cuda:{accelerator.process_index}"
|
27 |
|
28 |
|
29 |
+
use_ema = True
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
target_rms = 0.1
|
31 |
|
32 |
+
|
33 |
rel_path = str(files("f5_tts").joinpath("../../"))
|
34 |
|
35 |
|
36 |
def main():
|
|
|
|
|
37 |
parser = argparse.ArgumentParser(description="batch inference")
|
38 |
|
39 |
parser.add_argument("-s", "--seed", default=None, type=int)
|
|
|
40 |
parser.add_argument("-n", "--expname", required=True)
|
41 |
+
parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
|
|
|
|
|
42 |
|
43 |
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
44 |
parser.add_argument("-o", "--odemethod", default="euler")
|
|
|
49 |
args = parser.parse_args()
|
50 |
|
51 |
seed = args.seed
|
|
|
52 |
exp_name = args.expname
|
53 |
ckpt_step = args.ckptstep
|
|
|
|
|
|
|
54 |
|
55 |
nfe_step = args.nfestep
|
56 |
ode_method = args.odemethod
|
|
|
64 |
use_truth_duration = False
|
65 |
no_ref_audio = False
|
66 |
|
67 |
+
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
68 |
+
model_cls = globals()[model_cfg.model.backbone]
|
69 |
+
model_arc = model_cfg.model.arch
|
70 |
|
71 |
+
dataset_name = model_cfg.datasets.name
|
72 |
+
tokenizer = model_cfg.model.tokenizer
|
73 |
+
|
74 |
+
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
75 |
+
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
76 |
+
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
77 |
+
hop_length = model_cfg.model.mel_spec.hop_length
|
78 |
+
win_length = model_cfg.model.mel_spec.win_length
|
79 |
+
n_fft = model_cfg.model.mel_spec.n_fft
|
80 |
|
81 |
if testset == "ls_pc_test_clean":
|
82 |
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
|
|
104 |
|
105 |
# -------------------------------------------------#
|
106 |
|
|
|
|
|
107 |
prompts_all = get_inference_prompt(
|
108 |
metainfo,
|
109 |
speed=speed,
|
|
|
130 |
|
131 |
# Model
|
132 |
model = CFM(
|
133 |
+
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
134 |
mel_spec_kwargs=dict(
|
135 |
n_fft=n_fft,
|
136 |
hop_length=hop_length,
|
|
|
145 |
vocab_char_map=vocab_char_map,
|
146 |
).to(device)
|
147 |
|
148 |
+
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
149 |
+
if not os.path.exists(ckpt_path):
|
150 |
+
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
151 |
+
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
152 |
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
153 |
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
154 |
|
src/f5_tts/eval/eval_infer_batch.sh
CHANGED
@@ -1,13 +1,18 @@
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# e.g. F5-TTS, 16 NFE
|
4 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "
|
5 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "
|
6 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "
|
7 |
|
8 |
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
-
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# etc.
|
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# e.g. F5-TTS, 16 NFE
|
4 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
5 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
6 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
|
7 |
|
8 |
# e.g. Vanilla E2 TTS, 32 NFE
|
9 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
10 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
11 |
+
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
12 |
+
|
13 |
+
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
14 |
+
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
15 |
+
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
16 |
+
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
|
17 |
|
18 |
# etc.
|
src/f5_tts/eval/eval_librispeech_test_clean.py
CHANGED
@@ -53,43 +53,37 @@ def main():
|
|
53 |
asr_ckpt_dir = "" # auto download to cache dir
|
54 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
55 |
|
56 |
-
#
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
wers = []
|
61 |
|
|
|
62 |
with mp.Pool(processes=len(gpus)) as pool:
|
63 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
64 |
results = pool.map(run_asr_wer, args)
|
65 |
for r in results:
|
66 |
-
|
67 |
-
|
68 |
-
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
69 |
-
with open(wer_result_path, "w") as f:
|
70 |
-
for line in wer_results:
|
71 |
-
wers.append(line["wer"])
|
72 |
-
json_line = json.dumps(line, ensure_ascii=False)
|
73 |
-
f.write(json_line + "\n")
|
74 |
-
|
75 |
-
wer = round(np.mean(wers) * 100, 3)
|
76 |
-
print(f"\nTotal {len(wers)} samples")
|
77 |
-
print(f"WER : {wer}%")
|
78 |
-
print(f"Results have been saved to {wer_result_path}")
|
79 |
-
|
80 |
-
# --------------------------- SIM ---------------------------
|
81 |
-
|
82 |
-
if eval_task == "sim":
|
83 |
-
sims = []
|
84 |
with mp.Pool(processes=len(gpus)) as pool:
|
85 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
86 |
results = pool.map(run_sim, args)
|
87 |
for r in results:
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
|
95 |
if __name__ == "__main__":
|
|
|
53 |
asr_ckpt_dir = "" # auto download to cache dir
|
54 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
55 |
|
56 |
+
# --------------------------------------------------------------------------
|
57 |
|
58 |
+
full_results = []
|
59 |
+
metrics = []
|
|
|
60 |
|
61 |
+
if eval_task == "wer":
|
62 |
with mp.Pool(processes=len(gpus)) as pool:
|
63 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
64 |
results = pool.map(run_asr_wer, args)
|
65 |
for r in results:
|
66 |
+
full_results.extend(r)
|
67 |
+
elif eval_task == "sim":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
with mp.Pool(processes=len(gpus)) as pool:
|
69 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
70 |
results = pool.map(run_sim, args)
|
71 |
for r in results:
|
72 |
+
full_results.extend(r)
|
73 |
+
else:
|
74 |
+
raise ValueError(f"Unknown metric type: {eval_task}")
|
75 |
+
|
76 |
+
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
|
77 |
+
with open(result_path, "w") as f:
|
78 |
+
for line in full_results:
|
79 |
+
metrics.append(line[eval_task])
|
80 |
+
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
81 |
+
metric = round(np.mean(metrics), 5)
|
82 |
+
f.write(f"\n{eval_task.upper()}: {metric}\n")
|
83 |
+
|
84 |
+
print(f"\nTotal {len(metrics)} samples")
|
85 |
+
print(f"{eval_task.upper()}: {metric}")
|
86 |
+
print(f"{eval_task.upper()} results saved to {result_path}")
|
87 |
|
88 |
|
89 |
if __name__ == "__main__":
|
src/f5_tts/eval/eval_seedtts_testset.py
CHANGED
@@ -52,43 +52,37 @@ def main():
|
|
52 |
asr_ckpt_dir = "" # auto download to cache dir
|
53 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
54 |
|
55 |
-
#
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
wers = []
|
60 |
|
|
|
61 |
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
63 |
results = pool.map(run_asr_wer, args)
|
64 |
for r in results:
|
65 |
-
|
66 |
-
|
67 |
-
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
68 |
-
with open(wer_result_path, "w") as f:
|
69 |
-
for line in wer_results:
|
70 |
-
wers.append(line["wer"])
|
71 |
-
json_line = json.dumps(line, ensure_ascii=False)
|
72 |
-
f.write(json_line + "\n")
|
73 |
-
|
74 |
-
wer = round(np.mean(wers) * 100, 3)
|
75 |
-
print(f"\nTotal {len(wers)} samples")
|
76 |
-
print(f"WER : {wer}%")
|
77 |
-
print(f"Results have been saved to {wer_result_path}")
|
78 |
-
|
79 |
-
# --------------------------- SIM ---------------------------
|
80 |
-
|
81 |
-
if eval_task == "sim":
|
82 |
-
sims = []
|
83 |
with mp.Pool(processes=len(gpus)) as pool:
|
84 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
85 |
results = pool.map(run_sim, args)
|
86 |
for r in results:
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
|
94 |
if __name__ == "__main__":
|
|
|
52 |
asr_ckpt_dir = "" # auto download to cache dir
|
53 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
54 |
|
55 |
+
# --------------------------------------------------------------------------
|
56 |
|
57 |
+
full_results = []
|
58 |
+
metrics = []
|
|
|
59 |
|
60 |
+
if eval_task == "wer":
|
61 |
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
63 |
results = pool.map(run_asr_wer, args)
|
64 |
for r in results:
|
65 |
+
full_results.extend(r)
|
66 |
+
elif eval_task == "sim":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
with mp.Pool(processes=len(gpus)) as pool:
|
68 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
69 |
results = pool.map(run_sim, args)
|
70 |
for r in results:
|
71 |
+
full_results.extend(r)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Unknown metric type: {eval_task}")
|
74 |
+
|
75 |
+
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
|
76 |
+
with open(result_path, "w") as f:
|
77 |
+
for line in full_results:
|
78 |
+
metrics.append(line[eval_task])
|
79 |
+
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
80 |
+
metric = round(np.mean(metrics), 5)
|
81 |
+
f.write(f"\n{eval_task.upper()}: {metric}\n")
|
82 |
+
|
83 |
+
print(f"\nTotal {len(metrics)} samples")
|
84 |
+
print(f"{eval_task.upper()}: {metric}")
|
85 |
+
print(f"{eval_task.upper()} results saved to {result_path}")
|
86 |
|
87 |
|
88 |
if __name__ == "__main__":
|
src/f5_tts/eval/eval_utmos.py
CHANGED
@@ -19,25 +19,23 @@ def main():
|
|
19 |
predictor = predictor.to(device)
|
20 |
|
21 |
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
22 |
-
utmos_results = {}
|
23 |
utmos_score = 0
|
24 |
|
25 |
-
|
26 |
-
wav_name = audio_path.stem
|
27 |
-
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
28 |
-
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
29 |
-
score = predictor(wav_tensor, sr)
|
30 |
-
utmos_results[str(wav_name)] = score.item()
|
31 |
-
utmos_score += score.item()
|
32 |
-
|
33 |
-
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
34 |
-
print(f"UTMOS: {avg_score}")
|
35 |
-
|
36 |
-
utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
|
37 |
with open(utmos_result_path, "w", encoding="utf-8") as f:
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
if __name__ == "__main__":
|
|
|
19 |
predictor = predictor.to(device)
|
20 |
|
21 |
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
|
|
22 |
utmos_score = 0
|
23 |
|
24 |
+
utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
with open(utmos_result_path, "w", encoding="utf-8") as f:
|
26 |
+
for audio_path in tqdm(audio_paths, desc="Processing"):
|
27 |
+
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
28 |
+
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
29 |
+
score = predictor(wav_tensor, sr)
|
30 |
+
line = {}
|
31 |
+
line["wav"], line["utmos"] = str(audio_path.stem), score.item()
|
32 |
+
utmos_score += score.item()
|
33 |
+
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
34 |
+
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
35 |
+
f.write(f"\nUTMOS: {avg_score:.4f}\n")
|
36 |
+
|
37 |
+
print(f"UTMOS: {avg_score:.4f}")
|
38 |
+
print(f"UTMOS results saved to {utmos_result_path}")
|
39 |
|
40 |
|
41 |
if __name__ == "__main__":
|
src/f5_tts/eval/utils_eval.py
CHANGED
@@ -389,10 +389,10 @@ def run_sim(args):
|
|
389 |
model = model.cuda(device)
|
390 |
model.eval()
|
391 |
|
392 |
-
|
393 |
-
for
|
394 |
-
wav1, sr1 = torchaudio.load(
|
395 |
-
wav2, sr2 = torchaudio.load(
|
396 |
|
397 |
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
398 |
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
@@ -408,6 +408,11 @@ def run_sim(args):
|
|
408 |
|
409 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
410 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
-
return
|
|
|
389 |
model = model.cuda(device)
|
390 |
model.eval()
|
391 |
|
392 |
+
sim_results = []
|
393 |
+
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
394 |
+
wav1, sr1 = torchaudio.load(gen_wav)
|
395 |
+
wav2, sr2 = torchaudio.load(prompt_wav)
|
396 |
|
397 |
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
398 |
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
|
|
408 |
|
409 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
410 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
411 |
+
sim_results.append(
|
412 |
+
{
|
413 |
+
"wav": Path(gen_wav).stem,
|
414 |
+
"sim": sim,
|
415 |
+
}
|
416 |
+
)
|
417 |
|
418 |
+
return sim_results
|
src/f5_tts/infer/README.md
CHANGED
@@ -68,14 +68,16 @@ Basically you can inference with flags:
|
|
68 |
```bash
|
69 |
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
70 |
f5-tts_infer-cli \
|
71 |
-
--model
|
72 |
--ref_audio "ref_audio.wav" \
|
73 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
74 |
--gen_text "Some text you want TTS model generate for you."
|
75 |
|
76 |
-
#
|
77 |
-
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local
|
78 |
-
|
|
|
|
|
79 |
|
80 |
# More instructions
|
81 |
f5-tts_infer-cli --help
|
@@ -90,8 +92,8 @@ f5-tts_infer-cli -c custom.toml
|
|
90 |
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
91 |
|
92 |
```toml
|
93 |
-
#
|
94 |
-
model = "
|
95 |
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
96 |
# If an empty "", transcribes the reference audio automatically.
|
97 |
ref_text = "Some call me nature, others call me mother nature."
|
@@ -105,8 +107,8 @@ output_dir = "tests"
|
|
105 |
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
106 |
|
107 |
```toml
|
108 |
-
#
|
109 |
-
model = "
|
110 |
ref_audio = "infer/examples/multi/main.flac"
|
111 |
# If an empty "", transcribes the reference audio automatically.
|
112 |
ref_text = ""
|
@@ -126,94 +128,27 @@ ref_text = ""
|
|
126 |
```
|
127 |
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
128 |
|
129 |
-
##
|
130 |
-
|
131 |
-
To test speech editing capabilities, use the following command:
|
132 |
-
|
133 |
-
```bash
|
134 |
-
python src/f5_tts/infer/speech_edit.py
|
135 |
-
```
|
136 |
|
137 |
-
|
138 |
|
139 |
-
To communicate with socket server you need to run
|
140 |
```bash
|
|
|
141 |
python src/f5_tts/socket_server.py
|
142 |
-
```
|
143 |
-
|
144 |
-
<details>
|
145 |
-
<summary>Then create client to communicate</summary>
|
146 |
|
147 |
-
```bash
|
148 |
# If PyAudio not installed
|
149 |
sudo apt-get install portaudio19-dev
|
150 |
pip install pyaudio
|
151 |
-
```
|
152 |
-
|
153 |
-
``` python
|
154 |
-
# Create the socket_client.py
|
155 |
-
import socket
|
156 |
-
import asyncio
|
157 |
-
import pyaudio
|
158 |
-
import numpy as np
|
159 |
-
import logging
|
160 |
-
import time
|
161 |
-
|
162 |
-
logging.basicConfig(level=logging.INFO)
|
163 |
-
logger = logging.getLogger(__name__)
|
164 |
-
|
165 |
-
|
166 |
-
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
167 |
-
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
168 |
-
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
async def play_audio_stream():
|
174 |
-
nonlocal first_chunk_time
|
175 |
-
p = pyaudio.PyAudio()
|
176 |
-
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
177 |
-
|
178 |
-
try:
|
179 |
-
while True:
|
180 |
-
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
|
181 |
-
if not data:
|
182 |
-
break
|
183 |
-
if data == b"END":
|
184 |
-
logger.info("End of audio received.")
|
185 |
-
break
|
186 |
-
|
187 |
-
audio_array = np.frombuffer(data, dtype=np.float32)
|
188 |
-
stream.write(audio_array.tobytes())
|
189 |
-
|
190 |
-
if first_chunk_time is None:
|
191 |
-
first_chunk_time = time.time()
|
192 |
-
|
193 |
-
finally:
|
194 |
-
stream.stop_stream()
|
195 |
-
stream.close()
|
196 |
-
p.terminate()
|
197 |
-
|
198 |
-
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
|
199 |
-
|
200 |
-
try:
|
201 |
-
data_to_send = f"{text}".encode("utf-8")
|
202 |
-
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
|
203 |
-
await play_audio_stream()
|
204 |
-
|
205 |
-
except Exception as e:
|
206 |
-
logger.error(f"Error in listen_to_F5TTS: {e}")
|
207 |
-
|
208 |
-
finally:
|
209 |
-
client_socket.close()
|
210 |
|
|
|
211 |
|
212 |
-
|
213 |
-
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
|
214 |
|
215 |
-
|
|
|
216 |
```
|
217 |
|
218 |
-
</details>
|
219 |
-
|
|
|
68 |
```bash
|
69 |
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
70 |
f5-tts_infer-cli \
|
71 |
+
--model F5TTS_v1_Base \
|
72 |
--ref_audio "ref_audio.wav" \
|
73 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
74 |
--gen_text "Some text you want TTS model generate for you."
|
75 |
|
76 |
+
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
|
77 |
+
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
|
78 |
+
|
79 |
+
# Use custom path checkpoint, e.g.
|
80 |
+
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_Base/model_1200000.safetensors
|
81 |
|
82 |
# More instructions
|
83 |
f5-tts_infer-cli --help
|
|
|
92 |
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
93 |
|
94 |
```toml
|
95 |
+
# F5TTS_v1_Base | E2TTS_Base
|
96 |
+
model = "F5TTS_v1_Base"
|
97 |
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
98 |
# If an empty "", transcribes the reference audio automatically.
|
99 |
ref_text = "Some call me nature, others call me mother nature."
|
|
|
107 |
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
108 |
|
109 |
```toml
|
110 |
+
# F5TTS_v1_Base | E2TTS_Base
|
111 |
+
model = "F5TTS_v1_Base"
|
112 |
ref_audio = "infer/examples/multi/main.flac"
|
113 |
# If an empty "", transcribes the reference audio automatically.
|
114 |
ref_text = ""
|
|
|
128 |
```
|
129 |
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
130 |
|
131 |
+
## Socket Real-time Service
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
+
Real-time voice output with chunk stream:
|
134 |
|
|
|
135 |
```bash
|
136 |
+
# Start socket server
|
137 |
python src/f5_tts/socket_server.py
|
|
|
|
|
|
|
|
|
138 |
|
|
|
139 |
# If PyAudio not installed
|
140 |
sudo apt-get install portaudio19-dev
|
141 |
pip install pyaudio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
+
# Communicate with socket client
|
144 |
+
python src/f5_tts/socket_client.py
|
145 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
+
## Speech Editing
|
148 |
|
149 |
+
To test speech editing capabilities, use the following command:
|
|
|
150 |
|
151 |
+
```bash
|
152 |
+
python src/f5_tts/infer/speech_edit.py
|
153 |
```
|
154 |
|
|
|
|
src/f5_tts/infer/SHARED.md
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
<!-- omit in toc -->
|
17 |
### Supported Languages
|
18 |
- [Multilingual](#multilingual)
|
19 |
-
- [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
|
20 |
- [English](#english)
|
21 |
- [Finnish](#finnish)
|
22 |
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
@@ -37,7 +37,17 @@
|
|
37 |
|
38 |
## Multilingual
|
39 |
|
40 |
-
#### F5-TTS Base @ zh & en @ F5-TTS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
42 |
|:---:|:------------:|:-----------:|:-------------:|
|
43 |
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
@@ -45,7 +55,7 @@
|
|
45 |
```bash
|
46 |
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
47 |
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
48 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
49 |
```
|
50 |
|
51 |
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
@@ -64,7 +74,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
64 |
```bash
|
65 |
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
66 |
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
67 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
68 |
```
|
69 |
|
70 |
|
@@ -78,7 +88,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
78 |
```bash
|
79 |
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
80 |
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
81 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
82 |
```
|
83 |
|
84 |
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
@@ -96,7 +106,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
96 |
```bash
|
97 |
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
98 |
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
99 |
-
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
100 |
```
|
101 |
|
102 |
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
@@ -113,7 +123,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c
|
|
113 |
```bash
|
114 |
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
115 |
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
116 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
117 |
```
|
118 |
|
119 |
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
@@ -131,7 +141,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
131 |
```bash
|
132 |
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
|
133 |
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
|
134 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
135 |
```
|
136 |
|
137 |
|
@@ -148,7 +158,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
|
148 |
```bash
|
149 |
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
|
150 |
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
|
151 |
-
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
152 |
```
|
153 |
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
|
154 |
- Any improvements are welcome
|
|
|
16 |
<!-- omit in toc -->
|
17 |
### Supported Languages
|
18 |
- [Multilingual](#multilingual)
|
19 |
+
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
|
20 |
- [English](#english)
|
21 |
- [Finnish](#finnish)
|
22 |
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
|
|
37 |
|
38 |
## Multilingual
|
39 |
|
40 |
+
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|
41 |
+
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
42 |
+
|:---:|:------------:|:-----------:|:-------------:|
|
43 |
+
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
44 |
+
|
45 |
+
```bash
|
46 |
+
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
|
47 |
+
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
|
48 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
49 |
+
```
|
50 |
+
|
51 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
52 |
|:---:|:------------:|:-----------:|:-------------:|
|
53 |
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
|
|
55 |
```bash
|
56 |
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
57 |
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
58 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
59 |
```
|
60 |
|
61 |
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
|
|
74 |
```bash
|
75 |
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
76 |
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
77 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
78 |
```
|
79 |
|
80 |
|
|
|
88 |
```bash
|
89 |
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
90 |
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
91 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
92 |
```
|
93 |
|
94 |
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
|
|
106 |
```bash
|
107 |
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
108 |
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
109 |
+
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
110 |
```
|
111 |
|
112 |
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
|
|
123 |
```bash
|
124 |
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
125 |
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
126 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
127 |
```
|
128 |
|
129 |
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
|
|
141 |
```bash
|
142 |
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
|
143 |
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
|
144 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
145 |
```
|
146 |
|
147 |
|
|
|
158 |
```bash
|
159 |
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
|
160 |
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
|
161 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
162 |
```
|
163 |
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
|
164 |
- Any improvements are welcome
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import (
|
|
27 |
preprocess_ref_audio_text,
|
28 |
remove_silence_for_generated_wav,
|
29 |
)
|
30 |
-
from f5_tts.model import DiT, UNetT
|
31 |
|
32 |
|
33 |
parser = argparse.ArgumentParser(
|
@@ -50,7 +50,7 @@ parser.add_argument(
|
|
50 |
"-m",
|
51 |
"--model",
|
52 |
type=str,
|
53 |
-
help="The model name:
|
54 |
)
|
55 |
parser.add_argument(
|
56 |
"-mc",
|
@@ -172,8 +172,7 @@ config = tomli.load(open(args.config, "rb"))
|
|
172 |
|
173 |
# command-line interface parameters
|
174 |
|
175 |
-
model = args.model or config.get("model", "
|
176 |
-
model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
|
177 |
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
178 |
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
179 |
|
@@ -245,36 +244,32 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
|
|
245 |
|
246 |
# load TTS model
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
ckpt_step = 1250000
|
262 |
-
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
263 |
-
|
264 |
-
elif model == "E2-TTS":
|
265 |
-
assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
|
266 |
-
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
|
267 |
-
model_cls = UNetT
|
268 |
-
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
269 |
-
if not ckpt_file: # path not specified, download from repo
|
270 |
-
repo_name = "E2-TTS"
|
271 |
-
exp_name = "E2TTS_Base"
|
272 |
ckpt_step = 1200000
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
print(f"Using {model}...")
|
277 |
-
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
278 |
|
279 |
|
280 |
# inference process
|
|
|
27 |
preprocess_ref_audio_text,
|
28 |
remove_silence_for_generated_wav,
|
29 |
)
|
30 |
+
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
31 |
|
32 |
|
33 |
parser = argparse.ArgumentParser(
|
|
|
50 |
"-m",
|
51 |
"--model",
|
52 |
type=str,
|
53 |
+
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
|
54 |
)
|
55 |
parser.add_argument(
|
56 |
"-mc",
|
|
|
172 |
|
173 |
# command-line interface parameters
|
174 |
|
175 |
+
model = args.model or config.get("model", "F5TTS_v1_Base")
|
|
|
176 |
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
177 |
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
178 |
|
|
|
244 |
|
245 |
# load TTS model
|
246 |
|
247 |
+
model_cfg = OmegaConf.load(
|
248 |
+
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
249 |
+
).model
|
250 |
+
model_cls = globals()[model_cfg.backbone]
|
251 |
+
|
252 |
+
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
253 |
+
|
254 |
+
if model != "F5TTS_Base":
|
255 |
+
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
|
256 |
+
|
257 |
+
# override for previous models
|
258 |
+
if model == "F5TTS_Base":
|
259 |
+
if vocoder_name == "vocos":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
ckpt_step = 1200000
|
261 |
+
elif vocoder_name == "bigvgan":
|
262 |
+
model = "F5TTS_Base_bigvgan"
|
263 |
+
ckpt_type = "pt"
|
264 |
+
elif model == "E2TTS_Base":
|
265 |
+
repo_name = "E2-TTS"
|
266 |
+
ckpt_step = 1200000
|
267 |
+
|
268 |
+
if not ckpt_file:
|
269 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
270 |
|
271 |
print(f"Using {model}...")
|
272 |
+
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
273 |
|
274 |
|
275 |
# inference process
|
src/f5_tts/infer/infer_gradio.py
CHANGED
@@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import (
|
|
41 |
)
|
42 |
|
43 |
|
44 |
-
DEFAULT_TTS_MODEL = "F5-
|
45 |
tts_model_choice = DEFAULT_TTS_MODEL
|
46 |
|
47 |
DEFAULT_TTS_MODEL_CFG = [
|
48 |
-
"hf://SWivid/F5-TTS/
|
49 |
-
"hf://SWivid/F5-TTS/
|
50 |
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
51 |
]
|
52 |
|
@@ -56,13 +56,15 @@ DEFAULT_TTS_MODEL_CFG = [
|
|
56 |
vocoder = load_vocoder()
|
57 |
|
58 |
|
59 |
-
def load_f5tts(
|
60 |
-
|
|
|
61 |
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
|
62 |
|
63 |
|
64 |
-
def load_e2tts(
|
65 |
-
|
|
|
66 |
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
|
67 |
|
68 |
|
@@ -73,7 +75,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
|
|
73 |
if vocab_path.startswith("hf://"):
|
74 |
vocab_path = str(cached_path(vocab_path))
|
75 |
if model_cfg is None:
|
76 |
-
model_cfg =
|
77 |
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
78 |
|
79 |
|
@@ -130,7 +132,7 @@ def infer(
|
|
130 |
|
131 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
132 |
|
133 |
-
if model ==
|
134 |
ema_model = F5TTS_ema_model
|
135 |
elif model == "E2-TTS":
|
136 |
global E2TTS_ema_model
|
@@ -762,7 +764,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
762 |
"""
|
763 |
)
|
764 |
|
765 |
-
last_used_custom = files("f5_tts").joinpath("infer/.cache/
|
766 |
|
767 |
def load_last_used_custom():
|
768 |
try:
|
@@ -821,7 +823,30 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
821 |
custom_model_cfg = gr.Dropdown(
|
822 |
choices=[
|
823 |
DEFAULT_TTS_MODEL_CFG[2],
|
824 |
-
json.dumps(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
825 |
],
|
826 |
value=load_last_used_custom()[2],
|
827 |
allow_custom_value=True,
|
|
|
41 |
)
|
42 |
|
43 |
|
44 |
+
DEFAULT_TTS_MODEL = "F5-TTS_v1"
|
45 |
tts_model_choice = DEFAULT_TTS_MODEL
|
46 |
|
47 |
DEFAULT_TTS_MODEL_CFG = [
|
48 |
+
"hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
|
49 |
+
"hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
|
50 |
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
51 |
]
|
52 |
|
|
|
56 |
vocoder = load_vocoder()
|
57 |
|
58 |
|
59 |
+
def load_f5tts():
|
60 |
+
ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
|
61 |
+
F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
62 |
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
|
63 |
|
64 |
|
65 |
+
def load_e2tts():
|
66 |
+
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
67 |
+
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
|
68 |
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
|
69 |
|
70 |
|
|
|
75 |
if vocab_path.startswith("hf://"):
|
76 |
vocab_path = str(cached_path(vocab_path))
|
77 |
if model_cfg is None:
|
78 |
+
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
79 |
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
80 |
|
81 |
|
|
|
132 |
|
133 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
134 |
|
135 |
+
if model == DEFAULT_TTS_MODEL:
|
136 |
ema_model = F5TTS_ema_model
|
137 |
elif model == "E2-TTS":
|
138 |
global E2TTS_ema_model
|
|
|
764 |
"""
|
765 |
)
|
766 |
|
767 |
+
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
|
768 |
|
769 |
def load_last_used_custom():
|
770 |
try:
|
|
|
823 |
custom_model_cfg = gr.Dropdown(
|
824 |
choices=[
|
825 |
DEFAULT_TTS_MODEL_CFG[2],
|
826 |
+
json.dumps(
|
827 |
+
dict(
|
828 |
+
dim=1024,
|
829 |
+
depth=22,
|
830 |
+
heads=16,
|
831 |
+
ff_mult=2,
|
832 |
+
text_dim=512,
|
833 |
+
text_mask_padding=False,
|
834 |
+
conv_layers=4,
|
835 |
+
pe_attn_head=1,
|
836 |
+
)
|
837 |
+
),
|
838 |
+
json.dumps(
|
839 |
+
dict(
|
840 |
+
dim=768,
|
841 |
+
depth=18,
|
842 |
+
heads=12,
|
843 |
+
ff_mult=2,
|
844 |
+
text_dim=512,
|
845 |
+
text_mask_padding=False,
|
846 |
+
conv_layers=4,
|
847 |
+
pe_attn_head=1,
|
848 |
+
)
|
849 |
+
),
|
850 |
],
|
851 |
value=load_last_used_custom()[2],
|
852 |
allow_custom_value=True,
|
src/f5_tts/infer/speech_edit.py
CHANGED
@@ -2,12 +2,15 @@ import os
|
|
2 |
|
3 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
4 |
|
|
|
|
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
import torchaudio
|
|
|
8 |
|
9 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
10 |
-
from f5_tts.model import CFM, DiT, UNetT
|
11 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
12 |
|
13 |
device = (
|
@@ -21,44 +24,40 @@ device = (
|
|
21 |
)
|
22 |
|
23 |
|
24 |
-
# --------------------- Dataset Settings -------------------- #
|
25 |
-
|
26 |
-
target_sample_rate = 24000
|
27 |
-
n_mel_channels = 100
|
28 |
-
hop_length = 256
|
29 |
-
win_length = 1024
|
30 |
-
n_fft = 1024
|
31 |
-
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
32 |
-
target_rms = 0.1
|
33 |
-
|
34 |
-
tokenizer = "pinyin"
|
35 |
-
dataset_name = "Emilia_ZH_EN"
|
36 |
-
|
37 |
-
|
38 |
# ---------------------- infer setting ---------------------- #
|
39 |
|
40 |
seed = None # int | None
|
41 |
|
42 |
-
exp_name = "
|
43 |
-
ckpt_step =
|
44 |
|
45 |
nfe_step = 32 # 16, 32
|
46 |
cfg_strength = 2.0
|
47 |
ode_method = "euler" # euler | midpoint
|
48 |
sway_sampling_coef = -1.0
|
49 |
speed = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
58 |
|
59 |
-
|
|
|
60 |
output_dir = "tests"
|
61 |
|
|
|
62 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
63 |
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
64 |
# [write the origin_text into a file, e.g. tests/test_edit.txt]
|
@@ -67,7 +66,7 @@ output_dir = "tests"
|
|
67 |
# [--language "zho" for Chinese, "eng" for English]
|
68 |
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
|
69 |
|
70 |
-
audio_to_edit = "
|
71 |
origin_text = "Some call me nature, others call me mother nature."
|
72 |
target_text = "Some call me optimist, others call me realist."
|
73 |
parts_to_edit = [
|
@@ -106,7 +105,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
|
106 |
|
107 |
# Model
|
108 |
model = CFM(
|
109 |
-
transformer=model_cls(**
|
110 |
mel_spec_kwargs=dict(
|
111 |
n_fft=n_fft,
|
112 |
hop_length=hop_length,
|
|
|
2 |
|
3 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
4 |
|
5 |
+
from importlib.resources import files
|
6 |
+
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
import torchaudio
|
10 |
+
from omegaconf import OmegaConf
|
11 |
|
12 |
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
13 |
+
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
14 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
15 |
|
16 |
device = (
|
|
|
24 |
)
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# ---------------------- infer setting ---------------------- #
|
28 |
|
29 |
seed = None # int | None
|
30 |
|
31 |
+
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
|
32 |
+
ckpt_step = 1250000
|
33 |
|
34 |
nfe_step = 32 # 16, 32
|
35 |
cfg_strength = 2.0
|
36 |
ode_method = "euler" # euler | midpoint
|
37 |
sway_sampling_coef = -1.0
|
38 |
speed = 1.0
|
39 |
+
target_rms = 0.1
|
40 |
+
|
41 |
+
|
42 |
+
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
43 |
+
model_cls = globals()[model_cfg.model.backbone]
|
44 |
+
model_arc = model_cfg.model.arch
|
45 |
|
46 |
+
dataset_name = model_cfg.datasets.name
|
47 |
+
tokenizer = model_cfg.model.tokenizer
|
|
|
48 |
|
49 |
+
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
50 |
+
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
51 |
+
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
52 |
+
hop_length = model_cfg.model.mel_spec.hop_length
|
53 |
+
win_length = model_cfg.model.mel_spec.win_length
|
54 |
+
n_fft = model_cfg.model.mel_spec.n_fft
|
55 |
|
56 |
+
|
57 |
+
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
58 |
output_dir = "tests"
|
59 |
|
60 |
+
|
61 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
62 |
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
63 |
# [write the origin_text into a file, e.g. tests/test_edit.txt]
|
|
|
66 |
# [--language "zho" for Chinese, "eng" for English]
|
67 |
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
|
68 |
|
69 |
+
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
|
70 |
origin_text = "Some call me nature, others call me mother nature."
|
71 |
target_text = "Some call me optimist, others call me realist."
|
72 |
parts_to_edit = [
|
|
|
105 |
|
106 |
# Model
|
107 |
model = CFM(
|
108 |
+
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
109 |
mel_spec_kwargs=dict(
|
110 |
n_fft=n_fft,
|
111 |
hop_length=hop_length,
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -301,19 +301,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
|
301 |
)
|
302 |
non_silent_wave = AudioSegment.silent(duration=0)
|
303 |
for non_silent_seg in non_silent_segs:
|
304 |
-
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) >
|
305 |
show_info("Audio is over 15s, clipping short. (1)")
|
306 |
break
|
307 |
non_silent_wave += non_silent_seg
|
308 |
|
309 |
# 2. try to find short silence for clipping if 1. failed
|
310 |
-
if len(non_silent_wave) >
|
311 |
non_silent_segs = silence.split_on_silence(
|
312 |
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
313 |
)
|
314 |
non_silent_wave = AudioSegment.silent(duration=0)
|
315 |
for non_silent_seg in non_silent_segs:
|
316 |
-
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) >
|
317 |
show_info("Audio is over 15s, clipping short. (2)")
|
318 |
break
|
319 |
non_silent_wave += non_silent_seg
|
@@ -321,8 +321,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
|
321 |
aseg = non_silent_wave
|
322 |
|
323 |
# 3. if no proper silence found for clipping
|
324 |
-
if len(aseg) >
|
325 |
-
aseg = aseg[:
|
326 |
show_info("Audio is over 15s, clipping short. (3)")
|
327 |
|
328 |
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
@@ -383,7 +383,7 @@ def infer_process(
|
|
383 |
):
|
384 |
# Split the input text into batches
|
385 |
audio, sr = torchaudio.load(ref_audio)
|
386 |
-
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (
|
387 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
388 |
for i, gen_text in enumerate(gen_text_batches):
|
389 |
print(f"gen_text {i}", gen_text)
|
|
|
301 |
)
|
302 |
non_silent_wave = AudioSegment.silent(duration=0)
|
303 |
for non_silent_seg in non_silent_segs:
|
304 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
305 |
show_info("Audio is over 15s, clipping short. (1)")
|
306 |
break
|
307 |
non_silent_wave += non_silent_seg
|
308 |
|
309 |
# 2. try to find short silence for clipping if 1. failed
|
310 |
+
if len(non_silent_wave) > 12000:
|
311 |
non_silent_segs = silence.split_on_silence(
|
312 |
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
313 |
)
|
314 |
non_silent_wave = AudioSegment.silent(duration=0)
|
315 |
for non_silent_seg in non_silent_segs:
|
316 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
317 |
show_info("Audio is over 15s, clipping short. (2)")
|
318 |
break
|
319 |
non_silent_wave += non_silent_seg
|
|
|
321 |
aseg = non_silent_wave
|
322 |
|
323 |
# 3. if no proper silence found for clipping
|
324 |
+
if len(aseg) > 12000:
|
325 |
+
aseg = aseg[:12000]
|
326 |
show_info("Audio is over 15s, clipping short. (3)")
|
327 |
|
328 |
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
|
|
383 |
):
|
384 |
# Split the input text into batches
|
385 |
audio, sr = torchaudio.load(ref_audio)
|
386 |
+
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
|
387 |
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
388 |
for i, gen_text in enumerate(gen_text_batches):
|
389 |
print(f"gen_text {i}", gen_text)
|
src/f5_tts/model/backbones/README.md
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
### unett.py
|
5 |
- flat unet transformer
|
6 |
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
7 |
-
-
|
8 |
|
9 |
### dit.py
|
10 |
- adaln-zero dit
|
@@ -14,7 +14,7 @@
|
|
14 |
- possible long skip connection (first layer to last layer)
|
15 |
|
16 |
### mmdit.py
|
17 |
-
-
|
18 |
- timestep as condition
|
19 |
- left stream: text embedded and applied a abs pos emb
|
20 |
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
|
|
4 |
### unett.py
|
5 |
- flat unet transformer
|
6 |
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
7 |
+
- possible abs pos emb & convnextv2 blocks for embedded text before concat
|
8 |
|
9 |
### dit.py
|
10 |
- adaln-zero dit
|
|
|
14 |
- possible long skip connection (first layer to last layer)
|
15 |
|
16 |
### mmdit.py
|
17 |
+
- stable diffusion 3 block structure
|
18 |
- timestep as condition
|
19 |
- left stream: text embedded and applied a abs pos emb
|
20 |
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
src/f5_tts/model/backbones/dit.py
CHANGED
@@ -20,7 +20,7 @@ from f5_tts.model.modules import (
|
|
20 |
ConvNeXtV2Block,
|
21 |
ConvPositionEmbedding,
|
22 |
DiTBlock,
|
23 |
-
|
24 |
precompute_freqs_cis,
|
25 |
get_pos_embed_indices,
|
26 |
)
|
@@ -30,10 +30,12 @@ from f5_tts.model.modules import (
|
|
30 |
|
31 |
|
32 |
class TextEmbedding(nn.Module):
|
33 |
-
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
34 |
super().__init__()
|
35 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
36 |
|
|
|
|
|
37 |
if conv_layers > 0:
|
38 |
self.extra_modeling = True
|
39 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
@@ -49,6 +51,8 @@ class TextEmbedding(nn.Module):
|
|
49 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
50 |
batch, text_len = text.shape[0], text.shape[1]
|
51 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
|
|
|
|
52 |
|
53 |
if drop_text: # cfg for text
|
54 |
text = torch.zeros_like(text)
|
@@ -64,7 +68,13 @@ class TextEmbedding(nn.Module):
|
|
64 |
text = text + text_pos_embed
|
65 |
|
66 |
# convnextv2 blocks
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
return text
|
70 |
|
@@ -103,7 +113,10 @@ class DiT(nn.Module):
|
|
103 |
mel_dim=100,
|
104 |
text_num_embeds=256,
|
105 |
text_dim=None,
|
|
|
|
|
106 |
conv_layers=0,
|
|
|
107 |
long_skip_connection=False,
|
108 |
checkpoint_activations=False,
|
109 |
):
|
@@ -112,7 +125,10 @@ class DiT(nn.Module):
|
|
112 |
self.time_embed = TimestepEmbedding(dim)
|
113 |
if text_dim is None:
|
114 |
text_dim = mel_dim
|
115 |
-
self.text_embed = TextEmbedding(
|
|
|
|
|
|
|
116 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
117 |
|
118 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -121,15 +137,40 @@ class DiT(nn.Module):
|
|
121 |
self.depth = depth
|
122 |
|
123 |
self.transformer_blocks = nn.ModuleList(
|
124 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
)
|
126 |
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
127 |
|
128 |
-
self.norm_out =
|
129 |
self.proj_out = nn.Linear(dim, mel_dim)
|
130 |
|
131 |
self.checkpoint_activations = checkpoint_activations
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def ckpt_wrapper(self, module):
|
134 |
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
135 |
def ckpt_forward(*inputs):
|
@@ -138,6 +179,9 @@ class DiT(nn.Module):
|
|
138 |
|
139 |
return ckpt_forward
|
140 |
|
|
|
|
|
|
|
141 |
def forward(
|
142 |
self,
|
143 |
x: float["b n d"], # nosied input audio # noqa: F722
|
@@ -147,14 +191,25 @@ class DiT(nn.Module):
|
|
147 |
drop_audio_cond, # cfg for cond audio
|
148 |
drop_text, # cfg for text
|
149 |
mask: bool["b n"] | None = None, # noqa: F722
|
|
|
150 |
):
|
151 |
batch, seq_len = x.shape[0], x.shape[1]
|
152 |
if time.ndim == 0:
|
153 |
time = time.repeat(batch)
|
154 |
|
155 |
-
# t: conditioning time,
|
156 |
t = self.time_embed(time)
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
159 |
|
160 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
|
|
20 |
ConvNeXtV2Block,
|
21 |
ConvPositionEmbedding,
|
22 |
DiTBlock,
|
23 |
+
AdaLayerNorm_Final,
|
24 |
precompute_freqs_cis,
|
25 |
get_pos_embed_indices,
|
26 |
)
|
|
|
30 |
|
31 |
|
32 |
class TextEmbedding(nn.Module):
|
33 |
+
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
|
34 |
super().__init__()
|
35 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
36 |
|
37 |
+
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
38 |
+
|
39 |
if conv_layers > 0:
|
40 |
self.extra_modeling = True
|
41 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
|
|
51 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
52 |
batch, text_len = text.shape[0], text.shape[1]
|
53 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
54 |
+
if self.mask_padding:
|
55 |
+
text_mask = text == 0
|
56 |
|
57 |
if drop_text: # cfg for text
|
58 |
text = torch.zeros_like(text)
|
|
|
68 |
text = text + text_pos_embed
|
69 |
|
70 |
# convnextv2 blocks
|
71 |
+
if self.mask_padding:
|
72 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
73 |
+
for block in self.text_blocks:
|
74 |
+
text = block(text)
|
75 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
76 |
+
else:
|
77 |
+
text = self.text_blocks(text)
|
78 |
|
79 |
return text
|
80 |
|
|
|
113 |
mel_dim=100,
|
114 |
text_num_embeds=256,
|
115 |
text_dim=None,
|
116 |
+
text_mask_padding=True,
|
117 |
+
qk_norm=None,
|
118 |
conv_layers=0,
|
119 |
+
pe_attn_head=None,
|
120 |
long_skip_connection=False,
|
121 |
checkpoint_activations=False,
|
122 |
):
|
|
|
125 |
self.time_embed = TimestepEmbedding(dim)
|
126 |
if text_dim is None:
|
127 |
text_dim = mel_dim
|
128 |
+
self.text_embed = TextEmbedding(
|
129 |
+
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
130 |
+
)
|
131 |
+
self.text_cond, self.text_uncond = None, None # text cache
|
132 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
133 |
|
134 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
137 |
self.depth = depth
|
138 |
|
139 |
self.transformer_blocks = nn.ModuleList(
|
140 |
+
[
|
141 |
+
DiTBlock(
|
142 |
+
dim=dim,
|
143 |
+
heads=heads,
|
144 |
+
dim_head=dim_head,
|
145 |
+
ff_mult=ff_mult,
|
146 |
+
dropout=dropout,
|
147 |
+
qk_norm=qk_norm,
|
148 |
+
pe_attn_head=pe_attn_head,
|
149 |
+
)
|
150 |
+
for _ in range(depth)
|
151 |
+
]
|
152 |
)
|
153 |
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
154 |
|
155 |
+
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
156 |
self.proj_out = nn.Linear(dim, mel_dim)
|
157 |
|
158 |
self.checkpoint_activations = checkpoint_activations
|
159 |
|
160 |
+
self.initialize_weights()
|
161 |
+
|
162 |
+
def initialize_weights(self):
|
163 |
+
# Zero-out AdaLN layers in DiT blocks:
|
164 |
+
for block in self.transformer_blocks:
|
165 |
+
nn.init.constant_(block.attn_norm.linear.weight, 0)
|
166 |
+
nn.init.constant_(block.attn_norm.linear.bias, 0)
|
167 |
+
|
168 |
+
# Zero-out output layers:
|
169 |
+
nn.init.constant_(self.norm_out.linear.weight, 0)
|
170 |
+
nn.init.constant_(self.norm_out.linear.bias, 0)
|
171 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
172 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
173 |
+
|
174 |
def ckpt_wrapper(self, module):
|
175 |
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
176 |
def ckpt_forward(*inputs):
|
|
|
179 |
|
180 |
return ckpt_forward
|
181 |
|
182 |
+
def clear_cache(self):
|
183 |
+
self.text_cond, self.text_uncond = None, None
|
184 |
+
|
185 |
def forward(
|
186 |
self,
|
187 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
191 |
drop_audio_cond, # cfg for cond audio
|
192 |
drop_text, # cfg for text
|
193 |
mask: bool["b n"] | None = None, # noqa: F722
|
194 |
+
cache=False,
|
195 |
):
|
196 |
batch, seq_len = x.shape[0], x.shape[1]
|
197 |
if time.ndim == 0:
|
198 |
time = time.repeat(batch)
|
199 |
|
200 |
+
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
201 |
t = self.time_embed(time)
|
202 |
+
if cache:
|
203 |
+
if drop_text:
|
204 |
+
if self.text_uncond is None:
|
205 |
+
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
|
206 |
+
text_embed = self.text_uncond
|
207 |
+
else:
|
208 |
+
if self.text_cond is None:
|
209 |
+
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
|
210 |
+
text_embed = self.text_cond
|
211 |
+
else:
|
212 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
213 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
214 |
|
215 |
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
src/f5_tts/model/backbones/mmdit.py
CHANGED
@@ -18,7 +18,7 @@ from f5_tts.model.modules import (
|
|
18 |
TimestepEmbedding,
|
19 |
ConvPositionEmbedding,
|
20 |
MMDiTBlock,
|
21 |
-
|
22 |
precompute_freqs_cis,
|
23 |
get_pos_embed_indices,
|
24 |
)
|
@@ -28,18 +28,24 @@ from f5_tts.model.modules import (
|
|
28 |
|
29 |
|
30 |
class TextEmbedding(nn.Module):
|
31 |
-
def __init__(self, out_dim, text_num_embeds):
|
32 |
super().__init__()
|
33 |
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
34 |
|
|
|
|
|
35 |
self.precompute_max_pos = 1024
|
36 |
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
37 |
|
38 |
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
39 |
-
text = text + 1
|
40 |
-
if
|
|
|
|
|
|
|
41 |
text = torch.zeros_like(text)
|
42 |
-
|
|
|
43 |
|
44 |
# sinus pos emb
|
45 |
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
@@ -49,6 +55,9 @@ class TextEmbedding(nn.Module):
|
|
49 |
|
50 |
text = text + text_pos_embed
|
51 |
|
|
|
|
|
|
|
52 |
return text
|
53 |
|
54 |
|
@@ -83,13 +92,16 @@ class MMDiT(nn.Module):
|
|
83 |
dim_head=64,
|
84 |
dropout=0.1,
|
85 |
ff_mult=4,
|
86 |
-
text_num_embeds=256,
|
87 |
mel_dim=100,
|
|
|
|
|
|
|
88 |
):
|
89 |
super().__init__()
|
90 |
|
91 |
self.time_embed = TimestepEmbedding(dim)
|
92 |
-
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
|
|
93 |
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
94 |
|
95 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -106,13 +118,33 @@ class MMDiT(nn.Module):
|
|
106 |
dropout=dropout,
|
107 |
ff_mult=ff_mult,
|
108 |
context_pre_only=i == depth - 1,
|
|
|
109 |
)
|
110 |
for i in range(depth)
|
111 |
]
|
112 |
)
|
113 |
-
self.norm_out =
|
114 |
self.proj_out = nn.Linear(dim, mel_dim)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def forward(
|
117 |
self,
|
118 |
x: float["b n d"], # nosied input audio # noqa: F722
|
@@ -122,6 +154,7 @@ class MMDiT(nn.Module):
|
|
122 |
drop_audio_cond, # cfg for cond audio
|
123 |
drop_text, # cfg for text
|
124 |
mask: bool["b n"] | None = None, # noqa: F722
|
|
|
125 |
):
|
126 |
batch = x.shape[0]
|
127 |
if time.ndim == 0:
|
@@ -129,7 +162,17 @@ class MMDiT(nn.Module):
|
|
129 |
|
130 |
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
131 |
t = self.time_embed(time)
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
134 |
|
135 |
seq_len = x.shape[1]
|
|
|
18 |
TimestepEmbedding,
|
19 |
ConvPositionEmbedding,
|
20 |
MMDiTBlock,
|
21 |
+
AdaLayerNorm_Final,
|
22 |
precompute_freqs_cis,
|
23 |
get_pos_embed_indices,
|
24 |
)
|
|
|
28 |
|
29 |
|
30 |
class TextEmbedding(nn.Module):
|
31 |
+
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
|
32 |
super().__init__()
|
33 |
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
34 |
|
35 |
+
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
36 |
+
|
37 |
self.precompute_max_pos = 1024
|
38 |
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
39 |
|
40 |
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
41 |
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
42 |
+
if self.mask_padding:
|
43 |
+
text_mask = text == 0
|
44 |
+
|
45 |
+
if drop_text: # cfg for text
|
46 |
text = torch.zeros_like(text)
|
47 |
+
|
48 |
+
text = self.text_embed(text) # b nt -> b nt d
|
49 |
|
50 |
# sinus pos emb
|
51 |
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
|
|
55 |
|
56 |
text = text + text_pos_embed
|
57 |
|
58 |
+
if self.mask_padding:
|
59 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
60 |
+
|
61 |
return text
|
62 |
|
63 |
|
|
|
92 |
dim_head=64,
|
93 |
dropout=0.1,
|
94 |
ff_mult=4,
|
|
|
95 |
mel_dim=100,
|
96 |
+
text_num_embeds=256,
|
97 |
+
text_mask_padding=True,
|
98 |
+
qk_norm=None,
|
99 |
):
|
100 |
super().__init__()
|
101 |
|
102 |
self.time_embed = TimestepEmbedding(dim)
|
103 |
+
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
|
104 |
+
self.text_cond, self.text_uncond = None, None # text cache
|
105 |
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
106 |
|
107 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
118 |
dropout=dropout,
|
119 |
ff_mult=ff_mult,
|
120 |
context_pre_only=i == depth - 1,
|
121 |
+
qk_norm=qk_norm,
|
122 |
)
|
123 |
for i in range(depth)
|
124 |
]
|
125 |
)
|
126 |
+
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
127 |
self.proj_out = nn.Linear(dim, mel_dim)
|
128 |
|
129 |
+
self.initialize_weights()
|
130 |
+
|
131 |
+
def initialize_weights(self):
|
132 |
+
# Zero-out AdaLN layers in MMDiT blocks:
|
133 |
+
for block in self.transformer_blocks:
|
134 |
+
nn.init.constant_(block.attn_norm_x.linear.weight, 0)
|
135 |
+
nn.init.constant_(block.attn_norm_x.linear.bias, 0)
|
136 |
+
nn.init.constant_(block.attn_norm_c.linear.weight, 0)
|
137 |
+
nn.init.constant_(block.attn_norm_c.linear.bias, 0)
|
138 |
+
|
139 |
+
# Zero-out output layers:
|
140 |
+
nn.init.constant_(self.norm_out.linear.weight, 0)
|
141 |
+
nn.init.constant_(self.norm_out.linear.bias, 0)
|
142 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
143 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
144 |
+
|
145 |
+
def clear_cache(self):
|
146 |
+
self.text_cond, self.text_uncond = None, None
|
147 |
+
|
148 |
def forward(
|
149 |
self,
|
150 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
154 |
drop_audio_cond, # cfg for cond audio
|
155 |
drop_text, # cfg for text
|
156 |
mask: bool["b n"] | None = None, # noqa: F722
|
157 |
+
cache=False,
|
158 |
):
|
159 |
batch = x.shape[0]
|
160 |
if time.ndim == 0:
|
|
|
162 |
|
163 |
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
164 |
t = self.time_embed(time)
|
165 |
+
if cache:
|
166 |
+
if drop_text:
|
167 |
+
if self.text_uncond is None:
|
168 |
+
self.text_uncond = self.text_embed(text, drop_text=True)
|
169 |
+
c = self.text_uncond
|
170 |
+
else:
|
171 |
+
if self.text_cond is None:
|
172 |
+
self.text_cond = self.text_embed(text, drop_text=False)
|
173 |
+
c = self.text_cond
|
174 |
+
else:
|
175 |
+
c = self.text_embed(text, drop_text=drop_text)
|
176 |
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
177 |
|
178 |
seq_len = x.shape[1]
|
src/f5_tts/model/backbones/unett.py
CHANGED
@@ -33,10 +33,12 @@ from f5_tts.model.modules import (
|
|
33 |
|
34 |
|
35 |
class TextEmbedding(nn.Module):
|
36 |
-
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
37 |
super().__init__()
|
38 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
39 |
|
|
|
|
|
40 |
if conv_layers > 0:
|
41 |
self.extra_modeling = True
|
42 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
@@ -52,6 +54,8 @@ class TextEmbedding(nn.Module):
|
|
52 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
53 |
batch, text_len = text.shape[0], text.shape[1]
|
54 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
|
|
|
|
55 |
|
56 |
if drop_text: # cfg for text
|
57 |
text = torch.zeros_like(text)
|
@@ -67,7 +71,13 @@ class TextEmbedding(nn.Module):
|
|
67 |
text = text + text_pos_embed
|
68 |
|
69 |
# convnextv2 blocks
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
return text
|
73 |
|
@@ -106,7 +116,10 @@ class UNetT(nn.Module):
|
|
106 |
mel_dim=100,
|
107 |
text_num_embeds=256,
|
108 |
text_dim=None,
|
|
|
|
|
109 |
conv_layers=0,
|
|
|
110 |
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
111 |
):
|
112 |
super().__init__()
|
@@ -115,7 +128,10 @@ class UNetT(nn.Module):
|
|
115 |
self.time_embed = TimestepEmbedding(dim)
|
116 |
if text_dim is None:
|
117 |
text_dim = mel_dim
|
118 |
-
self.text_embed = TextEmbedding(
|
|
|
|
|
|
|
119 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
120 |
|
121 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
@@ -134,11 +150,12 @@ class UNetT(nn.Module):
|
|
134 |
|
135 |
attn_norm = RMSNorm(dim)
|
136 |
attn = Attention(
|
137 |
-
processor=AttnProcessor(),
|
138 |
dim=dim,
|
139 |
heads=heads,
|
140 |
dim_head=dim_head,
|
141 |
dropout=dropout,
|
|
|
142 |
)
|
143 |
|
144 |
ff_norm = RMSNorm(dim)
|
@@ -161,6 +178,9 @@ class UNetT(nn.Module):
|
|
161 |
self.norm_out = RMSNorm(dim)
|
162 |
self.proj_out = nn.Linear(dim, mel_dim)
|
163 |
|
|
|
|
|
|
|
164 |
def forward(
|
165 |
self,
|
166 |
x: float["b n d"], # nosied input audio # noqa: F722
|
@@ -170,6 +190,7 @@ class UNetT(nn.Module):
|
|
170 |
drop_audio_cond, # cfg for cond audio
|
171 |
drop_text, # cfg for text
|
172 |
mask: bool["b n"] | None = None, # noqa: F722
|
|
|
173 |
):
|
174 |
batch, seq_len = x.shape[0], x.shape[1]
|
175 |
if time.ndim == 0:
|
@@ -177,7 +198,17 @@ class UNetT(nn.Module):
|
|
177 |
|
178 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
179 |
t = self.time_embed(time)
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
182 |
|
183 |
# postfix time t to input x, [b n d] -> [b n+1 d]
|
|
|
33 |
|
34 |
|
35 |
class TextEmbedding(nn.Module):
|
36 |
+
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
|
37 |
super().__init__()
|
38 |
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
39 |
|
40 |
+
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
41 |
+
|
42 |
if conv_layers > 0:
|
43 |
self.extra_modeling = True
|
44 |
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
|
|
54 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
55 |
batch, text_len = text.shape[0], text.shape[1]
|
56 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
57 |
+
if self.mask_padding:
|
58 |
+
text_mask = text == 0
|
59 |
|
60 |
if drop_text: # cfg for text
|
61 |
text = torch.zeros_like(text)
|
|
|
71 |
text = text + text_pos_embed
|
72 |
|
73 |
# convnextv2 blocks
|
74 |
+
if self.mask_padding:
|
75 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
76 |
+
for block in self.text_blocks:
|
77 |
+
text = block(text)
|
78 |
+
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
79 |
+
else:
|
80 |
+
text = self.text_blocks(text)
|
81 |
|
82 |
return text
|
83 |
|
|
|
116 |
mel_dim=100,
|
117 |
text_num_embeds=256,
|
118 |
text_dim=None,
|
119 |
+
text_mask_padding=True,
|
120 |
+
qk_norm=None,
|
121 |
conv_layers=0,
|
122 |
+
pe_attn_head=None,
|
123 |
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
124 |
):
|
125 |
super().__init__()
|
|
|
128 |
self.time_embed = TimestepEmbedding(dim)
|
129 |
if text_dim is None:
|
130 |
text_dim = mel_dim
|
131 |
+
self.text_embed = TextEmbedding(
|
132 |
+
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
133 |
+
)
|
134 |
+
self.text_cond, self.text_uncond = None, None # text cache
|
135 |
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
136 |
|
137 |
self.rotary_embed = RotaryEmbedding(dim_head)
|
|
|
150 |
|
151 |
attn_norm = RMSNorm(dim)
|
152 |
attn = Attention(
|
153 |
+
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
154 |
dim=dim,
|
155 |
heads=heads,
|
156 |
dim_head=dim_head,
|
157 |
dropout=dropout,
|
158 |
+
qk_norm=qk_norm,
|
159 |
)
|
160 |
|
161 |
ff_norm = RMSNorm(dim)
|
|
|
178 |
self.norm_out = RMSNorm(dim)
|
179 |
self.proj_out = nn.Linear(dim, mel_dim)
|
180 |
|
181 |
+
def clear_cache(self):
|
182 |
+
self.text_cond, self.text_uncond = None, None
|
183 |
+
|
184 |
def forward(
|
185 |
self,
|
186 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
190 |
drop_audio_cond, # cfg for cond audio
|
191 |
drop_text, # cfg for text
|
192 |
mask: bool["b n"] | None = None, # noqa: F722
|
193 |
+
cache=False,
|
194 |
):
|
195 |
batch, seq_len = x.shape[0], x.shape[1]
|
196 |
if time.ndim == 0:
|
|
|
198 |
|
199 |
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
200 |
t = self.time_embed(time)
|
201 |
+
if cache:
|
202 |
+
if drop_text:
|
203 |
+
if self.text_uncond is None:
|
204 |
+
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
|
205 |
+
text_embed = self.text_uncond
|
206 |
+
else:
|
207 |
+
if self.text_cond is None:
|
208 |
+
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
|
209 |
+
text_embed = self.text_cond
|
210 |
+
else:
|
211 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
212 |
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
213 |
|
214 |
# postfix time t to input x, [b n d] -> [b n+1 d]
|
src/f5_tts/model/cfm.py
CHANGED
@@ -162,13 +162,13 @@ class CFM(nn.Module):
|
|
162 |
|
163 |
# predict flow
|
164 |
pred = self.transformer(
|
165 |
-
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
|
166 |
)
|
167 |
if cfg_strength < 1e-5:
|
168 |
return pred
|
169 |
|
170 |
null_pred = self.transformer(
|
171 |
-
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
|
172 |
)
|
173 |
return pred + (pred - null_pred) * cfg_strength
|
174 |
|
@@ -195,6 +195,7 @@ class CFM(nn.Module):
|
|
195 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
196 |
|
197 |
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
|
|
198 |
|
199 |
sampled = trajectory[-1]
|
200 |
out = sampled
|
|
|
162 |
|
163 |
# predict flow
|
164 |
pred = self.transformer(
|
165 |
+
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
|
166 |
)
|
167 |
if cfg_strength < 1e-5:
|
168 |
return pred
|
169 |
|
170 |
null_pred = self.transformer(
|
171 |
+
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
|
172 |
)
|
173 |
return pred + (pred - null_pred) * cfg_strength
|
174 |
|
|
|
195 |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
196 |
|
197 |
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
198 |
+
self.transformer.clear_cache()
|
199 |
|
200 |
sampled = trajectory[-1]
|
201 |
out = sampled
|
src/f5_tts/model/dataset.py
CHANGED
@@ -173,7 +173,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
173 |
"""
|
174 |
|
175 |
def __init__(
|
176 |
-
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None,
|
177 |
):
|
178 |
self.sampler = sampler
|
179 |
self.frames_threshold = frames_threshold
|
@@ -208,12 +208,15 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
|
208 |
batch = []
|
209 |
batch_frames = 0
|
210 |
|
211 |
-
if not
|
212 |
batches.append(batch)
|
213 |
|
214 |
del indices
|
215 |
self.batches = batches
|
216 |
|
|
|
|
|
|
|
217 |
def set_epoch(self, epoch: int) -> None:
|
218 |
"""Sets the epoch for this sampler."""
|
219 |
self.epoch = epoch
|
|
|
173 |
"""
|
174 |
|
175 |
def __init__(
|
176 |
+
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
|
177 |
):
|
178 |
self.sampler = sampler
|
179 |
self.frames_threshold = frames_threshold
|
|
|
208 |
batch = []
|
209 |
batch_frames = 0
|
210 |
|
211 |
+
if not drop_residual and len(batch) > 0:
|
212 |
batches.append(batch)
|
213 |
|
214 |
del indices
|
215 |
self.batches = batches
|
216 |
|
217 |
+
# Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
|
218 |
+
self.drop_last = True
|
219 |
+
|
220 |
def set_epoch(self, epoch: int) -> None:
|
221 |
"""Sets the epoch for this sampler."""
|
222 |
self.epoch = epoch
|
src/f5_tts/model/modules.py
CHANGED
@@ -269,11 +269,36 @@ class ConvNeXtV2Block(nn.Module):
|
|
269 |
return residual + x
|
270 |
|
271 |
|
272 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
# return with modulated x for attn input, and params for later mlp modulation
|
274 |
|
275 |
|
276 |
-
class
|
277 |
def __init__(self, dim):
|
278 |
super().__init__()
|
279 |
|
@@ -290,11 +315,11 @@ class AdaLayerNormZero(nn.Module):
|
|
290 |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
291 |
|
292 |
|
293 |
-
#
|
294 |
# return only with modulated x for attn input, cuz no more mlp modulation
|
295 |
|
296 |
|
297 |
-
class
|
298 |
def __init__(self, dim):
|
299 |
super().__init__()
|
300 |
|
@@ -341,7 +366,8 @@ class Attention(nn.Module):
|
|
341 |
dim_head: int = 64,
|
342 |
dropout: float = 0.0,
|
343 |
context_dim: Optional[int] = None, # if not None -> joint attention
|
344 |
-
context_pre_only=
|
|
|
345 |
):
|
346 |
super().__init__()
|
347 |
|
@@ -362,18 +388,32 @@ class Attention(nn.Module):
|
|
362 |
self.to_k = nn.Linear(dim, self.inner_dim)
|
363 |
self.to_v = nn.Linear(dim, self.inner_dim)
|
364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
if self.context_dim is not None:
|
|
|
366 |
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
367 |
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
368 |
-
if
|
369 |
-
self.
|
|
|
|
|
|
|
|
|
370 |
|
371 |
self.to_out = nn.ModuleList([])
|
372 |
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
373 |
self.to_out.append(nn.Dropout(dropout))
|
374 |
|
375 |
-
if self.
|
376 |
-
self.to_out_c = nn.Linear(self.inner_dim,
|
377 |
|
378 |
def forward(
|
379 |
self,
|
@@ -393,8 +433,11 @@ class Attention(nn.Module):
|
|
393 |
|
394 |
|
395 |
class AttnProcessor:
|
396 |
-
def __init__(
|
397 |
-
|
|
|
|
|
|
|
398 |
|
399 |
def __call__(
|
400 |
self,
|
@@ -405,19 +448,11 @@ class AttnProcessor:
|
|
405 |
) -> torch.FloatTensor:
|
406 |
batch_size = x.shape[0]
|
407 |
|
408 |
-
# `sample` projections
|
409 |
query = attn.to_q(x)
|
410 |
key = attn.to_k(x)
|
411 |
value = attn.to_v(x)
|
412 |
|
413 |
-
# apply rotary position embedding
|
414 |
-
if rope is not None:
|
415 |
-
freqs, xpos_scale = rope
|
416 |
-
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
417 |
-
|
418 |
-
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
419 |
-
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
420 |
-
|
421 |
# attention
|
422 |
inner_dim = key.shape[-1]
|
423 |
head_dim = inner_dim // attn.heads
|
@@ -425,6 +460,25 @@ class AttnProcessor:
|
|
425 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
426 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
429 |
if mask is not None:
|
430 |
attn_mask = mask
|
@@ -470,16 +524,36 @@ class JointAttnProcessor:
|
|
470 |
|
471 |
batch_size = c.shape[0]
|
472 |
|
473 |
-
# `sample` projections
|
474 |
query = attn.to_q(x)
|
475 |
key = attn.to_k(x)
|
476 |
value = attn.to_v(x)
|
477 |
|
478 |
-
# `context` projections
|
479 |
c_query = attn.to_q_c(c)
|
480 |
c_key = attn.to_k_c(c)
|
481 |
c_value = attn.to_v_c(c)
|
482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
# apply rope for context and noised input independently
|
484 |
if rope is not None:
|
485 |
freqs, xpos_scale = rope
|
@@ -492,16 +566,10 @@ class JointAttnProcessor:
|
|
492 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
493 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
494 |
|
495 |
-
# attention
|
496 |
-
query = torch.cat([query, c_query], dim=
|
497 |
-
key = torch.cat([key, c_key], dim=
|
498 |
-
value = torch.cat([value, c_value], dim=
|
499 |
-
|
500 |
-
inner_dim = key.shape[-1]
|
501 |
-
head_dim = inner_dim // attn.heads
|
502 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
503 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
504 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
505 |
|
506 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
507 |
if mask is not None:
|
@@ -540,16 +608,17 @@ class JointAttnProcessor:
|
|
540 |
|
541 |
|
542 |
class DiTBlock(nn.Module):
|
543 |
-
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
544 |
super().__init__()
|
545 |
|
546 |
-
self.attn_norm =
|
547 |
self.attn = Attention(
|
548 |
-
processor=AttnProcessor(),
|
549 |
dim=dim,
|
550 |
heads=heads,
|
551 |
dim_head=dim_head,
|
552 |
dropout=dropout,
|
|
|
553 |
)
|
554 |
|
555 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
@@ -585,26 +654,30 @@ class MMDiTBlock(nn.Module):
|
|
585 |
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
586 |
"""
|
587 |
|
588 |
-
def __init__(
|
|
|
|
|
589 |
super().__init__()
|
590 |
-
|
|
|
591 |
self.context_pre_only = context_pre_only
|
592 |
|
593 |
-
self.attn_norm_c =
|
594 |
-
self.attn_norm_x =
|
595 |
self.attn = Attention(
|
596 |
processor=JointAttnProcessor(),
|
597 |
dim=dim,
|
598 |
heads=heads,
|
599 |
dim_head=dim_head,
|
600 |
dropout=dropout,
|
601 |
-
context_dim=
|
602 |
context_pre_only=context_pre_only,
|
|
|
603 |
)
|
604 |
|
605 |
if not context_pre_only:
|
606 |
-
self.ff_norm_c = nn.LayerNorm(
|
607 |
-
self.ff_c = FeedForward(dim=
|
608 |
else:
|
609 |
self.ff_norm_c = None
|
610 |
self.ff_c = None
|
|
|
269 |
return residual + x
|
270 |
|
271 |
|
272 |
+
# RMSNorm
|
273 |
+
|
274 |
+
|
275 |
+
class RMSNorm(nn.Module):
|
276 |
+
def __init__(self, dim: int, eps: float):
|
277 |
+
super().__init__()
|
278 |
+
self.eps = eps
|
279 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
280 |
+
self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
if self.native_rms_norm:
|
284 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
285 |
+
x = x.to(self.weight.dtype)
|
286 |
+
x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
|
287 |
+
else:
|
288 |
+
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
289 |
+
x = x * torch.rsqrt(variance + self.eps)
|
290 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
291 |
+
x = x.to(self.weight.dtype)
|
292 |
+
x = x * self.weight
|
293 |
+
|
294 |
+
return x
|
295 |
+
|
296 |
+
|
297 |
+
# AdaLayerNorm
|
298 |
# return with modulated x for attn input, and params for later mlp modulation
|
299 |
|
300 |
|
301 |
+
class AdaLayerNorm(nn.Module):
|
302 |
def __init__(self, dim):
|
303 |
super().__init__()
|
304 |
|
|
|
315 |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
316 |
|
317 |
|
318 |
+
# AdaLayerNorm for final layer
|
319 |
# return only with modulated x for attn input, cuz no more mlp modulation
|
320 |
|
321 |
|
322 |
+
class AdaLayerNorm_Final(nn.Module):
|
323 |
def __init__(self, dim):
|
324 |
super().__init__()
|
325 |
|
|
|
366 |
dim_head: int = 64,
|
367 |
dropout: float = 0.0,
|
368 |
context_dim: Optional[int] = None, # if not None -> joint attention
|
369 |
+
context_pre_only: bool = False,
|
370 |
+
qk_norm: Optional[str] = None,
|
371 |
):
|
372 |
super().__init__()
|
373 |
|
|
|
388 |
self.to_k = nn.Linear(dim, self.inner_dim)
|
389 |
self.to_v = nn.Linear(dim, self.inner_dim)
|
390 |
|
391 |
+
if qk_norm is None:
|
392 |
+
self.q_norm = None
|
393 |
+
self.k_norm = None
|
394 |
+
elif qk_norm == "rms_norm":
|
395 |
+
self.q_norm = RMSNorm(dim_head, eps=1e-6)
|
396 |
+
self.k_norm = RMSNorm(dim_head, eps=1e-6)
|
397 |
+
else:
|
398 |
+
raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
|
399 |
+
|
400 |
if self.context_dim is not None:
|
401 |
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
402 |
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
403 |
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
404 |
+
if qk_norm is None:
|
405 |
+
self.c_q_norm = None
|
406 |
+
self.c_k_norm = None
|
407 |
+
elif qk_norm == "rms_norm":
|
408 |
+
self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
|
409 |
+
self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
|
410 |
|
411 |
self.to_out = nn.ModuleList([])
|
412 |
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
413 |
self.to_out.append(nn.Dropout(dropout))
|
414 |
|
415 |
+
if self.context_dim is not None and not self.context_pre_only:
|
416 |
+
self.to_out_c = nn.Linear(self.inner_dim, context_dim)
|
417 |
|
418 |
def forward(
|
419 |
self,
|
|
|
433 |
|
434 |
|
435 |
class AttnProcessor:
|
436 |
+
def __init__(
|
437 |
+
self,
|
438 |
+
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
439 |
+
):
|
440 |
+
self.pe_attn_head = pe_attn_head
|
441 |
|
442 |
def __call__(
|
443 |
self,
|
|
|
448 |
) -> torch.FloatTensor:
|
449 |
batch_size = x.shape[0]
|
450 |
|
451 |
+
# `sample` projections
|
452 |
query = attn.to_q(x)
|
453 |
key = attn.to_k(x)
|
454 |
value = attn.to_v(x)
|
455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
# attention
|
457 |
inner_dim = key.shape[-1]
|
458 |
head_dim = inner_dim // attn.heads
|
|
|
460 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
461 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
462 |
|
463 |
+
# qk norm
|
464 |
+
if attn.q_norm is not None:
|
465 |
+
query = attn.q_norm(query)
|
466 |
+
if attn.k_norm is not None:
|
467 |
+
key = attn.k_norm(key)
|
468 |
+
|
469 |
+
# apply rotary position embedding
|
470 |
+
if rope is not None:
|
471 |
+
freqs, xpos_scale = rope
|
472 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
473 |
+
|
474 |
+
if self.pe_attn_head is not None:
|
475 |
+
pn = self.pe_attn_head
|
476 |
+
query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
|
477 |
+
key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
|
478 |
+
else:
|
479 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
480 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
481 |
+
|
482 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
483 |
if mask is not None:
|
484 |
attn_mask = mask
|
|
|
524 |
|
525 |
batch_size = c.shape[0]
|
526 |
|
527 |
+
# `sample` projections
|
528 |
query = attn.to_q(x)
|
529 |
key = attn.to_k(x)
|
530 |
value = attn.to_v(x)
|
531 |
|
532 |
+
# `context` projections
|
533 |
c_query = attn.to_q_c(c)
|
534 |
c_key = attn.to_k_c(c)
|
535 |
c_value = attn.to_v_c(c)
|
536 |
|
537 |
+
# attention
|
538 |
+
inner_dim = key.shape[-1]
|
539 |
+
head_dim = inner_dim // attn.heads
|
540 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
541 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
542 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
543 |
+
c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
544 |
+
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
545 |
+
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
546 |
+
|
547 |
+
# qk norm
|
548 |
+
if attn.q_norm is not None:
|
549 |
+
query = attn.q_norm(query)
|
550 |
+
if attn.k_norm is not None:
|
551 |
+
key = attn.k_norm(key)
|
552 |
+
if attn.c_q_norm is not None:
|
553 |
+
c_query = attn.c_q_norm(c_query)
|
554 |
+
if attn.c_k_norm is not None:
|
555 |
+
c_key = attn.c_k_norm(c_key)
|
556 |
+
|
557 |
# apply rope for context and noised input independently
|
558 |
if rope is not None:
|
559 |
freqs, xpos_scale = rope
|
|
|
566 |
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
567 |
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
568 |
|
569 |
+
# joint attention
|
570 |
+
query = torch.cat([query, c_query], dim=2)
|
571 |
+
key = torch.cat([key, c_key], dim=2)
|
572 |
+
value = torch.cat([value, c_value], dim=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
575 |
if mask is not None:
|
|
|
608 |
|
609 |
|
610 |
class DiTBlock(nn.Module):
|
611 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
|
612 |
super().__init__()
|
613 |
|
614 |
+
self.attn_norm = AdaLayerNorm(dim)
|
615 |
self.attn = Attention(
|
616 |
+
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
617 |
dim=dim,
|
618 |
heads=heads,
|
619 |
dim_head=dim_head,
|
620 |
dropout=dropout,
|
621 |
+
qk_norm=qk_norm,
|
622 |
)
|
623 |
|
624 |
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
|
654 |
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
655 |
"""
|
656 |
|
657 |
+
def __init__(
|
658 |
+
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
|
659 |
+
):
|
660 |
super().__init__()
|
661 |
+
if context_dim is None:
|
662 |
+
context_dim = dim
|
663 |
self.context_pre_only = context_pre_only
|
664 |
|
665 |
+
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
|
666 |
+
self.attn_norm_x = AdaLayerNorm(dim)
|
667 |
self.attn = Attention(
|
668 |
processor=JointAttnProcessor(),
|
669 |
dim=dim,
|
670 |
heads=heads,
|
671 |
dim_head=dim_head,
|
672 |
dropout=dropout,
|
673 |
+
context_dim=context_dim,
|
674 |
context_pre_only=context_pre_only,
|
675 |
+
qk_norm=qk_norm,
|
676 |
)
|
677 |
|
678 |
if not context_pre_only:
|
679 |
+
self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
|
680 |
+
self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
681 |
else:
|
682 |
self.ff_norm_c = None
|
683 |
self.ff_c = None
|
src/f5_tts/model/trainer.py
CHANGED
@@ -32,7 +32,7 @@ class Trainer:
|
|
32 |
save_per_updates=1000,
|
33 |
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
34 |
checkpoint_path=None,
|
35 |
-
|
36 |
batch_size_type: str = "sample",
|
37 |
max_samples=32,
|
38 |
grad_accumulation_steps=1,
|
@@ -40,7 +40,7 @@ class Trainer:
|
|
40 |
noise_scheduler: str | None = None,
|
41 |
duration_predictor: torch.nn.Module | None = None,
|
42 |
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
|
43 |
-
wandb_project="
|
44 |
wandb_run_name="test_run",
|
45 |
wandb_resume_id: str = None,
|
46 |
log_samples: bool = False,
|
@@ -51,6 +51,7 @@ class Trainer:
|
|
51 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
52 |
is_local_vocoder: bool = False, # use local path vocoder
|
53 |
local_vocoder_path: str = "", # local vocoder path
|
|
|
54 |
):
|
55 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
56 |
|
@@ -72,21 +73,23 @@ class Trainer:
|
|
72 |
else:
|
73 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
init_kwargs=init_kwargs,
|
78 |
-
config={
|
79 |
"epochs": epochs,
|
80 |
"learning_rate": learning_rate,
|
81 |
"num_warmup_updates": num_warmup_updates,
|
82 |
-
"
|
83 |
"batch_size_type": batch_size_type,
|
84 |
"max_samples": max_samples,
|
85 |
"grad_accumulation_steps": grad_accumulation_steps,
|
86 |
"max_grad_norm": max_grad_norm,
|
87 |
-
"gpus": self.accelerator.num_processes,
|
88 |
"noise_scheduler": noise_scheduler,
|
89 |
-
}
|
|
|
|
|
|
|
|
|
|
|
90 |
)
|
91 |
|
92 |
elif self.logger == "tensorboard":
|
@@ -111,9 +114,9 @@ class Trainer:
|
|
111 |
self.save_per_updates = save_per_updates
|
112 |
self.keep_last_n_checkpoints = keep_last_n_checkpoints
|
113 |
self.last_per_updates = default(last_per_updates, save_per_updates)
|
114 |
-
self.checkpoint_path = default(checkpoint_path, "ckpts/
|
115 |
|
116 |
-
self.
|
117 |
self.batch_size_type = batch_size_type
|
118 |
self.max_samples = max_samples
|
119 |
self.grad_accumulation_steps = grad_accumulation_steps
|
@@ -179,7 +182,7 @@ class Trainer:
|
|
179 |
if (
|
180 |
not exists(self.checkpoint_path)
|
181 |
or not os.path.exists(self.checkpoint_path)
|
182 |
-
or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
|
183 |
):
|
184 |
return 0
|
185 |
|
@@ -191,7 +194,7 @@ class Trainer:
|
|
191 |
all_checkpoints = [
|
192 |
f
|
193 |
for f in os.listdir(self.checkpoint_path)
|
194 |
-
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
|
195 |
]
|
196 |
|
197 |
# First try to find regular training checkpoints
|
@@ -205,8 +208,16 @@ class Trainer:
|
|
205 |
# If no training checkpoints, use pretrained model
|
206 |
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
|
207 |
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
# patch for backward compatibility, 305e3ea
|
212 |
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
|
@@ -271,7 +282,7 @@ class Trainer:
|
|
271 |
num_workers=num_workers,
|
272 |
pin_memory=True,
|
273 |
persistent_workers=True,
|
274 |
-
batch_size=self.
|
275 |
shuffle=True,
|
276 |
generator=generator,
|
277 |
)
|
@@ -280,10 +291,10 @@ class Trainer:
|
|
280 |
sampler = SequentialSampler(train_dataset)
|
281 |
batch_sampler = DynamicBatchSampler(
|
282 |
sampler,
|
283 |
-
self.
|
284 |
max_samples=self.max_samples,
|
285 |
random_seed=resumable_with_seed, # This enables reproducible shuffling
|
286 |
-
|
287 |
)
|
288 |
train_dataloader = DataLoader(
|
289 |
train_dataset,
|
|
|
32 |
save_per_updates=1000,
|
33 |
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
34 |
checkpoint_path=None,
|
35 |
+
batch_size_per_gpu=32,
|
36 |
batch_size_type: str = "sample",
|
37 |
max_samples=32,
|
38 |
grad_accumulation_steps=1,
|
|
|
40 |
noise_scheduler: str | None = None,
|
41 |
duration_predictor: torch.nn.Module | None = None,
|
42 |
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
|
43 |
+
wandb_project="test_f5-tts",
|
44 |
wandb_run_name="test_run",
|
45 |
wandb_resume_id: str = None,
|
46 |
log_samples: bool = False,
|
|
|
51 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
52 |
is_local_vocoder: bool = False, # use local path vocoder
|
53 |
local_vocoder_path: str = "", # local vocoder path
|
54 |
+
cfg_dict: dict = dict(), # training config
|
55 |
):
|
56 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
57 |
|
|
|
73 |
else:
|
74 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
75 |
|
76 |
+
if not cfg_dict:
|
77 |
+
cfg_dict = {
|
|
|
|
|
78 |
"epochs": epochs,
|
79 |
"learning_rate": learning_rate,
|
80 |
"num_warmup_updates": num_warmup_updates,
|
81 |
+
"batch_size_per_gpu": batch_size_per_gpu,
|
82 |
"batch_size_type": batch_size_type,
|
83 |
"max_samples": max_samples,
|
84 |
"grad_accumulation_steps": grad_accumulation_steps,
|
85 |
"max_grad_norm": max_grad_norm,
|
|
|
86 |
"noise_scheduler": noise_scheduler,
|
87 |
+
}
|
88 |
+
cfg_dict["gpus"] = self.accelerator.num_processes
|
89 |
+
self.accelerator.init_trackers(
|
90 |
+
project_name=wandb_project,
|
91 |
+
init_kwargs=init_kwargs,
|
92 |
+
config=cfg_dict,
|
93 |
)
|
94 |
|
95 |
elif self.logger == "tensorboard":
|
|
|
114 |
self.save_per_updates = save_per_updates
|
115 |
self.keep_last_n_checkpoints = keep_last_n_checkpoints
|
116 |
self.last_per_updates = default(last_per_updates, save_per_updates)
|
117 |
+
self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
|
118 |
|
119 |
+
self.batch_size_per_gpu = batch_size_per_gpu
|
120 |
self.batch_size_type = batch_size_type
|
121 |
self.max_samples = max_samples
|
122 |
self.grad_accumulation_steps = grad_accumulation_steps
|
|
|
182 |
if (
|
183 |
not exists(self.checkpoint_path)
|
184 |
or not os.path.exists(self.checkpoint_path)
|
185 |
+
or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
|
186 |
):
|
187 |
return 0
|
188 |
|
|
|
194 |
all_checkpoints = [
|
195 |
f
|
196 |
for f in os.listdir(self.checkpoint_path)
|
197 |
+
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
|
198 |
]
|
199 |
|
200 |
# First try to find regular training checkpoints
|
|
|
208 |
# If no training checkpoints, use pretrained model
|
209 |
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
|
210 |
|
211 |
+
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
|
212 |
+
from safetensors.torch import load_file
|
213 |
+
|
214 |
+
checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
|
215 |
+
checkpoint = {"ema_model_state_dict": checkpoint}
|
216 |
+
elif latest_checkpoint.endswith(".pt"):
|
217 |
+
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
218 |
+
checkpoint = torch.load(
|
219 |
+
f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
|
220 |
+
)
|
221 |
|
222 |
# patch for backward compatibility, 305e3ea
|
223 |
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
|
|
|
282 |
num_workers=num_workers,
|
283 |
pin_memory=True,
|
284 |
persistent_workers=True,
|
285 |
+
batch_size=self.batch_size_per_gpu,
|
286 |
shuffle=True,
|
287 |
generator=generator,
|
288 |
)
|
|
|
291 |
sampler = SequentialSampler(train_dataset)
|
292 |
batch_sampler = DynamicBatchSampler(
|
293 |
sampler,
|
294 |
+
self.batch_size_per_gpu,
|
295 |
max_samples=self.max_samples,
|
296 |
random_seed=resumable_with_seed, # This enables reproducible shuffling
|
297 |
+
drop_residual=False,
|
298 |
)
|
299 |
train_dataloader = DataLoader(
|
300 |
train_dataset,
|
src/f5_tts/model/utils.py
CHANGED
@@ -133,11 +133,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
|
133 |
|
134 |
# convert char to pinyin
|
135 |
|
136 |
-
jieba.initialize()
|
137 |
-
print("Word segmentation module jieba initialized.\n")
|
138 |
-
|
139 |
|
140 |
def convert_char_to_pinyin(text_list, polyphone=True):
|
|
|
|
|
|
|
|
|
141 |
final_text_list = []
|
142 |
custom_trans = str.maketrans(
|
143 |
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
|
|
133 |
|
134 |
# convert char to pinyin
|
135 |
|
|
|
|
|
|
|
136 |
|
137 |
def convert_char_to_pinyin(text_list, polyphone=True):
|
138 |
+
if jieba.dt.initialized is False:
|
139 |
+
jieba.default_logger.setLevel(50) # CRITICAL
|
140 |
+
jieba.initialize()
|
141 |
+
|
142 |
final_text_list = []
|
143 |
custom_trans = str.maketrans(
|
144 |
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
src/f5_tts/scripts/count_max_epoch.py
CHANGED
@@ -9,7 +9,7 @@ mel_hop_length = 256
|
|
9 |
mel_sampling_rate = 24000
|
10 |
|
11 |
# target
|
12 |
-
wanted_max_updates =
|
13 |
|
14 |
# train params
|
15 |
gpus = 8
|
|
|
9 |
mel_sampling_rate = 24000
|
10 |
|
11 |
# target
|
12 |
+
wanted_max_updates = 1200000
|
13 |
|
14 |
# train params
|
15 |
gpus = 8
|
src/f5_tts/socket_client.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import asyncio
|
3 |
+
import pyaudio
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
import time
|
7 |
+
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
13 |
+
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
14 |
+
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
|
15 |
+
|
16 |
+
start_time = time.time()
|
17 |
+
first_chunk_time = None
|
18 |
+
|
19 |
+
async def play_audio_stream():
|
20 |
+
nonlocal first_chunk_time
|
21 |
+
p = pyaudio.PyAudio()
|
22 |
+
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
23 |
+
|
24 |
+
try:
|
25 |
+
while True:
|
26 |
+
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
|
27 |
+
if not data:
|
28 |
+
break
|
29 |
+
if data == b"END":
|
30 |
+
logger.info("End of audio received.")
|
31 |
+
break
|
32 |
+
|
33 |
+
audio_array = np.frombuffer(data, dtype=np.float32)
|
34 |
+
stream.write(audio_array.tobytes())
|
35 |
+
|
36 |
+
if first_chunk_time is None:
|
37 |
+
first_chunk_time = time.time()
|
38 |
+
|
39 |
+
finally:
|
40 |
+
stream.stop_stream()
|
41 |
+
stream.close()
|
42 |
+
p.terminate()
|
43 |
+
|
44 |
+
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
|
45 |
+
|
46 |
+
try:
|
47 |
+
data_to_send = f"{text}".encode("utf-8")
|
48 |
+
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
|
49 |
+
await play_audio_stream()
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
logger.error(f"Error in listen_to_F5TTS: {e}")
|
53 |
+
|
54 |
+
finally:
|
55 |
+
client_socket.close()
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
|
60 |
+
|
61 |
+
asyncio.run(listen_to_F5TTS(text_to_send))
|
src/f5_tts/socket_server.py
CHANGED
@@ -13,8 +13,9 @@ from importlib.resources import files
|
|
13 |
import torch
|
14 |
import torchaudio
|
15 |
from huggingface_hub import hf_hub_download
|
|
|
16 |
|
17 |
-
from f5_tts.model.backbones.dit import DiT
|
18 |
from f5_tts.infer.utils_infer import (
|
19 |
chunk_text,
|
20 |
preprocess_ref_audio_text,
|
@@ -68,7 +69,7 @@ class AudioFileWriterThread(threading.Thread):
|
|
68 |
|
69 |
|
70 |
class TTSStreamingProcessor:
|
71 |
-
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
72 |
self.device = device or (
|
73 |
"cuda"
|
74 |
if torch.cuda.is_available()
|
@@ -78,21 +79,24 @@ class TTSStreamingProcessor:
|
|
78 |
if torch.backends.mps.is_available()
|
79 |
else "cpu"
|
80 |
)
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
|
83 |
self.vocoder = self.load_vocoder_model()
|
84 |
-
|
85 |
self.update_reference(ref_audio, ref_text)
|
86 |
self._warm_up()
|
87 |
self.file_writer_thread = None
|
88 |
self.first_package = True
|
89 |
|
90 |
def load_ema_model(self, ckpt_file, vocab_file, dtype):
|
91 |
-
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
92 |
-
model_cls = DiT
|
93 |
return load_model(
|
94 |
-
model_cls
|
95 |
-
|
96 |
ckpt_path=ckpt_file,
|
97 |
mel_spec_type=self.mel_spec_type,
|
98 |
vocab_file=vocab_file,
|
@@ -212,9 +216,14 @@ if __name__ == "__main__":
|
|
212 |
parser.add_argument("--host", default="0.0.0.0")
|
213 |
parser.add_argument("--port", default=9998)
|
214 |
|
|
|
|
|
|
|
|
|
|
|
215 |
parser.add_argument(
|
216 |
"--ckpt_file",
|
217 |
-
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="
|
218 |
help="Path to the model checkpoint file",
|
219 |
)
|
220 |
parser.add_argument(
|
@@ -242,6 +251,7 @@ if __name__ == "__main__":
|
|
242 |
try:
|
243 |
# Initialize the processor with the model and vocoder
|
244 |
processor = TTSStreamingProcessor(
|
|
|
245 |
ckpt_file=args.ckpt_file,
|
246 |
vocab_file=args.vocab_file,
|
247 |
ref_audio=args.ref_audio,
|
|
|
13 |
import torch
|
14 |
import torchaudio
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
+
from omegaconf import OmegaConf
|
17 |
|
18 |
+
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
|
19 |
from f5_tts.infer.utils_infer import (
|
20 |
chunk_text,
|
21 |
preprocess_ref_audio_text,
|
|
|
69 |
|
70 |
|
71 |
class TTSStreamingProcessor:
|
72 |
+
def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
73 |
self.device = device or (
|
74 |
"cuda"
|
75 |
if torch.cuda.is_available()
|
|
|
79 |
if torch.backends.mps.is_available()
|
80 |
else "cpu"
|
81 |
)
|
82 |
+
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
83 |
+
self.model_cls = globals()[model_cfg.model.backbone]
|
84 |
+
self.model_arc = model_cfg.model.arch
|
85 |
+
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
86 |
+
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
|
87 |
+
|
88 |
self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
|
89 |
self.vocoder = self.load_vocoder_model()
|
90 |
+
|
91 |
self.update_reference(ref_audio, ref_text)
|
92 |
self._warm_up()
|
93 |
self.file_writer_thread = None
|
94 |
self.first_package = True
|
95 |
|
96 |
def load_ema_model(self, ckpt_file, vocab_file, dtype):
|
|
|
|
|
97 |
return load_model(
|
98 |
+
self.model_cls,
|
99 |
+
self.model_arc,
|
100 |
ckpt_path=ckpt_file,
|
101 |
mel_spec_type=self.mel_spec_type,
|
102 |
vocab_file=vocab_file,
|
|
|
216 |
parser.add_argument("--host", default="0.0.0.0")
|
217 |
parser.add_argument("--port", default=9998)
|
218 |
|
219 |
+
parser.add_argument(
|
220 |
+
"--model",
|
221 |
+
default="F5TTS_v1_Base",
|
222 |
+
help="The model name, e.g. F5TTS_v1_Base",
|
223 |
+
)
|
224 |
parser.add_argument(
|
225 |
"--ckpt_file",
|
226 |
+
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
|
227 |
help="Path to the model checkpoint file",
|
228 |
)
|
229 |
parser.add_argument(
|
|
|
251 |
try:
|
252 |
# Initialize the processor with the model and vocoder
|
253 |
processor = TTSStreamingProcessor(
|
254 |
+
model=args.model,
|
255 |
ckpt_file=args.ckpt_file,
|
256 |
vocab_file=args.vocab_file,
|
257 |
ref_audio=args.ref_audio,
|
src/f5_tts/train/README.md
CHANGED
@@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process.
|
|
40 |
accelerate config
|
41 |
|
42 |
# .yaml files are under src/f5_tts/configs directory
|
43 |
-
accelerate launch src/f5_tts/train/train.py --config-name
|
44 |
|
45 |
# possible to overwrite accelerate and hydra config
|
46 |
-
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name
|
47 |
```
|
48 |
|
49 |
### 2. Finetuning practice
|
@@ -53,7 +53,7 @@ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#1
|
|
53 |
|
54 |
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
|
55 |
|
56 |
-
### 3.
|
57 |
|
58 |
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
59 |
|
@@ -62,7 +62,7 @@ By default, the training script does NOT use logging (assuming you didn't manual
|
|
62 |
To turn on wandb logging, you can either:
|
63 |
|
64 |
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
65 |
-
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/
|
66 |
|
67 |
On Mac & Linux:
|
68 |
|
@@ -75,7 +75,7 @@ On Windows:
|
|
75 |
```
|
76 |
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
77 |
```
|
78 |
-
Moreover, if you couldn't access
|
79 |
|
80 |
```
|
81 |
export WANDB_MODE=offline
|
|
|
40 |
accelerate config
|
41 |
|
42 |
# .yaml files are under src/f5_tts/configs directory
|
43 |
+
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base_train.yaml
|
44 |
|
45 |
# possible to overwrite accelerate and hydra config
|
46 |
+
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base_train.yaml ++datasets.batch_size_per_gpu=19200
|
47 |
```
|
48 |
|
49 |
### 2. Finetuning practice
|
|
|
53 |
|
54 |
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
|
55 |
|
56 |
+
### 3. W&B Logging
|
57 |
|
58 |
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
59 |
|
|
|
62 |
To turn on wandb logging, you can either:
|
63 |
|
64 |
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
65 |
+
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
|
66 |
|
67 |
On Mac & Linux:
|
68 |
|
|
|
75 |
```
|
76 |
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
77 |
```
|
78 |
+
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
|
79 |
|
80 |
```
|
81 |
export WANDB_MODE=offline
|
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import argparse
|
2 |
import os
|
3 |
import shutil
|
|
|
4 |
|
5 |
from cached_path import cached_path
|
|
|
6 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
7 |
from f5_tts.model.utils import get_tokenizer
|
8 |
from f5_tts.model.dataset import load_dataset
|
9 |
-
from importlib.resources import files
|
10 |
|
11 |
|
12 |
# -------------------------- Dataset Settings --------------------------- #
|
@@ -20,19 +21,14 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
|
20 |
|
21 |
# -------------------------- Argument Parsing --------------------------- #
|
22 |
def parse_args():
|
23 |
-
# batch_size_per_gpu = 1000 settting for gpu 8GB
|
24 |
-
# batch_size_per_gpu = 1600 settting for gpu 12GB
|
25 |
-
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
26 |
-
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
27 |
-
|
28 |
-
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
29 |
-
|
30 |
-
# change save_per_updates , last_per_updates change this value what you need ,
|
31 |
-
|
32 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
33 |
|
34 |
parser.add_argument(
|
35 |
-
"--exp_name",
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
38 |
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
@@ -88,19 +84,54 @@ def main():
|
|
88 |
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
89 |
|
90 |
# Model parameters based on experiment name
|
91 |
-
|
|
|
92 |
wandb_resume_id = None
|
93 |
model_cls = DiT
|
94 |
-
model_cfg = dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if args.finetune:
|
96 |
if args.pretrain is None:
|
97 |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
98 |
else:
|
99 |
ckpt_path = args.pretrain
|
|
|
100 |
elif args.exp_name == "E2TTS_Base":
|
101 |
wandb_resume_id = None
|
102 |
model_cls = UNetT
|
103 |
-
model_cfg = dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
if args.finetune:
|
105 |
if args.pretrain is None:
|
106 |
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
@@ -120,6 +151,7 @@ def main():
|
|
120 |
print("copy checkpoint for finetune")
|
121 |
|
122 |
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
|
|
123 |
tokenizer = args.tokenizer
|
124 |
if tokenizer == "custom":
|
125 |
if not args.tokenizer_path:
|
@@ -156,7 +188,7 @@ def main():
|
|
156 |
save_per_updates=args.save_per_updates,
|
157 |
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
158 |
checkpoint_path=checkpoint_path,
|
159 |
-
|
160 |
batch_size_type=args.batch_size_type,
|
161 |
max_samples=args.max_samples,
|
162 |
grad_accumulation_steps=args.grad_accumulation_steps,
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
import shutil
|
4 |
+
from importlib.resources import files
|
5 |
|
6 |
from cached_path import cached_path
|
7 |
+
|
8 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
9 |
from f5_tts.model.utils import get_tokenizer
|
10 |
from f5_tts.model.dataset import load_dataset
|
|
|
11 |
|
12 |
|
13 |
# -------------------------- Dataset Settings --------------------------- #
|
|
|
21 |
|
22 |
# -------------------------- Argument Parsing --------------------------- #
|
23 |
def parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
25 |
|
26 |
parser.add_argument(
|
27 |
+
"--exp_name",
|
28 |
+
type=str,
|
29 |
+
default="F5TTS_v1_Base",
|
30 |
+
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
|
31 |
+
help="Experiment name",
|
32 |
)
|
33 |
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
34 |
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
|
|
84 |
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
85 |
|
86 |
# Model parameters based on experiment name
|
87 |
+
|
88 |
+
if args.exp_name == "F5TTS_v1_Base":
|
89 |
wandb_resume_id = None
|
90 |
model_cls = DiT
|
91 |
+
model_cfg = dict(
|
92 |
+
dim=1024,
|
93 |
+
depth=22,
|
94 |
+
heads=16,
|
95 |
+
ff_mult=2,
|
96 |
+
text_dim=512,
|
97 |
+
conv_layers=4,
|
98 |
+
)
|
99 |
+
if args.finetune:
|
100 |
+
if args.pretrain is None:
|
101 |
+
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
|
102 |
+
else:
|
103 |
+
ckpt_path = args.pretrain
|
104 |
+
|
105 |
+
elif args.exp_name == "F5TTS_Base":
|
106 |
+
wandb_resume_id = None
|
107 |
+
model_cls = DiT
|
108 |
+
model_cfg = dict(
|
109 |
+
dim=1024,
|
110 |
+
depth=22,
|
111 |
+
heads=16,
|
112 |
+
ff_mult=2,
|
113 |
+
text_dim=512,
|
114 |
+
text_mask_padding=False,
|
115 |
+
conv_layers=4,
|
116 |
+
pe_attn_head=1,
|
117 |
+
)
|
118 |
if args.finetune:
|
119 |
if args.pretrain is None:
|
120 |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
121 |
else:
|
122 |
ckpt_path = args.pretrain
|
123 |
+
|
124 |
elif args.exp_name == "E2TTS_Base":
|
125 |
wandb_resume_id = None
|
126 |
model_cls = UNetT
|
127 |
+
model_cfg = dict(
|
128 |
+
dim=1024,
|
129 |
+
depth=24,
|
130 |
+
heads=16,
|
131 |
+
ff_mult=4,
|
132 |
+
text_mask_padding=False,
|
133 |
+
pe_attn_head=1,
|
134 |
+
)
|
135 |
if args.finetune:
|
136 |
if args.pretrain is None:
|
137 |
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
|
|
151 |
print("copy checkpoint for finetune")
|
152 |
|
153 |
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
154 |
+
|
155 |
tokenizer = args.tokenizer
|
156 |
if tokenizer == "custom":
|
157 |
if not args.tokenizer_path:
|
|
|
188 |
save_per_updates=args.save_per_updates,
|
189 |
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
190 |
checkpoint_path=checkpoint_path,
|
191 |
+
batch_size_per_gpu=args.batch_size_per_gpu,
|
192 |
batch_size_type=args.batch_size_type,
|
193 |
max_samples=args.max_samples,
|
194 |
grad_accumulation_steps=args.grad_accumulation_steps,
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -1,36 +1,36 @@
|
|
1 |
-
import threading
|
2 |
-
import queue
|
3 |
-
import re
|
4 |
-
|
5 |
import gc
|
6 |
import json
|
|
|
7 |
import os
|
8 |
import platform
|
9 |
import psutil
|
|
|
10 |
import random
|
|
|
11 |
import signal
|
12 |
import shutil
|
13 |
import subprocess
|
14 |
import sys
|
15 |
import tempfile
|
|
|
16 |
import time
|
17 |
from glob import glob
|
|
|
|
|
18 |
|
19 |
import click
|
20 |
import gradio as gr
|
21 |
import librosa
|
22 |
-
import numpy as np
|
23 |
import torch
|
24 |
import torchaudio
|
|
|
25 |
from datasets import Dataset as Dataset_
|
26 |
from datasets.arrow_writer import ArrowWriter
|
27 |
-
from safetensors.torch import save_file
|
28 |
-
|
29 |
-
from cached_path import cached_path
|
30 |
from f5_tts.api import F5TTS
|
31 |
from f5_tts.model.utils import convert_char_to_pinyin
|
32 |
from f5_tts.infer.utils_infer import transcribe
|
33 |
-
from importlib.resources import files
|
34 |
|
35 |
|
36 |
training_process = None
|
@@ -118,16 +118,16 @@ def load_settings(project_name):
|
|
118 |
|
119 |
# Default settings
|
120 |
default_settings = {
|
121 |
-
"exp_name": "
|
122 |
-
"learning_rate": 1e-
|
123 |
-
"batch_size_per_gpu":
|
124 |
-
"batch_size_type": "
|
125 |
"max_samples": 64,
|
126 |
-
"grad_accumulation_steps":
|
127 |
"max_grad_norm": 1,
|
128 |
"epochs": 100,
|
129 |
-
"num_warmup_updates":
|
130 |
-
"save_per_updates":
|
131 |
"keep_last_n_checkpoints": -1,
|
132 |
"last_per_updates": 100,
|
133 |
"finetune": True,
|
@@ -362,18 +362,18 @@ def terminate_process(pid):
|
|
362 |
|
363 |
def start_training(
|
364 |
dataset_name="",
|
365 |
-
exp_name="
|
366 |
-
learning_rate=1e-
|
367 |
-
batch_size_per_gpu=
|
368 |
-
batch_size_type="
|
369 |
max_samples=64,
|
370 |
-
grad_accumulation_steps=
|
371 |
max_grad_norm=1.0,
|
372 |
-
epochs=
|
373 |
-
num_warmup_updates=
|
374 |
-
save_per_updates=
|
375 |
keep_last_n_checkpoints=-1,
|
376 |
-
last_per_updates=
|
377 |
finetune=True,
|
378 |
file_checkpoint_train="",
|
379 |
tokenizer_type="pinyin",
|
@@ -797,14 +797,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
|
797 |
print(f"Error processing {file_audio}: {e}")
|
798 |
continue
|
799 |
|
800 |
-
if duration < 1 or duration >
|
801 |
-
if duration >
|
802 |
-
error_files.append([file_audio, "duration >
|
803 |
if duration < 1:
|
804 |
error_files.append([file_audio, "duration < 1 sec "])
|
805 |
continue
|
806 |
if len(text) < 3:
|
807 |
-
error_files.append([file_audio, "very
|
808 |
continue
|
809 |
|
810 |
text = clear_text(text)
|
@@ -871,40 +871,37 @@ def check_user(value):
|
|
871 |
|
872 |
def calculate_train(
|
873 |
name_project,
|
|
|
|
|
|
|
874 |
batch_size_type,
|
875 |
max_samples,
|
876 |
-
learning_rate,
|
877 |
num_warmup_updates,
|
878 |
-
save_per_updates,
|
879 |
-
last_per_updates,
|
880 |
finetune,
|
881 |
):
|
882 |
path_project = os.path.join(path_data, name_project)
|
883 |
-
|
|
|
|
|
|
|
884 |
|
885 |
-
if not os.path.isfile(
|
886 |
return (
|
887 |
-
|
|
|
|
|
888 |
max_samples,
|
889 |
num_warmup_updates,
|
890 |
-
save_per_updates,
|
891 |
-
last_per_updates,
|
892 |
"project not found !",
|
893 |
-
learning_rate,
|
894 |
)
|
895 |
|
896 |
-
with open(
|
897 |
data = json.load(file)
|
898 |
|
899 |
duration_list = data["duration"]
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
# if torch.cuda.is_available():
|
904 |
-
# gpu_properties = torch.cuda.get_device_properties(0)
|
905 |
-
# total_memory = gpu_properties.total_memory / (1024**3)
|
906 |
-
# elif torch.backends.mps.is_available():
|
907 |
-
# total_memory = psutil.virtual_memory().available / (1024**3)
|
908 |
|
909 |
if torch.cuda.is_available():
|
910 |
gpu_count = torch.cuda.device_count()
|
@@ -912,64 +909,39 @@ def calculate_train(
|
|
912 |
for i in range(gpu_count):
|
913 |
gpu_properties = torch.cuda.get_device_properties(i)
|
914 |
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
915 |
-
|
916 |
elif torch.xpu.is_available():
|
917 |
gpu_count = torch.xpu.device_count()
|
918 |
total_memory = 0
|
919 |
for i in range(gpu_count):
|
920 |
gpu_properties = torch.xpu.get_device_properties(i)
|
921 |
total_memory += gpu_properties.total_memory / (1024**3)
|
922 |
-
|
923 |
elif torch.backends.mps.is_available():
|
924 |
gpu_count = 1
|
925 |
total_memory = psutil.virtual_memory().available / (1024**3)
|
926 |
|
|
|
|
|
|
|
927 |
if batch_size_type == "frame":
|
928 |
-
|
929 |
-
|
930 |
-
batch_size_per_gpu = int(
|
931 |
-
else:
|
932 |
-
batch_size_per_gpu = int(total_memory / 8)
|
933 |
-
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
|
934 |
-
batch = batch_size_per_gpu
|
935 |
|
936 |
-
if
|
937 |
-
|
938 |
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
|
952 |
-
if last_per_updates <= 0:
|
953 |
-
last_per_updates = 2
|
954 |
-
|
955 |
-
total_hours = hours
|
956 |
-
mel_hop_length = 256
|
957 |
-
mel_sampling_rate = 24000
|
958 |
-
|
959 |
-
# target
|
960 |
-
wanted_max_updates = 1000000
|
961 |
-
|
962 |
-
# train params
|
963 |
-
gpus = gpu_count
|
964 |
-
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
|
965 |
-
grad_accum = 1
|
966 |
-
|
967 |
-
# intermediate
|
968 |
-
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
969 |
-
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
970 |
-
updates_per_epoch = total_hours / mini_batch_hours
|
971 |
-
# steps_per_epoch = updates_per_epoch * grad_accum
|
972 |
-
epochs = wanted_max_updates / updates_per_epoch
|
973 |
|
974 |
if finetune:
|
975 |
learning_rate = 1e-5
|
@@ -977,14 +949,12 @@ def calculate_train(
|
|
977 |
learning_rate = 7.5e-5
|
978 |
|
979 |
return (
|
|
|
|
|
980 |
batch_size_per_gpu,
|
981 |
max_samples,
|
982 |
num_warmup_updates,
|
983 |
-
|
984 |
-
last_per_updates,
|
985 |
-
samples,
|
986 |
-
learning_rate,
|
987 |
-
int(epochs),
|
988 |
)
|
989 |
|
990 |
|
@@ -1021,7 +991,11 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
|
|
1021 |
torch.backends.cudnn.deterministic = True
|
1022 |
torch.backends.cudnn.benchmark = False
|
1023 |
|
1024 |
-
|
|
|
|
|
|
|
|
|
1025 |
|
1026 |
ema_sd = ckpt.get("ema_model_state_dict", {})
|
1027 |
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
|
@@ -1089,9 +1063,11 @@ def vocab_extend(project_name, symbols, model_type):
|
|
1089 |
with open(file_vocab_project, "w", encoding="utf-8") as f:
|
1090 |
f.write("\n".join(vocab))
|
1091 |
|
1092 |
-
if model_type == "
|
|
|
|
|
1093 |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
1094 |
-
|
1095 |
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
1096 |
|
1097 |
vocab_size_new = len(miss_symbols)
|
@@ -1101,7 +1077,7 @@ def vocab_extend(project_name, symbols, model_type):
|
|
1101 |
os.makedirs(new_ckpt_path, exist_ok=True)
|
1102 |
|
1103 |
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
|
1104 |
-
new_ckpt_file = os.path.join(new_ckpt_path, "
|
1105 |
|
1106 |
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
|
1107 |
|
@@ -1231,21 +1207,21 @@ def infer(
|
|
1231 |
vocab_file = os.path.join(path_data, project, "vocab.txt")
|
1232 |
|
1233 |
tts_api = F5TTS(
|
1234 |
-
|
1235 |
)
|
1236 |
|
1237 |
print("update >> ", device_test, file_checkpoint, use_ema)
|
1238 |
|
1239 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
1240 |
tts_api.infer(
|
1241 |
-
gen_text=gen_text.lower().strip(),
|
1242 |
-
ref_text=ref_text.lower().strip(),
|
1243 |
ref_file=ref_audio,
|
|
|
|
|
1244 |
nfe_step=nfe_step,
|
1245 |
-
file_wave=f.name,
|
1246 |
speed=speed,
|
1247 |
-
seed=seed,
|
1248 |
remove_silence=remove_silence,
|
|
|
|
|
1249 |
)
|
1250 |
return f.name, tts_api.device, str(tts_api.seed)
|
1251 |
|
@@ -1404,14 +1380,14 @@ def get_audio_select(file_sample):
|
|
1404 |
with gr.Blocks() as app:
|
1405 |
gr.Markdown(
|
1406 |
"""
|
1407 |
-
#
|
1408 |
|
1409 |
-
This is a local web UI for F5 TTS
|
1410 |
|
1411 |
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
1412 |
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
1413 |
|
1414 |
-
The checkpoints support English and Chinese.
|
1415 |
|
1416 |
For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
|
1417 |
"""
|
@@ -1488,7 +1464,9 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
|
|
1488 |
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
|
1489 |
```""")
|
1490 |
|
1491 |
-
exp_name_extend = gr.Radio(
|
|
|
|
|
1492 |
|
1493 |
with gr.Row():
|
1494 |
txt_extend = gr.Textbox(
|
@@ -1557,9 +1535,9 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
|
1557 |
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
1558 |
)
|
1559 |
|
1560 |
-
with gr.TabItem("Train
|
1561 |
gr.Markdown("""```plaintext
|
1562 |
-
The auto-setting is still experimental.
|
1563 |
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
1564 |
```""")
|
1565 |
with gr.Row():
|
@@ -1573,11 +1551,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1573 |
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
|
1574 |
|
1575 |
with gr.Row():
|
1576 |
-
exp_name = gr.Radio(
|
|
|
|
|
1577 |
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
1578 |
|
1579 |
with gr.Row():
|
1580 |
-
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=
|
1581 |
max_samples = gr.Number(label="Max Samples", value=64)
|
1582 |
|
1583 |
with gr.Row():
|
@@ -1585,23 +1565,23 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1585 |
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
1586 |
|
1587 |
with gr.Row():
|
1588 |
-
epochs = gr.Number(label="Epochs", value=
|
1589 |
-
num_warmup_updates = gr.Number(label="Warmup Updates", value=
|
1590 |
|
1591 |
with gr.Row():
|
1592 |
-
save_per_updates = gr.Number(label="Save per Updates", value=
|
1593 |
keep_last_n_checkpoints = gr.Number(
|
1594 |
label="Keep Last N Checkpoints",
|
1595 |
value=-1,
|
1596 |
step=1,
|
1597 |
precision=0,
|
1598 |
-
info="-1
|
1599 |
)
|
1600 |
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
1601 |
|
1602 |
with gr.Row():
|
1603 |
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
1604 |
-
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="
|
1605 |
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
1606 |
start_button = gr.Button("Start Training")
|
1607 |
stop_button = gr.Button("Stop Training", interactive=False)
|
@@ -1718,23 +1698,21 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1718 |
fn=calculate_train,
|
1719 |
inputs=[
|
1720 |
cm_project,
|
|
|
|
|
|
|
1721 |
batch_size_type,
|
1722 |
max_samples,
|
1723 |
-
learning_rate,
|
1724 |
num_warmup_updates,
|
1725 |
-
save_per_updates,
|
1726 |
-
last_per_updates,
|
1727 |
ch_finetune,
|
1728 |
],
|
1729 |
outputs=[
|
|
|
|
|
1730 |
batch_size_per_gpu,
|
1731 |
max_samples,
|
1732 |
num_warmup_updates,
|
1733 |
-
save_per_updates,
|
1734 |
-
last_per_updates,
|
1735 |
lb_samples,
|
1736 |
-
learning_rate,
|
1737 |
-
epochs,
|
1738 |
],
|
1739 |
)
|
1740 |
|
@@ -1744,25 +1722,25 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1744 |
|
1745 |
def setup_load_settings():
|
1746 |
output_components = [
|
1747 |
-
exp_name,
|
1748 |
-
learning_rate,
|
1749 |
-
batch_size_per_gpu,
|
1750 |
-
batch_size_type,
|
1751 |
-
max_samples,
|
1752 |
-
grad_accumulation_steps,
|
1753 |
-
max_grad_norm,
|
1754 |
-
epochs,
|
1755 |
-
num_warmup_updates,
|
1756 |
-
save_per_updates,
|
1757 |
-
keep_last_n_checkpoints,
|
1758 |
-
last_per_updates,
|
1759 |
-
ch_finetune,
|
1760 |
-
file_checkpoint_train,
|
1761 |
-
tokenizer_type,
|
1762 |
-
tokenizer_file,
|
1763 |
-
mixed_precision,
|
1764 |
-
cd_logger,
|
1765 |
-
ch_8bit_adam,
|
1766 |
]
|
1767 |
return output_components
|
1768 |
|
@@ -1784,7 +1762,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1784 |
gr.Markdown("""```plaintext
|
1785 |
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
|
1786 |
```""")
|
1787 |
-
exp_name = gr.Radio(
|
|
|
|
|
1788 |
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
1789 |
|
1790 |
with gr.Row():
|
@@ -1838,9 +1818,9 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
|
1838 |
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1839 |
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1840 |
|
1841 |
-
with gr.TabItem("
|
1842 |
gr.Markdown("""```plaintext
|
1843 |
-
Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or
|
1844 |
```""")
|
1845 |
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
|
1846 |
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
|
|
|
|
|
|
|
|
|
|
|
1 |
import gc
|
2 |
import json
|
3 |
+
import numpy as np
|
4 |
import os
|
5 |
import platform
|
6 |
import psutil
|
7 |
+
import queue
|
8 |
import random
|
9 |
+
import re
|
10 |
import signal
|
11 |
import shutil
|
12 |
import subprocess
|
13 |
import sys
|
14 |
import tempfile
|
15 |
+
import threading
|
16 |
import time
|
17 |
from glob import glob
|
18 |
+
from importlib.resources import files
|
19 |
+
from scipy.io import wavfile
|
20 |
|
21 |
import click
|
22 |
import gradio as gr
|
23 |
import librosa
|
|
|
24 |
import torch
|
25 |
import torchaudio
|
26 |
+
from cached_path import cached_path
|
27 |
from datasets import Dataset as Dataset_
|
28 |
from datasets.arrow_writer import ArrowWriter
|
29 |
+
from safetensors.torch import load_file, save_file
|
30 |
+
|
|
|
31 |
from f5_tts.api import F5TTS
|
32 |
from f5_tts.model.utils import convert_char_to_pinyin
|
33 |
from f5_tts.infer.utils_infer import transcribe
|
|
|
34 |
|
35 |
|
36 |
training_process = None
|
|
|
118 |
|
119 |
# Default settings
|
120 |
default_settings = {
|
121 |
+
"exp_name": "F5TTS_v1_Base",
|
122 |
+
"learning_rate": 1e-5,
|
123 |
+
"batch_size_per_gpu": 1,
|
124 |
+
"batch_size_type": "sample",
|
125 |
"max_samples": 64,
|
126 |
+
"grad_accumulation_steps": 4,
|
127 |
"max_grad_norm": 1,
|
128 |
"epochs": 100,
|
129 |
+
"num_warmup_updates": 100,
|
130 |
+
"save_per_updates": 500,
|
131 |
"keep_last_n_checkpoints": -1,
|
132 |
"last_per_updates": 100,
|
133 |
"finetune": True,
|
|
|
362 |
|
363 |
def start_training(
|
364 |
dataset_name="",
|
365 |
+
exp_name="F5TTS_v1_Base",
|
366 |
+
learning_rate=1e-5,
|
367 |
+
batch_size_per_gpu=1,
|
368 |
+
batch_size_type="sample",
|
369 |
max_samples=64,
|
370 |
+
grad_accumulation_steps=4,
|
371 |
max_grad_norm=1.0,
|
372 |
+
epochs=100,
|
373 |
+
num_warmup_updates=100,
|
374 |
+
save_per_updates=500,
|
375 |
keep_last_n_checkpoints=-1,
|
376 |
+
last_per_updates=100,
|
377 |
finetune=True,
|
378 |
file_checkpoint_train="",
|
379 |
tokenizer_type="pinyin",
|
|
|
797 |
print(f"Error processing {file_audio}: {e}")
|
798 |
continue
|
799 |
|
800 |
+
if duration < 1 or duration > 30:
|
801 |
+
if duration > 30:
|
802 |
+
error_files.append([file_audio, "duration > 30 sec"])
|
803 |
if duration < 1:
|
804 |
error_files.append([file_audio, "duration < 1 sec "])
|
805 |
continue
|
806 |
if len(text) < 3:
|
807 |
+
error_files.append([file_audio, "very short text length 3"])
|
808 |
continue
|
809 |
|
810 |
text = clear_text(text)
|
|
|
871 |
|
872 |
def calculate_train(
|
873 |
name_project,
|
874 |
+
epochs,
|
875 |
+
learning_rate,
|
876 |
+
batch_size_per_gpu,
|
877 |
batch_size_type,
|
878 |
max_samples,
|
|
|
879 |
num_warmup_updates,
|
|
|
|
|
880 |
finetune,
|
881 |
):
|
882 |
path_project = os.path.join(path_data, name_project)
|
883 |
+
file_duration = os.path.join(path_project, "duration.json")
|
884 |
+
|
885 |
+
hop_length = 256
|
886 |
+
sampling_rate = 24000
|
887 |
|
888 |
+
if not os.path.isfile(file_duration):
|
889 |
return (
|
890 |
+
epochs,
|
891 |
+
learning_rate,
|
892 |
+
batch_size_per_gpu,
|
893 |
max_samples,
|
894 |
num_warmup_updates,
|
|
|
|
|
895 |
"project not found !",
|
|
|
896 |
)
|
897 |
|
898 |
+
with open(file_duration, "r") as file:
|
899 |
data = json.load(file)
|
900 |
|
901 |
duration_list = data["duration"]
|
902 |
+
max_sample_length = max(duration_list) * sampling_rate / hop_length
|
903 |
+
total_samples = len(duration_list)
|
904 |
+
total_duration = sum(duration_list)
|
|
|
|
|
|
|
|
|
|
|
905 |
|
906 |
if torch.cuda.is_available():
|
907 |
gpu_count = torch.cuda.device_count()
|
|
|
909 |
for i in range(gpu_count):
|
910 |
gpu_properties = torch.cuda.get_device_properties(i)
|
911 |
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
|
|
912 |
elif torch.xpu.is_available():
|
913 |
gpu_count = torch.xpu.device_count()
|
914 |
total_memory = 0
|
915 |
for i in range(gpu_count):
|
916 |
gpu_properties = torch.xpu.get_device_properties(i)
|
917 |
total_memory += gpu_properties.total_memory / (1024**3)
|
|
|
918 |
elif torch.backends.mps.is_available():
|
919 |
gpu_count = 1
|
920 |
total_memory = psutil.virtual_memory().available / (1024**3)
|
921 |
|
922 |
+
avg_gpu_memory = total_memory / gpu_count
|
923 |
+
|
924 |
+
# rough estimate of batch size
|
925 |
if batch_size_type == "frame":
|
926 |
+
batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
|
927 |
+
elif batch_size_type == "sample":
|
928 |
+
batch_size_per_gpu = int(200 / (total_duration / total_samples))
|
|
|
|
|
|
|
|
|
929 |
|
930 |
+
if total_samples < 64:
|
931 |
+
max_samples = int(total_samples * 0.25)
|
932 |
|
933 |
+
num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
|
934 |
+
|
935 |
+
# take 1.2M updates as the maximum
|
936 |
+
max_updates = 1200000
|
937 |
+
|
938 |
+
if batch_size_type == "frame":
|
939 |
+
mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
|
940 |
+
updates_per_epoch = total_duration / mini_batch_duration
|
941 |
+
elif batch_size_type == "sample":
|
942 |
+
updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
|
943 |
+
|
944 |
+
epochs = int(max_updates / updates_per_epoch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
945 |
|
946 |
if finetune:
|
947 |
learning_rate = 1e-5
|
|
|
949 |
learning_rate = 7.5e-5
|
950 |
|
951 |
return (
|
952 |
+
epochs,
|
953 |
+
learning_rate,
|
954 |
batch_size_per_gpu,
|
955 |
max_samples,
|
956 |
num_warmup_updates,
|
957 |
+
total_samples,
|
|
|
|
|
|
|
|
|
958 |
)
|
959 |
|
960 |
|
|
|
991 |
torch.backends.cudnn.deterministic = True
|
992 |
torch.backends.cudnn.benchmark = False
|
993 |
|
994 |
+
if ckpt_path.endswith(".safetensors"):
|
995 |
+
ckpt = load_file(ckpt_path, device="cpu")
|
996 |
+
ckpt = {"ema_model_state_dict": ckpt}
|
997 |
+
elif ckpt_path.endswith(".pt"):
|
998 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
999 |
|
1000 |
ema_sd = ckpt.get("ema_model_state_dict", {})
|
1001 |
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
|
|
|
1063 |
with open(file_vocab_project, "w", encoding="utf-8") as f:
|
1064 |
f.write("\n".join(vocab))
|
1065 |
|
1066 |
+
if model_type == "F5TTS_v1_Base":
|
1067 |
+
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
|
1068 |
+
elif model_type == "F5TTS_Base":
|
1069 |
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
1070 |
+
elif model_type == "E2TTS_Base":
|
1071 |
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
1072 |
|
1073 |
vocab_size_new = len(miss_symbols)
|
|
|
1077 |
os.makedirs(new_ckpt_path, exist_ok=True)
|
1078 |
|
1079 |
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
|
1080 |
+
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
|
1081 |
|
1082 |
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
|
1083 |
|
|
|
1207 |
vocab_file = os.path.join(path_data, project, "vocab.txt")
|
1208 |
|
1209 |
tts_api = F5TTS(
|
1210 |
+
model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
|
1211 |
)
|
1212 |
|
1213 |
print("update >> ", device_test, file_checkpoint, use_ema)
|
1214 |
|
1215 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
1216 |
tts_api.infer(
|
|
|
|
|
1217 |
ref_file=ref_audio,
|
1218 |
+
ref_text=ref_text.lower().strip(),
|
1219 |
+
gen_text=gen_text.lower().strip(),
|
1220 |
nfe_step=nfe_step,
|
|
|
1221 |
speed=speed,
|
|
|
1222 |
remove_silence=remove_silence,
|
1223 |
+
file_wave=f.name,
|
1224 |
+
seed=seed,
|
1225 |
)
|
1226 |
return f.name, tts_api.device, str(tts_api.seed)
|
1227 |
|
|
|
1380 |
with gr.Blocks() as app:
|
1381 |
gr.Markdown(
|
1382 |
"""
|
1383 |
+
# F5 TTS Automatic Finetune
|
1384 |
|
1385 |
+
This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
|
1386 |
|
1387 |
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
1388 |
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
1389 |
|
1390 |
+
The pretrained checkpoints support English and Chinese.
|
1391 |
|
1392 |
For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
|
1393 |
"""
|
|
|
1464 |
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
|
1465 |
```""")
|
1466 |
|
1467 |
+
exp_name_extend = gr.Radio(
|
1468 |
+
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
1469 |
+
)
|
1470 |
|
1471 |
with gr.Row():
|
1472 |
txt_extend = gr.Textbox(
|
|
|
1535 |
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
1536 |
)
|
1537 |
|
1538 |
+
with gr.TabItem("Train Model"):
|
1539 |
gr.Markdown("""```plaintext
|
1540 |
+
The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
|
1541 |
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
1542 |
```""")
|
1543 |
with gr.Row():
|
|
|
1551 |
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
|
1552 |
|
1553 |
with gr.Row():
|
1554 |
+
exp_name = gr.Radio(
|
1555 |
+
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
1556 |
+
)
|
1557 |
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
1558 |
|
1559 |
with gr.Row():
|
1560 |
+
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
|
1561 |
max_samples = gr.Number(label="Max Samples", value=64)
|
1562 |
|
1563 |
with gr.Row():
|
|
|
1565 |
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
1566 |
|
1567 |
with gr.Row():
|
1568 |
+
epochs = gr.Number(label="Epochs", value=100)
|
1569 |
+
num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
|
1570 |
|
1571 |
with gr.Row():
|
1572 |
+
save_per_updates = gr.Number(label="Save per Updates", value=500)
|
1573 |
keep_last_n_checkpoints = gr.Number(
|
1574 |
label="Keep Last N Checkpoints",
|
1575 |
value=-1,
|
1576 |
step=1,
|
1577 |
precision=0,
|
1578 |
+
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
1579 |
)
|
1580 |
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
1581 |
|
1582 |
with gr.Row():
|
1583 |
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
1584 |
+
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
|
1585 |
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
1586 |
start_button = gr.Button("Start Training")
|
1587 |
stop_button = gr.Button("Stop Training", interactive=False)
|
|
|
1698 |
fn=calculate_train,
|
1699 |
inputs=[
|
1700 |
cm_project,
|
1701 |
+
epochs,
|
1702 |
+
learning_rate,
|
1703 |
+
batch_size_per_gpu,
|
1704 |
batch_size_type,
|
1705 |
max_samples,
|
|
|
1706 |
num_warmup_updates,
|
|
|
|
|
1707 |
ch_finetune,
|
1708 |
],
|
1709 |
outputs=[
|
1710 |
+
epochs,
|
1711 |
+
learning_rate,
|
1712 |
batch_size_per_gpu,
|
1713 |
max_samples,
|
1714 |
num_warmup_updates,
|
|
|
|
|
1715 |
lb_samples,
|
|
|
|
|
1716 |
],
|
1717 |
)
|
1718 |
|
|
|
1722 |
|
1723 |
def setup_load_settings():
|
1724 |
output_components = [
|
1725 |
+
exp_name,
|
1726 |
+
learning_rate,
|
1727 |
+
batch_size_per_gpu,
|
1728 |
+
batch_size_type,
|
1729 |
+
max_samples,
|
1730 |
+
grad_accumulation_steps,
|
1731 |
+
max_grad_norm,
|
1732 |
+
epochs,
|
1733 |
+
num_warmup_updates,
|
1734 |
+
save_per_updates,
|
1735 |
+
keep_last_n_checkpoints,
|
1736 |
+
last_per_updates,
|
1737 |
+
ch_finetune,
|
1738 |
+
file_checkpoint_train,
|
1739 |
+
tokenizer_type,
|
1740 |
+
tokenizer_file,
|
1741 |
+
mixed_precision,
|
1742 |
+
cd_logger,
|
1743 |
+
ch_8bit_adam,
|
1744 |
]
|
1745 |
return output_components
|
1746 |
|
|
|
1762 |
gr.Markdown("""```plaintext
|
1763 |
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
|
1764 |
```""")
|
1765 |
+
exp_name = gr.Radio(
|
1766 |
+
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
1767 |
+
)
|
1768 |
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
1769 |
|
1770 |
with gr.Row():
|
|
|
1818 |
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1819 |
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
1820 |
|
1821 |
+
with gr.TabItem("Prune Checkpoint"):
|
1822 |
gr.Markdown("""```plaintext
|
1823 |
+
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
|
1824 |
```""")
|
1825 |
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
|
1826 |
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
|
src/f5_tts/train/train.py
CHANGED
@@ -4,8 +4,9 @@ import os
|
|
4 |
from importlib.resources import files
|
5 |
|
6 |
import hydra
|
|
|
7 |
|
8 |
-
from f5_tts.model import CFM, DiT,
|
9 |
from f5_tts.model.dataset import load_dataset
|
10 |
from f5_tts.model.utils import get_tokenizer
|
11 |
|
@@ -14,9 +15,13 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
|
|
14 |
|
15 |
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
16 |
def main(cfg):
|
|
|
|
|
17 |
tokenizer = cfg.model.tokenizer
|
18 |
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
|
|
19 |
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
|
|
20 |
|
21 |
# set text tokenizer
|
22 |
if tokenizer != "custom":
|
@@ -26,14 +31,8 @@ def main(cfg):
|
|
26 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
27 |
|
28 |
# set model
|
29 |
-
if "F5TTS" in cfg.model.name:
|
30 |
-
model_cls = DiT
|
31 |
-
elif "E2TTS" in cfg.model.name:
|
32 |
-
model_cls = UNetT
|
33 |
-
wandb_resume_id = None
|
34 |
-
|
35 |
model = CFM(
|
36 |
-
transformer=model_cls(**
|
37 |
mel_spec_kwargs=cfg.model.mel_spec,
|
38 |
vocab_char_map=vocab_char_map,
|
39 |
)
|
@@ -45,9 +44,9 @@ def main(cfg):
|
|
45 |
learning_rate=cfg.optim.learning_rate,
|
46 |
num_warmup_updates=cfg.optim.num_warmup_updates,
|
47 |
save_per_updates=cfg.ckpts.save_per_updates,
|
48 |
-
keep_last_n_checkpoints=
|
49 |
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
50 |
-
|
51 |
batch_size_type=cfg.datasets.batch_size_type,
|
52 |
max_samples=cfg.datasets.max_samples,
|
53 |
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
@@ -57,11 +56,12 @@ def main(cfg):
|
|
57 |
wandb_run_name=exp_name,
|
58 |
wandb_resume_id=wandb_resume_id,
|
59 |
last_per_updates=cfg.ckpts.last_per_updates,
|
60 |
-
log_samples=
|
61 |
bnb_optimizer=cfg.optim.bnb_optimizer,
|
62 |
mel_spec_type=mel_spec_type,
|
63 |
is_local_vocoder=cfg.model.vocoder.is_local,
|
64 |
local_vocoder_path=cfg.model.vocoder.local_path,
|
|
|
65 |
)
|
66 |
|
67 |
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
|
|
4 |
from importlib.resources import files
|
5 |
|
6 |
import hydra
|
7 |
+
from omegaconf import OmegaConf
|
8 |
|
9 |
+
from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
|
10 |
from f5_tts.model.dataset import load_dataset
|
11 |
from f5_tts.model.utils import get_tokenizer
|
12 |
|
|
|
15 |
|
16 |
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
17 |
def main(cfg):
|
18 |
+
model_cls = globals()[cfg.model.backbone]
|
19 |
+
model_arc = cfg.model.arch
|
20 |
tokenizer = cfg.model.tokenizer
|
21 |
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
22 |
+
|
23 |
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
24 |
+
wandb_resume_id = None
|
25 |
|
26 |
# set text tokenizer
|
27 |
if tokenizer != "custom":
|
|
|
31 |
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
32 |
|
33 |
# set model
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
model = CFM(
|
35 |
+
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
|
36 |
mel_spec_kwargs=cfg.model.mel_spec,
|
37 |
vocab_char_map=vocab_char_map,
|
38 |
)
|
|
|
44 |
learning_rate=cfg.optim.learning_rate,
|
45 |
num_warmup_updates=cfg.optim.num_warmup_updates,
|
46 |
save_per_updates=cfg.ckpts.save_per_updates,
|
47 |
+
keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
|
48 |
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
49 |
+
batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
|
50 |
batch_size_type=cfg.datasets.batch_size_type,
|
51 |
max_samples=cfg.datasets.max_samples,
|
52 |
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
|
|
56 |
wandb_run_name=exp_name,
|
57 |
wandb_resume_id=wandb_resume_id,
|
58 |
last_per_updates=cfg.ckpts.last_per_updates,
|
59 |
+
log_samples=cfg.ckpts.log_samples,
|
60 |
bnb_optimizer=cfg.optim.bnb_optimizer,
|
61 |
mel_spec_type=mel_spec_type,
|
62 |
is_local_vocoder=cfg.model.vocoder.is_local,
|
63 |
local_vocoder_path=cfg.model.vocoder.local_path,
|
64 |
+
cfg_dict=OmegaConf.to_container(cfg, resolve=True),
|
65 |
)
|
66 |
|
67 |
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|