SWivid commited on
Commit
7ccc021
·
1 Parent(s): dfe1d95

1.0.0 F5-TTS v1 base model with better training and inference performance

Browse files
Files changed (38) hide show
  1. .github/workflows/publish-pypi.yaml +66 -0
  2. README.md +3 -2
  3. ckpts/README.md +5 -3
  4. pyproject.toml +1 -2
  5. src/f5_tts/api.py +50 -59
  6. src/f5_tts/configs/{E2TTS_Base_train.yaml → E2TTS_Base.yaml} +11 -7
  7. src/f5_tts/configs/{E2TTS_Small_train.yaml → E2TTS_Small.yaml} +11 -7
  8. src/f5_tts/configs/{F5TTS_Base_train.yaml → F5TTS_Base.yaml} +11 -7
  9. src/f5_tts/configs/{F5TTS_Small_train.yaml → F5TTS_Small.yaml} +11 -7
  10. src/f5_tts/configs/F5TTS_v1_Base.yaml +53 -0
  11. src/f5_tts/eval/eval_infer_batch.py +22 -27
  12. src/f5_tts/eval/eval_infer_batch.sh +11 -6
  13. src/f5_tts/eval/eval_librispeech_test_clean.py +21 -27
  14. src/f5_tts/eval/eval_seedtts_testset.py +21 -27
  15. src/f5_tts/eval/eval_utmos.py +14 -16
  16. src/f5_tts/eval/utils_eval.py +11 -6
  17. src/f5_tts/infer/README.md +20 -85
  18. src/f5_tts/infer/SHARED.md +19 -9
  19. src/f5_tts/infer/infer_cli.py +26 -31
  20. src/f5_tts/infer/infer_gradio.py +36 -11
  21. src/f5_tts/infer/speech_edit.py +25 -26
  22. src/f5_tts/infer/utils_infer.py +6 -6
  23. src/f5_tts/model/backbones/README.md +2 -2
  24. src/f5_tts/model/backbones/dit.py +63 -8
  25. src/f5_tts/model/backbones/mmdit.py +52 -9
  26. src/f5_tts/model/backbones/unett.py +36 -5
  27. src/f5_tts/model/cfm.py +3 -2
  28. src/f5_tts/model/dataset.py +5 -2
  29. src/f5_tts/model/modules.py +115 -42
  30. src/f5_tts/model/trainer.py +29 -18
  31. src/f5_tts/model/utils.py +4 -3
  32. src/f5_tts/scripts/count_max_epoch.py +1 -1
  33. src/f5_tts/socket_client.py +61 -0
  34. src/f5_tts/socket_server.py +19 -9
  35. src/f5_tts/train/README.md +5 -5
  36. src/f5_tts/train/finetune_cli.py +47 -15
  37. src/f5_tts/train/finetune_gradio.py +128 -148
  38. src/f5_tts/train/train.py +11 -11
.github/workflows/publish-pypi.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow uses actions that are not certified by GitHub.
2
+ # They are provided by a third-party and are governed by
3
+ # separate terms of service, privacy policy, and support
4
+ # documentation.
5
+
6
+ # GitHub recommends pinning actions to a commit SHA.
7
+ # To get a newer version, you will need to update the SHA.
8
+ # You can also reference a tag or branch, but the action may change without warning.
9
+
10
+ name: Upload Python Package
11
+
12
+ on:
13
+ release:
14
+ types: [published]
15
+
16
+ permissions:
17
+ contents: read
18
+
19
+ jobs:
20
+ release-build:
21
+ runs-on: ubuntu-latest
22
+
23
+ steps:
24
+ - uses: actions/checkout@v4
25
+
26
+ - uses: actions/setup-python@v5
27
+ with:
28
+ python-version: "3.x"
29
+
30
+ - name: Build release distributions
31
+ run: |
32
+ # NOTE: put your own distribution build steps here.
33
+ python -m pip install build
34
+ python -m build
35
+
36
+ - name: Upload distributions
37
+ uses: actions/upload-artifact@v4
38
+ with:
39
+ name: release-dists
40
+ path: dist/
41
+
42
+ pypi-publish:
43
+ runs-on: ubuntu-latest
44
+
45
+ needs:
46
+ - release-build
47
+
48
+ permissions:
49
+ # IMPORTANT: this permission is mandatory for trusted publishing
50
+ id-token: write
51
+
52
+ # Dedicated environments with protections for publishing are strongly recommended.
53
+ environment:
54
+ name: pypi
55
+ # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
56
+ # url: https://pypi.org/p/YOURPROJECT
57
+
58
+ steps:
59
+ - name: Retrieve release distributions
60
+ uses: actions/download-artifact@v4
61
+ with:
62
+ name: release-dists
63
+ path: dist/
64
+
65
+ - name: Publish release distributions to PyPI
66
+ uses: pypa/gh-action-pypi-publish@6f7e8d9c0b1a2c3d4e5f6a7b8c9d0e1f2a3b4c5d
README.md CHANGED
@@ -18,6 +18,7 @@
18
  ### Thanks to all the contributors !
19
 
20
  ## News
 
21
  - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
22
 
23
  ## Installation
@@ -37,7 +38,7 @@ conda activate f5-tts
37
 
38
  > ```bash
39
  > # Install pytorch with your CUDA version, e.g.
40
- > pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
41
  > ```
42
 
43
  </details>
@@ -159,7 +160,7 @@ volumes:
159
  # Run with flags
160
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
161
  f5-tts_infer-cli \
162
- --model "F5-TTS" \
163
  --ref_audio "ref_audio.wav" \
164
  --ref_text "The content, subtitle or transcription of reference audio." \
165
  --gen_text "Some text you want TTS model generate for you."
 
18
  ### Thanks to all the contributors !
19
 
20
  ## News
21
+ - **2025/03/12**: F5-TTS v1 base model with better training and inference performance.
22
  - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
23
 
24
  ## Installation
 
38
 
39
  > ```bash
40
  > # Install pytorch with your CUDA version, e.g.
41
+ > pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
42
  > ```
43
 
44
  </details>
 
160
  # Run with flags
161
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
162
  f5-tts_infer-cli \
163
+ --model "F5-TTS_v1" \
164
  --ref_audio "ref_audio.wav" \
165
  --ref_text "The content, subtitle or transcription of reference audio." \
166
  --gen_text "Some text you want TTS model generate for you."
ckpts/README.md CHANGED
@@ -3,8 +3,10 @@ Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
3
 
4
  ```
5
  ckpts/
6
- E2TTS_Base/
7
- model_1200000.pt
8
  F5TTS_Base/
9
- model_1200000.pt
 
 
10
  ```
 
3
 
4
  ```
5
  ckpts/
6
+ F5TTS_v1_Base/
7
+ model_1250000.safetensors
8
  F5TTS_Base/
9
+ model_1200000.safetensors
10
+ E2TTS_Base/
11
+ model_1200000.safetensors
12
  ```
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "0.6.2"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
@@ -25,7 +25,6 @@ dependencies = [
25
  "jieba",
26
  "librosa",
27
  "matplotlib",
28
- "nltk",
29
  "numpy<=1.26.4",
30
  "pydub",
31
  "pypinyin",
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "1.0.0"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
25
  "jieba",
26
  "librosa",
27
  "matplotlib",
 
28
  "numpy<=1.26.4",
29
  "pydub",
30
  "pypinyin",
src/f5_tts/api.py CHANGED
@@ -5,43 +5,43 @@ from importlib.resources import files
5
  import soundfile as sf
6
  import tqdm
7
  from cached_path import cached_path
 
8
 
9
  from f5_tts.infer.utils_infer import (
10
- hop_length,
11
- infer_process,
12
  load_model,
13
  load_vocoder,
 
14
  preprocess_ref_audio_text,
 
15
  remove_silence_for_generated_wav,
16
  save_spectrogram,
17
- transcribe,
18
- target_sample_rate,
19
  )
20
- from f5_tts.model import DiT, UNetT
21
  from f5_tts.model.utils import seed_everything
22
 
23
 
24
  class F5TTS:
25
  def __init__(
26
  self,
27
- model_type="F5-TTS",
28
  ckpt_file="",
29
  vocab_file="",
30
  ode_method="euler",
31
  use_ema=True,
32
- vocoder_name="vocos",
33
- local_path=None,
34
  device=None,
35
  hf_cache_dir=None,
36
  ):
37
- # Initialize parameters
38
- self.final_wave = None
39
- self.target_sample_rate = target_sample_rate
40
- self.hop_length = hop_length
41
- self.seed = -1
42
- self.mel_spec_type = vocoder_name
43
-
44
- # Set device
 
 
45
  if device is not None:
46
  self.device = device
47
  else:
@@ -58,39 +58,31 @@ class F5TTS:
58
  )
59
 
60
  # Load models
61
- self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
62
- self.load_ema_model(
63
- model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
64
  )
65
 
66
- def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
67
- self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
68
-
69
- def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
70
- if model_type == "F5-TTS":
71
- if not ckpt_file:
72
- if mel_spec_type == "vocos":
73
- ckpt_file = str(
74
- cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
75
- )
76
- elif mel_spec_type == "bigvgan":
77
- ckpt_file = str(
78
- cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
79
- )
80
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
81
- model_cls = DiT
82
- elif model_type == "E2-TTS":
83
- if not ckpt_file:
84
- ckpt_file = str(
85
- cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
86
- )
87
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
88
- model_cls = UNetT
89
  else:
90
- raise ValueError(f"Unknown model type: {model_type}")
91
 
 
 
 
 
92
  self.ema_model = load_model(
93
- model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
94
  )
95
 
96
  def transcribe(self, ref_audio, language=None):
@@ -102,8 +94,8 @@ class F5TTS:
102
  if remove_silence:
103
  remove_silence_for_generated_wav(file_wave)
104
 
105
- def export_spectrogram(self, spect, file_spect):
106
- save_spectrogram(spect, file_spect)
107
 
108
  def infer(
109
  self,
@@ -121,17 +113,16 @@ class F5TTS:
121
  fix_duration=None,
122
  remove_silence=False,
123
  file_wave=None,
124
- file_spect=None,
125
- seed=-1,
126
  ):
127
- if seed == -1:
128
- seed = random.randint(0, sys.maxsize)
129
- seed_everything(seed)
130
- self.seed = seed
131
 
132
  ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
133
 
134
- wav, sr, spect = infer_process(
135
  ref_file,
136
  ref_text,
137
  gen_text,
@@ -153,22 +144,22 @@ class F5TTS:
153
  if file_wave is not None:
154
  self.export_wav(wav, file_wave, remove_silence)
155
 
156
- if file_spect is not None:
157
- self.export_spectrogram(spect, file_spect)
158
 
159
- return wav, sr, spect
160
 
161
 
162
  if __name__ == "__main__":
163
  f5tts = F5TTS()
164
 
165
- wav, sr, spect = f5tts.infer(
166
  ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
167
  ref_text="some call me nature, others call me mother nature.",
168
  gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
169
  file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
170
- file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
171
- seed=-1, # random seed = -1
172
  )
173
 
174
  print("seed :", f5tts.seed)
 
5
  import soundfile as sf
6
  import tqdm
7
  from cached_path import cached_path
8
+ from omegaconf import OmegaConf
9
 
10
  from f5_tts.infer.utils_infer import (
 
 
11
  load_model,
12
  load_vocoder,
13
+ transcribe,
14
  preprocess_ref_audio_text,
15
+ infer_process,
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
 
 
18
  )
19
+ from f5_tts.model import DiT, UNetT # noqa: F401. used for config
20
  from f5_tts.model.utils import seed_everything
21
 
22
 
23
  class F5TTS:
24
  def __init__(
25
  self,
26
+ model="F5TTS_v1_Base",
27
  ckpt_file="",
28
  vocab_file="",
29
  ode_method="euler",
30
  use_ema=True,
31
+ vocoder_local_path=None,
 
32
  device=None,
33
  hf_cache_dir=None,
34
  ):
35
+ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
36
+ model_cls = globals()[model_cfg.model.backbone]
37
+ model_arc = model_cfg.model.arch
38
+
39
+ self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
40
+ self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
41
+
42
+ self.ode_method = ode_method
43
+ self.use_ema = use_ema
44
+
45
  if device is not None:
46
  self.device = device
47
  else:
 
58
  )
59
 
60
  # Load models
61
+ self.vocoder = load_vocoder(
62
+ self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
 
63
  )
64
 
65
+ repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
66
+
67
+ # override for previous models
68
+ if model == "F5TTS_Base":
69
+ if self.mel_spec_type == "vocos":
70
+ ckpt_step = 1200000
71
+ elif self.mel_spec_type == "bigvgan":
72
+ model = "F5TTS_Base_bigvgan"
73
+ ckpt_type = "pt"
74
+ elif model == "E2TTS_Base":
75
+ repo_name = "E2-TTS"
76
+ ckpt_step = 1200000
 
 
 
 
 
 
 
 
 
 
 
77
  else:
78
+ raise ValueError(f"Unknown model type: {model}")
79
 
80
+ if not ckpt_file:
81
+ ckpt_file = str(
82
+ cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
83
+ )
84
  self.ema_model = load_model(
85
+ model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
86
  )
87
 
88
  def transcribe(self, ref_audio, language=None):
 
94
  if remove_silence:
95
  remove_silence_for_generated_wav(file_wave)
96
 
97
+ def export_spectrogram(self, spec, file_spec):
98
+ save_spectrogram(spec, file_spec)
99
 
100
  def infer(
101
  self,
 
113
  fix_duration=None,
114
  remove_silence=False,
115
  file_wave=None,
116
+ file_spec=None,
117
+ seed=None,
118
  ):
119
+ if seed is None:
120
+ self.seed = random.randint(0, sys.maxsize)
121
+ seed_everything(self.seed)
 
122
 
123
  ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
124
 
125
+ wav, sr, spec = infer_process(
126
  ref_file,
127
  ref_text,
128
  gen_text,
 
144
  if file_wave is not None:
145
  self.export_wav(wav, file_wave, remove_silence)
146
 
147
+ if file_spec is not None:
148
+ self.export_spectrogram(spec, file_spec)
149
 
150
+ return wav, sr, spec
151
 
152
 
153
  if __name__ == "__main__":
154
  f5tts = F5TTS()
155
 
156
+ wav, sr, spec = f5tts.infer(
157
  ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
158
  ref_text="some call me nature, others call me mother nature.",
159
  gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
160
  file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
161
+ file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
162
+ seed=None,
163
  )
164
 
165
  print("seed :", f5tts.seed)
src/f5_tts/configs/{E2TTS_Base_train.yaml → E2TTS_Base.yaml} RENAMED
@@ -1,16 +1,16 @@
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
  datasets:
6
  name: Emilia_ZH_EN # dataset name
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
- epochs: 15
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,25 +20,29 @@ optim:
20
  model:
21
  name: E2TTS_Base
22
  tokenizer: pinyin
23
- tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
 
24
  arch:
25
  dim: 1024
26
  depth: 24
27
  heads: 16
28
  ff_mult: 4
 
 
29
  mel_spec:
30
  target_sample_rate: 24000
31
  n_mel_channels: 100
32
  hop_length: 256
33
  win_length: 1024
34
  n_fft: 1024
35
- mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
  vocoder:
37
  is_local: False # use local offline ckpt or not
38
- local_path: None # local vocoder path
39
 
40
  ckpts:
41
- logger: wandb # wandb | tensorboard | None
 
42
  save_per_updates: 50000 # save checkpoint per updates
43
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
44
  last_per_updates: 5000 # save last checkpoint per updates
 
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
  datasets:
6
  name: Emilia_ZH_EN # dataset name
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
+ epochs: 11
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
 
20
  model:
21
  name: E2TTS_Base
22
  tokenizer: pinyin
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: UNetT
25
  arch:
26
  dim: 1024
27
  depth: 24
28
  heads: 16
29
  ff_mult: 4
30
+ text_mask_padding: False
31
+ pe_attn_head: 1
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
35
  hop_length: 256
36
  win_length: 1024
37
  n_fft: 1024
38
+ mel_spec_type: vocos # vocos | bigvgan
39
  vocoder:
40
  is_local: False # use local offline ckpt or not
41
+ local_path: null # local vocoder path
42
 
43
  ckpts:
44
+ logger: wandb # wandb | tensorboard | null
45
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
46
  save_per_updates: 50000 # save checkpoint per updates
47
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
48
  last_per_updates: 5000 # save last checkpoint per updates
src/f5_tts/configs/{E2TTS_Small_train.yaml → E2TTS_Small.yaml} RENAMED
@@ -1,16 +1,16 @@
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
  datasets:
6
  name: Emilia_ZH_EN
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
- epochs: 15
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,25 +20,29 @@ optim:
20
  model:
21
  name: E2TTS_Small
22
  tokenizer: pinyin
23
- tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
 
24
  arch:
25
  dim: 768
26
  depth: 20
27
  heads: 12
28
  ff_mult: 4
 
 
29
  mel_spec:
30
  target_sample_rate: 24000
31
  n_mel_channels: 100
32
  hop_length: 256
33
  win_length: 1024
34
  n_fft: 1024
35
- mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
  vocoder:
37
  is_local: False # use local offline ckpt or not
38
- local_path: None # local vocoder path
39
 
40
  ckpts:
41
- logger: wandb # wandb | tensorboard | None
 
42
  save_per_updates: 50000 # save checkpoint per updates
43
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
44
  last_per_updates: 5000 # save last checkpoint per updates
 
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
  datasets:
6
  name: Emilia_ZH_EN
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
+ epochs: 11
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
 
20
  model:
21
  name: E2TTS_Small
22
  tokenizer: pinyin
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: UNetT
25
  arch:
26
  dim: 768
27
  depth: 20
28
  heads: 12
29
  ff_mult: 4
30
+ text_mask_padding: False
31
+ pe_attn_head: 1
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
35
  hop_length: 256
36
  win_length: 1024
37
  n_fft: 1024
38
+ mel_spec_type: vocos # vocos | bigvgan
39
  vocoder:
40
  is_local: False # use local offline ckpt or not
41
+ local_path: null # local vocoder path
42
 
43
  ckpts:
44
+ logger: wandb # wandb | tensorboard | null
45
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
46
  save_per_updates: 50000 # save checkpoint per updates
47
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
48
  last_per_updates: 5000 # save last checkpoint per updates
src/f5_tts/configs/{F5TTS_Base_train.yaml → F5TTS_Base.yaml} RENAMED
@@ -1,16 +1,16 @@
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
  datasets:
6
  name: Emilia_ZH_EN # dataset name
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
- epochs: 15
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,14 +20,17 @@ optim:
20
  model:
21
  name: F5TTS_Base # model name
22
  tokenizer: pinyin # tokenizer type
23
- tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
 
24
  arch:
25
  dim: 1024
26
  depth: 22
27
  heads: 16
28
  ff_mult: 2
29
  text_dim: 512
 
30
  conv_layers: 4
 
31
  checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
@@ -35,13 +38,14 @@ model:
35
  hop_length: 256
36
  win_length: 1024
37
  n_fft: 1024
38
- mel_spec_type: vocos # 'vocos' or 'bigvgan'
39
  vocoder:
40
  is_local: False # use local offline ckpt or not
41
- local_path: None # local vocoder path
42
 
43
  ckpts:
44
- logger: wandb # wandb | tensorboard | None
 
45
  save_per_updates: 50000 # save checkpoint per updates
46
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
47
  last_per_updates: 5000 # save last checkpoint per updates
 
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
  datasets:
6
  name: Emilia_ZH_EN # dataset name
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
+ epochs: 11
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
 
20
  model:
21
  name: F5TTS_Base # model name
22
  tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: DiT
25
  arch:
26
  dim: 1024
27
  depth: 22
28
  heads: 16
29
  ff_mult: 2
30
  text_dim: 512
31
+ text_mask_padding: False
32
  conv_layers: 4
33
+ pe_attn_head: 1
34
  checkpoint_activations: False # recompute activations and save memory for extra compute
35
  mel_spec:
36
  target_sample_rate: 24000
 
38
  hop_length: 256
39
  win_length: 1024
40
  n_fft: 1024
41
+ mel_spec_type: vocos # vocos | bigvgan
42
  vocoder:
43
  is_local: False # use local offline ckpt or not
44
+ local_path: null # local vocoder path
45
 
46
  ckpts:
47
+ logger: wandb # wandb | tensorboard | null
48
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
49
  save_per_updates: 50000 # save checkpoint per updates
50
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
51
  last_per_updates: 5000 # save last checkpoint per updates
src/f5_tts/configs/{F5TTS_Small_train.yaml → F5TTS_Small.yaml} RENAMED
@@ -1,16 +1,16 @@
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
  datasets:
6
  name: Emilia_ZH_EN
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # "frame" or "sample"
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
- epochs: 15
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,14 +20,17 @@ optim:
20
  model:
21
  name: F5TTS_Small
22
  tokenizer: pinyin
23
- tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
 
24
  arch:
25
  dim: 768
26
  depth: 18
27
  heads: 12
28
  ff_mult: 2
29
  text_dim: 512
 
30
  conv_layers: 4
 
31
  checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
@@ -35,13 +38,14 @@ model:
35
  hop_length: 256
36
  win_length: 1024
37
  n_fft: 1024
38
- mel_spec_type: vocos # 'vocos' or 'bigvgan'
39
  vocoder:
40
  is_local: False # use local offline ckpt or not
41
- local_path: None # local vocoder path
42
 
43
  ckpts:
44
- logger: wandb # wandb | tensorboard | None
 
45
  save_per_updates: 50000 # save checkpoint per updates
46
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
47
  last_per_updates: 5000 # save last checkpoint per updates
 
1
  hydra:
2
  run:
3
  dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
  datasets:
6
  name: Emilia_ZH_EN
7
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
  max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
  num_workers: 16
11
 
12
  optim:
13
+ epochs: 11
14
  learning_rate: 7.5e-5
15
  num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
 
20
  model:
21
  name: F5TTS_Small
22
  tokenizer: pinyin
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: DiT
25
  arch:
26
  dim: 768
27
  depth: 18
28
  heads: 12
29
  ff_mult: 2
30
  text_dim: 512
31
+ text_mask_padding: False
32
  conv_layers: 4
33
+ pe_attn_head: 1
34
  checkpoint_activations: False # recompute activations and save memory for extra compute
35
  mel_spec:
36
  target_sample_rate: 24000
 
38
  hop_length: 256
39
  win_length: 1024
40
  n_fft: 1024
41
+ mel_spec_type: vocos # vocos | bigvgan
42
  vocoder:
43
  is_local: False # use local offline ckpt or not
44
+ local_path: null # local vocoder path
45
 
46
  ckpts:
47
+ logger: wandb # wandb | tensorboard | null
48
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
49
  save_per_updates: 50000 # save checkpoint per updates
50
  keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
51
  last_per_updates: 5000 # save last checkpoint per updates
src/f5_tts/configs/F5TTS_v1_Base.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ datasets:
6
+ name: Emilia_ZH_EN # dataset name
7
+ batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
+ batch_size_type: frame # frame | sample
9
+ max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
+ num_workers: 16
11
+
12
+ optim:
13
+ epochs: 11
14
+ learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
+ grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
+ max_grad_norm: 1.0 # gradient clipping
18
+ bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
+
20
+ model:
21
+ name: F5TTS_v1_Base # model name
22
+ tokenizer: pinyin # tokenizer type
23
+ tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
24
+ backbone: DiT
25
+ arch:
26
+ dim: 1024
27
+ depth: 22
28
+ heads: 16
29
+ ff_mult: 2
30
+ text_dim: 512
31
+ text_mask_padding: True
32
+ qk_norm: null # null | rms_norm
33
+ conv_layers: 4
34
+ pe_attn_head: null
35
+ checkpoint_activations: False # recompute activations and save memory for extra compute
36
+ mel_spec:
37
+ target_sample_rate: 24000
38
+ n_mel_channels: 100
39
+ hop_length: 256
40
+ win_length: 1024
41
+ n_fft: 1024
42
+ mel_spec_type: vocos # vocos | bigvgan
43
+ vocoder:
44
+ is_local: False # use local offline ckpt or not
45
+ local_path: null # local vocoder path
46
+
47
+ ckpts:
48
+ logger: wandb # wandb | tensorboard | null
49
+ log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
50
+ save_per_updates: 50000 # save checkpoint per updates
51
+ keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
52
+ last_per_updates: 5000 # save last checkpoint per updates
53
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -10,6 +10,7 @@ from importlib.resources import files
10
  import torch
11
  import torchaudio
12
  from accelerate import Accelerator
 
13
  from tqdm import tqdm
14
 
15
  from f5_tts.eval.utils_eval import (
@@ -18,36 +19,26 @@ from f5_tts.eval.utils_eval import (
18
  get_seedtts_testset_metainfo,
19
  )
20
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
21
- from f5_tts.model import CFM, DiT, UNetT
22
  from f5_tts.model.utils import get_tokenizer
23
 
24
  accelerator = Accelerator()
25
  device = f"cuda:{accelerator.process_index}"
26
 
27
 
28
- # --------------------- Dataset Settings -------------------- #
29
-
30
- target_sample_rate = 24000
31
- n_mel_channels = 100
32
- hop_length = 256
33
- win_length = 1024
34
- n_fft = 1024
35
  target_rms = 0.1
36
 
 
37
  rel_path = str(files("f5_tts").joinpath("../../"))
38
 
39
 
40
  def main():
41
- # ---------------------- infer setting ---------------------- #
42
-
43
  parser = argparse.ArgumentParser(description="batch inference")
44
 
45
  parser.add_argument("-s", "--seed", default=None, type=int)
46
- parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
47
  parser.add_argument("-n", "--expname", required=True)
48
- parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
49
- parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
50
- parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
51
 
52
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
53
  parser.add_argument("-o", "--odemethod", default="euler")
@@ -58,12 +49,8 @@ def main():
58
  args = parser.parse_args()
59
 
60
  seed = args.seed
61
- dataset_name = args.dataset
62
  exp_name = args.expname
63
  ckpt_step = args.ckptstep
64
- ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
65
- mel_spec_type = args.mel_spec_type
66
- tokenizer = args.tokenizer
67
 
68
  nfe_step = args.nfestep
69
  ode_method = args.odemethod
@@ -77,13 +64,19 @@ def main():
77
  use_truth_duration = False
78
  no_ref_audio = False
79
 
80
- if exp_name == "F5TTS_Base":
81
- model_cls = DiT
82
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
83
 
84
- elif exp_name == "E2TTS_Base":
85
- model_cls = UNetT
86
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
 
 
 
87
 
88
  if testset == "ls_pc_test_clean":
89
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
@@ -111,8 +104,6 @@ def main():
111
 
112
  # -------------------------------------------------#
113
 
114
- use_ema = True
115
-
116
  prompts_all = get_inference_prompt(
117
  metainfo,
118
  speed=speed,
@@ -139,7 +130,7 @@ def main():
139
 
140
  # Model
141
  model = CFM(
142
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
143
  mel_spec_kwargs=dict(
144
  n_fft=n_fft,
145
  hop_length=hop_length,
@@ -154,6 +145,10 @@ def main():
154
  vocab_char_map=vocab_char_map,
155
  ).to(device)
156
 
 
 
 
 
157
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
158
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
159
 
 
10
  import torch
11
  import torchaudio
12
  from accelerate import Accelerator
13
+ from omegaconf import OmegaConf
14
  from tqdm import tqdm
15
 
16
  from f5_tts.eval.utils_eval import (
 
19
  get_seedtts_testset_metainfo,
20
  )
21
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
22
+ from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
23
  from f5_tts.model.utils import get_tokenizer
24
 
25
  accelerator = Accelerator()
26
  device = f"cuda:{accelerator.process_index}"
27
 
28
 
29
+ use_ema = True
 
 
 
 
 
 
30
  target_rms = 0.1
31
 
32
+
33
  rel_path = str(files("f5_tts").joinpath("../../"))
34
 
35
 
36
  def main():
 
 
37
  parser = argparse.ArgumentParser(description="batch inference")
38
 
39
  parser.add_argument("-s", "--seed", default=None, type=int)
 
40
  parser.add_argument("-n", "--expname", required=True)
41
+ parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
 
 
42
 
43
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
44
  parser.add_argument("-o", "--odemethod", default="euler")
 
49
  args = parser.parse_args()
50
 
51
  seed = args.seed
 
52
  exp_name = args.expname
53
  ckpt_step = args.ckptstep
 
 
 
54
 
55
  nfe_step = args.nfestep
56
  ode_method = args.odemethod
 
64
  use_truth_duration = False
65
  no_ref_audio = False
66
 
67
+ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
68
+ model_cls = globals()[model_cfg.model.backbone]
69
+ model_arc = model_cfg.model.arch
70
 
71
+ dataset_name = model_cfg.datasets.name
72
+ tokenizer = model_cfg.model.tokenizer
73
+
74
+ mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
75
+ target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
76
+ n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
77
+ hop_length = model_cfg.model.mel_spec.hop_length
78
+ win_length = model_cfg.model.mel_spec.win_length
79
+ n_fft = model_cfg.model.mel_spec.n_fft
80
 
81
  if testset == "ls_pc_test_clean":
82
  metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
 
104
 
105
  # -------------------------------------------------#
106
 
 
 
107
  prompts_all = get_inference_prompt(
108
  metainfo,
109
  speed=speed,
 
130
 
131
  # Model
132
  model = CFM(
133
+ transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
134
  mel_spec_kwargs=dict(
135
  n_fft=n_fft,
136
  hop_length=hop_length,
 
145
  vocab_char_map=vocab_char_map,
146
  ).to(device)
147
 
148
+ ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
149
+ if not os.path.exists(ckpt_path):
150
+ print("Loading from self-organized training checkpoints rather than released pretrained.")
151
+ ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
152
  dtype = torch.float32 if mel_spec_type == "bigvgan" else None
153
  model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
154
 
src/f5_tts/eval/eval_infer_batch.sh CHANGED
@@ -1,13 +1,18 @@
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
- accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
 
 
 
 
 
12
 
13
  # etc.
 
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
14
+ python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
15
+ python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
16
+ python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
17
 
18
  # etc.
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -53,43 +53,37 @@ def main():
53
  asr_ckpt_dir = "" # auto download to cache dir
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
- # --------------------------- WER ---------------------------
57
 
58
- if eval_task == "wer":
59
- wer_results = []
60
- wers = []
61
 
 
62
  with mp.Pool(processes=len(gpus)) as pool:
63
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
  results = pool.map(run_asr_wer, args)
65
  for r in results:
66
- wer_results.extend(r)
67
-
68
- wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
69
- with open(wer_result_path, "w") as f:
70
- for line in wer_results:
71
- wers.append(line["wer"])
72
- json_line = json.dumps(line, ensure_ascii=False)
73
- f.write(json_line + "\n")
74
-
75
- wer = round(np.mean(wers) * 100, 3)
76
- print(f"\nTotal {len(wers)} samples")
77
- print(f"WER : {wer}%")
78
- print(f"Results have been saved to {wer_result_path}")
79
-
80
- # --------------------------- SIM ---------------------------
81
-
82
- if eval_task == "sim":
83
- sims = []
84
  with mp.Pool(processes=len(gpus)) as pool:
85
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
86
  results = pool.map(run_sim, args)
87
  for r in results:
88
- sims.extend(r)
89
-
90
- sim = round(sum(sims) / len(sims), 3)
91
- print(f"\nTotal {len(sims)} samples")
92
- print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  if __name__ == "__main__":
 
53
  asr_ckpt_dir = "" # auto download to cache dir
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
+ # --------------------------------------------------------------------------
57
 
58
+ full_results = []
59
+ metrics = []
 
60
 
61
+ if eval_task == "wer":
62
  with mp.Pool(processes=len(gpus)) as pool:
63
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
  results = pool.map(run_asr_wer, args)
65
  for r in results:
66
+ full_results.extend(r)
67
+ elif eval_task == "sim":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with mp.Pool(processes=len(gpus)) as pool:
69
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
70
  results = pool.map(run_sim, args)
71
  for r in results:
72
+ full_results.extend(r)
73
+ else:
74
+ raise ValueError(f"Unknown metric type: {eval_task}")
75
+
76
+ result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
77
+ with open(result_path, "w") as f:
78
+ for line in full_results:
79
+ metrics.append(line[eval_task])
80
+ f.write(json.dumps(line, ensure_ascii=False) + "\n")
81
+ metric = round(np.mean(metrics), 5)
82
+ f.write(f"\n{eval_task.upper()}: {metric}\n")
83
+
84
+ print(f"\nTotal {len(metrics)} samples")
85
+ print(f"{eval_task.upper()}: {metric}")
86
+ print(f"{eval_task.upper()} results saved to {result_path}")
87
 
88
 
89
  if __name__ == "__main__":
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -52,43 +52,37 @@ def main():
52
  asr_ckpt_dir = "" # auto download to cache dir
53
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
54
 
55
- # --------------------------- WER ---------------------------
56
 
57
- if eval_task == "wer":
58
- wer_results = []
59
- wers = []
60
 
 
61
  with mp.Pool(processes=len(gpus)) as pool:
62
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
  results = pool.map(run_asr_wer, args)
64
  for r in results:
65
- wer_results.extend(r)
66
-
67
- wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
68
- with open(wer_result_path, "w") as f:
69
- for line in wer_results:
70
- wers.append(line["wer"])
71
- json_line = json.dumps(line, ensure_ascii=False)
72
- f.write(json_line + "\n")
73
-
74
- wer = round(np.mean(wers) * 100, 3)
75
- print(f"\nTotal {len(wers)} samples")
76
- print(f"WER : {wer}%")
77
- print(f"Results have been saved to {wer_result_path}")
78
-
79
- # --------------------------- SIM ---------------------------
80
-
81
- if eval_task == "sim":
82
- sims = []
83
  with mp.Pool(processes=len(gpus)) as pool:
84
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
85
  results = pool.map(run_sim, args)
86
  for r in results:
87
- sims.extend(r)
88
-
89
- sim = round(sum(sims) / len(sims), 3)
90
- print(f"\nTotal {len(sims)} samples")
91
- print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  if __name__ == "__main__":
 
52
  asr_ckpt_dir = "" # auto download to cache dir
53
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
54
 
55
+ # --------------------------------------------------------------------------
56
 
57
+ full_results = []
58
+ metrics = []
 
59
 
60
+ if eval_task == "wer":
61
  with mp.Pool(processes=len(gpus)) as pool:
62
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
  results = pool.map(run_asr_wer, args)
64
  for r in results:
65
+ full_results.extend(r)
66
+ elif eval_task == "sim":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with mp.Pool(processes=len(gpus)) as pool:
68
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
69
  results = pool.map(run_sim, args)
70
  for r in results:
71
+ full_results.extend(r)
72
+ else:
73
+ raise ValueError(f"Unknown metric type: {eval_task}")
74
+
75
+ result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
76
+ with open(result_path, "w") as f:
77
+ for line in full_results:
78
+ metrics.append(line[eval_task])
79
+ f.write(json.dumps(line, ensure_ascii=False) + "\n")
80
+ metric = round(np.mean(metrics), 5)
81
+ f.write(f"\n{eval_task.upper()}: {metric}\n")
82
+
83
+ print(f"\nTotal {len(metrics)} samples")
84
+ print(f"{eval_task.upper()}: {metric}")
85
+ print(f"{eval_task.upper()} results saved to {result_path}")
86
 
87
 
88
  if __name__ == "__main__":
src/f5_tts/eval/eval_utmos.py CHANGED
@@ -19,25 +19,23 @@ def main():
19
  predictor = predictor.to(device)
20
 
21
  audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
22
- utmos_results = {}
23
  utmos_score = 0
24
 
25
- for audio_path in tqdm(audio_paths, desc="Processing"):
26
- wav_name = audio_path.stem
27
- wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
- wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
- score = predictor(wav_tensor, sr)
30
- utmos_results[str(wav_name)] = score.item()
31
- utmos_score += score.item()
32
-
33
- avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
34
- print(f"UTMOS: {avg_score}")
35
-
36
- utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
37
  with open(utmos_result_path, "w", encoding="utf-8") as f:
38
- json.dump(utmos_results, f, ensure_ascii=False, indent=4)
39
-
40
- print(f"Results have been saved to {utmos_result_path}")
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  if __name__ == "__main__":
 
19
  predictor = predictor.to(device)
20
 
21
  audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
 
22
  utmos_score = 0
23
 
24
+ utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
 
 
 
 
 
 
 
 
 
 
 
25
  with open(utmos_result_path, "w", encoding="utf-8") as f:
26
+ for audio_path in tqdm(audio_paths, desc="Processing"):
27
+ wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
+ wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
+ score = predictor(wav_tensor, sr)
30
+ line = {}
31
+ line["wav"], line["utmos"] = str(audio_path.stem), score.item()
32
+ utmos_score += score.item()
33
+ f.write(json.dumps(line, ensure_ascii=False) + "\n")
34
+ avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
35
+ f.write(f"\nUTMOS: {avg_score:.4f}\n")
36
+
37
+ print(f"UTMOS: {avg_score:.4f}")
38
+ print(f"UTMOS results saved to {utmos_result_path}")
39
 
40
 
41
  if __name__ == "__main__":
src/f5_tts/eval/utils_eval.py CHANGED
@@ -389,10 +389,10 @@ def run_sim(args):
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
- sims = []
393
- for wav1, wav2, truth in tqdm(test_set):
394
- wav1, sr1 = torchaudio.load(wav1)
395
- wav2, sr2 = torchaudio.load(wav2)
396
 
397
  resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
398
  resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
@@ -408,6 +408,11 @@ def run_sim(args):
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
- sims.append(sim)
 
 
 
 
 
412
 
413
- return sims
 
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
+ sim_results = []
393
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
394
+ wav1, sr1 = torchaudio.load(gen_wav)
395
+ wav2, sr2 = torchaudio.load(prompt_wav)
396
 
397
  resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
398
  resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
 
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
+ sim_results.append(
412
+ {
413
+ "wav": Path(gen_wav).stem,
414
+ "sim": sim,
415
+ }
416
+ )
417
 
418
+ return sim_results
src/f5_tts/infer/README.md CHANGED
@@ -68,14 +68,16 @@ Basically you can inference with flags:
68
  ```bash
69
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
70
  f5-tts_infer-cli \
71
- --model "F5-TTS" \
72
  --ref_audio "ref_audio.wav" \
73
  --ref_text "The content, subtitle or transcription of reference audio." \
74
  --gen_text "Some text you want TTS model generate for you."
75
 
76
- # Choose Vocoder
77
- f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
78
- f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
 
 
79
 
80
  # More instructions
81
  f5-tts_infer-cli --help
@@ -90,8 +92,8 @@ f5-tts_infer-cli -c custom.toml
90
  For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
91
 
92
  ```toml
93
- # F5-TTS | E2-TTS
94
- model = "F5-TTS"
95
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
96
  # If an empty "", transcribes the reference audio automatically.
97
  ref_text = "Some call me nature, others call me mother nature."
@@ -105,8 +107,8 @@ output_dir = "tests"
105
  You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
106
 
107
  ```toml
108
- # F5-TTS | E2-TTS
109
- model = "F5-TTS"
110
  ref_audio = "infer/examples/multi/main.flac"
111
  # If an empty "", transcribes the reference audio automatically.
112
  ref_text = ""
@@ -126,94 +128,27 @@ ref_text = ""
126
  ```
127
  You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
128
 
129
- ## Speech Editing
130
-
131
- To test speech editing capabilities, use the following command:
132
-
133
- ```bash
134
- python src/f5_tts/infer/speech_edit.py
135
- ```
136
 
137
- ## Socket Realtime Client
138
 
139
- To communicate with socket server you need to run
140
  ```bash
 
141
  python src/f5_tts/socket_server.py
142
- ```
143
-
144
- <details>
145
- <summary>Then create client to communicate</summary>
146
 
147
- ```bash
148
  # If PyAudio not installed
149
  sudo apt-get install portaudio19-dev
150
  pip install pyaudio
151
- ```
152
-
153
- ``` python
154
- # Create the socket_client.py
155
- import socket
156
- import asyncio
157
- import pyaudio
158
- import numpy as np
159
- import logging
160
- import time
161
-
162
- logging.basicConfig(level=logging.INFO)
163
- logger = logging.getLogger(__name__)
164
-
165
-
166
- async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
167
- client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
168
- await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
169
 
170
- start_time = time.time()
171
- first_chunk_time = None
172
-
173
- async def play_audio_stream():
174
- nonlocal first_chunk_time
175
- p = pyaudio.PyAudio()
176
- stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
177
-
178
- try:
179
- while True:
180
- data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
181
- if not data:
182
- break
183
- if data == b"END":
184
- logger.info("End of audio received.")
185
- break
186
-
187
- audio_array = np.frombuffer(data, dtype=np.float32)
188
- stream.write(audio_array.tobytes())
189
-
190
- if first_chunk_time is None:
191
- first_chunk_time = time.time()
192
-
193
- finally:
194
- stream.stop_stream()
195
- stream.close()
196
- p.terminate()
197
-
198
- logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
199
-
200
- try:
201
- data_to_send = f"{text}".encode("utf-8")
202
- await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
203
- await play_audio_stream()
204
-
205
- except Exception as e:
206
- logger.error(f"Error in listen_to_F5TTS: {e}")
207
-
208
- finally:
209
- client_socket.close()
210
 
 
211
 
212
- if __name__ == "__main__":
213
- text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
214
 
215
- asyncio.run(listen_to_F5TTS(text_to_send))
 
216
  ```
217
 
218
- </details>
219
-
 
68
  ```bash
69
  # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
70
  f5-tts_infer-cli \
71
+ --model F5TTS_v1_Base \
72
  --ref_audio "ref_audio.wav" \
73
  --ref_text "The content, subtitle or transcription of reference audio." \
74
  --gen_text "Some text you want TTS model generate for you."
75
 
76
+ # Use BigVGAN as vocoder. Currently only support F5TTS_Base.
77
+ f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
78
+
79
+ # Use custom path checkpoint, e.g.
80
+ f5-tts_infer-cli --ckpt_file ckpts/F5TTS_Base/model_1200000.safetensors
81
 
82
  # More instructions
83
  f5-tts_infer-cli --help
 
92
  For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
93
 
94
  ```toml
95
+ # F5TTS_v1_Base | E2TTS_Base
96
+ model = "F5TTS_v1_Base"
97
  ref_audio = "infer/examples/basic/basic_ref_en.wav"
98
  # If an empty "", transcribes the reference audio automatically.
99
  ref_text = "Some call me nature, others call me mother nature."
 
107
  You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
108
 
109
  ```toml
110
+ # F5TTS_v1_Base | E2TTS_Base
111
+ model = "F5TTS_v1_Base"
112
  ref_audio = "infer/examples/multi/main.flac"
113
  # If an empty "", transcribes the reference audio automatically.
114
  ref_text = ""
 
128
  ```
129
  You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
130
 
131
+ ## Socket Real-time Service
 
 
 
 
 
 
132
 
133
+ Real-time voice output with chunk stream:
134
 
 
135
  ```bash
136
+ # Start socket server
137
  python src/f5_tts/socket_server.py
 
 
 
 
138
 
 
139
  # If PyAudio not installed
140
  sudo apt-get install portaudio19-dev
141
  pip install pyaudio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # Communicate with socket client
144
+ python src/f5_tts/socket_client.py
145
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ ## Speech Editing
148
 
149
+ To test speech editing capabilities, use the following command:
 
150
 
151
+ ```bash
152
+ python src/f5_tts/infer/speech_edit.py
153
  ```
154
 
 
 
src/f5_tts/infer/SHARED.md CHANGED
@@ -16,7 +16,7 @@
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
- - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
  - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
@@ -37,7 +37,17 @@
37
 
38
  ## Multilingual
39
 
40
- #### F5-TTS Base @ zh & en @ F5-TTS
 
 
 
 
 
 
 
 
 
 
41
  |Model|🤗Hugging Face|Data (Hours)|Model License|
42
  |:---:|:------------:|:-----------:|:-------------:|
43
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
@@ -45,7 +55,7 @@
45
  ```bash
46
  Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
47
  Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
48
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
49
  ```
50
 
51
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
@@ -64,7 +74,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
64
  ```bash
65
  Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
66
  Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
67
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
68
  ```
69
 
70
 
@@ -78,7 +88,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
78
  ```bash
79
  Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
80
  Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
81
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
82
  ```
83
 
84
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
@@ -96,7 +106,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
96
  ```bash
97
  Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
98
  Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
99
- Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
100
  ```
101
 
102
  - Authors: SPRING Lab, Indian Institute of Technology, Madras
@@ -113,7 +123,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c
113
  ```bash
114
  Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
115
  Vocab: hf://alien79/F5-TTS-italian/vocab.txt
116
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
117
  ```
118
 
119
  - Trained by [Mithril Man](https://github.com/MithrilMan)
@@ -131,7 +141,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
131
  ```bash
132
  Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
133
  Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
134
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
135
  ```
136
 
137
 
@@ -148,7 +158,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
148
  ```bash
149
  Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
150
  Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
151
- Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
152
  ```
153
  - Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
154
  - Any improvements are welcome
 
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
+ - [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
  - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
 
37
 
38
  ## Multilingual
39
 
40
+ #### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
41
+ |Model|🤗Hugging Face|Data (Hours)|Model License|
42
+ |:---:|:------------:|:-----------:|:-------------:|
43
+ |F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
44
+
45
+ ```bash
46
+ Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
47
+ Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
48
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
49
+ ```
50
+
51
  |Model|🤗Hugging Face|Data (Hours)|Model License|
52
  |:---:|:------------:|:-----------:|:-------------:|
53
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
 
55
  ```bash
56
  Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
57
  Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
58
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
59
  ```
60
 
61
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
 
74
  ```bash
75
  Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
76
  Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
77
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
78
  ```
79
 
80
 
 
88
  ```bash
89
  Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
90
  Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
91
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
92
  ```
93
 
94
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
 
106
  ```bash
107
  Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
108
  Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
109
+ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
110
  ```
111
 
112
  - Authors: SPRING Lab, Indian Institute of Technology, Madras
 
123
  ```bash
124
  Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
125
  Vocab: hf://alien79/F5-TTS-italian/vocab.txt
126
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
127
  ```
128
 
129
  - Trained by [Mithril Man](https://github.com/MithrilMan)
 
141
  ```bash
142
  Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
143
  Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
144
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
145
  ```
146
 
147
 
 
158
  ```bash
159
  Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
160
  Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
161
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
162
  ```
163
  - Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
164
  - Any improvements are welcome
src/f5_tts/infer/infer_cli.py CHANGED
@@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import (
27
  preprocess_ref_audio_text,
28
  remove_silence_for_generated_wav,
29
  )
30
- from f5_tts.model import DiT, UNetT
31
 
32
 
33
  parser = argparse.ArgumentParser(
@@ -50,7 +50,7 @@ parser.add_argument(
50
  "-m",
51
  "--model",
52
  type=str,
53
- help="The model name: F5-TTS | E2-TTS",
54
  )
55
  parser.add_argument(
56
  "-mc",
@@ -172,8 +172,7 @@ config = tomli.load(open(args.config, "rb"))
172
 
173
  # command-line interface parameters
174
 
175
- model = args.model or config.get("model", "F5-TTS")
176
- model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
177
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
178
  vocab_file = args.vocab_file or config.get("vocab_file", "")
179
 
@@ -245,36 +244,32 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
245
 
246
  # load TTS model
247
 
248
- if model == "F5-TTS":
249
- model_cls = DiT
250
- model_cfg = OmegaConf.load(model_cfg).model.arch
251
- if not ckpt_file: # path not specified, download from repo
252
- if vocoder_name == "vocos":
253
- repo_name = "F5-TTS"
254
- exp_name = "F5TTS_Base"
255
- ckpt_step = 1200000
256
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
257
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
258
- elif vocoder_name == "bigvgan":
259
- repo_name = "F5-TTS"
260
- exp_name = "F5TTS_Base_bigvgan"
261
- ckpt_step = 1250000
262
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
263
-
264
- elif model == "E2-TTS":
265
- assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
266
- assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
267
- model_cls = UNetT
268
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
269
- if not ckpt_file: # path not specified, download from repo
270
- repo_name = "E2-TTS"
271
- exp_name = "E2TTS_Base"
272
  ckpt_step = 1200000
273
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
274
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
 
 
 
 
 
 
 
275
 
276
  print(f"Using {model}...")
277
- ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
278
 
279
 
280
  # inference process
 
27
  preprocess_ref_audio_text,
28
  remove_silence_for_generated_wav,
29
  )
30
+ from f5_tts.model import DiT, UNetT # noqa: F401. used for config
31
 
32
 
33
  parser = argparse.ArgumentParser(
 
50
  "-m",
51
  "--model",
52
  type=str,
53
+ help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
54
  )
55
  parser.add_argument(
56
  "-mc",
 
172
 
173
  # command-line interface parameters
174
 
175
+ model = args.model or config.get("model", "F5TTS_v1_Base")
 
176
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
177
  vocab_file = args.vocab_file or config.get("vocab_file", "")
178
 
 
244
 
245
  # load TTS model
246
 
247
+ model_cfg = OmegaConf.load(
248
+ args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
249
+ ).model
250
+ model_cls = globals()[model_cfg.backbone]
251
+
252
+ repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
253
+
254
+ if model != "F5TTS_Base":
255
+ assert vocoder_name == model_cfg.mel_spec.mel_spec_type
256
+
257
+ # override for previous models
258
+ if model == "F5TTS_Base":
259
+ if vocoder_name == "vocos":
 
 
 
 
 
 
 
 
 
 
 
260
  ckpt_step = 1200000
261
+ elif vocoder_name == "bigvgan":
262
+ model = "F5TTS_Base_bigvgan"
263
+ ckpt_type = "pt"
264
+ elif model == "E2TTS_Base":
265
+ repo_name = "E2-TTS"
266
+ ckpt_step = 1200000
267
+
268
+ if not ckpt_file:
269
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
270
 
271
  print(f"Using {model}...")
272
+ ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
273
 
274
 
275
  # inference process
src/f5_tts/infer/infer_gradio.py CHANGED
@@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import (
41
  )
42
 
43
 
44
- DEFAULT_TTS_MODEL = "F5-TTS"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
  DEFAULT_TTS_MODEL_CFG = [
48
- "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
49
- "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
50
  json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
  ]
52
 
@@ -56,13 +56,15 @@ DEFAULT_TTS_MODEL_CFG = [
56
  vocoder = load_vocoder()
57
 
58
 
59
- def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
60
- F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
61
  return load_model(DiT, F5TTS_model_cfg, ckpt_path)
62
 
63
 
64
- def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
65
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
66
  return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
67
 
68
 
@@ -73,7 +75,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
73
  if vocab_path.startswith("hf://"):
74
  vocab_path = str(cached_path(vocab_path))
75
  if model_cfg is None:
76
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
77
  return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
78
 
79
 
@@ -130,7 +132,7 @@ def infer(
130
 
131
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
132
 
133
- if model == "F5-TTS":
134
  ema_model = F5TTS_ema_model
135
  elif model == "E2-TTS":
136
  global E2TTS_ema_model
@@ -762,7 +764,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
762
  """
763
  )
764
 
765
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
766
 
767
  def load_last_used_custom():
768
  try:
@@ -821,7 +823,30 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
821
  custom_model_cfg = gr.Dropdown(
822
  choices=[
823
  DEFAULT_TTS_MODEL_CFG[2],
824
- json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
  ],
826
  value=load_last_used_custom()[2],
827
  allow_custom_value=True,
 
41
  )
42
 
43
 
44
+ DEFAULT_TTS_MODEL = "F5-TTS_v1"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
  DEFAULT_TTS_MODEL_CFG = [
48
+ "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
49
+ "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
50
  json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
  ]
52
 
 
56
  vocoder = load_vocoder()
57
 
58
 
59
+ def load_f5tts():
60
+ ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
61
+ F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
62
  return load_model(DiT, F5TTS_model_cfg, ckpt_path)
63
 
64
 
65
+ def load_e2tts():
66
+ ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
67
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
68
  return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
69
 
70
 
 
75
  if vocab_path.startswith("hf://"):
76
  vocab_path = str(cached_path(vocab_path))
77
  if model_cfg is None:
78
+ model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
79
  return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
80
 
81
 
 
132
 
133
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
134
 
135
+ if model == DEFAULT_TTS_MODEL:
136
  ema_model = F5TTS_ema_model
137
  elif model == "E2-TTS":
138
  global E2TTS_ema_model
 
764
  """
765
  )
766
 
767
+ last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
768
 
769
  def load_last_used_custom():
770
  try:
 
823
  custom_model_cfg = gr.Dropdown(
824
  choices=[
825
  DEFAULT_TTS_MODEL_CFG[2],
826
+ json.dumps(
827
+ dict(
828
+ dim=1024,
829
+ depth=22,
830
+ heads=16,
831
+ ff_mult=2,
832
+ text_dim=512,
833
+ text_mask_padding=False,
834
+ conv_layers=4,
835
+ pe_attn_head=1,
836
+ )
837
+ ),
838
+ json.dumps(
839
+ dict(
840
+ dim=768,
841
+ depth=18,
842
+ heads=12,
843
+ ff_mult=2,
844
+ text_dim=512,
845
+ text_mask_padding=False,
846
+ conv_layers=4,
847
+ pe_attn_head=1,
848
+ )
849
+ ),
850
  ],
851
  value=load_last_used_custom()[2],
852
  allow_custom_value=True,
src/f5_tts/infer/speech_edit.py CHANGED
@@ -2,12 +2,15 @@ import os
2
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
 
 
 
5
  import torch
6
  import torch.nn.functional as F
7
  import torchaudio
 
8
 
9
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
10
- from f5_tts.model import CFM, DiT, UNetT
11
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
12
 
13
  device = (
@@ -21,44 +24,40 @@ device = (
21
  )
22
 
23
 
24
- # --------------------- Dataset Settings -------------------- #
25
-
26
- target_sample_rate = 24000
27
- n_mel_channels = 100
28
- hop_length = 256
29
- win_length = 1024
30
- n_fft = 1024
31
- mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
32
- target_rms = 0.1
33
-
34
- tokenizer = "pinyin"
35
- dataset_name = "Emilia_ZH_EN"
36
-
37
-
38
  # ---------------------- infer setting ---------------------- #
39
 
40
  seed = None # int | None
41
 
42
- exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
43
- ckpt_step = 1200000
44
 
45
  nfe_step = 32 # 16, 32
46
  cfg_strength = 2.0
47
  ode_method = "euler" # euler | midpoint
48
  sway_sampling_coef = -1.0
49
  speed = 1.0
 
 
 
 
 
 
50
 
51
- if exp_name == "F5TTS_Base":
52
- model_cls = DiT
53
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
54
 
55
- elif exp_name == "E2TTS_Base":
56
- model_cls = UNetT
57
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
58
 
59
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
 
60
  output_dir = "tests"
61
 
 
62
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
63
  # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
64
  # [write the origin_text into a file, e.g. tests/test_edit.txt]
@@ -67,7 +66,7 @@ output_dir = "tests"
67
  # [--language "zho" for Chinese, "eng" for English]
68
  # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
69
 
70
- audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
71
  origin_text = "Some call me nature, others call me mother nature."
72
  target_text = "Some call me optimist, others call me realist."
73
  parts_to_edit = [
@@ -106,7 +105,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
106
 
107
  # Model
108
  model = CFM(
109
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
110
  mel_spec_kwargs=dict(
111
  n_fft=n_fft,
112
  hop_length=hop_length,
 
2
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
 
5
+ from importlib.resources import files
6
+
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
10
+ from omegaconf import OmegaConf
11
 
12
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
13
+ from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
14
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
15
 
16
  device = (
 
24
  )
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ---------------------- infer setting ---------------------- #
28
 
29
  seed = None # int | None
30
 
31
+ exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
32
+ ckpt_step = 1250000
33
 
34
  nfe_step = 32 # 16, 32
35
  cfg_strength = 2.0
36
  ode_method = "euler" # euler | midpoint
37
  sway_sampling_coef = -1.0
38
  speed = 1.0
39
+ target_rms = 0.1
40
+
41
+
42
+ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
43
+ model_cls = globals()[model_cfg.model.backbone]
44
+ model_arc = model_cfg.model.arch
45
 
46
+ dataset_name = model_cfg.datasets.name
47
+ tokenizer = model_cfg.model.tokenizer
 
48
 
49
+ mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
50
+ target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
51
+ n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
52
+ hop_length = model_cfg.model.mel_spec.hop_length
53
+ win_length = model_cfg.model.mel_spec.win_length
54
+ n_fft = model_cfg.model.mel_spec.n_fft
55
 
56
+
57
+ ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
58
  output_dir = "tests"
59
 
60
+
61
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
62
  # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
63
  # [write the origin_text into a file, e.g. tests/test_edit.txt]
 
66
  # [--language "zho" for Chinese, "eng" for English]
67
  # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
68
 
69
+ audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
70
  origin_text = "Some call me nature, others call me mother nature."
71
  target_text = "Some call me optimist, others call me realist."
72
  parts_to_edit = [
 
105
 
106
  # Model
107
  model = CFM(
108
+ transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
109
  mel_spec_kwargs=dict(
110
  n_fft=n_fft,
111
  hop_length=hop_length,
src/f5_tts/infer/utils_infer.py CHANGED
@@ -301,19 +301,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
301
  )
302
  non_silent_wave = AudioSegment.silent(duration=0)
303
  for non_silent_seg in non_silent_segs:
304
- if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
305
  show_info("Audio is over 15s, clipping short. (1)")
306
  break
307
  non_silent_wave += non_silent_seg
308
 
309
  # 2. try to find short silence for clipping if 1. failed
310
- if len(non_silent_wave) > 15000:
311
  non_silent_segs = silence.split_on_silence(
312
  aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
313
  )
314
  non_silent_wave = AudioSegment.silent(duration=0)
315
  for non_silent_seg in non_silent_segs:
316
- if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
317
  show_info("Audio is over 15s, clipping short. (2)")
318
  break
319
  non_silent_wave += non_silent_seg
@@ -321,8 +321,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
321
  aseg = non_silent_wave
322
 
323
  # 3. if no proper silence found for clipping
324
- if len(aseg) > 15000:
325
- aseg = aseg[:15000]
326
  show_info("Audio is over 15s, clipping short. (3)")
327
 
328
  aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
@@ -383,7 +383,7 @@ def infer_process(
383
  ):
384
  # Split the input text into batches
385
  audio, sr = torchaudio.load(ref_audio)
386
- max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
387
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
388
  for i, gen_text in enumerate(gen_text_batches):
389
  print(f"gen_text {i}", gen_text)
 
301
  )
302
  non_silent_wave = AudioSegment.silent(duration=0)
303
  for non_silent_seg in non_silent_segs:
304
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
305
  show_info("Audio is over 15s, clipping short. (1)")
306
  break
307
  non_silent_wave += non_silent_seg
308
 
309
  # 2. try to find short silence for clipping if 1. failed
310
+ if len(non_silent_wave) > 12000:
311
  non_silent_segs = silence.split_on_silence(
312
  aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
313
  )
314
  non_silent_wave = AudioSegment.silent(duration=0)
315
  for non_silent_seg in non_silent_segs:
316
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
317
  show_info("Audio is over 15s, clipping short. (2)")
318
  break
319
  non_silent_wave += non_silent_seg
 
321
  aseg = non_silent_wave
322
 
323
  # 3. if no proper silence found for clipping
324
+ if len(aseg) > 12000:
325
+ aseg = aseg[:12000]
326
  show_info("Audio is over 15s, clipping short. (3)")
327
 
328
  aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
 
383
  ):
384
  # Split the input text into batches
385
  audio, sr = torchaudio.load(ref_audio)
386
+ max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
387
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
388
  for i, gen_text in enumerate(gen_text_batches):
389
  print(f"gen_text {i}", gen_text)
src/f5_tts/model/backbones/README.md CHANGED
@@ -4,7 +4,7 @@
4
  ### unett.py
5
  - flat unet transformer
6
  - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
- - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
 
9
  ### dit.py
10
  - adaln-zero dit
@@ -14,7 +14,7 @@
14
  - possible long skip connection (first layer to last layer)
15
 
16
  ### mmdit.py
17
- - sd3 structure
18
  - timestep as condition
19
  - left stream: text embedded and applied a abs pos emb
20
  - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
 
4
  ### unett.py
5
  - flat unet transformer
6
  - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - possible abs pos emb & convnextv2 blocks for embedded text before concat
8
 
9
  ### dit.py
10
  - adaln-zero dit
 
14
  - possible long skip connection (first layer to last layer)
15
 
16
  ### mmdit.py
17
+ - stable diffusion 3 block structure
18
  - timestep as condition
19
  - left stream: text embedded and applied a abs pos emb
20
  - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
src/f5_tts/model/backbones/dit.py CHANGED
@@ -20,7 +20,7 @@ from f5_tts.model.modules import (
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
- AdaLayerNormZero_Final,
24
  precompute_freqs_cis,
25
  get_pos_embed_indices,
26
  )
@@ -30,10 +30,12 @@ from f5_tts.model.modules import (
30
 
31
 
32
  class TextEmbedding(nn.Module):
33
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
 
 
37
  if conv_layers > 0:
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -49,6 +51,8 @@ class TextEmbedding(nn.Module):
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
  batch, text_len = text.shape[0], text.shape[1]
51
  text = F.pad(text, (0, seq_len - text_len), value=0)
 
 
52
 
53
  if drop_text: # cfg for text
54
  text = torch.zeros_like(text)
@@ -64,7 +68,13 @@ class TextEmbedding(nn.Module):
64
  text = text + text_pos_embed
65
 
66
  # convnextv2 blocks
67
- text = self.text_blocks(text)
 
 
 
 
 
 
68
 
69
  return text
70
 
@@ -103,7 +113,10 @@ class DiT(nn.Module):
103
  mel_dim=100,
104
  text_num_embeds=256,
105
  text_dim=None,
 
 
106
  conv_layers=0,
 
107
  long_skip_connection=False,
108
  checkpoint_activations=False,
109
  ):
@@ -112,7 +125,10 @@ class DiT(nn.Module):
112
  self.time_embed = TimestepEmbedding(dim)
113
  if text_dim is None:
114
  text_dim = mel_dim
115
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
 
 
 
116
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
117
 
118
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -121,15 +137,40 @@ class DiT(nn.Module):
121
  self.depth = depth
122
 
123
  self.transformer_blocks = nn.ModuleList(
124
- [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
  self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
127
 
128
- self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
129
  self.proj_out = nn.Linear(dim, mel_dim)
130
 
131
  self.checkpoint_activations = checkpoint_activations
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def ckpt_wrapper(self, module):
134
  # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
135
  def ckpt_forward(*inputs):
@@ -138,6 +179,9 @@ class DiT(nn.Module):
138
 
139
  return ckpt_forward
140
 
 
 
 
141
  def forward(
142
  self,
143
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -147,14 +191,25 @@ class DiT(nn.Module):
147
  drop_audio_cond, # cfg for cond audio
148
  drop_text, # cfg for text
149
  mask: bool["b n"] | None = None, # noqa: F722
 
150
  ):
151
  batch, seq_len = x.shape[0], x.shape[1]
152
  if time.ndim == 0:
153
  time = time.repeat(batch)
154
 
155
- # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
156
  t = self.time_embed(time)
157
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
158
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
159
 
160
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
 
20
  ConvNeXtV2Block,
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
+ AdaLayerNorm_Final,
24
  precompute_freqs_cis,
25
  get_pos_embed_indices,
26
  )
 
30
 
31
 
32
  class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
37
+ self.mask_padding = mask_padding # mask filler and batch padding tokens or not
38
+
39
  if conv_layers > 0:
40
  self.extra_modeling = True
41
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
 
51
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
  batch, text_len = text.shape[0], text.shape[1]
53
  text = F.pad(text, (0, seq_len - text_len), value=0)
54
+ if self.mask_padding:
55
+ text_mask = text == 0
56
 
57
  if drop_text: # cfg for text
58
  text = torch.zeros_like(text)
 
68
  text = text + text_pos_embed
69
 
70
  # convnextv2 blocks
71
+ if self.mask_padding:
72
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
73
+ for block in self.text_blocks:
74
+ text = block(text)
75
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
76
+ else:
77
+ text = self.text_blocks(text)
78
 
79
  return text
80
 
 
113
  mel_dim=100,
114
  text_num_embeds=256,
115
  text_dim=None,
116
+ text_mask_padding=True,
117
+ qk_norm=None,
118
  conv_layers=0,
119
+ pe_attn_head=None,
120
  long_skip_connection=False,
121
  checkpoint_activations=False,
122
  ):
 
125
  self.time_embed = TimestepEmbedding(dim)
126
  if text_dim is None:
127
  text_dim = mel_dim
128
+ self.text_embed = TextEmbedding(
129
+ text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
130
+ )
131
+ self.text_cond, self.text_uncond = None, None # text cache
132
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
133
 
134
  self.rotary_embed = RotaryEmbedding(dim_head)
 
137
  self.depth = depth
138
 
139
  self.transformer_blocks = nn.ModuleList(
140
+ [
141
+ DiTBlock(
142
+ dim=dim,
143
+ heads=heads,
144
+ dim_head=dim_head,
145
+ ff_mult=ff_mult,
146
+ dropout=dropout,
147
+ qk_norm=qk_norm,
148
+ pe_attn_head=pe_attn_head,
149
+ )
150
+ for _ in range(depth)
151
+ ]
152
  )
153
  self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
154
 
155
+ self.norm_out = AdaLayerNorm_Final(dim) # final modulation
156
  self.proj_out = nn.Linear(dim, mel_dim)
157
 
158
  self.checkpoint_activations = checkpoint_activations
159
 
160
+ self.initialize_weights()
161
+
162
+ def initialize_weights(self):
163
+ # Zero-out AdaLN layers in DiT blocks:
164
+ for block in self.transformer_blocks:
165
+ nn.init.constant_(block.attn_norm.linear.weight, 0)
166
+ nn.init.constant_(block.attn_norm.linear.bias, 0)
167
+
168
+ # Zero-out output layers:
169
+ nn.init.constant_(self.norm_out.linear.weight, 0)
170
+ nn.init.constant_(self.norm_out.linear.bias, 0)
171
+ nn.init.constant_(self.proj_out.weight, 0)
172
+ nn.init.constant_(self.proj_out.bias, 0)
173
+
174
  def ckpt_wrapper(self, module):
175
  # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
176
  def ckpt_forward(*inputs):
 
179
 
180
  return ckpt_forward
181
 
182
+ def clear_cache(self):
183
+ self.text_cond, self.text_uncond = None, None
184
+
185
  def forward(
186
  self,
187
  x: float["b n d"], # nosied input audio # noqa: F722
 
191
  drop_audio_cond, # cfg for cond audio
192
  drop_text, # cfg for text
193
  mask: bool["b n"] | None = None, # noqa: F722
194
+ cache=False,
195
  ):
196
  batch, seq_len = x.shape[0], x.shape[1]
197
  if time.ndim == 0:
198
  time = time.repeat(batch)
199
 
200
+ # t: conditioning time, text: text, x: noised audio + cond audio + text
201
  t = self.time_embed(time)
202
+ if cache:
203
+ if drop_text:
204
+ if self.text_uncond is None:
205
+ self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
206
+ text_embed = self.text_uncond
207
+ else:
208
+ if self.text_cond is None:
209
+ self.text_cond = self.text_embed(text, seq_len, drop_text=False)
210
+ text_embed = self.text_cond
211
+ else:
212
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
213
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
214
 
215
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
src/f5_tts/model/backbones/mmdit.py CHANGED
@@ -18,7 +18,7 @@ from f5_tts.model.modules import (
18
  TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
- AdaLayerNormZero_Final,
22
  precompute_freqs_cis,
23
  get_pos_embed_indices,
24
  )
@@ -28,18 +28,24 @@ from f5_tts.model.modules import (
28
 
29
 
30
  class TextEmbedding(nn.Module):
31
- def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
33
  self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
 
 
 
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
  def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
- text = text + 1
40
- if drop_text:
 
 
 
41
  text = torch.zeros_like(text)
42
- text = self.text_embed(text)
 
43
 
44
  # sinus pos emb
45
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
@@ -49,6 +55,9 @@ class TextEmbedding(nn.Module):
49
 
50
  text = text + text_pos_embed
51
 
 
 
 
52
  return text
53
 
54
 
@@ -83,13 +92,16 @@ class MMDiT(nn.Module):
83
  dim_head=64,
84
  dropout=0.1,
85
  ff_mult=4,
86
- text_num_embeds=256,
87
  mel_dim=100,
 
 
 
88
  ):
89
  super().__init__()
90
 
91
  self.time_embed = TimestepEmbedding(dim)
92
- self.text_embed = TextEmbedding(dim, text_num_embeds)
 
93
  self.audio_embed = AudioEmbedding(mel_dim, dim)
94
 
95
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -106,13 +118,33 @@ class MMDiT(nn.Module):
106
  dropout=dropout,
107
  ff_mult=ff_mult,
108
  context_pre_only=i == depth - 1,
 
109
  )
110
  for i in range(depth)
111
  ]
112
  )
113
- self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
114
  self.proj_out = nn.Linear(dim, mel_dim)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def forward(
117
  self,
118
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -122,6 +154,7 @@ class MMDiT(nn.Module):
122
  drop_audio_cond, # cfg for cond audio
123
  drop_text, # cfg for text
124
  mask: bool["b n"] | None = None, # noqa: F722
 
125
  ):
126
  batch = x.shape[0]
127
  if time.ndim == 0:
@@ -129,7 +162,17 @@ class MMDiT(nn.Module):
129
 
130
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
  t = self.time_embed(time)
132
- c = self.text_embed(text, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
133
  x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
 
135
  seq_len = x.shape[1]
 
18
  TimestepEmbedding,
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
+ AdaLayerNorm_Final,
22
  precompute_freqs_cis,
23
  get_pos_embed_indices,
24
  )
 
28
 
29
 
30
  class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds, mask_padding=True):
32
  super().__init__()
33
  self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
 
35
+ self.mask_padding = mask_padding # mask filler and batch padding tokens or not
36
+
37
  self.precompute_max_pos = 1024
38
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
39
 
40
  def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
41
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
42
+ if self.mask_padding:
43
+ text_mask = text == 0
44
+
45
+ if drop_text: # cfg for text
46
  text = torch.zeros_like(text)
47
+
48
+ text = self.text_embed(text) # b nt -> b nt d
49
 
50
  # sinus pos emb
51
  batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
 
55
 
56
  text = text + text_pos_embed
57
 
58
+ if self.mask_padding:
59
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
60
+
61
  return text
62
 
63
 
 
92
  dim_head=64,
93
  dropout=0.1,
94
  ff_mult=4,
 
95
  mel_dim=100,
96
+ text_num_embeds=256,
97
+ text_mask_padding=True,
98
+ qk_norm=None,
99
  ):
100
  super().__init__()
101
 
102
  self.time_embed = TimestepEmbedding(dim)
103
+ self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
104
+ self.text_cond, self.text_uncond = None, None # text cache
105
  self.audio_embed = AudioEmbedding(mel_dim, dim)
106
 
107
  self.rotary_embed = RotaryEmbedding(dim_head)
 
118
  dropout=dropout,
119
  ff_mult=ff_mult,
120
  context_pre_only=i == depth - 1,
121
+ qk_norm=qk_norm,
122
  )
123
  for i in range(depth)
124
  ]
125
  )
126
+ self.norm_out = AdaLayerNorm_Final(dim) # final modulation
127
  self.proj_out = nn.Linear(dim, mel_dim)
128
 
129
+ self.initialize_weights()
130
+
131
+ def initialize_weights(self):
132
+ # Zero-out AdaLN layers in MMDiT blocks:
133
+ for block in self.transformer_blocks:
134
+ nn.init.constant_(block.attn_norm_x.linear.weight, 0)
135
+ nn.init.constant_(block.attn_norm_x.linear.bias, 0)
136
+ nn.init.constant_(block.attn_norm_c.linear.weight, 0)
137
+ nn.init.constant_(block.attn_norm_c.linear.bias, 0)
138
+
139
+ # Zero-out output layers:
140
+ nn.init.constant_(self.norm_out.linear.weight, 0)
141
+ nn.init.constant_(self.norm_out.linear.bias, 0)
142
+ nn.init.constant_(self.proj_out.weight, 0)
143
+ nn.init.constant_(self.proj_out.bias, 0)
144
+
145
+ def clear_cache(self):
146
+ self.text_cond, self.text_uncond = None, None
147
+
148
  def forward(
149
  self,
150
  x: float["b n d"], # nosied input audio # noqa: F722
 
154
  drop_audio_cond, # cfg for cond audio
155
  drop_text, # cfg for text
156
  mask: bool["b n"] | None = None, # noqa: F722
157
+ cache=False,
158
  ):
159
  batch = x.shape[0]
160
  if time.ndim == 0:
 
162
 
163
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
164
  t = self.time_embed(time)
165
+ if cache:
166
+ if drop_text:
167
+ if self.text_uncond is None:
168
+ self.text_uncond = self.text_embed(text, drop_text=True)
169
+ c = self.text_uncond
170
+ else:
171
+ if self.text_cond is None:
172
+ self.text_cond = self.text_embed(text, drop_text=False)
173
+ c = self.text_cond
174
+ else:
175
+ c = self.text_embed(text, drop_text=drop_text)
176
  x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
177
 
178
  seq_len = x.shape[1]
src/f5_tts/model/backbones/unett.py CHANGED
@@ -33,10 +33,12 @@ from f5_tts.model.modules import (
33
 
34
 
35
  class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
 
 
40
  if conv_layers > 0:
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -52,6 +54,8 @@ class TextEmbedding(nn.Module):
52
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
  batch, text_len = text.shape[0], text.shape[1]
54
  text = F.pad(text, (0, seq_len - text_len), value=0)
 
 
55
 
56
  if drop_text: # cfg for text
57
  text = torch.zeros_like(text)
@@ -67,7 +71,13 @@ class TextEmbedding(nn.Module):
67
  text = text + text_pos_embed
68
 
69
  # convnextv2 blocks
70
- text = self.text_blocks(text)
 
 
 
 
 
 
71
 
72
  return text
73
 
@@ -106,7 +116,10 @@ class UNetT(nn.Module):
106
  mel_dim=100,
107
  text_num_embeds=256,
108
  text_dim=None,
 
 
109
  conv_layers=0,
 
110
  skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
  ):
112
  super().__init__()
@@ -115,7 +128,10 @@ class UNetT(nn.Module):
115
  self.time_embed = TimestepEmbedding(dim)
116
  if text_dim is None:
117
  text_dim = mel_dim
118
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
 
 
 
119
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
 
121
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -134,11 +150,12 @@ class UNetT(nn.Module):
134
 
135
  attn_norm = RMSNorm(dim)
136
  attn = Attention(
137
- processor=AttnProcessor(),
138
  dim=dim,
139
  heads=heads,
140
  dim_head=dim_head,
141
  dropout=dropout,
 
142
  )
143
 
144
  ff_norm = RMSNorm(dim)
@@ -161,6 +178,9 @@ class UNetT(nn.Module):
161
  self.norm_out = RMSNorm(dim)
162
  self.proj_out = nn.Linear(dim, mel_dim)
163
 
 
 
 
164
  def forward(
165
  self,
166
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -170,6 +190,7 @@ class UNetT(nn.Module):
170
  drop_audio_cond, # cfg for cond audio
171
  drop_text, # cfg for text
172
  mask: bool["b n"] | None = None, # noqa: F722
 
173
  ):
174
  batch, seq_len = x.shape[0], x.shape[1]
175
  if time.ndim == 0:
@@ -177,7 +198,17 @@ class UNetT(nn.Module):
177
 
178
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
  t = self.time_embed(time)
180
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
 
 
 
 
 
 
 
 
 
 
181
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
 
183
  # postfix time t to input x, [b n d] -> [b n+1 d]
 
33
 
34
 
35
  class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
40
+ self.mask_padding = mask_padding # mask filler and batch padding tokens or not
41
+
42
  if conv_layers > 0:
43
  self.extra_modeling = True
44
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
 
54
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
55
  batch, text_len = text.shape[0], text.shape[1]
56
  text = F.pad(text, (0, seq_len - text_len), value=0)
57
+ if self.mask_padding:
58
+ text_mask = text == 0
59
 
60
  if drop_text: # cfg for text
61
  text = torch.zeros_like(text)
 
71
  text = text + text_pos_embed
72
 
73
  # convnextv2 blocks
74
+ if self.mask_padding:
75
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
76
+ for block in self.text_blocks:
77
+ text = block(text)
78
+ text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
79
+ else:
80
+ text = self.text_blocks(text)
81
 
82
  return text
83
 
 
116
  mel_dim=100,
117
  text_num_embeds=256,
118
  text_dim=None,
119
+ text_mask_padding=True,
120
+ qk_norm=None,
121
  conv_layers=0,
122
+ pe_attn_head=None,
123
  skip_connect_type: Literal["add", "concat", "none"] = "concat",
124
  ):
125
  super().__init__()
 
128
  self.time_embed = TimestepEmbedding(dim)
129
  if text_dim is None:
130
  text_dim = mel_dim
131
+ self.text_embed = TextEmbedding(
132
+ text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
133
+ )
134
+ self.text_cond, self.text_uncond = None, None # text cache
135
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
136
 
137
  self.rotary_embed = RotaryEmbedding(dim_head)
 
150
 
151
  attn_norm = RMSNorm(dim)
152
  attn = Attention(
153
+ processor=AttnProcessor(pe_attn_head=pe_attn_head),
154
  dim=dim,
155
  heads=heads,
156
  dim_head=dim_head,
157
  dropout=dropout,
158
+ qk_norm=qk_norm,
159
  )
160
 
161
  ff_norm = RMSNorm(dim)
 
178
  self.norm_out = RMSNorm(dim)
179
  self.proj_out = nn.Linear(dim, mel_dim)
180
 
181
+ def clear_cache(self):
182
+ self.text_cond, self.text_uncond = None, None
183
+
184
  def forward(
185
  self,
186
  x: float["b n d"], # nosied input audio # noqa: F722
 
190
  drop_audio_cond, # cfg for cond audio
191
  drop_text, # cfg for text
192
  mask: bool["b n"] | None = None, # noqa: F722
193
+ cache=False,
194
  ):
195
  batch, seq_len = x.shape[0], x.shape[1]
196
  if time.ndim == 0:
 
198
 
199
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
200
  t = self.time_embed(time)
201
+ if cache:
202
+ if drop_text:
203
+ if self.text_uncond is None:
204
+ self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
205
+ text_embed = self.text_uncond
206
+ else:
207
+ if self.text_cond is None:
208
+ self.text_cond = self.text_embed(text, seq_len, drop_text=False)
209
+ text_embed = self.text_cond
210
+ else:
211
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
212
  x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
213
 
214
  # postfix time t to input x, [b n d] -> [b n+1 d]
src/f5_tts/model/cfm.py CHANGED
@@ -162,13 +162,13 @@ class CFM(nn.Module):
162
 
163
  # predict flow
164
  pred = self.transformer(
165
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
166
  )
167
  if cfg_strength < 1e-5:
168
  return pred
169
 
170
  null_pred = self.transformer(
171
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
172
  )
173
  return pred + (pred - null_pred) * cfg_strength
174
 
@@ -195,6 +195,7 @@ class CFM(nn.Module):
195
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
196
 
197
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
 
198
 
199
  sampled = trajectory[-1]
200
  out = sampled
 
162
 
163
  # predict flow
164
  pred = self.transformer(
165
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
166
  )
167
  if cfg_strength < 1e-5:
168
  return pred
169
 
170
  null_pred = self.transformer(
171
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
172
  )
173
  return pred + (pred - null_pred) * cfg_strength
174
 
 
195
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
196
 
197
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
198
+ self.transformer.clear_cache()
199
 
200
  sampled = trajectory[-1]
201
  out = sampled
src/f5_tts/model/dataset.py CHANGED
@@ -173,7 +173,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
173
  """
174
 
175
  def __init__(
176
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
177
  ):
178
  self.sampler = sampler
179
  self.frames_threshold = frames_threshold
@@ -208,12 +208,15 @@ class DynamicBatchSampler(Sampler[list[int]]):
208
  batch = []
209
  batch_frames = 0
210
 
211
- if not drop_last and len(batch) > 0:
212
  batches.append(batch)
213
 
214
  del indices
215
  self.batches = batches
216
 
 
 
 
217
  def set_epoch(self, epoch: int) -> None:
218
  """Sets the epoch for this sampler."""
219
  self.epoch = epoch
 
173
  """
174
 
175
  def __init__(
176
+ self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
177
  ):
178
  self.sampler = sampler
179
  self.frames_threshold = frames_threshold
 
208
  batch = []
209
  batch_frames = 0
210
 
211
+ if not drop_residual and len(batch) > 0:
212
  batches.append(batch)
213
 
214
  del indices
215
  self.batches = batches
216
 
217
+ # Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
218
+ self.drop_last = True
219
+
220
  def set_epoch(self, epoch: int) -> None:
221
  """Sets the epoch for this sampler."""
222
  self.epoch = epoch
src/f5_tts/model/modules.py CHANGED
@@ -269,11 +269,36 @@ class ConvNeXtV2Block(nn.Module):
269
  return residual + x
270
 
271
 
272
- # AdaLayerNormZero
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  # return with modulated x for attn input, and params for later mlp modulation
274
 
275
 
276
- class AdaLayerNormZero(nn.Module):
277
  def __init__(self, dim):
278
  super().__init__()
279
 
@@ -290,11 +315,11 @@ class AdaLayerNormZero(nn.Module):
290
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
291
 
292
 
293
- # AdaLayerNormZero for final layer
294
  # return only with modulated x for attn input, cuz no more mlp modulation
295
 
296
 
297
- class AdaLayerNormZero_Final(nn.Module):
298
  def __init__(self, dim):
299
  super().__init__()
300
 
@@ -341,7 +366,8 @@ class Attention(nn.Module):
341
  dim_head: int = 64,
342
  dropout: float = 0.0,
343
  context_dim: Optional[int] = None, # if not None -> joint attention
344
- context_pre_only=None,
 
345
  ):
346
  super().__init__()
347
 
@@ -362,18 +388,32 @@ class Attention(nn.Module):
362
  self.to_k = nn.Linear(dim, self.inner_dim)
363
  self.to_v = nn.Linear(dim, self.inner_dim)
364
 
 
 
 
 
 
 
 
 
 
365
  if self.context_dim is not None:
 
366
  self.to_k_c = nn.Linear(context_dim, self.inner_dim)
367
  self.to_v_c = nn.Linear(context_dim, self.inner_dim)
368
- if self.context_pre_only is not None:
369
- self.to_q_c = nn.Linear(context_dim, self.inner_dim)
 
 
 
 
370
 
371
  self.to_out = nn.ModuleList([])
372
  self.to_out.append(nn.Linear(self.inner_dim, dim))
373
  self.to_out.append(nn.Dropout(dropout))
374
 
375
- if self.context_pre_only is not None and not self.context_pre_only:
376
- self.to_out_c = nn.Linear(self.inner_dim, dim)
377
 
378
  def forward(
379
  self,
@@ -393,8 +433,11 @@ class Attention(nn.Module):
393
 
394
 
395
  class AttnProcessor:
396
- def __init__(self):
397
- pass
 
 
 
398
 
399
  def __call__(
400
  self,
@@ -405,19 +448,11 @@ class AttnProcessor:
405
  ) -> torch.FloatTensor:
406
  batch_size = x.shape[0]
407
 
408
- # `sample` projections.
409
  query = attn.to_q(x)
410
  key = attn.to_k(x)
411
  value = attn.to_v(x)
412
 
413
- # apply rotary position embedding
414
- if rope is not None:
415
- freqs, xpos_scale = rope
416
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
417
-
418
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
419
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
420
-
421
  # attention
422
  inner_dim = key.shape[-1]
423
  head_dim = inner_dim // attn.heads
@@ -425,6 +460,25 @@ class AttnProcessor:
425
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # mask. e.g. inference got a batch with different target durations, mask out the padding
429
  if mask is not None:
430
  attn_mask = mask
@@ -470,16 +524,36 @@ class JointAttnProcessor:
470
 
471
  batch_size = c.shape[0]
472
 
473
- # `sample` projections.
474
  query = attn.to_q(x)
475
  key = attn.to_k(x)
476
  value = attn.to_v(x)
477
 
478
- # `context` projections.
479
  c_query = attn.to_q_c(c)
480
  c_key = attn.to_k_c(c)
481
  c_value = attn.to_v_c(c)
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  # apply rope for context and noised input independently
484
  if rope is not None:
485
  freqs, xpos_scale = rope
@@ -492,16 +566,10 @@ class JointAttnProcessor:
492
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
493
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
494
 
495
- # attention
496
- query = torch.cat([query, c_query], dim=1)
497
- key = torch.cat([key, c_key], dim=1)
498
- value = torch.cat([value, c_value], dim=1)
499
-
500
- inner_dim = key.shape[-1]
501
- head_dim = inner_dim // attn.heads
502
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
503
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
504
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
505
 
506
  # mask. e.g. inference got a batch with different target durations, mask out the padding
507
  if mask is not None:
@@ -540,16 +608,17 @@ class JointAttnProcessor:
540
 
541
 
542
  class DiTBlock(nn.Module):
543
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
544
  super().__init__()
545
 
546
- self.attn_norm = AdaLayerNormZero(dim)
547
  self.attn = Attention(
548
- processor=AttnProcessor(),
549
  dim=dim,
550
  heads=heads,
551
  dim_head=dim_head,
552
  dropout=dropout,
 
553
  )
554
 
555
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
@@ -585,26 +654,30 @@ class MMDiTBlock(nn.Module):
585
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
586
  """
587
 
588
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
 
 
589
  super().__init__()
590
-
 
591
  self.context_pre_only = context_pre_only
592
 
593
- self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
594
- self.attn_norm_x = AdaLayerNormZero(dim)
595
  self.attn = Attention(
596
  processor=JointAttnProcessor(),
597
  dim=dim,
598
  heads=heads,
599
  dim_head=dim_head,
600
  dropout=dropout,
601
- context_dim=dim,
602
  context_pre_only=context_pre_only,
 
603
  )
604
 
605
  if not context_pre_only:
606
- self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
607
- self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
608
  else:
609
  self.ff_norm_c = None
610
  self.ff_c = None
 
269
  return residual + x
270
 
271
 
272
+ # RMSNorm
273
+
274
+
275
+ class RMSNorm(nn.Module):
276
+ def __init__(self, dim: int, eps: float):
277
+ super().__init__()
278
+ self.eps = eps
279
+ self.weight = nn.Parameter(torch.ones(dim))
280
+ self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
281
+
282
+ def forward(self, x):
283
+ if self.native_rms_norm:
284
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
285
+ x = x.to(self.weight.dtype)
286
+ x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
287
+ else:
288
+ variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
289
+ x = x * torch.rsqrt(variance + self.eps)
290
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
291
+ x = x.to(self.weight.dtype)
292
+ x = x * self.weight
293
+
294
+ return x
295
+
296
+
297
+ # AdaLayerNorm
298
  # return with modulated x for attn input, and params for later mlp modulation
299
 
300
 
301
+ class AdaLayerNorm(nn.Module):
302
  def __init__(self, dim):
303
  super().__init__()
304
 
 
315
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
316
 
317
 
318
+ # AdaLayerNorm for final layer
319
  # return only with modulated x for attn input, cuz no more mlp modulation
320
 
321
 
322
+ class AdaLayerNorm_Final(nn.Module):
323
  def __init__(self, dim):
324
  super().__init__()
325
 
 
366
  dim_head: int = 64,
367
  dropout: float = 0.0,
368
  context_dim: Optional[int] = None, # if not None -> joint attention
369
+ context_pre_only: bool = False,
370
+ qk_norm: Optional[str] = None,
371
  ):
372
  super().__init__()
373
 
 
388
  self.to_k = nn.Linear(dim, self.inner_dim)
389
  self.to_v = nn.Linear(dim, self.inner_dim)
390
 
391
+ if qk_norm is None:
392
+ self.q_norm = None
393
+ self.k_norm = None
394
+ elif qk_norm == "rms_norm":
395
+ self.q_norm = RMSNorm(dim_head, eps=1e-6)
396
+ self.k_norm = RMSNorm(dim_head, eps=1e-6)
397
+ else:
398
+ raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
399
+
400
  if self.context_dim is not None:
401
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
402
  self.to_k_c = nn.Linear(context_dim, self.inner_dim)
403
  self.to_v_c = nn.Linear(context_dim, self.inner_dim)
404
+ if qk_norm is None:
405
+ self.c_q_norm = None
406
+ self.c_k_norm = None
407
+ elif qk_norm == "rms_norm":
408
+ self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
409
+ self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
410
 
411
  self.to_out = nn.ModuleList([])
412
  self.to_out.append(nn.Linear(self.inner_dim, dim))
413
  self.to_out.append(nn.Dropout(dropout))
414
 
415
+ if self.context_dim is not None and not self.context_pre_only:
416
+ self.to_out_c = nn.Linear(self.inner_dim, context_dim)
417
 
418
  def forward(
419
  self,
 
433
 
434
 
435
  class AttnProcessor:
436
+ def __init__(
437
+ self,
438
+ pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
439
+ ):
440
+ self.pe_attn_head = pe_attn_head
441
 
442
  def __call__(
443
  self,
 
448
  ) -> torch.FloatTensor:
449
  batch_size = x.shape[0]
450
 
451
+ # `sample` projections
452
  query = attn.to_q(x)
453
  key = attn.to_k(x)
454
  value = attn.to_v(x)
455
 
 
 
 
 
 
 
 
 
456
  # attention
457
  inner_dim = key.shape[-1]
458
  head_dim = inner_dim // attn.heads
 
460
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
461
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
462
 
463
+ # qk norm
464
+ if attn.q_norm is not None:
465
+ query = attn.q_norm(query)
466
+ if attn.k_norm is not None:
467
+ key = attn.k_norm(key)
468
+
469
+ # apply rotary position embedding
470
+ if rope is not None:
471
+ freqs, xpos_scale = rope
472
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
473
+
474
+ if self.pe_attn_head is not None:
475
+ pn = self.pe_attn_head
476
+ query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
477
+ key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
478
+ else:
479
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
480
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
481
+
482
  # mask. e.g. inference got a batch with different target durations, mask out the padding
483
  if mask is not None:
484
  attn_mask = mask
 
524
 
525
  batch_size = c.shape[0]
526
 
527
+ # `sample` projections
528
  query = attn.to_q(x)
529
  key = attn.to_k(x)
530
  value = attn.to_v(x)
531
 
532
+ # `context` projections
533
  c_query = attn.to_q_c(c)
534
  c_key = attn.to_k_c(c)
535
  c_value = attn.to_v_c(c)
536
 
537
+ # attention
538
+ inner_dim = key.shape[-1]
539
+ head_dim = inner_dim // attn.heads
540
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
541
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
542
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
543
+ c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
544
+ c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
545
+ c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
546
+
547
+ # qk norm
548
+ if attn.q_norm is not None:
549
+ query = attn.q_norm(query)
550
+ if attn.k_norm is not None:
551
+ key = attn.k_norm(key)
552
+ if attn.c_q_norm is not None:
553
+ c_query = attn.c_q_norm(c_query)
554
+ if attn.c_k_norm is not None:
555
+ c_key = attn.c_k_norm(c_key)
556
+
557
  # apply rope for context and noised input independently
558
  if rope is not None:
559
  freqs, xpos_scale = rope
 
566
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
567
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
568
 
569
+ # joint attention
570
+ query = torch.cat([query, c_query], dim=2)
571
+ key = torch.cat([key, c_key], dim=2)
572
+ value = torch.cat([value, c_value], dim=2)
 
 
 
 
 
 
573
 
574
  # mask. e.g. inference got a batch with different target durations, mask out the padding
575
  if mask is not None:
 
608
 
609
 
610
  class DiTBlock(nn.Module):
611
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
612
  super().__init__()
613
 
614
+ self.attn_norm = AdaLayerNorm(dim)
615
  self.attn = Attention(
616
+ processor=AttnProcessor(pe_attn_head=pe_attn_head),
617
  dim=dim,
618
  heads=heads,
619
  dim_head=dim_head,
620
  dropout=dropout,
621
+ qk_norm=qk_norm,
622
  )
623
 
624
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
 
654
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
655
  """
656
 
657
+ def __init__(
658
+ self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
659
+ ):
660
  super().__init__()
661
+ if context_dim is None:
662
+ context_dim = dim
663
  self.context_pre_only = context_pre_only
664
 
665
+ self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
666
+ self.attn_norm_x = AdaLayerNorm(dim)
667
  self.attn = Attention(
668
  processor=JointAttnProcessor(),
669
  dim=dim,
670
  heads=heads,
671
  dim_head=dim_head,
672
  dropout=dropout,
673
+ context_dim=context_dim,
674
  context_pre_only=context_pre_only,
675
+ qk_norm=qk_norm,
676
  )
677
 
678
  if not context_pre_only:
679
+ self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
680
+ self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
681
  else:
682
  self.ff_norm_c = None
683
  self.ff_c = None
src/f5_tts/model/trainer.py CHANGED
@@ -32,7 +32,7 @@ class Trainer:
32
  save_per_updates=1000,
33
  keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
34
  checkpoint_path=None,
35
- batch_size=32,
36
  batch_size_type: str = "sample",
37
  max_samples=32,
38
  grad_accumulation_steps=1,
@@ -40,7 +40,7 @@ class Trainer:
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
  logger: str | None = "wandb", # "wandb" | "tensorboard" | None
43
- wandb_project="test_e2-tts",
44
  wandb_run_name="test_run",
45
  wandb_resume_id: str = None,
46
  log_samples: bool = False,
@@ -51,6 +51,7 @@ class Trainer:
51
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
52
  is_local_vocoder: bool = False, # use local path vocoder
53
  local_vocoder_path: str = "", # local vocoder path
 
54
  ):
55
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
56
 
@@ -72,21 +73,23 @@ class Trainer:
72
  else:
73
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
74
 
75
- self.accelerator.init_trackers(
76
- project_name=wandb_project,
77
- init_kwargs=init_kwargs,
78
- config={
79
  "epochs": epochs,
80
  "learning_rate": learning_rate,
81
  "num_warmup_updates": num_warmup_updates,
82
- "batch_size": batch_size,
83
  "batch_size_type": batch_size_type,
84
  "max_samples": max_samples,
85
  "grad_accumulation_steps": grad_accumulation_steps,
86
  "max_grad_norm": max_grad_norm,
87
- "gpus": self.accelerator.num_processes,
88
  "noise_scheduler": noise_scheduler,
89
- },
 
 
 
 
 
90
  )
91
 
92
  elif self.logger == "tensorboard":
@@ -111,9 +114,9 @@ class Trainer:
111
  self.save_per_updates = save_per_updates
112
  self.keep_last_n_checkpoints = keep_last_n_checkpoints
113
  self.last_per_updates = default(last_per_updates, save_per_updates)
114
- self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
115
 
116
- self.batch_size = batch_size
117
  self.batch_size_type = batch_size_type
118
  self.max_samples = max_samples
119
  self.grad_accumulation_steps = grad_accumulation_steps
@@ -179,7 +182,7 @@ class Trainer:
179
  if (
180
  not exists(self.checkpoint_path)
181
  or not os.path.exists(self.checkpoint_path)
182
- or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
183
  ):
184
  return 0
185
 
@@ -191,7 +194,7 @@ class Trainer:
191
  all_checkpoints = [
192
  f
193
  for f in os.listdir(self.checkpoint_path)
194
- if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
195
  ]
196
 
197
  # First try to find regular training checkpoints
@@ -205,8 +208,16 @@ class Trainer:
205
  # If no training checkpoints, use pretrained model
206
  latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
207
 
208
- # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
209
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
 
 
 
 
 
 
 
 
210
 
211
  # patch for backward compatibility, 305e3ea
212
  for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
@@ -271,7 +282,7 @@ class Trainer:
271
  num_workers=num_workers,
272
  pin_memory=True,
273
  persistent_workers=True,
274
- batch_size=self.batch_size,
275
  shuffle=True,
276
  generator=generator,
277
  )
@@ -280,10 +291,10 @@ class Trainer:
280
  sampler = SequentialSampler(train_dataset)
281
  batch_sampler = DynamicBatchSampler(
282
  sampler,
283
- self.batch_size,
284
  max_samples=self.max_samples,
285
  random_seed=resumable_with_seed, # This enables reproducible shuffling
286
- drop_last=False,
287
  )
288
  train_dataloader = DataLoader(
289
  train_dataset,
 
32
  save_per_updates=1000,
33
  keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
34
  checkpoint_path=None,
35
+ batch_size_per_gpu=32,
36
  batch_size_type: str = "sample",
37
  max_samples=32,
38
  grad_accumulation_steps=1,
 
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
  logger: str | None = "wandb", # "wandb" | "tensorboard" | None
43
+ wandb_project="test_f5-tts",
44
  wandb_run_name="test_run",
45
  wandb_resume_id: str = None,
46
  log_samples: bool = False,
 
51
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
52
  is_local_vocoder: bool = False, # use local path vocoder
53
  local_vocoder_path: str = "", # local vocoder path
54
+ cfg_dict: dict = dict(), # training config
55
  ):
56
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
57
 
 
73
  else:
74
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
75
 
76
+ if not cfg_dict:
77
+ cfg_dict = {
 
 
78
  "epochs": epochs,
79
  "learning_rate": learning_rate,
80
  "num_warmup_updates": num_warmup_updates,
81
+ "batch_size_per_gpu": batch_size_per_gpu,
82
  "batch_size_type": batch_size_type,
83
  "max_samples": max_samples,
84
  "grad_accumulation_steps": grad_accumulation_steps,
85
  "max_grad_norm": max_grad_norm,
 
86
  "noise_scheduler": noise_scheduler,
87
+ }
88
+ cfg_dict["gpus"] = self.accelerator.num_processes
89
+ self.accelerator.init_trackers(
90
+ project_name=wandb_project,
91
+ init_kwargs=init_kwargs,
92
+ config=cfg_dict,
93
  )
94
 
95
  elif self.logger == "tensorboard":
 
114
  self.save_per_updates = save_per_updates
115
  self.keep_last_n_checkpoints = keep_last_n_checkpoints
116
  self.last_per_updates = default(last_per_updates, save_per_updates)
117
+ self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
118
 
119
+ self.batch_size_per_gpu = batch_size_per_gpu
120
  self.batch_size_type = batch_size_type
121
  self.max_samples = max_samples
122
  self.grad_accumulation_steps = grad_accumulation_steps
 
182
  if (
183
  not exists(self.checkpoint_path)
184
  or not os.path.exists(self.checkpoint_path)
185
+ or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
186
  ):
187
  return 0
188
 
 
194
  all_checkpoints = [
195
  f
196
  for f in os.listdir(self.checkpoint_path)
197
+ if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
198
  ]
199
 
200
  # First try to find regular training checkpoints
 
208
  # If no training checkpoints, use pretrained model
209
  latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
210
 
211
+ if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
212
+ from safetensors.torch import load_file
213
+
214
+ checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
215
+ checkpoint = {"ema_model_state_dict": checkpoint}
216
+ elif latest_checkpoint.endswith(".pt"):
217
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
218
+ checkpoint = torch.load(
219
+ f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
220
+ )
221
 
222
  # patch for backward compatibility, 305e3ea
223
  for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
 
282
  num_workers=num_workers,
283
  pin_memory=True,
284
  persistent_workers=True,
285
+ batch_size=self.batch_size_per_gpu,
286
  shuffle=True,
287
  generator=generator,
288
  )
 
291
  sampler = SequentialSampler(train_dataset)
292
  batch_sampler = DynamicBatchSampler(
293
  sampler,
294
+ self.batch_size_per_gpu,
295
  max_samples=self.max_samples,
296
  random_seed=resumable_with_seed, # This enables reproducible shuffling
297
+ drop_residual=False,
298
  )
299
  train_dataloader = DataLoader(
300
  train_dataset,
src/f5_tts/model/utils.py CHANGED
@@ -133,11 +133,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
133
 
134
  # convert char to pinyin
135
 
136
- jieba.initialize()
137
- print("Word segmentation module jieba initialized.\n")
138
-
139
 
140
  def convert_char_to_pinyin(text_list, polyphone=True):
 
 
 
 
141
  final_text_list = []
142
  custom_trans = str.maketrans(
143
  {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
 
133
 
134
  # convert char to pinyin
135
 
 
 
 
136
 
137
  def convert_char_to_pinyin(text_list, polyphone=True):
138
+ if jieba.dt.initialized is False:
139
+ jieba.default_logger.setLevel(50) # CRITICAL
140
+ jieba.initialize()
141
+
142
  final_text_list = []
143
  custom_trans = str.maketrans(
144
  {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
src/f5_tts/scripts/count_max_epoch.py CHANGED
@@ -9,7 +9,7 @@ mel_hop_length = 256
9
  mel_sampling_rate = 24000
10
 
11
  # target
12
- wanted_max_updates = 1000000
13
 
14
  # train params
15
  gpus = 8
 
9
  mel_sampling_rate = 24000
10
 
11
  # target
12
+ wanted_max_updates = 1200000
13
 
14
  # train params
15
  gpus = 8
src/f5_tts/socket_client.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import asyncio
3
+ import pyaudio
4
+ import numpy as np
5
+ import logging
6
+ import time
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
13
+ client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
14
+ await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
15
+
16
+ start_time = time.time()
17
+ first_chunk_time = None
18
+
19
+ async def play_audio_stream():
20
+ nonlocal first_chunk_time
21
+ p = pyaudio.PyAudio()
22
+ stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
23
+
24
+ try:
25
+ while True:
26
+ data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
27
+ if not data:
28
+ break
29
+ if data == b"END":
30
+ logger.info("End of audio received.")
31
+ break
32
+
33
+ audio_array = np.frombuffer(data, dtype=np.float32)
34
+ stream.write(audio_array.tobytes())
35
+
36
+ if first_chunk_time is None:
37
+ first_chunk_time = time.time()
38
+
39
+ finally:
40
+ stream.stop_stream()
41
+ stream.close()
42
+ p.terminate()
43
+
44
+ logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
45
+
46
+ try:
47
+ data_to_send = f"{text}".encode("utf-8")
48
+ await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
49
+ await play_audio_stream()
50
+
51
+ except Exception as e:
52
+ logger.error(f"Error in listen_to_F5TTS: {e}")
53
+
54
+ finally:
55
+ client_socket.close()
56
+
57
+
58
+ if __name__ == "__main__":
59
+ text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
60
+
61
+ asyncio.run(listen_to_F5TTS(text_to_send))
src/f5_tts/socket_server.py CHANGED
@@ -13,8 +13,9 @@ from importlib.resources import files
13
  import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
 
16
 
17
- from f5_tts.model.backbones.dit import DiT
18
  from f5_tts.infer.utils_infer import (
19
  chunk_text,
20
  preprocess_ref_audio_text,
@@ -68,7 +69,7 @@ class AudioFileWriterThread(threading.Thread):
68
 
69
 
70
  class TTSStreamingProcessor:
71
- def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
72
  self.device = device or (
73
  "cuda"
74
  if torch.cuda.is_available()
@@ -78,21 +79,24 @@ class TTSStreamingProcessor:
78
  if torch.backends.mps.is_available()
79
  else "cpu"
80
  )
81
- self.mel_spec_type = "vocos"
 
 
 
 
 
82
  self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
83
  self.vocoder = self.load_vocoder_model()
84
- self.sampling_rate = 24000
85
  self.update_reference(ref_audio, ref_text)
86
  self._warm_up()
87
  self.file_writer_thread = None
88
  self.first_package = True
89
 
90
  def load_ema_model(self, ckpt_file, vocab_file, dtype):
91
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
92
- model_cls = DiT
93
  return load_model(
94
- model_cls=model_cls,
95
- model_cfg=model_cfg,
96
  ckpt_path=ckpt_file,
97
  mel_spec_type=self.mel_spec_type,
98
  vocab_file=vocab_file,
@@ -212,9 +216,14 @@ if __name__ == "__main__":
212
  parser.add_argument("--host", default="0.0.0.0")
213
  parser.add_argument("--port", default=9998)
214
 
 
 
 
 
 
215
  parser.add_argument(
216
  "--ckpt_file",
217
- default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_Base/model_1200000.safetensors")),
218
  help="Path to the model checkpoint file",
219
  )
220
  parser.add_argument(
@@ -242,6 +251,7 @@ if __name__ == "__main__":
242
  try:
243
  # Initialize the processor with the model and vocoder
244
  processor = TTSStreamingProcessor(
 
245
  ckpt_file=args.ckpt_file,
246
  vocab_file=args.vocab_file,
247
  ref_audio=args.ref_audio,
 
13
  import torch
14
  import torchaudio
15
  from huggingface_hub import hf_hub_download
16
+ from omegaconf import OmegaConf
17
 
18
+ from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
19
  from f5_tts.infer.utils_infer import (
20
  chunk_text,
21
  preprocess_ref_audio_text,
 
69
 
70
 
71
  class TTSStreamingProcessor:
72
+ def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
73
  self.device = device or (
74
  "cuda"
75
  if torch.cuda.is_available()
 
79
  if torch.backends.mps.is_available()
80
  else "cpu"
81
  )
82
+ model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
83
+ self.model_cls = globals()[model_cfg.model.backbone]
84
+ self.model_arc = model_cfg.model.arch
85
+ self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
86
+ self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
87
+
88
  self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
89
  self.vocoder = self.load_vocoder_model()
90
+
91
  self.update_reference(ref_audio, ref_text)
92
  self._warm_up()
93
  self.file_writer_thread = None
94
  self.first_package = True
95
 
96
  def load_ema_model(self, ckpt_file, vocab_file, dtype):
 
 
97
  return load_model(
98
+ self.model_cls,
99
+ self.model_arc,
100
  ckpt_path=ckpt_file,
101
  mel_spec_type=self.mel_spec_type,
102
  vocab_file=vocab_file,
 
216
  parser.add_argument("--host", default="0.0.0.0")
217
  parser.add_argument("--port", default=9998)
218
 
219
+ parser.add_argument(
220
+ "--model",
221
+ default="F5TTS_v1_Base",
222
+ help="The model name, e.g. F5TTS_v1_Base",
223
+ )
224
  parser.add_argument(
225
  "--ckpt_file",
226
+ default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
227
  help="Path to the model checkpoint file",
228
  )
229
  parser.add_argument(
 
251
  try:
252
  # Initialize the processor with the model and vocoder
253
  processor = TTSStreamingProcessor(
254
+ model=args.model,
255
  ckpt_file=args.ckpt_file,
256
  vocab_file=args.vocab_file,
257
  ref_audio=args.ref_audio,
src/f5_tts/train/README.md CHANGED
@@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process.
40
  accelerate config
41
 
42
  # .yaml files are under src/f5_tts/configs directory
43
- accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
44
 
45
  # possible to overwrite accelerate and hydra config
46
- accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200
47
  ```
48
 
49
  ### 2. Finetuning practice
@@ -53,7 +53,7 @@ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#1
53
 
54
  The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
55
 
56
- ### 3. Wandb Logging
57
 
58
  The `wandb/` dir will be created under path you run training/finetuning scripts.
59
 
@@ -62,7 +62,7 @@ By default, the training script does NOT use logging (assuming you didn't manual
62
  To turn on wandb logging, you can either:
63
 
64
  1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
65
- 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
66
 
67
  On Mac & Linux:
68
 
@@ -75,7 +75,7 @@ On Windows:
75
  ```
76
  set WANDB_API_KEY=<YOUR WANDB API KEY>
77
  ```
78
- Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
79
 
80
  ```
81
  export WANDB_MODE=offline
 
40
  accelerate config
41
 
42
  # .yaml files are under src/f5_tts/configs directory
43
+ accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base_train.yaml
44
 
45
  # possible to overwrite accelerate and hydra config
46
+ accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base_train.yaml ++datasets.batch_size_per_gpu=19200
47
  ```
48
 
49
  ### 2. Finetuning practice
 
53
 
54
  The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
55
 
56
+ ### 3. W&B Logging
57
 
58
  The `wandb/` dir will be created under path you run training/finetuning scripts.
59
 
 
62
  To turn on wandb logging, you can either:
63
 
64
  1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
65
+ 2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
66
 
67
  On Mac & Linux:
68
 
 
75
  ```
76
  set WANDB_API_KEY=<YOUR WANDB API KEY>
77
  ```
78
+ Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
79
 
80
  ```
81
  export WANDB_MODE=offline
src/f5_tts/train/finetune_cli.py CHANGED
@@ -1,12 +1,13 @@
1
  import argparse
2
  import os
3
  import shutil
 
4
 
5
  from cached_path import cached_path
 
6
  from f5_tts.model import CFM, UNetT, DiT, Trainer
7
  from f5_tts.model.utils import get_tokenizer
8
  from f5_tts.model.dataset import load_dataset
9
- from importlib.resources import files
10
 
11
 
12
  # -------------------------- Dataset Settings --------------------------- #
@@ -20,19 +21,14 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
20
 
21
  # -------------------------- Argument Parsing --------------------------- #
22
  def parse_args():
23
- # batch_size_per_gpu = 1000 settting for gpu 8GB
24
- # batch_size_per_gpu = 1600 settting for gpu 12GB
25
- # batch_size_per_gpu = 2000 settting for gpu 16GB
26
- # batch_size_per_gpu = 3200 settting for gpu 24GB
27
-
28
- # num_warmup_updates = 300 for 5000 sample about 10 hours
29
-
30
- # change save_per_updates , last_per_updates change this value what you need ,
31
-
32
  parser = argparse.ArgumentParser(description="Train CFM Model")
33
 
34
  parser.add_argument(
35
- "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
 
 
 
 
36
  )
37
  parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
38
  parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
@@ -88,19 +84,54 @@ def main():
88
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
89
 
90
  # Model parameters based on experiment name
91
- if args.exp_name == "F5TTS_Base":
 
92
  wandb_resume_id = None
93
  model_cls = DiT
94
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  if args.finetune:
96
  if args.pretrain is None:
97
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
98
  else:
99
  ckpt_path = args.pretrain
 
100
  elif args.exp_name == "E2TTS_Base":
101
  wandb_resume_id = None
102
  model_cls = UNetT
103
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
 
 
 
 
 
 
 
104
  if args.finetune:
105
  if args.pretrain is None:
106
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
@@ -120,6 +151,7 @@ def main():
120
  print("copy checkpoint for finetune")
121
 
122
  # Use the tokenizer and tokenizer_path provided in the command line arguments
 
123
  tokenizer = args.tokenizer
124
  if tokenizer == "custom":
125
  if not args.tokenizer_path:
@@ -156,7 +188,7 @@ def main():
156
  save_per_updates=args.save_per_updates,
157
  keep_last_n_checkpoints=args.keep_last_n_checkpoints,
158
  checkpoint_path=checkpoint_path,
159
- batch_size=args.batch_size_per_gpu,
160
  batch_size_type=args.batch_size_type,
161
  max_samples=args.max_samples,
162
  grad_accumulation_steps=args.grad_accumulation_steps,
 
1
  import argparse
2
  import os
3
  import shutil
4
+ from importlib.resources import files
5
 
6
  from cached_path import cached_path
7
+
8
  from f5_tts.model import CFM, UNetT, DiT, Trainer
9
  from f5_tts.model.utils import get_tokenizer
10
  from f5_tts.model.dataset import load_dataset
 
11
 
12
 
13
  # -------------------------- Dataset Settings --------------------------- #
 
21
 
22
  # -------------------------- Argument Parsing --------------------------- #
23
  def parse_args():
 
 
 
 
 
 
 
 
 
24
  parser = argparse.ArgumentParser(description="Train CFM Model")
25
 
26
  parser.add_argument(
27
+ "--exp_name",
28
+ type=str,
29
+ default="F5TTS_v1_Base",
30
+ choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
31
+ help="Experiment name",
32
  )
33
  parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
34
  parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
 
84
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
85
 
86
  # Model parameters based on experiment name
87
+
88
+ if args.exp_name == "F5TTS_v1_Base":
89
  wandb_resume_id = None
90
  model_cls = DiT
91
+ model_cfg = dict(
92
+ dim=1024,
93
+ depth=22,
94
+ heads=16,
95
+ ff_mult=2,
96
+ text_dim=512,
97
+ conv_layers=4,
98
+ )
99
+ if args.finetune:
100
+ if args.pretrain is None:
101
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
102
+ else:
103
+ ckpt_path = args.pretrain
104
+
105
+ elif args.exp_name == "F5TTS_Base":
106
+ wandb_resume_id = None
107
+ model_cls = DiT
108
+ model_cfg = dict(
109
+ dim=1024,
110
+ depth=22,
111
+ heads=16,
112
+ ff_mult=2,
113
+ text_dim=512,
114
+ text_mask_padding=False,
115
+ conv_layers=4,
116
+ pe_attn_head=1,
117
+ )
118
  if args.finetune:
119
  if args.pretrain is None:
120
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
121
  else:
122
  ckpt_path = args.pretrain
123
+
124
  elif args.exp_name == "E2TTS_Base":
125
  wandb_resume_id = None
126
  model_cls = UNetT
127
+ model_cfg = dict(
128
+ dim=1024,
129
+ depth=24,
130
+ heads=16,
131
+ ff_mult=4,
132
+ text_mask_padding=False,
133
+ pe_attn_head=1,
134
+ )
135
  if args.finetune:
136
  if args.pretrain is None:
137
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
 
151
  print("copy checkpoint for finetune")
152
 
153
  # Use the tokenizer and tokenizer_path provided in the command line arguments
154
+
155
  tokenizer = args.tokenizer
156
  if tokenizer == "custom":
157
  if not args.tokenizer_path:
 
188
  save_per_updates=args.save_per_updates,
189
  keep_last_n_checkpoints=args.keep_last_n_checkpoints,
190
  checkpoint_path=checkpoint_path,
191
+ batch_size_per_gpu=args.batch_size_per_gpu,
192
  batch_size_type=args.batch_size_type,
193
  max_samples=args.max_samples,
194
  grad_accumulation_steps=args.grad_accumulation_steps,
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1,36 +1,36 @@
1
- import threading
2
- import queue
3
- import re
4
-
5
  import gc
6
  import json
 
7
  import os
8
  import platform
9
  import psutil
 
10
  import random
 
11
  import signal
12
  import shutil
13
  import subprocess
14
  import sys
15
  import tempfile
 
16
  import time
17
  from glob import glob
 
 
18
 
19
  import click
20
  import gradio as gr
21
  import librosa
22
- import numpy as np
23
  import torch
24
  import torchaudio
 
25
  from datasets import Dataset as Dataset_
26
  from datasets.arrow_writer import ArrowWriter
27
- from safetensors.torch import save_file
28
- from scipy.io import wavfile
29
- from cached_path import cached_path
30
  from f5_tts.api import F5TTS
31
  from f5_tts.model.utils import convert_char_to_pinyin
32
  from f5_tts.infer.utils_infer import transcribe
33
- from importlib.resources import files
34
 
35
 
36
  training_process = None
@@ -118,16 +118,16 @@ def load_settings(project_name):
118
 
119
  # Default settings
120
  default_settings = {
121
- "exp_name": "F5TTS_Base",
122
- "learning_rate": 1e-05,
123
- "batch_size_per_gpu": 1000,
124
- "batch_size_type": "frame",
125
  "max_samples": 64,
126
- "grad_accumulation_steps": 1,
127
  "max_grad_norm": 1,
128
  "epochs": 100,
129
- "num_warmup_updates": 2,
130
- "save_per_updates": 300,
131
  "keep_last_n_checkpoints": -1,
132
  "last_per_updates": 100,
133
  "finetune": True,
@@ -362,18 +362,18 @@ def terminate_process(pid):
362
 
363
  def start_training(
364
  dataset_name="",
365
- exp_name="F5TTS_Base",
366
- learning_rate=1e-4,
367
- batch_size_per_gpu=400,
368
- batch_size_type="frame",
369
  max_samples=64,
370
- grad_accumulation_steps=1,
371
  max_grad_norm=1.0,
372
- epochs=11,
373
- num_warmup_updates=200,
374
- save_per_updates=400,
375
  keep_last_n_checkpoints=-1,
376
- last_per_updates=800,
377
  finetune=True,
378
  file_checkpoint_train="",
379
  tokenizer_type="pinyin",
@@ -797,14 +797,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
797
  print(f"Error processing {file_audio}: {e}")
798
  continue
799
 
800
- if duration < 1 or duration > 25:
801
- if duration > 25:
802
- error_files.append([file_audio, "duration > 25 sec"])
803
  if duration < 1:
804
  error_files.append([file_audio, "duration < 1 sec "])
805
  continue
806
  if len(text) < 3:
807
- error_files.append([file_audio, "very small text len 3"])
808
  continue
809
 
810
  text = clear_text(text)
@@ -871,40 +871,37 @@ def check_user(value):
871
 
872
  def calculate_train(
873
  name_project,
 
 
 
874
  batch_size_type,
875
  max_samples,
876
- learning_rate,
877
  num_warmup_updates,
878
- save_per_updates,
879
- last_per_updates,
880
  finetune,
881
  ):
882
  path_project = os.path.join(path_data, name_project)
883
- file_duraction = os.path.join(path_project, "duration.json")
 
 
 
884
 
885
- if not os.path.isfile(file_duraction):
886
  return (
887
- 1000,
 
 
888
  max_samples,
889
  num_warmup_updates,
890
- save_per_updates,
891
- last_per_updates,
892
  "project not found !",
893
- learning_rate,
894
  )
895
 
896
- with open(file_duraction, "r") as file:
897
  data = json.load(file)
898
 
899
  duration_list = data["duration"]
900
- samples = len(duration_list)
901
- hours = sum(duration_list) / 3600
902
-
903
- # if torch.cuda.is_available():
904
- # gpu_properties = torch.cuda.get_device_properties(0)
905
- # total_memory = gpu_properties.total_memory / (1024**3)
906
- # elif torch.backends.mps.is_available():
907
- # total_memory = psutil.virtual_memory().available / (1024**3)
908
 
909
  if torch.cuda.is_available():
910
  gpu_count = torch.cuda.device_count()
@@ -912,64 +909,39 @@ def calculate_train(
912
  for i in range(gpu_count):
913
  gpu_properties = torch.cuda.get_device_properties(i)
914
  total_memory += gpu_properties.total_memory / (1024**3) # in GB
915
-
916
  elif torch.xpu.is_available():
917
  gpu_count = torch.xpu.device_count()
918
  total_memory = 0
919
  for i in range(gpu_count):
920
  gpu_properties = torch.xpu.get_device_properties(i)
921
  total_memory += gpu_properties.total_memory / (1024**3)
922
-
923
  elif torch.backends.mps.is_available():
924
  gpu_count = 1
925
  total_memory = psutil.virtual_memory().available / (1024**3)
926
 
 
 
 
927
  if batch_size_type == "frame":
928
- batch = int(total_memory * 0.5)
929
- batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
930
- batch_size_per_gpu = int(38400 / batch)
931
- else:
932
- batch_size_per_gpu = int(total_memory / 8)
933
- batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
934
- batch = batch_size_per_gpu
935
 
936
- if batch_size_per_gpu <= 0:
937
- batch_size_per_gpu = 1
938
 
939
- if samples < 64:
940
- max_samples = int(samples * 0.25)
941
- else:
942
- max_samples = 64
943
-
944
- num_warmup_updates = int(samples * 0.05)
945
- save_per_updates = int(samples * 0.10)
946
- last_per_updates = int(save_per_updates * 0.25)
947
-
948
- max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
949
- num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
950
- save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
951
- last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
952
- if last_per_updates <= 0:
953
- last_per_updates = 2
954
-
955
- total_hours = hours
956
- mel_hop_length = 256
957
- mel_sampling_rate = 24000
958
-
959
- # target
960
- wanted_max_updates = 1000000
961
-
962
- # train params
963
- gpus = gpu_count
964
- frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
965
- grad_accum = 1
966
-
967
- # intermediate
968
- mini_batch_frames = frames_per_gpu * grad_accum * gpus
969
- mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
970
- updates_per_epoch = total_hours / mini_batch_hours
971
- # steps_per_epoch = updates_per_epoch * grad_accum
972
- epochs = wanted_max_updates / updates_per_epoch
973
 
974
  if finetune:
975
  learning_rate = 1e-5
@@ -977,14 +949,12 @@ def calculate_train(
977
  learning_rate = 7.5e-5
978
 
979
  return (
 
 
980
  batch_size_per_gpu,
981
  max_samples,
982
  num_warmup_updates,
983
- save_per_updates,
984
- last_per_updates,
985
- samples,
986
- learning_rate,
987
- int(epochs),
988
  )
989
 
990
 
@@ -1021,7 +991,11 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
1021
  torch.backends.cudnn.deterministic = True
1022
  torch.backends.cudnn.benchmark = False
1023
 
1024
- ckpt = torch.load(ckpt_path, map_location="cpu")
 
 
 
 
1025
 
1026
  ema_sd = ckpt.get("ema_model_state_dict", {})
1027
  embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
@@ -1089,9 +1063,11 @@ def vocab_extend(project_name, symbols, model_type):
1089
  with open(file_vocab_project, "w", encoding="utf-8") as f:
1090
  f.write("\n".join(vocab))
1091
 
1092
- if model_type == "F5-TTS":
 
 
1093
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
1094
- else:
1095
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
1096
 
1097
  vocab_size_new = len(miss_symbols)
@@ -1101,7 +1077,7 @@ def vocab_extend(project_name, symbols, model_type):
1101
  os.makedirs(new_ckpt_path, exist_ok=True)
1102
 
1103
  # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
1104
- new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
1105
 
1106
  size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
1107
 
@@ -1231,21 +1207,21 @@ def infer(
1231
  vocab_file = os.path.join(path_data, project, "vocab.txt")
1232
 
1233
  tts_api = F5TTS(
1234
- model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
1235
  )
1236
 
1237
  print("update >> ", device_test, file_checkpoint, use_ema)
1238
 
1239
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1240
  tts_api.infer(
1241
- gen_text=gen_text.lower().strip(),
1242
- ref_text=ref_text.lower().strip(),
1243
  ref_file=ref_audio,
 
 
1244
  nfe_step=nfe_step,
1245
- file_wave=f.name,
1246
  speed=speed,
1247
- seed=seed,
1248
  remove_silence=remove_silence,
 
 
1249
  )
1250
  return f.name, tts_api.device, str(tts_api.seed)
1251
 
@@ -1404,14 +1380,14 @@ def get_audio_select(file_sample):
1404
  with gr.Blocks() as app:
1405
  gr.Markdown(
1406
  """
1407
- # E2/F5 TTS Automatic Finetune
1408
 
1409
- This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
1410
 
1411
  * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
1412
  * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
1413
 
1414
- The checkpoints support English and Chinese.
1415
 
1416
  For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
1417
  """
@@ -1488,7 +1464,9 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
1488
  Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
1489
  ```""")
1490
 
1491
- exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
 
 
1492
 
1493
  with gr.Row():
1494
  txt_extend = gr.Textbox(
@@ -1557,9 +1535,9 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
1557
  fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
1558
  )
1559
 
1560
- with gr.TabItem("Train Data"):
1561
  gr.Markdown("""```plaintext
1562
- The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed.
1563
  If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
1564
  ```""")
1565
  with gr.Row():
@@ -1573,11 +1551,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1573
  file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
1574
 
1575
  with gr.Row():
1576
- exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
 
 
1577
  learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
1578
 
1579
  with gr.Row():
1580
- batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
1581
  max_samples = gr.Number(label="Max Samples", value=64)
1582
 
1583
  with gr.Row():
@@ -1585,23 +1565,23 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1585
  max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
1586
 
1587
  with gr.Row():
1588
- epochs = gr.Number(label="Epochs", value=10)
1589
- num_warmup_updates = gr.Number(label="Warmup Updates", value=2)
1590
 
1591
  with gr.Row():
1592
- save_per_updates = gr.Number(label="Save per Updates", value=300)
1593
  keep_last_n_checkpoints = gr.Number(
1594
  label="Keep Last N Checkpoints",
1595
  value=-1,
1596
  step=1,
1597
  precision=0,
1598
- info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
1599
  )
1600
  last_per_updates = gr.Number(label="Last per Updates", value=100)
1601
 
1602
  with gr.Row():
1603
  ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
1604
- mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
1605
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1606
  start_button = gr.Button("Start Training")
1607
  stop_button = gr.Button("Stop Training", interactive=False)
@@ -1718,23 +1698,21 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1718
  fn=calculate_train,
1719
  inputs=[
1720
  cm_project,
 
 
 
1721
  batch_size_type,
1722
  max_samples,
1723
- learning_rate,
1724
  num_warmup_updates,
1725
- save_per_updates,
1726
- last_per_updates,
1727
  ch_finetune,
1728
  ],
1729
  outputs=[
 
 
1730
  batch_size_per_gpu,
1731
  max_samples,
1732
  num_warmup_updates,
1733
- save_per_updates,
1734
- last_per_updates,
1735
  lb_samples,
1736
- learning_rate,
1737
- epochs,
1738
  ],
1739
  )
1740
 
@@ -1744,25 +1722,25 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1744
 
1745
  def setup_load_settings():
1746
  output_components = [
1747
- exp_name, # 1
1748
- learning_rate, # 2
1749
- batch_size_per_gpu, # 3
1750
- batch_size_type, # 4
1751
- max_samples, # 5
1752
- grad_accumulation_steps, # 6
1753
- max_grad_norm, # 7
1754
- epochs, # 8
1755
- num_warmup_updates, # 9
1756
- save_per_updates, # 10
1757
- keep_last_n_checkpoints, # 11
1758
- last_per_updates, # 12
1759
- ch_finetune, # 13
1760
- file_checkpoint_train, # 14
1761
- tokenizer_type, # 15
1762
- tokenizer_file, # 16
1763
- mixed_precision, # 17
1764
- cd_logger, # 18
1765
- ch_8bit_adam, # 19
1766
  ]
1767
  return output_components
1768
 
@@ -1784,7 +1762,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1784
  gr.Markdown("""```plaintext
1785
  SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
1786
  ```""")
1787
- exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
 
 
1788
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1789
 
1790
  with gr.Row():
@@ -1838,9 +1818,9 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
1838
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1839
  cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1840
 
1841
- with gr.TabItem("Reduce Checkpoint"):
1842
  gr.Markdown("""```plaintext
1843
- Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or fine-tuning afterward, but it cannot be used to continue training.
1844
  ```""")
1845
  txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
1846
  txt_path_checkpoint_small = gr.Text(label="Path to Output:")
 
 
 
 
 
1
  import gc
2
  import json
3
+ import numpy as np
4
  import os
5
  import platform
6
  import psutil
7
+ import queue
8
  import random
9
+ import re
10
  import signal
11
  import shutil
12
  import subprocess
13
  import sys
14
  import tempfile
15
+ import threading
16
  import time
17
  from glob import glob
18
+ from importlib.resources import files
19
+ from scipy.io import wavfile
20
 
21
  import click
22
  import gradio as gr
23
  import librosa
 
24
  import torch
25
  import torchaudio
26
+ from cached_path import cached_path
27
  from datasets import Dataset as Dataset_
28
  from datasets.arrow_writer import ArrowWriter
29
+ from safetensors.torch import load_file, save_file
30
+
 
31
  from f5_tts.api import F5TTS
32
  from f5_tts.model.utils import convert_char_to_pinyin
33
  from f5_tts.infer.utils_infer import transcribe
 
34
 
35
 
36
  training_process = None
 
118
 
119
  # Default settings
120
  default_settings = {
121
+ "exp_name": "F5TTS_v1_Base",
122
+ "learning_rate": 1e-5,
123
+ "batch_size_per_gpu": 1,
124
+ "batch_size_type": "sample",
125
  "max_samples": 64,
126
+ "grad_accumulation_steps": 4,
127
  "max_grad_norm": 1,
128
  "epochs": 100,
129
+ "num_warmup_updates": 100,
130
+ "save_per_updates": 500,
131
  "keep_last_n_checkpoints": -1,
132
  "last_per_updates": 100,
133
  "finetune": True,
 
362
 
363
  def start_training(
364
  dataset_name="",
365
+ exp_name="F5TTS_v1_Base",
366
+ learning_rate=1e-5,
367
+ batch_size_per_gpu=1,
368
+ batch_size_type="sample",
369
  max_samples=64,
370
+ grad_accumulation_steps=4,
371
  max_grad_norm=1.0,
372
+ epochs=100,
373
+ num_warmup_updates=100,
374
+ save_per_updates=500,
375
  keep_last_n_checkpoints=-1,
376
+ last_per_updates=100,
377
  finetune=True,
378
  file_checkpoint_train="",
379
  tokenizer_type="pinyin",
 
797
  print(f"Error processing {file_audio}: {e}")
798
  continue
799
 
800
+ if duration < 1 or duration > 30:
801
+ if duration > 30:
802
+ error_files.append([file_audio, "duration > 30 sec"])
803
  if duration < 1:
804
  error_files.append([file_audio, "duration < 1 sec "])
805
  continue
806
  if len(text) < 3:
807
+ error_files.append([file_audio, "very short text length 3"])
808
  continue
809
 
810
  text = clear_text(text)
 
871
 
872
  def calculate_train(
873
  name_project,
874
+ epochs,
875
+ learning_rate,
876
+ batch_size_per_gpu,
877
  batch_size_type,
878
  max_samples,
 
879
  num_warmup_updates,
 
 
880
  finetune,
881
  ):
882
  path_project = os.path.join(path_data, name_project)
883
+ file_duration = os.path.join(path_project, "duration.json")
884
+
885
+ hop_length = 256
886
+ sampling_rate = 24000
887
 
888
+ if not os.path.isfile(file_duration):
889
  return (
890
+ epochs,
891
+ learning_rate,
892
+ batch_size_per_gpu,
893
  max_samples,
894
  num_warmup_updates,
 
 
895
  "project not found !",
 
896
  )
897
 
898
+ with open(file_duration, "r") as file:
899
  data = json.load(file)
900
 
901
  duration_list = data["duration"]
902
+ max_sample_length = max(duration_list) * sampling_rate / hop_length
903
+ total_samples = len(duration_list)
904
+ total_duration = sum(duration_list)
 
 
 
 
 
905
 
906
  if torch.cuda.is_available():
907
  gpu_count = torch.cuda.device_count()
 
909
  for i in range(gpu_count):
910
  gpu_properties = torch.cuda.get_device_properties(i)
911
  total_memory += gpu_properties.total_memory / (1024**3) # in GB
 
912
  elif torch.xpu.is_available():
913
  gpu_count = torch.xpu.device_count()
914
  total_memory = 0
915
  for i in range(gpu_count):
916
  gpu_properties = torch.xpu.get_device_properties(i)
917
  total_memory += gpu_properties.total_memory / (1024**3)
 
918
  elif torch.backends.mps.is_available():
919
  gpu_count = 1
920
  total_memory = psutil.virtual_memory().available / (1024**3)
921
 
922
+ avg_gpu_memory = total_memory / gpu_count
923
+
924
+ # rough estimate of batch size
925
  if batch_size_type == "frame":
926
+ batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
927
+ elif batch_size_type == "sample":
928
+ batch_size_per_gpu = int(200 / (total_duration / total_samples))
 
 
 
 
929
 
930
+ if total_samples < 64:
931
+ max_samples = int(total_samples * 0.25)
932
 
933
+ num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
934
+
935
+ # take 1.2M updates as the maximum
936
+ max_updates = 1200000
937
+
938
+ if batch_size_type == "frame":
939
+ mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
940
+ updates_per_epoch = total_duration / mini_batch_duration
941
+ elif batch_size_type == "sample":
942
+ updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
943
+
944
+ epochs = int(max_updates / updates_per_epoch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
945
 
946
  if finetune:
947
  learning_rate = 1e-5
 
949
  learning_rate = 7.5e-5
950
 
951
  return (
952
+ epochs,
953
+ learning_rate,
954
  batch_size_per_gpu,
955
  max_samples,
956
  num_warmup_updates,
957
+ total_samples,
 
 
 
 
958
  )
959
 
960
 
 
991
  torch.backends.cudnn.deterministic = True
992
  torch.backends.cudnn.benchmark = False
993
 
994
+ if ckpt_path.endswith(".safetensors"):
995
+ ckpt = load_file(ckpt_path, device="cpu")
996
+ ckpt = {"ema_model_state_dict": ckpt}
997
+ elif ckpt_path.endswith(".pt"):
998
+ ckpt = torch.load(ckpt_path, map_location="cpu")
999
 
1000
  ema_sd = ckpt.get("ema_model_state_dict", {})
1001
  embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
 
1063
  with open(file_vocab_project, "w", encoding="utf-8") as f:
1064
  f.write("\n".join(vocab))
1065
 
1066
+ if model_type == "F5TTS_v1_Base":
1067
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
1068
+ elif model_type == "F5TTS_Base":
1069
  ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
1070
+ elif model_type == "E2TTS_Base":
1071
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
1072
 
1073
  vocab_size_new = len(miss_symbols)
 
1077
  os.makedirs(new_ckpt_path, exist_ok=True)
1078
 
1079
  # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
1080
+ new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
1081
 
1082
  size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
1083
 
 
1207
  vocab_file = os.path.join(path_data, project, "vocab.txt")
1208
 
1209
  tts_api = F5TTS(
1210
+ model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
1211
  )
1212
 
1213
  print("update >> ", device_test, file_checkpoint, use_ema)
1214
 
1215
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1216
  tts_api.infer(
 
 
1217
  ref_file=ref_audio,
1218
+ ref_text=ref_text.lower().strip(),
1219
+ gen_text=gen_text.lower().strip(),
1220
  nfe_step=nfe_step,
 
1221
  speed=speed,
 
1222
  remove_silence=remove_silence,
1223
+ file_wave=f.name,
1224
+ seed=seed,
1225
  )
1226
  return f.name, tts_api.device, str(tts_api.seed)
1227
 
 
1380
  with gr.Blocks() as app:
1381
  gr.Markdown(
1382
  """
1383
+ # F5 TTS Automatic Finetune
1384
 
1385
+ This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
1386
 
1387
  * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
1388
  * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
1389
 
1390
+ The pretrained checkpoints support English and Chinese.
1391
 
1392
  For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
1393
  """
 
1464
  Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
1465
  ```""")
1466
 
1467
+ exp_name_extend = gr.Radio(
1468
+ label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
1469
+ )
1470
 
1471
  with gr.Row():
1472
  txt_extend = gr.Textbox(
 
1535
  fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
1536
  )
1537
 
1538
+ with gr.TabItem("Train Model"):
1539
  gr.Markdown("""```plaintext
1540
+ The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
1541
  If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
1542
  ```""")
1543
  with gr.Row():
 
1551
  file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
1552
 
1553
  with gr.Row():
1554
+ exp_name = gr.Radio(
1555
+ label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
1556
+ )
1557
  learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
1558
 
1559
  with gr.Row():
1560
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
1561
  max_samples = gr.Number(label="Max Samples", value=64)
1562
 
1563
  with gr.Row():
 
1565
  max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
1566
 
1567
  with gr.Row():
1568
+ epochs = gr.Number(label="Epochs", value=100)
1569
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
1570
 
1571
  with gr.Row():
1572
+ save_per_updates = gr.Number(label="Save per Updates", value=500)
1573
  keep_last_n_checkpoints = gr.Number(
1574
  label="Keep Last N Checkpoints",
1575
  value=-1,
1576
  step=1,
1577
  precision=0,
1578
+ info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
1579
  )
1580
  last_per_updates = gr.Number(label="Last per Updates", value=100)
1581
 
1582
  with gr.Row():
1583
  ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
1584
+ mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
1585
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1586
  start_button = gr.Button("Start Training")
1587
  stop_button = gr.Button("Stop Training", interactive=False)
 
1698
  fn=calculate_train,
1699
  inputs=[
1700
  cm_project,
1701
+ epochs,
1702
+ learning_rate,
1703
+ batch_size_per_gpu,
1704
  batch_size_type,
1705
  max_samples,
 
1706
  num_warmup_updates,
 
 
1707
  ch_finetune,
1708
  ],
1709
  outputs=[
1710
+ epochs,
1711
+ learning_rate,
1712
  batch_size_per_gpu,
1713
  max_samples,
1714
  num_warmup_updates,
 
 
1715
  lb_samples,
 
 
1716
  ],
1717
  )
1718
 
 
1722
 
1723
  def setup_load_settings():
1724
  output_components = [
1725
+ exp_name,
1726
+ learning_rate,
1727
+ batch_size_per_gpu,
1728
+ batch_size_type,
1729
+ max_samples,
1730
+ grad_accumulation_steps,
1731
+ max_grad_norm,
1732
+ epochs,
1733
+ num_warmup_updates,
1734
+ save_per_updates,
1735
+ keep_last_n_checkpoints,
1736
+ last_per_updates,
1737
+ ch_finetune,
1738
+ file_checkpoint_train,
1739
+ tokenizer_type,
1740
+ tokenizer_file,
1741
+ mixed_precision,
1742
+ cd_logger,
1743
+ ch_8bit_adam,
1744
  ]
1745
  return output_components
1746
 
 
1762
  gr.Markdown("""```plaintext
1763
  SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
1764
  ```""")
1765
+ exp_name = gr.Radio(
1766
+ label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
1767
+ )
1768
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1769
 
1770
  with gr.Row():
 
1818
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1819
  cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1820
 
1821
+ with gr.TabItem("Prune Checkpoint"):
1822
  gr.Markdown("""```plaintext
1823
+ Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
1824
  ```""")
1825
  txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
1826
  txt_path_checkpoint_small = gr.Text(label="Path to Output:")
src/f5_tts/train/train.py CHANGED
@@ -4,8 +4,9 @@ import os
4
  from importlib.resources import files
5
 
6
  import hydra
 
7
 
8
- from f5_tts.model import CFM, DiT, Trainer, UNetT
9
  from f5_tts.model.dataset import load_dataset
10
  from f5_tts.model.utils import get_tokenizer
11
 
@@ -14,9 +15,13 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
14
 
15
  @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
16
  def main(cfg):
 
 
17
  tokenizer = cfg.model.tokenizer
18
  mel_spec_type = cfg.model.mel_spec.mel_spec_type
 
19
  exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
 
20
 
21
  # set text tokenizer
22
  if tokenizer != "custom":
@@ -26,14 +31,8 @@ def main(cfg):
26
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
27
 
28
  # set model
29
- if "F5TTS" in cfg.model.name:
30
- model_cls = DiT
31
- elif "E2TTS" in cfg.model.name:
32
- model_cls = UNetT
33
- wandb_resume_id = None
34
-
35
  model = CFM(
36
- transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
37
  mel_spec_kwargs=cfg.model.mel_spec,
38
  vocab_char_map=vocab_char_map,
39
  )
@@ -45,9 +44,9 @@ def main(cfg):
45
  learning_rate=cfg.optim.learning_rate,
46
  num_warmup_updates=cfg.optim.num_warmup_updates,
47
  save_per_updates=cfg.ckpts.save_per_updates,
48
- keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1),
49
  checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
50
- batch_size=cfg.datasets.batch_size_per_gpu,
51
  batch_size_type=cfg.datasets.batch_size_type,
52
  max_samples=cfg.datasets.max_samples,
53
  grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
@@ -57,11 +56,12 @@ def main(cfg):
57
  wandb_run_name=exp_name,
58
  wandb_resume_id=wandb_resume_id,
59
  last_per_updates=cfg.ckpts.last_per_updates,
60
- log_samples=True,
61
  bnb_optimizer=cfg.optim.bnb_optimizer,
62
  mel_spec_type=mel_spec_type,
63
  is_local_vocoder=cfg.model.vocoder.is_local,
64
  local_vocoder_path=cfg.model.vocoder.local_path,
 
65
  )
66
 
67
  train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
 
4
  from importlib.resources import files
5
 
6
  import hydra
7
+ from omegaconf import OmegaConf
8
 
9
+ from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
10
  from f5_tts.model.dataset import load_dataset
11
  from f5_tts.model.utils import get_tokenizer
12
 
 
15
 
16
  @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
17
  def main(cfg):
18
+ model_cls = globals()[cfg.model.backbone]
19
+ model_arc = cfg.model.arch
20
  tokenizer = cfg.model.tokenizer
21
  mel_spec_type = cfg.model.mel_spec.mel_spec_type
22
+
23
  exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
24
+ wandb_resume_id = None
25
 
26
  # set text tokenizer
27
  if tokenizer != "custom":
 
31
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
32
 
33
  # set model
 
 
 
 
 
 
34
  model = CFM(
35
+ transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
36
  mel_spec_kwargs=cfg.model.mel_spec,
37
  vocab_char_map=vocab_char_map,
38
  )
 
44
  learning_rate=cfg.optim.learning_rate,
45
  num_warmup_updates=cfg.optim.num_warmup_updates,
46
  save_per_updates=cfg.ckpts.save_per_updates,
47
+ keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
48
  checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
49
+ batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
50
  batch_size_type=cfg.datasets.batch_size_type,
51
  max_samples=cfg.datasets.max_samples,
52
  grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
 
56
  wandb_run_name=exp_name,
57
  wandb_resume_id=wandb_resume_id,
58
  last_per_updates=cfg.ckpts.last_per_updates,
59
+ log_samples=cfg.ckpts.log_samples,
60
  bnb_optimizer=cfg.optim.bnb_optimizer,
61
  mel_spec_type=mel_spec_type,
62
  is_local_vocoder=cfg.model.vocoder.is_local,
63
  local_vocoder_path=cfg.model.vocoder.local_path,
64
+ cfg_dict=OmegaConf.to_container(cfg, resolve=True),
65
  )
66
 
67
  train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)