refactor: more details about bigvgan, clear function definition
Browse files- README.md +3 -5
- src/f5_tts/api.py +8 -5
- src/f5_tts/eval/eval_infer_batch.py +13 -8
- src/f5_tts/eval/utils_eval.py +2 -2
- src/f5_tts/infer/README.md +4 -0
- src/f5_tts/infer/infer_cli.py +5 -7
- src/f5_tts/infer/speech_edit.py +13 -8
- src/f5_tts/infer/utils_infer.py +20 -11
- src/f5_tts/model/cfm.py +0 -3
- src/f5_tts/model/dataset.py +5 -5
- src/f5_tts/model/modules.py +4 -6
- src/f5_tts/model/trainer.py +2 -2
- src/f5_tts/train/train.py +3 -3
README.md
CHANGED
@@ -46,11 +46,13 @@ cd F5-TTS
|
|
46 |
pip install -e .
|
47 |
|
48 |
# Init submodule(optional, if you want to change the vocoder from vocos to bigvgan)
|
49 |
-
git submodule update --init --recursive
|
|
|
50 |
```
|
51 |
|
52 |
After init submodule, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
|
53 |
```python
|
|
|
54 |
import sys
|
55 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
56 |
```
|
@@ -104,10 +106,6 @@ f5-tts_infer-cli -c custom.toml
|
|
104 |
|
105 |
# Multi voice. See src/f5_tts/infer/README.md
|
106 |
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
|
107 |
-
|
108 |
-
# Choose Vocoder
|
109 |
-
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/model_1250000.pt >
|
110 |
-
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors >
|
111 |
```
|
112 |
|
113 |
### 3. More instructions
|
|
|
46 |
pip install -e .
|
47 |
|
48 |
# Init submodule(optional, if you want to change the vocoder from vocos to bigvgan)
|
49 |
+
# git submodule update --init --recursive
|
50 |
+
# pip install -e .
|
51 |
```
|
52 |
|
53 |
After init submodule, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
|
54 |
```python
|
55 |
+
import os
|
56 |
import sys
|
57 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
58 |
```
|
|
|
106 |
|
107 |
# Multi voice. See src/f5_tts/infer/README.md
|
108 |
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
|
|
|
|
|
|
|
|
|
109 |
```
|
110 |
|
111 |
### 3. More instructions
|
src/f5_tts/api.py
CHANGED
@@ -38,7 +38,7 @@ class F5TTS:
|
|
38 |
self.target_sample_rate = target_sample_rate
|
39 |
self.hop_length = hop_length
|
40 |
self.seed = -1
|
41 |
-
self.
|
42 |
|
43 |
# Set device
|
44 |
self.device = device or (
|
@@ -52,10 +52,13 @@ class F5TTS:
|
|
52 |
def load_vocoder_model(self, vocoder_name, local_path):
|
53 |
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
|
54 |
|
55 |
-
def load_ema_model(self, model_type, ckpt_file,
|
56 |
if model_type == "F5-TTS":
|
57 |
if not ckpt_file:
|
58 |
-
|
|
|
|
|
|
|
59 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
60 |
model_cls = DiT
|
61 |
elif model_type == "E2-TTS":
|
@@ -67,7 +70,7 @@ class F5TTS:
|
|
67 |
raise ValueError(f"Unknown model type: {model_type}")
|
68 |
|
69 |
self.ema_model = load_model(
|
70 |
-
model_cls, model_cfg, ckpt_file,
|
71 |
)
|
72 |
|
73 |
def export_wav(self, wav, file_wave, remove_silence=False):
|
@@ -111,7 +114,7 @@ class F5TTS:
|
|
111 |
gen_text,
|
112 |
self.ema_model,
|
113 |
self.vocoder,
|
114 |
-
self.
|
115 |
show_info=show_info,
|
116 |
progress=progress,
|
117 |
target_rms=target_rms,
|
|
|
38 |
self.target_sample_rate = target_sample_rate
|
39 |
self.hop_length = hop_length
|
40 |
self.seed = -1
|
41 |
+
self.mel_spec_type = vocoder_name
|
42 |
|
43 |
# Set device
|
44 |
self.device = device or (
|
|
|
52 |
def load_vocoder_model(self, vocoder_name, local_path):
|
53 |
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
|
54 |
|
55 |
+
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
|
56 |
if model_type == "F5-TTS":
|
57 |
if not ckpt_file:
|
58 |
+
if mel_spec_type == "vocos":
|
59 |
+
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
|
60 |
+
elif mel_spec_type == "bigvgan":
|
61 |
+
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
|
62 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
63 |
model_cls = DiT
|
64 |
elif model_type == "E2-TTS":
|
|
|
70 |
raise ValueError(f"Unknown model type: {model_type}")
|
71 |
|
72 |
self.ema_model = load_model(
|
73 |
+
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
|
74 |
)
|
75 |
|
76 |
def export_wav(self, wav, file_wave, remove_silence=False):
|
|
|
114 |
gen_text,
|
115 |
self.ema_model,
|
116 |
self.vocoder,
|
117 |
+
self.mel_spec_type,
|
118 |
show_info=show_info,
|
119 |
progress=progress,
|
120 |
target_rms=target_rms,
|
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -32,7 +32,7 @@ n_mel_channels = 100
|
|
32 |
hop_length = 256
|
33 |
win_length = 1024
|
34 |
n_fft = 1024
|
35 |
-
|
36 |
target_rms = 0.1
|
37 |
|
38 |
|
@@ -126,11 +126,11 @@ def main():
|
|
126 |
|
127 |
# Vocoder model
|
128 |
local = False
|
129 |
-
if
|
130 |
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
131 |
-
elif
|
132 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
133 |
-
vocoder = load_vocoder(vocoder_name=
|
134 |
|
135 |
# Tokenizer
|
136 |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
@@ -144,7 +144,7 @@ def main():
|
|
144 |
win_length=win_length,
|
145 |
n_mel_channels=n_mel_channels,
|
146 |
target_sample_rate=target_sample_rate,
|
147 |
-
|
148 |
),
|
149 |
odeint_kwargs=dict(
|
150 |
method=ode_method,
|
@@ -152,7 +152,12 @@ def main():
|
|
152 |
vocab_char_map=vocab_char_map,
|
153 |
).to(device)
|
154 |
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
156 |
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
157 |
|
158 |
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
@@ -186,9 +191,9 @@ def main():
|
|
186 |
for i, gen in enumerate(generated):
|
187 |
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
188 |
gen_mel_spec = gen.permute(0, 2, 1)
|
189 |
-
if
|
190 |
generated_wave = vocoder.decode(gen_mel_spec)
|
191 |
-
elif
|
192 |
generated_wave = vocoder(gen_mel_spec)
|
193 |
|
194 |
if ref_rms_list[i] < target_rms:
|
|
|
32 |
hop_length = 256
|
33 |
win_length = 1024
|
34 |
n_fft = 1024
|
35 |
+
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
36 |
target_rms = 0.1
|
37 |
|
38 |
|
|
|
126 |
|
127 |
# Vocoder model
|
128 |
local = False
|
129 |
+
if mel_spec_type == "vocos":
|
130 |
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
131 |
+
elif mel_spec_type == "bigvgan":
|
132 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
133 |
+
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
|
134 |
|
135 |
# Tokenizer
|
136 |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
|
|
144 |
win_length=win_length,
|
145 |
n_mel_channels=n_mel_channels,
|
146 |
target_sample_rate=target_sample_rate,
|
147 |
+
mel_spec_type=mel_spec_type,
|
148 |
),
|
149 |
odeint_kwargs=dict(
|
150 |
method=ode_method,
|
|
|
152 |
vocab_char_map=vocab_char_map,
|
153 |
).to(device)
|
154 |
|
155 |
+
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
156 |
+
if supports_fp16 and mel_spec_type == "vocos":
|
157 |
+
dtype = torch.float16
|
158 |
+
else:
|
159 |
+
dtype = torch.float32
|
160 |
+
|
161 |
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
162 |
|
163 |
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
|
|
191 |
for i, gen in enumerate(generated):
|
192 |
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
193 |
gen_mel_spec = gen.permute(0, 2, 1)
|
194 |
+
if mel_spec_type == "vocos":
|
195 |
generated_wave = vocoder.decode(gen_mel_spec)
|
196 |
+
elif mel_spec_type == "bigvgan":
|
197 |
generated_wave = vocoder(gen_mel_spec)
|
198 |
|
199 |
if ref_rms_list[i] < target_rms:
|
src/f5_tts/eval/utils_eval.py
CHANGED
@@ -78,7 +78,7 @@ def get_inference_prompt(
|
|
78 |
win_length=1024,
|
79 |
n_mel_channels=100,
|
80 |
hop_length=256,
|
81 |
-
|
82 |
target_rms=0.1,
|
83 |
use_truth_duration=False,
|
84 |
infer_batch_size=1,
|
@@ -102,7 +102,7 @@ def get_inference_prompt(
|
|
102 |
win_length=win_length,
|
103 |
n_mel_channels=n_mel_channels,
|
104 |
target_sample_rate=target_sample_rate,
|
105 |
-
|
106 |
)
|
107 |
|
108 |
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
|
|
78 |
win_length=1024,
|
79 |
n_mel_channels=100,
|
80 |
hop_length=256,
|
81 |
+
mel_spec_type="bigvgan",
|
82 |
target_rms=0.1,
|
83 |
use_truth_duration=False,
|
84 |
infer_batch_size=1,
|
|
|
102 |
win_length=win_length,
|
103 |
n_mel_channels=n_mel_channels,
|
104 |
target_sample_rate=target_sample_rate,
|
105 |
+
mel_spec_type=mel_spec_type,
|
106 |
)
|
107 |
|
108 |
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
src/f5_tts/infer/README.md
CHANGED
@@ -56,6 +56,10 @@ f5-tts_infer-cli \
|
|
56 |
--ref_audio "ref_audio.wav" \
|
57 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
58 |
--gen_text "Some text you want TTS model generate for you."
|
|
|
|
|
|
|
|
|
59 |
```
|
60 |
|
61 |
And a `.toml` file would help with more flexible usage.
|
|
|
56 |
--ref_audio "ref_audio.wav" \
|
57 |
--ref_text "The content, subtitle or transcription of reference audio." \
|
58 |
--gen_text "Some text you want TTS model generate for you."
|
59 |
+
|
60 |
+
# Choose Vocoder
|
61 |
+
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/model_1250000.pt >
|
62 |
+
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors >
|
63 |
```
|
64 |
|
65 |
And a `.toml` file would help with more flexible usage.
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -115,11 +115,9 @@ if args.vocoder_name == "vocos":
|
|
115 |
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
116 |
elif args.vocoder_name == "bigvgan":
|
117 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
118 |
-
|
119 |
|
120 |
-
vocoder = load_vocoder(
|
121 |
-
vocoder_name=extract_backend, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path
|
122 |
-
)
|
123 |
|
124 |
|
125 |
# load models
|
@@ -159,7 +157,7 @@ print(f"Using {model}...")
|
|
159 |
ema_model = load_model(model_cls, model_cfg, ckpt_file, args.vocoder_name, vocab_file)
|
160 |
|
161 |
|
162 |
-
def main_process(ref_audio, ref_text, text_gen, model_obj,
|
163 |
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
164 |
if "voices" not in config:
|
165 |
voices = {"main": main_voice}
|
@@ -194,7 +192,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remo
|
|
194 |
ref_text = voices[voice]["ref_text"]
|
195 |
print(f"Voice: {voice}")
|
196 |
audio, final_sample_rate, spectragram = infer_process(
|
197 |
-
ref_audio, ref_text, gen_text, model_obj, vocoder,
|
198 |
)
|
199 |
generated_audio_segments.append(audio)
|
200 |
|
@@ -213,7 +211,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remo
|
|
213 |
|
214 |
|
215 |
def main():
|
216 |
-
main_process(ref_audio, ref_text, gen_text, ema_model,
|
217 |
|
218 |
|
219 |
if __name__ == "__main__":
|
|
|
115 |
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
116 |
elif args.vocoder_name == "bigvgan":
|
117 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
118 |
+
mel_spec_type = args.vocoder_name
|
119 |
|
120 |
+
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
|
|
|
|
|
121 |
|
122 |
|
123 |
# load models
|
|
|
157 |
ema_model = load_model(model_cls, model_cfg, ckpt_file, args.vocoder_name, vocab_file)
|
158 |
|
159 |
|
160 |
+
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
|
161 |
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
162 |
if "voices" not in config:
|
163 |
voices = {"main": main_voice}
|
|
|
192 |
ref_text = voices[voice]["ref_text"]
|
193 |
print(f"Voice: {voice}")
|
194 |
audio, final_sample_rate, spectragram = infer_process(
|
195 |
+
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type, speed=speed
|
196 |
)
|
197 |
generated_audio_segments.append(audio)
|
198 |
|
|
|
211 |
|
212 |
|
213 |
def main():
|
214 |
+
main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
|
215 |
|
216 |
|
217 |
if __name__ == "__main__":
|
src/f5_tts/infer/speech_edit.py
CHANGED
@@ -18,7 +18,7 @@ n_mel_channels = 100
|
|
18 |
hop_length = 256
|
19 |
win_length = 1024
|
20 |
n_fft = 1024
|
21 |
-
|
22 |
target_rms = 0.1
|
23 |
|
24 |
tokenizer = "pinyin"
|
@@ -85,11 +85,11 @@ if not os.path.exists(output_dir):
|
|
85 |
|
86 |
# Vocoder model
|
87 |
local = False
|
88 |
-
if
|
89 |
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
90 |
-
elif
|
91 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
92 |
-
vocoder = load_vocoder(vocoder_name=
|
93 |
|
94 |
# Tokenizer
|
95 |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
@@ -103,7 +103,7 @@ model = CFM(
|
|
103 |
win_length=win_length,
|
104 |
n_mel_channels=n_mel_channels,
|
105 |
target_sample_rate=target_sample_rate,
|
106 |
-
|
107 |
),
|
108 |
odeint_kwargs=dict(
|
109 |
method=ode_method,
|
@@ -111,7 +111,12 @@ model = CFM(
|
|
111 |
vocab_char_map=vocab_char_map,
|
112 |
).to(device)
|
113 |
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
115 |
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
116 |
|
117 |
# Audio
|
@@ -178,9 +183,9 @@ with torch.inference_mode():
|
|
178 |
generated = generated.to(torch.float32)
|
179 |
generated = generated[:, ref_audio_len:, :]
|
180 |
gen_mel_spec = generated.permute(0, 2, 1)
|
181 |
-
if
|
182 |
generated_wave = vocoder.decode(gen_mel_spec)
|
183 |
-
elif
|
184 |
generated_wave = vocoder(gen_mel_spec)
|
185 |
|
186 |
if rms < target_rms:
|
|
|
18 |
hop_length = 256
|
19 |
win_length = 1024
|
20 |
n_fft = 1024
|
21 |
+
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
22 |
target_rms = 0.1
|
23 |
|
24 |
tokenizer = "pinyin"
|
|
|
85 |
|
86 |
# Vocoder model
|
87 |
local = False
|
88 |
+
if mel_spec_type == "vocos":
|
89 |
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
90 |
+
elif mel_spec_type == "bigvgan":
|
91 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
92 |
+
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
|
93 |
|
94 |
# Tokenizer
|
95 |
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
|
|
103 |
win_length=win_length,
|
104 |
n_mel_channels=n_mel_channels,
|
105 |
target_sample_rate=target_sample_rate,
|
106 |
+
mel_spec_type=mel_spec_type,
|
107 |
),
|
108 |
odeint_kwargs=dict(
|
109 |
method=ode_method,
|
|
|
111 |
vocab_char_map=vocab_char_map,
|
112 |
).to(device)
|
113 |
|
114 |
+
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
115 |
+
if supports_fp16 and mel_spec_type == "vocos":
|
116 |
+
dtype = torch.float16
|
117 |
+
else:
|
118 |
+
dtype = torch.float32
|
119 |
+
|
120 |
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
121 |
|
122 |
# Audio
|
|
|
183 |
generated = generated.to(torch.float32)
|
184 |
generated = generated[:, ref_audio_len:, :]
|
185 |
gen_mel_spec = generated.permute(0, 2, 1)
|
186 |
+
if mel_spec_type == "vocos":
|
187 |
generated_wave = vocoder.decode(gen_mel_spec)
|
188 |
+
elif mel_spec_type == "bigvgan":
|
189 |
generated_wave = vocoder(gen_mel_spec)
|
190 |
|
191 |
if rms < target_rms:
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import sys
|
5 |
|
6 |
sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
|
7 |
-
|
8 |
import hashlib
|
9 |
import re
|
10 |
import tempfile
|
@@ -40,7 +40,7 @@ n_mel_channels = 100
|
|
40 |
hop_length = 256
|
41 |
win_length = 1024
|
42 |
n_fft = 1024
|
43 |
-
|
44 |
target_rms = 0.1
|
45 |
cross_fade_duration = 0.15
|
46 |
ode_method = "euler"
|
@@ -97,8 +97,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
97 |
vocoder = vocoder.eval().to(device)
|
98 |
else:
|
99 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
100 |
-
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
101 |
elif vocoder_name == "bigvgan":
|
|
|
|
|
|
|
|
|
102 |
if is_local:
|
103 |
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
104 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
@@ -165,7 +169,7 @@ def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
|
|
165 |
|
166 |
|
167 |
def load_model(
|
168 |
-
model_cls, model_cfg, ckpt_path,
|
169 |
):
|
170 |
if vocab_file == "":
|
171 |
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
@@ -184,7 +188,7 @@ def load_model(
|
|
184 |
win_length=win_length,
|
185 |
n_mel_channels=n_mel_channels,
|
186 |
target_sample_rate=target_sample_rate,
|
187 |
-
|
188 |
),
|
189 |
odeint_kwargs=dict(
|
190 |
method=ode_method,
|
@@ -192,7 +196,12 @@ def load_model(
|
|
192 |
vocab_char_map=vocab_char_map,
|
193 |
).to(device)
|
194 |
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
196 |
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
197 |
|
198 |
return model
|
@@ -288,7 +297,7 @@ def infer_process(
|
|
288 |
gen_text,
|
289 |
model_obj,
|
290 |
vocoder,
|
291 |
-
|
292 |
show_info=print,
|
293 |
progress=tqdm,
|
294 |
target_rms=target_rms,
|
@@ -314,7 +323,7 @@ def infer_process(
|
|
314 |
gen_text_batches,
|
315 |
model_obj,
|
316 |
vocoder,
|
317 |
-
|
318 |
progress=progress,
|
319 |
target_rms=target_rms,
|
320 |
cross_fade_duration=cross_fade_duration,
|
@@ -336,7 +345,7 @@ def infer_batch_process(
|
|
336 |
gen_text_batches,
|
337 |
model_obj,
|
338 |
vocoder,
|
339 |
-
|
340 |
progress=tqdm,
|
341 |
target_rms=0.1,
|
342 |
cross_fade_duration=0.15,
|
@@ -392,9 +401,9 @@ def infer_batch_process(
|
|
392 |
generated = generated.to(torch.float32)
|
393 |
generated = generated[:, ref_audio_len:, :]
|
394 |
generated_mel_spec = generated.permute(0, 2, 1)
|
395 |
-
if
|
396 |
generated_wave = vocoder.decode(generated_mel_spec)
|
397 |
-
elif
|
398 |
generated_wave = vocoder(generated_mel_spec)
|
399 |
if rms < target_rms:
|
400 |
generated_wave = generated_wave * rms / target_rms
|
|
|
4 |
import sys
|
5 |
|
6 |
sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
|
7 |
+
|
8 |
import hashlib
|
9 |
import re
|
10 |
import tempfile
|
|
|
40 |
hop_length = 256
|
41 |
win_length = 1024
|
42 |
n_fft = 1024
|
43 |
+
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
44 |
target_rms = 0.1
|
45 |
cross_fade_duration = 0.15
|
46 |
ode_method = "euler"
|
|
|
97 |
vocoder = vocoder.eval().to(device)
|
98 |
else:
|
99 |
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
100 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
101 |
elif vocoder_name == "bigvgan":
|
102 |
+
try:
|
103 |
+
from third_party.BigVGAN import bigvgan
|
104 |
+
except ImportError:
|
105 |
+
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
106 |
if is_local:
|
107 |
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
108 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
|
|
169 |
|
170 |
|
171 |
def load_model(
|
172 |
+
model_cls, model_cfg, ckpt_path, mel_spec_type, vocab_file="", ode_method=ode_method, use_ema=True, device=device
|
173 |
):
|
174 |
if vocab_file == "":
|
175 |
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
|
|
188 |
win_length=win_length,
|
189 |
n_mel_channels=n_mel_channels,
|
190 |
target_sample_rate=target_sample_rate,
|
191 |
+
mel_spec_type=mel_spec_type,
|
192 |
),
|
193 |
odeint_kwargs=dict(
|
194 |
method=ode_method,
|
|
|
196 |
vocab_char_map=vocab_char_map,
|
197 |
).to(device)
|
198 |
|
199 |
+
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
200 |
+
if supports_fp16 and mel_spec_type == "vocos":
|
201 |
+
dtype = torch.float16
|
202 |
+
else:
|
203 |
+
dtype = torch.float32
|
204 |
+
|
205 |
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
206 |
|
207 |
return model
|
|
|
297 |
gen_text,
|
298 |
model_obj,
|
299 |
vocoder,
|
300 |
+
mel_spec_type,
|
301 |
show_info=print,
|
302 |
progress=tqdm,
|
303 |
target_rms=target_rms,
|
|
|
323 |
gen_text_batches,
|
324 |
model_obj,
|
325 |
vocoder,
|
326 |
+
mel_spec_type,
|
327 |
progress=progress,
|
328 |
target_rms=target_rms,
|
329 |
cross_fade_duration=cross_fade_duration,
|
|
|
345 |
gen_text_batches,
|
346 |
model_obj,
|
347 |
vocoder,
|
348 |
+
mel_spec_type,
|
349 |
progress=tqdm,
|
350 |
target_rms=0.1,
|
351 |
cross_fade_duration=0.15,
|
|
|
401 |
generated = generated.to(torch.float32)
|
402 |
generated = generated[:, ref_audio_len:, :]
|
403 |
generated_mel_spec = generated.permute(0, 2, 1)
|
404 |
+
if mel_spec_type == "vocos":
|
405 |
generated_wave = vocoder.decode(generated_mel_spec)
|
406 |
+
elif mel_spec_type == "bigvgan":
|
407 |
generated_wave = vocoder(generated_mel_spec)
|
408 |
if rms < target_rms:
|
409 |
generated_wave = generated_wave * rms / target_rms
|
src/f5_tts/model/cfm.py
CHANGED
@@ -105,9 +105,6 @@ class CFM(nn.Module):
|
|
105 |
cond = cond.permute(0, 2, 1)
|
106 |
assert cond.shape[-1] == self.num_channels
|
107 |
|
108 |
-
assert next(self.parameters()).dtype == torch.float32 or next(self.parameters()).dtype == torch.float16, print(
|
109 |
-
"Only support fp16 and fp32 inference currently"
|
110 |
-
)
|
111 |
cond = cond.to(next(self.parameters()).dtype)
|
112 |
|
113 |
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
|
|
105 |
cond = cond.permute(0, 2, 1)
|
106 |
assert cond.shape[-1] == self.num_channels
|
107 |
|
|
|
|
|
|
|
108 |
cond = cond.to(next(self.parameters()).dtype)
|
109 |
|
110 |
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
src/f5_tts/model/dataset.py
CHANGED
@@ -24,7 +24,7 @@ class HFDataset(Dataset):
|
|
24 |
hop_length=256,
|
25 |
n_fft=1024,
|
26 |
win_length=1024,
|
27 |
-
|
28 |
):
|
29 |
self.data = hf_dataset
|
30 |
self.target_sample_rate = target_sample_rate
|
@@ -36,7 +36,7 @@ class HFDataset(Dataset):
|
|
36 |
win_length=win_length,
|
37 |
n_mel_channels=n_mel_channels,
|
38 |
target_sample_rate=target_sample_rate,
|
39 |
-
|
40 |
)
|
41 |
|
42 |
def get_frame_len(self, index):
|
@@ -90,7 +90,7 @@ class CustomDataset(Dataset):
|
|
90 |
n_mel_channels=100,
|
91 |
n_fft=1024,
|
92 |
win_length=1024,
|
93 |
-
|
94 |
preprocessed_mel=False,
|
95 |
mel_spec_module: nn.Module | None = None,
|
96 |
):
|
@@ -100,7 +100,7 @@ class CustomDataset(Dataset):
|
|
100 |
self.hop_length = hop_length
|
101 |
self.n_fft = n_fft
|
102 |
self.win_length = win_length
|
103 |
-
self.
|
104 |
self.preprocessed_mel = preprocessed_mel
|
105 |
|
106 |
if not preprocessed_mel:
|
@@ -112,7 +112,7 @@ class CustomDataset(Dataset):
|
|
112 |
win_length=win_length,
|
113 |
n_mel_channels=n_mel_channels,
|
114 |
target_sample_rate=target_sample_rate,
|
115 |
-
|
116 |
),
|
117 |
)
|
118 |
|
|
|
24 |
hop_length=256,
|
25 |
n_fft=1024,
|
26 |
win_length=1024,
|
27 |
+
mel_spec_type="vocos",
|
28 |
):
|
29 |
self.data = hf_dataset
|
30 |
self.target_sample_rate = target_sample_rate
|
|
|
36 |
win_length=win_length,
|
37 |
n_mel_channels=n_mel_channels,
|
38 |
target_sample_rate=target_sample_rate,
|
39 |
+
mel_spec_type=mel_spec_type,
|
40 |
)
|
41 |
|
42 |
def get_frame_len(self, index):
|
|
|
90 |
n_mel_channels=100,
|
91 |
n_fft=1024,
|
92 |
win_length=1024,
|
93 |
+
mel_spec_type="vocos",
|
94 |
preprocessed_mel=False,
|
95 |
mel_spec_module: nn.Module | None = None,
|
96 |
):
|
|
|
100 |
self.hop_length = hop_length
|
101 |
self.n_fft = n_fft
|
102 |
self.win_length = win_length
|
103 |
+
self.mel_spec_type = mel_spec_type
|
104 |
self.preprocessed_mel = preprocessed_mel
|
105 |
|
106 |
if not preprocessed_mel:
|
|
|
112 |
win_length=win_length,
|
113 |
n_mel_channels=n_mel_channels,
|
114 |
target_sample_rate=target_sample_rate,
|
115 |
+
mel_spec_type=mel_spec_type,
|
116 |
),
|
117 |
)
|
118 |
|
src/f5_tts/model/modules.py
CHANGED
@@ -142,12 +142,10 @@ class MelSpec(nn.Module):
|
|
142 |
win_length=1024,
|
143 |
n_mel_channels=100,
|
144 |
target_sample_rate=24_000,
|
145 |
-
|
146 |
):
|
147 |
super().__init__()
|
148 |
-
assert
|
149 |
-
"We only support two extract mel backend: vocos or bigvgan"
|
150 |
-
)
|
151 |
|
152 |
self.n_fft = n_fft
|
153 |
self.hop_length = hop_length
|
@@ -155,9 +153,9 @@ class MelSpec(nn.Module):
|
|
155 |
self.n_mel_channels = n_mel_channels
|
156 |
self.target_sample_rate = target_sample_rate
|
157 |
|
158 |
-
if
|
159 |
self.extractor = get_vocos_mel_spectrogram
|
160 |
-
elif
|
161 |
self.extractor = get_bigvgan_mel_spectrogram
|
162 |
|
163 |
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
|
|
142 |
win_length=1024,
|
143 |
n_mel_channels=100,
|
144 |
target_sample_rate=24_000,
|
145 |
+
mel_spec_type="vocos",
|
146 |
):
|
147 |
super().__init__()
|
148 |
+
assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
|
|
|
|
|
149 |
|
150 |
self.n_fft = n_fft
|
151 |
self.hop_length = hop_length
|
|
|
153 |
self.n_mel_channels = n_mel_channels
|
154 |
self.target_sample_rate = target_sample_rate
|
155 |
|
156 |
+
if mel_spec_type == "vocos":
|
157 |
self.extractor = get_vocos_mel_spectrogram
|
158 |
+
elif mel_spec_type == "bigvgan":
|
159 |
self.extractor = get_bigvgan_mel_spectrogram
|
160 |
|
161 |
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
src/f5_tts/model/trainer.py
CHANGED
@@ -46,7 +46,7 @@ class Trainer:
|
|
46 |
accelerate_kwargs: dict = dict(),
|
47 |
ema_kwargs: dict = dict(),
|
48 |
bnb_optimizer: bool = False,
|
49 |
-
|
50 |
):
|
51 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
52 |
|
@@ -108,7 +108,7 @@ class Trainer:
|
|
108 |
self.max_samples = max_samples
|
109 |
self.grad_accumulation_steps = grad_accumulation_steps
|
110 |
self.max_grad_norm = max_grad_norm
|
111 |
-
self.vocoder_name =
|
112 |
|
113 |
self.noise_scheduler = noise_scheduler
|
114 |
|
|
|
46 |
accelerate_kwargs: dict = dict(),
|
47 |
ema_kwargs: dict = dict(),
|
48 |
bnb_optimizer: bool = False,
|
49 |
+
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
50 |
):
|
51 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
52 |
|
|
|
108 |
self.max_samples = max_samples
|
109 |
self.grad_accumulation_steps = grad_accumulation_steps
|
110 |
self.max_grad_norm = max_grad_norm
|
111 |
+
self.vocoder_name = mel_spec_type
|
112 |
|
113 |
self.noise_scheduler = noise_scheduler
|
114 |
|
src/f5_tts/train/train.py
CHANGED
@@ -13,7 +13,7 @@ n_mel_channels = 100
|
|
13 |
hop_length = 256
|
14 |
win_length = 1024
|
15 |
n_fft = 1024
|
16 |
-
|
17 |
|
18 |
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
19 |
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
@@ -63,7 +63,7 @@ def main():
|
|
63 |
win_length=win_length,
|
64 |
n_mel_channels=n_mel_channels,
|
65 |
target_sample_rate=target_sample_rate,
|
66 |
-
|
67 |
)
|
68 |
|
69 |
model = CFM(
|
@@ -89,7 +89,7 @@ def main():
|
|
89 |
wandb_resume_id=wandb_resume_id,
|
90 |
last_per_steps=last_per_steps,
|
91 |
log_samples=True,
|
92 |
-
|
93 |
)
|
94 |
|
95 |
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
|
|
13 |
hop_length = 256
|
14 |
win_length = 1024
|
15 |
n_fft = 1024
|
16 |
+
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
17 |
|
18 |
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
19 |
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
|
63 |
win_length=win_length,
|
64 |
n_mel_channels=n_mel_channels,
|
65 |
target_sample_rate=target_sample_rate,
|
66 |
+
mel_spec_type=mel_spec_type,
|
67 |
)
|
68 |
|
69 |
model = CFM(
|
|
|
89 |
wandb_resume_id=wandb_resume_id,
|
90 |
last_per_steps=last_per_steps,
|
91 |
log_samples=True,
|
92 |
+
mel_spec_type=mel_spec_type,
|
93 |
)
|
94 |
|
95 |
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|