SWivid commited on
Commit
381ea0c
·
1 Parent(s): da1b409

fix vocoder loading

Browse files
src/f5_tts/api.py CHANGED
@@ -47,7 +47,7 @@ class F5TTS:
47
  self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
48
 
49
  def load_vocoder_model(self, local_path):
50
- self.vocos = load_vocoder(local_path is not None, local_path, self.device)
51
 
52
  def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
53
  if model_type == "F5-TTS":
@@ -102,6 +102,7 @@ class F5TTS:
102
  ref_text,
103
  gen_text,
104
  self.ema_model,
 
105
  show_info=show_info,
106
  progress=progress,
107
  target_rms=target_rms,
 
47
  self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
48
 
49
  def load_vocoder_model(self, local_path):
50
+ self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
51
 
52
  def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
53
  if model_type == "F5-TTS":
 
102
  ref_text,
103
  gen_text,
104
  self.ema_model,
105
+ self.vocoder,
106
  show_info=show_info,
107
  progress=progress,
108
  target_rms=target_rms,
src/f5_tts/infer/infer_cli.py CHANGED
@@ -113,7 +113,7 @@ wave_path = Path(output_dir) / "infer_cli_out.wav"
113
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
115
 
116
- vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
117
 
118
 
119
  # load models
@@ -175,7 +175,9 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed
175
  ref_audio = voices[voice]["ref_audio"]
176
  ref_text = voices[voice]["ref_text"]
177
  print(f"Voice: {voice}")
178
- audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed)
 
 
179
  generated_audio_segments.append(audio)
180
 
181
  if generated_audio_segments:
 
113
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
115
 
116
+ vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
117
 
118
 
119
  # load models
 
175
  ref_audio = voices[voice]["ref_audio"]
176
  ref_text = voices[voice]["ref_text"]
177
  print(f"Voice: {voice}")
178
+ audio, final_sample_rate, spectragram = infer_process(
179
+ ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed
180
+ )
181
  generated_audio_segments.append(audio)
182
 
183
  if generated_audio_segments:
src/f5_tts/infer/infer_gradio.py CHANGED
@@ -37,7 +37,7 @@ from f5_tts.infer.utils_infer import (
37
  save_spectrogram,
38
  )
39
 
40
- vocos = load_vocoder()
41
 
42
 
43
  # load models
@@ -94,6 +94,7 @@ def infer(
94
  ref_text,
95
  gen_text,
96
  ema_model,
 
97
  cross_fade_duration=cross_fade_duration,
98
  speed=speed,
99
  show_info=show_info,
 
37
  save_spectrogram,
38
  )
39
 
40
+ vocoder = load_vocoder()
41
 
42
 
43
  # load models
 
94
  ref_text,
95
  gen_text,
96
  ema_model,
97
+ vocoder,
98
  cross_fade_duration=cross_fade_duration,
99
  speed=speed,
100
  show_info=show_info,
src/f5_tts/infer/utils_infer.py CHANGED
@@ -29,9 +29,6 @@ _ref_audio_cache = {}
29
 
30
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
31
 
32
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
33
-
34
-
35
  # -----------------------------------------
36
 
37
  target_sample_rate = 24000
@@ -263,6 +260,7 @@ def infer_process(
263
  ref_text,
264
  gen_text,
265
  model_obj,
 
266
  show_info=print,
267
  progress=tqdm,
268
  target_rms=target_rms,
@@ -287,6 +285,7 @@ def infer_process(
287
  ref_text,
288
  gen_text_batches,
289
  model_obj,
 
290
  progress=progress,
291
  target_rms=target_rms,
292
  cross_fade_duration=cross_fade_duration,
@@ -307,6 +306,7 @@ def infer_batch_process(
307
  ref_text,
308
  gen_text_batches,
309
  model_obj,
 
310
  progress=tqdm,
311
  target_rms=0.1,
312
  cross_fade_duration=0.15,
@@ -362,7 +362,7 @@ def infer_batch_process(
362
  generated = generated.to(torch.float32)
363
  generated = generated[:, ref_audio_len:, :]
364
  generated_mel_spec = generated.permute(0, 2, 1)
365
- generated_wave = vocos.decode(generated_mel_spec.cpu())
366
  if rms < target_rms:
367
  generated_wave = generated_wave * rms / target_rms
368
 
 
29
 
30
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
31
 
 
 
 
32
  # -----------------------------------------
33
 
34
  target_sample_rate = 24000
 
260
  ref_text,
261
  gen_text,
262
  model_obj,
263
+ vocoder,
264
  show_info=print,
265
  progress=tqdm,
266
  target_rms=target_rms,
 
285
  ref_text,
286
  gen_text_batches,
287
  model_obj,
288
+ vocoder,
289
  progress=progress,
290
  target_rms=target_rms,
291
  cross_fade_duration=cross_fade_duration,
 
306
  ref_text,
307
  gen_text_batches,
308
  model_obj,
309
+ vocoder,
310
  progress=tqdm,
311
  target_rms=0.1,
312
  cross_fade_duration=0.15,
 
362
  generated = generated.to(torch.float32)
363
  generated = generated[:, ref_audio_len:, :]
364
  generated_mel_spec = generated.permute(0, 2, 1)
365
+ generated_wave = vocoder.decode(generated_mel_spec.cpu())
366
  if rms < target_rms:
367
  generated_wave = generated_wave * rms / target_rms
368