Spaces:
Configuration error
Configuration error
feature. allow custom model config for gradio infer
Browse files- README.md +1 -1
- src/f5_tts/configs/F5TTS_Base_train.yaml +1 -1
- src/f5_tts/configs/F5TTS_Small_train.yaml +1 -1
- src/f5_tts/infer/SHARED.md +39 -34
- src/f5_tts/infer/infer_cli.py +14 -5
- src/f5_tts/infer/infer_gradio.py +69 -26
- src/f5_tts/model/backbones/dit.py +1 -2
README.md
CHANGED
@@ -150,7 +150,7 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
|
|
150 |
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
|
151 |
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
|
152 |
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
|
153 |
-
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN
|
154 |
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
|
155 |
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
|
156 |
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
|
|
|
150 |
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
|
151 |
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
|
152 |
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
|
153 |
+
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
|
154 |
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
|
155 |
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
|
156 |
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
|
src/f5_tts/configs/F5TTS_Base_train.yaml
CHANGED
@@ -28,7 +28,7 @@ model:
|
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
30 |
conv_layers: 4
|
31 |
-
checkpoint_activations: False
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
34 |
n_mel_channels: 100
|
|
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
30 |
conv_layers: 4
|
31 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
34 |
n_mel_channels: 100
|
src/f5_tts/configs/F5TTS_Small_train.yaml
CHANGED
@@ -28,7 +28,7 @@ model:
|
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
30 |
conv_layers: 4
|
31 |
-
checkpoint_activations: False
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
34 |
n_mel_channels: 100
|
|
|
28 |
ff_mult: 2
|
29 |
text_dim: 512
|
30 |
conv_layers: 4
|
31 |
+
checkpoint_activations: False # recompute activations and save memory for extra compute
|
32 |
mel_spec:
|
33 |
target_sample_rate: 24000
|
34 |
n_mel_channels: 100
|
src/f5_tts/infer/SHARED.md
CHANGED
@@ -16,33 +16,34 @@
|
|
16 |
<!-- omit in toc -->
|
17 |
### Supported Languages
|
18 |
- [Multilingual](#multilingual)
|
19 |
-
- [F5-TTS Base @
|
20 |
- [English](#english)
|
21 |
- [Finnish](#finnish)
|
22 |
-
- [
|
23 |
- [French](#french)
|
24 |
-
- [
|
25 |
- [Hindi](#hindi)
|
26 |
-
- [F5-TTS Small @
|
27 |
- [Italian](#italian)
|
28 |
-
- [F5-TTS
|
29 |
- [Japanese](#japanese)
|
30 |
-
- [F5-TTS
|
31 |
- [Mandarin](#mandarin)
|
32 |
- [Spanish](#spanish)
|
33 |
-
- [F5-TTS
|
34 |
|
35 |
|
36 |
## Multilingual
|
37 |
|
38 |
-
#### F5-TTS Base @
|
39 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
40 |
|:---:|:------------:|:-----------:|:-------------:|
|
41 |
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
42 |
|
43 |
```bash
|
44 |
-
|
45 |
-
|
|
|
46 |
```
|
47 |
|
48 |
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
@@ -53,27 +54,29 @@ VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
|
53 |
|
54 |
## Finnish
|
55 |
|
56 |
-
####
|
57 |
|Model|🤗Hugging Face|Data|Model License|
|
58 |
|:---:|:------------:|:-----------:|:-------------:|
|
59 |
-
|F5-TTS
|
60 |
|
61 |
```bash
|
62 |
-
|
63 |
-
|
|
|
64 |
```
|
65 |
|
66 |
|
67 |
## French
|
68 |
|
69 |
-
####
|
70 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
71 |
|:---:|:------------:|:-----------:|:-------------:|
|
72 |
-
|F5-TTS
|
73 |
|
74 |
```bash
|
75 |
-
|
76 |
-
|
|
|
77 |
```
|
78 |
|
79 |
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
@@ -83,31 +86,32 @@ VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
|
83 |
|
84 |
## Hindi
|
85 |
|
86 |
-
#### F5-TTS Small @
|
87 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
88 |
|:---:|:------------:|:-----------:|:-------------:|
|
89 |
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
|
90 |
|
91 |
```bash
|
92 |
-
|
93 |
-
|
|
|
94 |
```
|
95 |
|
96 |
-
Authors: SPRING Lab, Indian Institute of Technology, Madras
|
97 |
-
|
98 |
-
Website: https://asr.iitm.ac.in/
|
99 |
|
100 |
|
101 |
## Italian
|
102 |
|
103 |
-
#### F5-TTS
|
104 |
|Model|🤗Hugging Face|Data|Model License|
|
105 |
|:---:|:------------:|:-----------:|:-------------:|
|
106 |
-
|F5-TTS
|
107 |
|
108 |
```bash
|
109 |
-
|
110 |
-
|
|
|
111 |
```
|
112 |
|
113 |
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
@@ -117,14 +121,15 @@ VOCAB_FILE: hf://alien79/F5-TTS-italian/vocab.txt
|
|
117 |
|
118 |
## Japanese
|
119 |
|
120 |
-
#### F5-TTS
|
121 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
122 |
|:---:|:------------:|:-----------:|:-------------:|
|
123 |
-
|F5-TTS
|
124 |
|
125 |
```bash
|
126 |
-
|
127 |
-
|
|
|
128 |
```
|
129 |
|
130 |
|
@@ -133,9 +138,9 @@ VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
|
|
133 |
|
134 |
## Spanish
|
135 |
|
136 |
-
#### F5-TTS
|
137 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
138 |
|:---:|:------------:|:-----------:|:-------------:|
|
139 |
-
|F5-TTS
|
140 |
|
141 |
- @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
|
|
|
16 |
<!-- omit in toc -->
|
17 |
### Supported Languages
|
18 |
- [Multilingual](#multilingual)
|
19 |
+
- [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
|
20 |
- [English](#english)
|
21 |
- [Finnish](#finnish)
|
22 |
+
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
23 |
- [French](#french)
|
24 |
+
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
|
25 |
- [Hindi](#hindi)
|
26 |
+
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
|
27 |
- [Italian](#italian)
|
28 |
+
- [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
|
29 |
- [Japanese](#japanese)
|
30 |
+
- [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
|
31 |
- [Mandarin](#mandarin)
|
32 |
- [Spanish](#spanish)
|
33 |
+
- [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)
|
34 |
|
35 |
|
36 |
## Multilingual
|
37 |
|
38 |
+
#### F5-TTS Base @ zh & en @ F5-TTS
|
39 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
40 |
|:---:|:------------:|:-----------:|:-------------:|
|
41 |
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
42 |
|
43 |
```bash
|
44 |
+
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
45 |
+
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
46 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
47 |
```
|
48 |
|
49 |
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
|
|
54 |
|
55 |
## Finnish
|
56 |
|
57 |
+
#### F5-TTS Base @ fi @ AsmoKoskinen
|
58 |
|Model|🤗Hugging Face|Data|Model License|
|
59 |
|:---:|:------------:|:-----------:|:-------------:|
|
60 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
|
61 |
|
62 |
```bash
|
63 |
+
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
64 |
+
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
65 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
66 |
```
|
67 |
|
68 |
|
69 |
## French
|
70 |
|
71 |
+
#### F5-TTS Base @ fr @ RASPIAUDIO
|
72 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
73 |
|:---:|:------------:|:-----------:|:-------------:|
|
74 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
|
75 |
|
76 |
```bash
|
77 |
+
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
78 |
+
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
79 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
80 |
```
|
81 |
|
82 |
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
|
|
86 |
|
87 |
## Hindi
|
88 |
|
89 |
+
#### F5-TTS Small @ hi @ SPRINGLab
|
90 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
91 |
|:---:|:------------:|:-----------:|:-------------:|
|
92 |
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
|
93 |
|
94 |
```bash
|
95 |
+
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
96 |
+
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
97 |
+
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
98 |
```
|
99 |
|
100 |
+
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
101 |
+
- Website: https://asr.iitm.ac.in/
|
|
|
102 |
|
103 |
|
104 |
## Italian
|
105 |
|
106 |
+
#### F5-TTS Base @ it @ alien79
|
107 |
|Model|🤗Hugging Face|Data|Model License|
|
108 |
|:---:|:------------:|:-----------:|:-------------:|
|
109 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
|
110 |
|
111 |
```bash
|
112 |
+
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
113 |
+
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
114 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
115 |
```
|
116 |
|
117 |
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
|
|
121 |
|
122 |
## Japanese
|
123 |
|
124 |
+
#### F5-TTS Base @ ja @ Jmica
|
125 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
126 |
|:---:|:------------:|:-----------:|:-------------:|
|
127 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|
128 |
|
129 |
```bash
|
130 |
+
Model: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
|
131 |
+
Vocab: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
|
132 |
+
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
133 |
```
|
134 |
|
135 |
|
|
|
138 |
|
139 |
## Spanish
|
140 |
|
141 |
+
#### F5-TTS Base @ es @ jpgallegoar
|
142 |
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
143 |
|:---:|:------------:|:-----------:|:-------------:|
|
144 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
|
145 |
|
146 |
- @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -10,6 +10,7 @@ import numpy as np
|
|
10 |
import soundfile as sf
|
11 |
import tomli
|
12 |
from cached_path import cached_path
|
|
|
13 |
|
14 |
from f5_tts.infer.utils_infer import (
|
15 |
mel_spec_type,
|
@@ -51,6 +52,12 @@ parser.add_argument(
|
|
51 |
type=str,
|
52 |
help="The model name: F5-TTS | E2-TTS",
|
53 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
parser.add_argument(
|
55 |
"-p",
|
56 |
"--ckpt_file",
|
@@ -166,6 +173,7 @@ config = tomli.load(open(args.config, "rb"))
|
|
166 |
# command-line interface parameters
|
167 |
|
168 |
model = args.model or config.get("model", "F5-TTS")
|
|
|
169 |
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
170 |
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
171 |
|
@@ -179,9 +187,9 @@ output_file = args.output_file or config.get(
|
|
179 |
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
|
180 |
)
|
181 |
|
182 |
-
save_chunk = args.save_chunk
|
183 |
-
remove_silence = args.remove_silence
|
184 |
-
load_vocoder_from_local = args.load_vocoder_from_local
|
185 |
|
186 |
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
|
187 |
target_rms = args.target_rms or config.get("target_rms", target_rms)
|
@@ -235,7 +243,7 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
|
|
235 |
|
236 |
if model == "F5-TTS":
|
237 |
model_cls = DiT
|
238 |
-
model_cfg =
|
239 |
if not ckpt_file: # path not specified, download from repo
|
240 |
if vocoder_name == "vocos":
|
241 |
repo_name = "F5-TTS"
|
@@ -250,7 +258,8 @@ if model == "F5-TTS":
|
|
250 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
251 |
|
252 |
elif model == "E2-TTS":
|
253 |
-
assert
|
|
|
254 |
model_cls = UNetT
|
255 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
256 |
if not ckpt_file: # path not specified, download from repo
|
|
|
10 |
import soundfile as sf
|
11 |
import tomli
|
12 |
from cached_path import cached_path
|
13 |
+
from omegaconf import OmegaConf
|
14 |
|
15 |
from f5_tts.infer.utils_infer import (
|
16 |
mel_spec_type,
|
|
|
52 |
type=str,
|
53 |
help="The model name: F5-TTS | E2-TTS",
|
54 |
)
|
55 |
+
parser.add_argument(
|
56 |
+
"-mc",
|
57 |
+
"--model_cfg",
|
58 |
+
type=str,
|
59 |
+
help="The path to F5-TTS model config file .yaml",
|
60 |
+
)
|
61 |
parser.add_argument(
|
62 |
"-p",
|
63 |
"--ckpt_file",
|
|
|
173 |
# command-line interface parameters
|
174 |
|
175 |
model = args.model or config.get("model", "F5-TTS")
|
176 |
+
model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
|
177 |
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
178 |
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
179 |
|
|
|
187 |
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
|
188 |
)
|
189 |
|
190 |
+
save_chunk = args.save_chunk or config.get("save_chunk", False)
|
191 |
+
remove_silence = args.remove_silence or config.get("remove_silence", False)
|
192 |
+
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
|
193 |
|
194 |
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
|
195 |
target_rms = args.target_rms or config.get("target_rms", target_rms)
|
|
|
243 |
|
244 |
if model == "F5-TTS":
|
245 |
model_cls = DiT
|
246 |
+
model_cfg = OmegaConf.load(model_cfg).model.arch
|
247 |
if not ckpt_file: # path not specified, download from repo
|
248 |
if vocoder_name == "vocos":
|
249 |
repo_name = "F5-TTS"
|
|
|
258 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
259 |
|
260 |
elif model == "E2-TTS":
|
261 |
+
assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
|
262 |
+
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
|
263 |
model_cls = UNetT
|
264 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
265 |
if not ckpt_file: # path not specified, download from repo
|
src/f5_tts/infer/infer_gradio.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# ruff: noqa: E402
|
2 |
# Above allows ruff to ignore E402: module level import not at top of file
|
3 |
|
|
|
4 |
import re
|
5 |
import tempfile
|
6 |
from collections import OrderedDict
|
@@ -43,6 +44,12 @@ from f5_tts.infer.utils_infer import (
|
|
43 |
DEFAULT_TTS_MODEL = "F5-TTS"
|
44 |
tts_model_choice = DEFAULT_TTS_MODEL
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
# load models
|
48 |
|
@@ -103,7 +110,15 @@ def generate_response(messages, model, tokenizer):
|
|
103 |
|
104 |
@gpu_decorator
|
105 |
def infer(
|
106 |
-
ref_audio_orig,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
):
|
108 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
109 |
|
@@ -120,7 +135,7 @@ def infer(
|
|
120 |
global custom_ema_model, pre_custom_path
|
121 |
if pre_custom_path != model[1]:
|
122 |
show_info("Loading Custom TTS model...")
|
123 |
-
custom_ema_model = load_custom(model[1], vocab_path=model[2])
|
124 |
pre_custom_path = model[1]
|
125 |
ema_model = custom_ema_model
|
126 |
|
@@ -131,6 +146,7 @@ def infer(
|
|
131 |
ema_model,
|
132 |
vocoder,
|
133 |
cross_fade_duration=cross_fade_duration,
|
|
|
134 |
speed=speed,
|
135 |
show_info=show_info,
|
136 |
progress=gr.Progress(),
|
@@ -184,6 +200,14 @@ with gr.Blocks() as app_tts:
|
|
184 |
step=0.1,
|
185 |
info="Adjust the speed of the audio.",
|
186 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
cross_fade_duration_slider = gr.Slider(
|
188 |
label="Cross-Fade Duration (s)",
|
189 |
minimum=0.0,
|
@@ -203,6 +227,7 @@ with gr.Blocks() as app_tts:
|
|
203 |
gen_text_input,
|
204 |
remove_silence,
|
205 |
cross_fade_duration_slider,
|
|
|
206 |
speed_slider,
|
207 |
):
|
208 |
audio_out, spectrogram_path, ref_text_out = infer(
|
@@ -211,8 +236,9 @@ with gr.Blocks() as app_tts:
|
|
211 |
gen_text_input,
|
212 |
tts_model_choice,
|
213 |
remove_silence,
|
214 |
-
cross_fade_duration_slider,
|
215 |
-
|
|
|
216 |
)
|
217 |
return audio_out, spectrogram_path, gr.update(value=ref_text_out)
|
218 |
|
@@ -224,6 +250,7 @@ with gr.Blocks() as app_tts:
|
|
224 |
gen_text_input,
|
225 |
remove_silence,
|
226 |
cross_fade_duration_slider,
|
|
|
227 |
speed_slider,
|
228 |
],
|
229 |
outputs=[audio_output, spectrogram_output, ref_text_input],
|
@@ -744,34 +771,38 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
744 |
"""
|
745 |
)
|
746 |
|
747 |
-
last_used_custom = files("f5_tts").joinpath("infer/.cache/
|
748 |
|
749 |
def load_last_used_custom():
|
750 |
try:
|
751 |
-
|
752 |
-
|
|
|
|
|
|
|
753 |
except FileNotFoundError:
|
754 |
last_used_custom.parent.mkdir(parents=True, exist_ok=True)
|
755 |
-
return
|
756 |
-
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
|
757 |
-
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
|
758 |
-
]
|
759 |
|
760 |
def switch_tts_model(new_choice):
|
761 |
global tts_model_choice
|
762 |
if new_choice == "Custom": # override in case webpage is refreshed
|
763 |
-
custom_ckpt_path, custom_vocab_path = load_last_used_custom()
|
764 |
-
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
|
765 |
-
return
|
|
|
|
|
|
|
|
|
766 |
else:
|
767 |
tts_model_choice = new_choice
|
768 |
-
return gr.update(visible=False), gr.update(visible=False)
|
769 |
|
770 |
-
def set_custom_model(custom_ckpt_path, custom_vocab_path):
|
771 |
global tts_model_choice
|
772 |
-
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
|
773 |
-
with open(last_used_custom, "w") as f:
|
774 |
-
f.write(
|
775 |
|
776 |
with gr.Row():
|
777 |
if not USING_SPACES:
|
@@ -783,34 +814,46 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
783 |
choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
|
784 |
)
|
785 |
custom_ckpt_path = gr.Dropdown(
|
786 |
-
choices=[
|
787 |
value=load_last_used_custom()[0],
|
788 |
allow_custom_value=True,
|
789 |
-
label="
|
790 |
visible=False,
|
791 |
)
|
792 |
custom_vocab_path = gr.Dropdown(
|
793 |
-
choices=[
|
794 |
value=load_last_used_custom()[1],
|
795 |
allow_custom_value=True,
|
796 |
-
label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
797 |
visible=False,
|
798 |
)
|
799 |
|
800 |
choose_tts_model.change(
|
801 |
switch_tts_model,
|
802 |
inputs=[choose_tts_model],
|
803 |
-
outputs=[custom_ckpt_path, custom_vocab_path],
|
804 |
show_progress="hidden",
|
805 |
)
|
806 |
custom_ckpt_path.change(
|
807 |
set_custom_model,
|
808 |
-
inputs=[custom_ckpt_path, custom_vocab_path],
|
809 |
show_progress="hidden",
|
810 |
)
|
811 |
custom_vocab_path.change(
|
812 |
set_custom_model,
|
813 |
-
inputs=[custom_ckpt_path, custom_vocab_path],
|
|
|
|
|
|
|
|
|
|
|
814 |
show_progress="hidden",
|
815 |
)
|
816 |
|
|
|
1 |
# ruff: noqa: E402
|
2 |
# Above allows ruff to ignore E402: module level import not at top of file
|
3 |
|
4 |
+
import json
|
5 |
import re
|
6 |
import tempfile
|
7 |
from collections import OrderedDict
|
|
|
44 |
DEFAULT_TTS_MODEL = "F5-TTS"
|
45 |
tts_model_choice = DEFAULT_TTS_MODEL
|
46 |
|
47 |
+
DEFAULT_TTS_MODEL_CFG = [
|
48 |
+
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
|
49 |
+
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
|
50 |
+
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
51 |
+
]
|
52 |
+
|
53 |
|
54 |
# load models
|
55 |
|
|
|
110 |
|
111 |
@gpu_decorator
|
112 |
def infer(
|
113 |
+
ref_audio_orig,
|
114 |
+
ref_text,
|
115 |
+
gen_text,
|
116 |
+
model,
|
117 |
+
remove_silence,
|
118 |
+
cross_fade_duration=0.15,
|
119 |
+
nfe_step=32,
|
120 |
+
speed=1,
|
121 |
+
show_info=gr.Info,
|
122 |
):
|
123 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
124 |
|
|
|
135 |
global custom_ema_model, pre_custom_path
|
136 |
if pre_custom_path != model[1]:
|
137 |
show_info("Loading Custom TTS model...")
|
138 |
+
custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
|
139 |
pre_custom_path = model[1]
|
140 |
ema_model = custom_ema_model
|
141 |
|
|
|
146 |
ema_model,
|
147 |
vocoder,
|
148 |
cross_fade_duration=cross_fade_duration,
|
149 |
+
nfe_step=nfe_step,
|
150 |
speed=speed,
|
151 |
show_info=show_info,
|
152 |
progress=gr.Progress(),
|
|
|
200 |
step=0.1,
|
201 |
info="Adjust the speed of the audio.",
|
202 |
)
|
203 |
+
nfe_slider = gr.Slider(
|
204 |
+
label="NFE Steps",
|
205 |
+
minimum=4,
|
206 |
+
maximum=64,
|
207 |
+
value=32,
|
208 |
+
step=2,
|
209 |
+
info="Set the number of denoising steps.",
|
210 |
+
)
|
211 |
cross_fade_duration_slider = gr.Slider(
|
212 |
label="Cross-Fade Duration (s)",
|
213 |
minimum=0.0,
|
|
|
227 |
gen_text_input,
|
228 |
remove_silence,
|
229 |
cross_fade_duration_slider,
|
230 |
+
nfe_slider,
|
231 |
speed_slider,
|
232 |
):
|
233 |
audio_out, spectrogram_path, ref_text_out = infer(
|
|
|
236 |
gen_text_input,
|
237 |
tts_model_choice,
|
238 |
remove_silence,
|
239 |
+
cross_fade_duration=cross_fade_duration_slider,
|
240 |
+
nfe_step=nfe_slider,
|
241 |
+
speed=speed_slider,
|
242 |
)
|
243 |
return audio_out, spectrogram_path, gr.update(value=ref_text_out)
|
244 |
|
|
|
250 |
gen_text_input,
|
251 |
remove_silence,
|
252 |
cross_fade_duration_slider,
|
253 |
+
nfe_slider,
|
254 |
speed_slider,
|
255 |
],
|
256 |
outputs=[audio_output, spectrogram_output, ref_text_input],
|
|
|
771 |
"""
|
772 |
)
|
773 |
|
774 |
+
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
|
775 |
|
776 |
def load_last_used_custom():
|
777 |
try:
|
778 |
+
custom = []
|
779 |
+
with open(last_used_custom, "r", encoding='utf-8') as f:
|
780 |
+
for line in f:
|
781 |
+
custom.append(line.strip())
|
782 |
+
return custom
|
783 |
except FileNotFoundError:
|
784 |
last_used_custom.parent.mkdir(parents=True, exist_ok=True)
|
785 |
+
return DEFAULT_TTS_MODEL_CFG
|
|
|
|
|
|
|
786 |
|
787 |
def switch_tts_model(new_choice):
|
788 |
global tts_model_choice
|
789 |
if new_choice == "Custom": # override in case webpage is refreshed
|
790 |
+
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
|
791 |
+
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
792 |
+
return (
|
793 |
+
gr.update(visible=True, value=custom_ckpt_path),
|
794 |
+
gr.update(visible=True, value=custom_vocab_path),
|
795 |
+
gr.update(visible=True, value=custom_model_cfg),
|
796 |
+
)
|
797 |
else:
|
798 |
tts_model_choice = new_choice
|
799 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
800 |
|
801 |
+
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
|
802 |
global tts_model_choice
|
803 |
+
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
804 |
+
with open(last_used_custom, "w", encoding='utf-8') as f:
|
805 |
+
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
|
806 |
|
807 |
with gr.Row():
|
808 |
if not USING_SPACES:
|
|
|
814 |
choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
|
815 |
)
|
816 |
custom_ckpt_path = gr.Dropdown(
|
817 |
+
choices=[DEFAULT_TTS_MODEL_CFG[0]],
|
818 |
value=load_last_used_custom()[0],
|
819 |
allow_custom_value=True,
|
820 |
+
label="Model: local_path | hf://user_id/repo_id/model_ckpt",
|
821 |
visible=False,
|
822 |
)
|
823 |
custom_vocab_path = gr.Dropdown(
|
824 |
+
choices=[DEFAULT_TTS_MODEL_CFG[1]],
|
825 |
value=load_last_used_custom()[1],
|
826 |
allow_custom_value=True,
|
827 |
+
label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
|
828 |
+
visible=False,
|
829 |
+
)
|
830 |
+
custom_model_cfg = gr.Dropdown(
|
831 |
+
choices=[DEFAULT_TTS_MODEL_CFG[2]],
|
832 |
+
value=load_last_used_custom()[2],
|
833 |
+
allow_custom_value=True,
|
834 |
+
label="Config: in a dictionary form",
|
835 |
visible=False,
|
836 |
)
|
837 |
|
838 |
choose_tts_model.change(
|
839 |
switch_tts_model,
|
840 |
inputs=[choose_tts_model],
|
841 |
+
outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
|
842 |
show_progress="hidden",
|
843 |
)
|
844 |
custom_ckpt_path.change(
|
845 |
set_custom_model,
|
846 |
+
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
|
847 |
show_progress="hidden",
|
848 |
)
|
849 |
custom_vocab_path.change(
|
850 |
set_custom_model,
|
851 |
+
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
|
852 |
+
show_progress="hidden",
|
853 |
+
)
|
854 |
+
custom_model_cfg.change(
|
855 |
+
set_custom_model,
|
856 |
+
inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
|
857 |
show_progress="hidden",
|
858 |
)
|
859 |
|
src/f5_tts/model/backbones/dit.py
CHANGED
@@ -131,8 +131,7 @@ class DiT(nn.Module):
|
|
131 |
self.checkpoint_activations = checkpoint_activations
|
132 |
|
133 |
def ckpt_wrapper(self, module):
|
134 |
-
|
135 |
-
|
136 |
def ckpt_forward(*inputs):
|
137 |
outputs = module(*inputs)
|
138 |
return outputs
|
|
|
131 |
self.checkpoint_activations = checkpoint_activations
|
132 |
|
133 |
def ckpt_wrapper(self, module):
|
134 |
+
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
|
|
135 |
def ckpt_forward(*inputs):
|
136 |
outputs = module(*inputs)
|
137 |
return outputs
|