Spaces:
Configuration error
Configuration error
fix vocoder loading
Browse files
src/f5_tts/api.py
CHANGED
@@ -47,7 +47,7 @@ class F5TTS:
|
|
47 |
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
48 |
|
49 |
def load_vocoder_model(self, local_path):
|
50 |
-
self.
|
51 |
|
52 |
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
53 |
if model_type == "F5-TTS":
|
@@ -102,6 +102,7 @@ class F5TTS:
|
|
102 |
ref_text,
|
103 |
gen_text,
|
104 |
self.ema_model,
|
|
|
105 |
show_info=show_info,
|
106 |
progress=progress,
|
107 |
target_rms=target_rms,
|
|
|
47 |
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
48 |
|
49 |
def load_vocoder_model(self, local_path):
|
50 |
+
self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
|
51 |
|
52 |
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
53 |
if model_type == "F5-TTS":
|
|
|
102 |
ref_text,
|
103 |
gen_text,
|
104 |
self.ema_model,
|
105 |
+
self.vocoder,
|
106 |
show_info=show_info,
|
107 |
progress=progress,
|
108 |
target_rms=target_rms,
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -113,7 +113,7 @@ wave_path = Path(output_dir) / "infer_cli_out.wav"
|
|
113 |
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
114 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
115 |
|
116 |
-
|
117 |
|
118 |
|
119 |
# load models
|
@@ -175,7 +175,9 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed
|
|
175 |
ref_audio = voices[voice]["ref_audio"]
|
176 |
ref_text = voices[voice]["ref_text"]
|
177 |
print(f"Voice: {voice}")
|
178 |
-
audio, final_sample_rate, spectragram = infer_process(
|
|
|
|
|
179 |
generated_audio_segments.append(audio)
|
180 |
|
181 |
if generated_audio_segments:
|
|
|
113 |
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
114 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
115 |
|
116 |
+
vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
|
117 |
|
118 |
|
119 |
# load models
|
|
|
175 |
ref_audio = voices[voice]["ref_audio"]
|
176 |
ref_text = voices[voice]["ref_text"]
|
177 |
print(f"Voice: {voice}")
|
178 |
+
audio, final_sample_rate, spectragram = infer_process(
|
179 |
+
ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed
|
180 |
+
)
|
181 |
generated_audio_segments.append(audio)
|
182 |
|
183 |
if generated_audio_segments:
|
src/f5_tts/infer/infer_gradio.py
CHANGED
@@ -37,7 +37,7 @@ from f5_tts.infer.utils_infer import (
|
|
37 |
save_spectrogram,
|
38 |
)
|
39 |
|
40 |
-
|
41 |
|
42 |
|
43 |
# load models
|
@@ -94,6 +94,7 @@ def infer(
|
|
94 |
ref_text,
|
95 |
gen_text,
|
96 |
ema_model,
|
|
|
97 |
cross_fade_duration=cross_fade_duration,
|
98 |
speed=speed,
|
99 |
show_info=show_info,
|
|
|
37 |
save_spectrogram,
|
38 |
)
|
39 |
|
40 |
+
vocoder = load_vocoder()
|
41 |
|
42 |
|
43 |
# load models
|
|
|
94 |
ref_text,
|
95 |
gen_text,
|
96 |
ema_model,
|
97 |
+
vocoder,
|
98 |
cross_fade_duration=cross_fade_duration,
|
99 |
speed=speed,
|
100 |
show_info=show_info,
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -29,9 +29,6 @@ _ref_audio_cache = {}
|
|
29 |
|
30 |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
31 |
|
32 |
-
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
33 |
-
|
34 |
-
|
35 |
# -----------------------------------------
|
36 |
|
37 |
target_sample_rate = 24000
|
@@ -263,6 +260,7 @@ def infer_process(
|
|
263 |
ref_text,
|
264 |
gen_text,
|
265 |
model_obj,
|
|
|
266 |
show_info=print,
|
267 |
progress=tqdm,
|
268 |
target_rms=target_rms,
|
@@ -287,6 +285,7 @@ def infer_process(
|
|
287 |
ref_text,
|
288 |
gen_text_batches,
|
289 |
model_obj,
|
|
|
290 |
progress=progress,
|
291 |
target_rms=target_rms,
|
292 |
cross_fade_duration=cross_fade_duration,
|
@@ -307,6 +306,7 @@ def infer_batch_process(
|
|
307 |
ref_text,
|
308 |
gen_text_batches,
|
309 |
model_obj,
|
|
|
310 |
progress=tqdm,
|
311 |
target_rms=0.1,
|
312 |
cross_fade_duration=0.15,
|
@@ -362,7 +362,7 @@ def infer_batch_process(
|
|
362 |
generated = generated.to(torch.float32)
|
363 |
generated = generated[:, ref_audio_len:, :]
|
364 |
generated_mel_spec = generated.permute(0, 2, 1)
|
365 |
-
generated_wave =
|
366 |
if rms < target_rms:
|
367 |
generated_wave = generated_wave * rms / target_rms
|
368 |
|
|
|
29 |
|
30 |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
31 |
|
|
|
|
|
|
|
32 |
# -----------------------------------------
|
33 |
|
34 |
target_sample_rate = 24000
|
|
|
260 |
ref_text,
|
261 |
gen_text,
|
262 |
model_obj,
|
263 |
+
vocoder,
|
264 |
show_info=print,
|
265 |
progress=tqdm,
|
266 |
target_rms=target_rms,
|
|
|
285 |
ref_text,
|
286 |
gen_text_batches,
|
287 |
model_obj,
|
288 |
+
vocoder,
|
289 |
progress=progress,
|
290 |
target_rms=target_rms,
|
291 |
cross_fade_duration=cross_fade_duration,
|
|
|
306 |
ref_text,
|
307 |
gen_text_batches,
|
308 |
model_obj,
|
309 |
+
vocoder,
|
310 |
progress=tqdm,
|
311 |
target_rms=0.1,
|
312 |
cross_fade_duration=0.15,
|
|
|
362 |
generated = generated.to(torch.float32)
|
363 |
generated = generated[:, ref_audio_len:, :]
|
364 |
generated_mel_spec = generated.permute(0, 2, 1)
|
365 |
+
generated_wave = vocoder.decode(generated_mel_spec.cpu())
|
366 |
if rms < target_rms:
|
367 |
generated_wave = generated_wave * rms / target_rms
|
368 |
|