SWivid commited on
Commit
c93462f
·
1 Parent(s): 649b46e

feature. allow custom model config for gradio infer

Browse files
README.md CHANGED
@@ -150,7 +150,7 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
150
  - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
151
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
- - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN/tree/main) as vocoder
154
  - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
 
150
  - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
151
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
+ - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
154
  - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
src/f5_tts/configs/F5TTS_Base_train.yaml CHANGED
@@ -28,7 +28,7 @@ model:
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
31
- checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
 
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
31
+ checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
src/f5_tts/configs/F5TTS_Small_train.yaml CHANGED
@@ -28,7 +28,7 @@ model:
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
31
- checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
 
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
31
+ checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
src/f5_tts/infer/SHARED.md CHANGED
@@ -16,33 +16,34 @@
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
- - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
- - [Finnish Common\_Voice Vox\_Populi @ finetune @ fi](#finnish-common_voice-vox_populi--finetune--fi)
23
  - [French](#french)
24
- - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
25
  - [Hindi](#hindi)
26
- - [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi)
27
  - [Italian](#italian)
28
- - [F5-TTS Italian @ finetune @ it](#f5-tts-italian--finetune--it)
29
  - [Japanese](#japanese)
30
- - [F5-TTS Japanese @ pretrain/finetune @ ja](#f5-tts-japanese--pretrainfinetune--ja)
31
  - [Mandarin](#mandarin)
32
  - [Spanish](#spanish)
33
- - [F5-TTS Spanish @ pretrain/finetune @ es](#f5-tts-spanish--pretrainfinetune--es)
34
 
35
 
36
  ## Multilingual
37
 
38
- #### F5-TTS Base @ pretrain @ zh & en
39
  |Model|🤗Hugging Face|Data (Hours)|Model License|
40
  |:---:|:------------:|:-----------:|:-------------:|
41
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
42
 
43
  ```bash
44
- MODEL_CKPT: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
45
- VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
 
46
  ```
47
 
48
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
@@ -53,27 +54,29 @@ VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
53
 
54
  ## Finnish
55
 
56
- #### Finnish Common_Voice Vox_Populi @ finetune @ fi
57
  |Model|🤗Hugging Face|Data|Model License|
58
  |:---:|:------------:|:-----------:|:-------------:|
59
- |F5-TTS Finnish|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
60
 
61
  ```bash
62
- MODEL_CKPT: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
63
- VOCAB_FILE: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
 
64
  ```
65
 
66
 
67
  ## French
68
 
69
- #### French LibriVox @ finetune @ fr
70
  |Model|🤗Hugging Face|Data (Hours)|Model License|
71
  |:---:|:------------:|:-----------:|:-------------:|
72
- |F5-TTS French|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
73
 
74
  ```bash
75
- MODEL_CKPT: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
76
- VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
 
77
  ```
78
 
79
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
@@ -83,31 +86,32 @@ VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
83
 
84
  ## Hindi
85
 
86
- #### F5-TTS Small @ pretrain @ hi
87
  |Model|🤗Hugging Face|Data (Hours)|Model License|
88
  |:---:|:------------:|:-----------:|:-------------:|
89
  |F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
90
 
91
  ```bash
92
- MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
93
- VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
 
94
  ```
95
 
96
- Authors: SPRING Lab, Indian Institute of Technology, Madras
97
- <br>
98
- Website: https://asr.iitm.ac.in/
99
 
100
 
101
  ## Italian
102
 
103
- #### F5-TTS Italian @ finetune @ it
104
  |Model|🤗Hugging Face|Data|Model License|
105
  |:---:|:------------:|:-----------:|:-------------:|
106
- |F5-TTS Italian|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
107
 
108
  ```bash
109
- MODEL_CKPT: hf://alien79/F5-TTS-italian/model_159600.safetensors
110
- VOCAB_FILE: hf://alien79/F5-TTS-italian/vocab.txt
 
111
  ```
112
 
113
  - Trained by [Mithril Man](https://github.com/MithrilMan)
@@ -117,14 +121,15 @@ VOCAB_FILE: hf://alien79/F5-TTS-italian/vocab.txt
117
 
118
  ## Japanese
119
 
120
- #### F5-TTS Japanese @ pretrain/finetune @ ja
121
  |Model|🤗Hugging Face|Data (Hours)|Model License|
122
  |:---:|:------------:|:-----------:|:-------------:|
123
- |F5-TTS Japanese|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
124
 
125
  ```bash
126
- MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
127
- VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
 
128
  ```
129
 
130
 
@@ -133,9 +138,9 @@ VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
133
 
134
  ## Spanish
135
 
136
- #### F5-TTS Spanish @ pretrain/finetune @ es
137
  |Model|🤗Hugging Face|Data (Hours)|Model License|
138
  |:---:|:------------:|:-----------:|:-------------:|
139
- |F5-TTS Spanish|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
140
 
141
  - @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
 
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
+ - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
+ - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
23
  - [French](#french)
24
+ - [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
25
  - [Hindi](#hindi)
26
+ - [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
27
  - [Italian](#italian)
28
+ - [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
29
  - [Japanese](#japanese)
30
+ - [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
31
  - [Mandarin](#mandarin)
32
  - [Spanish](#spanish)
33
+ - [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)
34
 
35
 
36
  ## Multilingual
37
 
38
+ #### F5-TTS Base @ zh & en @ F5-TTS
39
  |Model|🤗Hugging Face|Data (Hours)|Model License|
40
  |:---:|:------------:|:-----------:|:-------------:|
41
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
42
 
43
  ```bash
44
+ Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
45
+ Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
46
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
47
  ```
48
 
49
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
 
54
 
55
  ## Finnish
56
 
57
+ #### F5-TTS Base @ fi @ AsmoKoskinen
58
  |Model|🤗Hugging Face|Data|Model License|
59
  |:---:|:------------:|:-----------:|:-------------:|
60
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
61
 
62
  ```bash
63
+ Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
64
+ Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
65
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
66
  ```
67
 
68
 
69
  ## French
70
 
71
+ #### F5-TTS Base @ fr @ RASPIAUDIO
72
  |Model|🤗Hugging Face|Data (Hours)|Model License|
73
  |:---:|:------------:|:-----------:|:-------------:|
74
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
75
 
76
  ```bash
77
+ Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
78
+ Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
79
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
80
  ```
81
 
82
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
 
86
 
87
  ## Hindi
88
 
89
+ #### F5-TTS Small @ hi @ SPRINGLab
90
  |Model|🤗Hugging Face|Data (Hours)|Model License|
91
  |:---:|:------------:|:-----------:|:-------------:|
92
  |F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
93
 
94
  ```bash
95
+ Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
96
+ Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
97
+ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
98
  ```
99
 
100
+ - Authors: SPRING Lab, Indian Institute of Technology, Madras
101
+ - Website: https://asr.iitm.ac.in/
 
102
 
103
 
104
  ## Italian
105
 
106
+ #### F5-TTS Base @ it @ alien79
107
  |Model|🤗Hugging Face|Data|Model License|
108
  |:---:|:------------:|:-----------:|:-------------:|
109
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
110
 
111
  ```bash
112
+ Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
113
+ Vocab: hf://alien79/F5-TTS-italian/vocab.txt
114
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
115
  ```
116
 
117
  - Trained by [Mithril Man](https://github.com/MithrilMan)
 
121
 
122
  ## Japanese
123
 
124
+ #### F5-TTS Base @ ja @ Jmica
125
  |Model|🤗Hugging Face|Data (Hours)|Model License|
126
  |:---:|:------------:|:-----------:|:-------------:|
127
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
128
 
129
  ```bash
130
+ Model: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
131
+ Vocab: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
132
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
133
  ```
134
 
135
 
 
138
 
139
  ## Spanish
140
 
141
+ #### F5-TTS Base @ es @ jpgallegoar
142
  |Model|🤗Hugging Face|Data (Hours)|Model License|
143
  |:---:|:------------:|:-----------:|:-------------:|
144
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
145
 
146
  - @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
src/f5_tts/infer/infer_cli.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  import soundfile as sf
11
  import tomli
12
  from cached_path import cached_path
 
13
 
14
  from f5_tts.infer.utils_infer import (
15
  mel_spec_type,
@@ -51,6 +52,12 @@ parser.add_argument(
51
  type=str,
52
  help="The model name: F5-TTS | E2-TTS",
53
  )
 
 
 
 
 
 
54
  parser.add_argument(
55
  "-p",
56
  "--ckpt_file",
@@ -166,6 +173,7 @@ config = tomli.load(open(args.config, "rb"))
166
  # command-line interface parameters
167
 
168
  model = args.model or config.get("model", "F5-TTS")
 
169
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
170
  vocab_file = args.vocab_file or config.get("vocab_file", "")
171
 
@@ -179,9 +187,9 @@ output_file = args.output_file or config.get(
179
  "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
180
  )
181
 
182
- save_chunk = args.save_chunk
183
- remove_silence = args.remove_silence
184
- load_vocoder_from_local = args.load_vocoder_from_local
185
 
186
  vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
187
  target_rms = args.target_rms or config.get("target_rms", target_rms)
@@ -235,7 +243,7 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
235
 
236
  if model == "F5-TTS":
237
  model_cls = DiT
238
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
239
  if not ckpt_file: # path not specified, download from repo
240
  if vocoder_name == "vocos":
241
  repo_name = "F5-TTS"
@@ -250,7 +258,8 @@ if model == "F5-TTS":
250
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
251
 
252
  elif model == "E2-TTS":
253
- assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
 
254
  model_cls = UNetT
255
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
256
  if not ckpt_file: # path not specified, download from repo
 
10
  import soundfile as sf
11
  import tomli
12
  from cached_path import cached_path
13
+ from omegaconf import OmegaConf
14
 
15
  from f5_tts.infer.utils_infer import (
16
  mel_spec_type,
 
52
  type=str,
53
  help="The model name: F5-TTS | E2-TTS",
54
  )
55
+ parser.add_argument(
56
+ "-mc",
57
+ "--model_cfg",
58
+ type=str,
59
+ help="The path to F5-TTS model config file .yaml",
60
+ )
61
  parser.add_argument(
62
  "-p",
63
  "--ckpt_file",
 
173
  # command-line interface parameters
174
 
175
  model = args.model or config.get("model", "F5-TTS")
176
+ model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
177
  ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
178
  vocab_file = args.vocab_file or config.get("vocab_file", "")
179
 
 
187
  "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
188
  )
189
 
190
+ save_chunk = args.save_chunk or config.get("save_chunk", False)
191
+ remove_silence = args.remove_silence or config.get("remove_silence", False)
192
+ load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
193
 
194
  vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
195
  target_rms = args.target_rms or config.get("target_rms", target_rms)
 
243
 
244
  if model == "F5-TTS":
245
  model_cls = DiT
246
+ model_cfg = OmegaConf.load(model_cfg).model.arch
247
  if not ckpt_file: # path not specified, download from repo
248
  if vocoder_name == "vocos":
249
  repo_name = "F5-TTS"
 
258
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
259
 
260
  elif model == "E2-TTS":
261
+ assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
262
+ assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
263
  model_cls = UNetT
264
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
265
  if not ckpt_file: # path not specified, download from repo
src/f5_tts/infer/infer_gradio.py CHANGED
@@ -1,6 +1,7 @@
1
  # ruff: noqa: E402
2
  # Above allows ruff to ignore E402: module level import not at top of file
3
 
 
4
  import re
5
  import tempfile
6
  from collections import OrderedDict
@@ -43,6 +44,12 @@ from f5_tts.infer.utils_infer import (
43
  DEFAULT_TTS_MODEL = "F5-TTS"
44
  tts_model_choice = DEFAULT_TTS_MODEL
45
 
 
 
 
 
 
 
46
 
47
  # load models
48
 
@@ -103,7 +110,15 @@ def generate_response(messages, model, tokenizer):
103
 
104
  @gpu_decorator
105
  def infer(
106
- ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
 
 
 
 
 
 
 
 
107
  ):
108
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
109
 
@@ -120,7 +135,7 @@ def infer(
120
  global custom_ema_model, pre_custom_path
121
  if pre_custom_path != model[1]:
122
  show_info("Loading Custom TTS model...")
123
- custom_ema_model = load_custom(model[1], vocab_path=model[2])
124
  pre_custom_path = model[1]
125
  ema_model = custom_ema_model
126
 
@@ -131,6 +146,7 @@ def infer(
131
  ema_model,
132
  vocoder,
133
  cross_fade_duration=cross_fade_duration,
 
134
  speed=speed,
135
  show_info=show_info,
136
  progress=gr.Progress(),
@@ -184,6 +200,14 @@ with gr.Blocks() as app_tts:
184
  step=0.1,
185
  info="Adjust the speed of the audio.",
186
  )
 
 
 
 
 
 
 
 
187
  cross_fade_duration_slider = gr.Slider(
188
  label="Cross-Fade Duration (s)",
189
  minimum=0.0,
@@ -203,6 +227,7 @@ with gr.Blocks() as app_tts:
203
  gen_text_input,
204
  remove_silence,
205
  cross_fade_duration_slider,
 
206
  speed_slider,
207
  ):
208
  audio_out, spectrogram_path, ref_text_out = infer(
@@ -211,8 +236,9 @@ with gr.Blocks() as app_tts:
211
  gen_text_input,
212
  tts_model_choice,
213
  remove_silence,
214
- cross_fade_duration_slider,
215
- speed_slider,
 
216
  )
217
  return audio_out, spectrogram_path, gr.update(value=ref_text_out)
218
 
@@ -224,6 +250,7 @@ with gr.Blocks() as app_tts:
224
  gen_text_input,
225
  remove_silence,
226
  cross_fade_duration_slider,
 
227
  speed_slider,
228
  ],
229
  outputs=[audio_output, spectrogram_output, ref_text_input],
@@ -744,34 +771,38 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
744
  """
745
  )
746
 
747
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom.txt")
748
 
749
  def load_last_used_custom():
750
  try:
751
- with open(last_used_custom, "r") as f:
752
- return f.read().split(",")
 
 
 
753
  except FileNotFoundError:
754
  last_used_custom.parent.mkdir(parents=True, exist_ok=True)
755
- return [
756
- "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
757
- "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
758
- ]
759
 
760
  def switch_tts_model(new_choice):
761
  global tts_model_choice
762
  if new_choice == "Custom": # override in case webpage is refreshed
763
- custom_ckpt_path, custom_vocab_path = load_last_used_custom()
764
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
765
- return gr.update(visible=True, value=custom_ckpt_path), gr.update(visible=True, value=custom_vocab_path)
 
 
 
 
766
  else:
767
  tts_model_choice = new_choice
768
- return gr.update(visible=False), gr.update(visible=False)
769
 
770
- def set_custom_model(custom_ckpt_path, custom_vocab_path):
771
  global tts_model_choice
772
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
773
- with open(last_used_custom, "w") as f:
774
- f.write(f"{custom_ckpt_path},{custom_vocab_path}")
775
 
776
  with gr.Row():
777
  if not USING_SPACES:
@@ -783,34 +814,46 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
783
  choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
784
  )
785
  custom_ckpt_path = gr.Dropdown(
786
- choices=["hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"],
787
  value=load_last_used_custom()[0],
788
  allow_custom_value=True,
789
- label="MODEL CKPT: local_path | hf://user_id/repo_id/model_ckpt",
790
  visible=False,
791
  )
792
  custom_vocab_path = gr.Dropdown(
793
- choices=["hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt"],
794
  value=load_last_used_custom()[1],
795
  allow_custom_value=True,
796
- label="VOCAB FILE: local_path | hf://user_id/repo_id/vocab_file",
 
 
 
 
 
 
 
797
  visible=False,
798
  )
799
 
800
  choose_tts_model.change(
801
  switch_tts_model,
802
  inputs=[choose_tts_model],
803
- outputs=[custom_ckpt_path, custom_vocab_path],
804
  show_progress="hidden",
805
  )
806
  custom_ckpt_path.change(
807
  set_custom_model,
808
- inputs=[custom_ckpt_path, custom_vocab_path],
809
  show_progress="hidden",
810
  )
811
  custom_vocab_path.change(
812
  set_custom_model,
813
- inputs=[custom_ckpt_path, custom_vocab_path],
 
 
 
 
 
814
  show_progress="hidden",
815
  )
816
 
 
1
  # ruff: noqa: E402
2
  # Above allows ruff to ignore E402: module level import not at top of file
3
 
4
+ import json
5
  import re
6
  import tempfile
7
  from collections import OrderedDict
 
44
  DEFAULT_TTS_MODEL = "F5-TTS"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
+ DEFAULT_TTS_MODEL_CFG = [
48
+ "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
49
+ "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
50
+ json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
+ ]
52
+
53
 
54
  # load models
55
 
 
110
 
111
  @gpu_decorator
112
  def infer(
113
+ ref_audio_orig,
114
+ ref_text,
115
+ gen_text,
116
+ model,
117
+ remove_silence,
118
+ cross_fade_duration=0.15,
119
+ nfe_step=32,
120
+ speed=1,
121
+ show_info=gr.Info,
122
  ):
123
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
124
 
 
135
  global custom_ema_model, pre_custom_path
136
  if pre_custom_path != model[1]:
137
  show_info("Loading Custom TTS model...")
138
+ custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
139
  pre_custom_path = model[1]
140
  ema_model = custom_ema_model
141
 
 
146
  ema_model,
147
  vocoder,
148
  cross_fade_duration=cross_fade_duration,
149
+ nfe_step=nfe_step,
150
  speed=speed,
151
  show_info=show_info,
152
  progress=gr.Progress(),
 
200
  step=0.1,
201
  info="Adjust the speed of the audio.",
202
  )
203
+ nfe_slider = gr.Slider(
204
+ label="NFE Steps",
205
+ minimum=4,
206
+ maximum=64,
207
+ value=32,
208
+ step=2,
209
+ info="Set the number of denoising steps.",
210
+ )
211
  cross_fade_duration_slider = gr.Slider(
212
  label="Cross-Fade Duration (s)",
213
  minimum=0.0,
 
227
  gen_text_input,
228
  remove_silence,
229
  cross_fade_duration_slider,
230
+ nfe_slider,
231
  speed_slider,
232
  ):
233
  audio_out, spectrogram_path, ref_text_out = infer(
 
236
  gen_text_input,
237
  tts_model_choice,
238
  remove_silence,
239
+ cross_fade_duration=cross_fade_duration_slider,
240
+ nfe_step=nfe_slider,
241
+ speed=speed_slider,
242
  )
243
  return audio_out, spectrogram_path, gr.update(value=ref_text_out)
244
 
 
250
  gen_text_input,
251
  remove_silence,
252
  cross_fade_duration_slider,
253
+ nfe_slider,
254
  speed_slider,
255
  ],
256
  outputs=[audio_output, spectrogram_output, ref_text_input],
 
771
  """
772
  )
773
 
774
+ last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
775
 
776
  def load_last_used_custom():
777
  try:
778
+ custom = []
779
+ with open(last_used_custom, "r", encoding='utf-8') as f:
780
+ for line in f:
781
+ custom.append(line.strip())
782
+ return custom
783
  except FileNotFoundError:
784
  last_used_custom.parent.mkdir(parents=True, exist_ok=True)
785
+ return DEFAULT_TTS_MODEL_CFG
 
 
 
786
 
787
  def switch_tts_model(new_choice):
788
  global tts_model_choice
789
  if new_choice == "Custom": # override in case webpage is refreshed
790
+ custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
791
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
792
+ return (
793
+ gr.update(visible=True, value=custom_ckpt_path),
794
+ gr.update(visible=True, value=custom_vocab_path),
795
+ gr.update(visible=True, value=custom_model_cfg),
796
+ )
797
  else:
798
  tts_model_choice = new_choice
799
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
800
 
801
+ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
802
  global tts_model_choice
803
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
804
+ with open(last_used_custom, "w", encoding='utf-8') as f:
805
+ f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
806
 
807
  with gr.Row():
808
  if not USING_SPACES:
 
814
  choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
815
  )
816
  custom_ckpt_path = gr.Dropdown(
817
+ choices=[DEFAULT_TTS_MODEL_CFG[0]],
818
  value=load_last_used_custom()[0],
819
  allow_custom_value=True,
820
+ label="Model: local_path | hf://user_id/repo_id/model_ckpt",
821
  visible=False,
822
  )
823
  custom_vocab_path = gr.Dropdown(
824
+ choices=[DEFAULT_TTS_MODEL_CFG[1]],
825
  value=load_last_used_custom()[1],
826
  allow_custom_value=True,
827
+ label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
828
+ visible=False,
829
+ )
830
+ custom_model_cfg = gr.Dropdown(
831
+ choices=[DEFAULT_TTS_MODEL_CFG[2]],
832
+ value=load_last_used_custom()[2],
833
+ allow_custom_value=True,
834
+ label="Config: in a dictionary form",
835
  visible=False,
836
  )
837
 
838
  choose_tts_model.change(
839
  switch_tts_model,
840
  inputs=[choose_tts_model],
841
+ outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
842
  show_progress="hidden",
843
  )
844
  custom_ckpt_path.change(
845
  set_custom_model,
846
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
847
  show_progress="hidden",
848
  )
849
  custom_vocab_path.change(
850
  set_custom_model,
851
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
852
+ show_progress="hidden",
853
+ )
854
+ custom_model_cfg.change(
855
+ set_custom_model,
856
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
857
  show_progress="hidden",
858
  )
859
 
src/f5_tts/model/backbones/dit.py CHANGED
@@ -131,8 +131,7 @@ class DiT(nn.Module):
131
  self.checkpoint_activations = checkpoint_activations
132
 
133
  def ckpt_wrapper(self, module):
134
- """Code from https://github.com/chuanyangjin/fast-DiT/blob/1a8ecce58f346f877749f2dc67cdb190d295e4dc/models.py#L233-L237"""
135
-
136
  def ckpt_forward(*inputs):
137
  outputs = module(*inputs)
138
  return outputs
 
131
  self.checkpoint_activations = checkpoint_activations
132
 
133
  def ckpt_wrapper(self, module):
134
+ # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
 
135
  def ckpt_forward(*inputs):
136
  outputs = module(*inputs)
137
  return outputs