zkniu commited on
Commit
36a4aad
·
1 Parent(s): 712d527

change some infer function to support two vocoder

Browse files
README.md CHANGED
@@ -44,20 +44,18 @@ pip install git+https://github.com/SWivid/F5-TTS.git
44
  git clone https://github.com/SWivid/F5-TTS.git
45
  cd F5-TTS
46
  pip install -e .
47
- ```
48
-
49
- ### 3. Init submodule( optional, if you want to change the vocoder from vocos to bigvgan)
50
 
51
- ```bash
52
  git submodule update --init --recursive
53
  ```
54
- After that, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
 
55
  ```python
56
  import sys
57
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
58
  ```
59
 
60
- ### 4. Docker usage
61
  ```bash
62
  # Build from Dockerfile
63
  docker build -t f5tts:v1 .
@@ -106,6 +104,10 @@ f5-tts_infer-cli -c custom.toml
106
 
107
  # Multi voice. See src/f5_tts/infer/README.md
108
  f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
 
 
 
 
109
  ```
110
 
111
  ### 3. More instructions
 
44
  git clone https://github.com/SWivid/F5-TTS.git
45
  cd F5-TTS
46
  pip install -e .
 
 
 
47
 
48
+ # Init submodule(optional, if you want to change the vocoder from vocos to bigvgan)
49
  git submodule update --init --recursive
50
  ```
51
+
52
+ After init submodule, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
53
  ```python
54
  import sys
55
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
56
  ```
57
 
58
+ ### 3. Docker usage
59
  ```bash
60
  # Build from Dockerfile
61
  docker build -t f5tts:v1 .
 
104
 
105
  # Multi voice. See src/f5_tts/infer/README.md
106
  f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
107
+
108
+ # Choose Vocoder
109
+ f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/model_1250000.pt >
110
+ f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors >
111
  ```
112
 
113
  ### 3. More instructions
src/f5_tts/api.py CHANGED
@@ -7,10 +7,16 @@ import torch
7
  import tqdm
8
  from cached_path import cached_path
9
 
10
- from f5_tts.infer.utils_infer import (hop_length, infer_process, load_model,
11
- load_vocoder, preprocess_ref_audio_text,
12
- remove_silence_for_generated_wav,
13
- save_spectrogram, target_sample_rate)
 
 
 
 
 
 
14
  from f5_tts.model import DiT, UNetT
15
  from f5_tts.model.utils import seed_everything
16
 
@@ -32,6 +38,7 @@ class F5TTS:
32
  self.target_sample_rate = target_sample_rate
33
  self.hop_length = hop_length
34
  self.seed = -1
 
35
 
36
  # Set device
37
  self.device = device or (
@@ -40,12 +47,12 @@ class F5TTS:
40
 
41
  # Load models
42
  self.load_vocoder_model(vocoder_name, local_path)
43
- self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
44
 
45
  def load_vocoder_model(self, vocoder_name, local_path):
46
  self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
47
 
48
- def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
49
  if model_type == "F5-TTS":
50
  if not ckpt_file:
51
  ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
@@ -59,7 +66,9 @@ class F5TTS:
59
  else:
60
  raise ValueError(f"Unknown model type: {model_type}")
61
 
62
- self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
 
 
63
 
64
  def export_wav(self, wav, file_wave, remove_silence=False):
65
  sf.write(file_wave, wav, self.target_sample_rate)
@@ -102,6 +111,7 @@ class F5TTS:
102
  gen_text,
103
  self.ema_model,
104
  self.vocoder,
 
105
  show_info=show_info,
106
  progress=progress,
107
  target_rms=target_rms,
 
7
  import tqdm
8
  from cached_path import cached_path
9
 
10
+ from f5_tts.infer.utils_infer import (
11
+ hop_length,
12
+ infer_process,
13
+ load_model,
14
+ load_vocoder,
15
+ preprocess_ref_audio_text,
16
+ remove_silence_for_generated_wav,
17
+ save_spectrogram,
18
+ target_sample_rate,
19
+ )
20
  from f5_tts.model import DiT, UNetT
21
  from f5_tts.model.utils import seed_everything
22
 
 
38
  self.target_sample_rate = target_sample_rate
39
  self.hop_length = hop_length
40
  self.seed = -1
41
+ self.extract_backend = vocoder_name
42
 
43
  # Set device
44
  self.device = device or (
 
47
 
48
  # Load models
49
  self.load_vocoder_model(vocoder_name, local_path)
50
+ self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
51
 
52
  def load_vocoder_model(self, vocoder_name, local_path):
53
  self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
 
55
+ def load_ema_model(self, model_type, ckpt_file, extract_backend, vocab_file, ode_method, use_ema):
56
  if model_type == "F5-TTS":
57
  if not ckpt_file:
58
  ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
 
66
  else:
67
  raise ValueError(f"Unknown model type: {model_type}")
68
 
69
+ self.ema_model = load_model(
70
+ model_cls, model_cfg, ckpt_file, extract_backend, vocab_file, ode_method, use_ema, self.device
71
+ )
72
 
73
  def export_wav(self, wav, file_wave, remove_silence=False):
74
  sf.write(file_wave, wav, self.target_sample_rate)
 
111
  gen_text,
112
  self.ema_model,
113
  self.vocoder,
114
+ self.extract_backend,
115
  show_info=show_info,
116
  progress=progress,
117
  target_rms=target_rms,
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -12,9 +12,11 @@ import torchaudio
12
  from accelerate import Accelerator
13
  from tqdm import tqdm
14
 
15
- from f5_tts.eval.utils_eval import (get_inference_prompt,
16
- get_librispeech_test_clean_metainfo,
17
- get_seedtts_testset_metainfo)
 
 
18
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
19
  from f5_tts.model import CFM, DiT, UNetT
20
  from f5_tts.model.utils import get_tokenizer
@@ -185,7 +187,7 @@ def main():
185
  gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
186
  gen_mel_spec = gen.permute(0, 2, 1)
187
  if extract_backend == "vocos":
188
- generated_wave = vocoder.decode(gen_mel_spec.cpu())
189
  elif extract_backend == "bigvgan":
190
  generated_wave = vocoder(gen_mel_spec)
191
 
 
12
  from accelerate import Accelerator
13
  from tqdm import tqdm
14
 
15
+ from f5_tts.eval.utils_eval import (
16
+ get_inference_prompt,
17
+ get_librispeech_test_clean_metainfo,
18
+ get_seedtts_testset_metainfo,
19
+ )
20
  from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
21
  from f5_tts.model import CFM, DiT, UNetT
22
  from f5_tts.model.utils import get_tokenizer
 
187
  gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
188
  gen_mel_spec = gen.permute(0, 2, 1)
189
  if extract_backend == "vocos":
190
+ generated_wave = vocoder.decode(gen_mel_spec)
191
  elif extract_backend == "bigvgan":
192
  generated_wave = vocoder(gen_mel_spec)
193
 
src/f5_tts/infer/infer_cli.py CHANGED
@@ -10,9 +10,13 @@ import soundfile as sf
10
  import tomli
11
  from cached_path import cached_path
12
 
13
- from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder,
14
- preprocess_ref_audio_text,
15
- remove_silence_for_generated_wav)
 
 
 
 
16
  from f5_tts.model import DiT, UNetT
17
 
18
  parser = argparse.ArgumentParser(
@@ -108,12 +112,13 @@ speed = args.speed
108
  wave_path = Path(output_dir) / "infer_cli_out.wav"
109
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
110
  if args.vocoder_name == "vocos":
111
- vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
112
  elif args.vocoder_name == "bigvgan":
113
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
 
114
 
115
  vocoder = load_vocoder(
116
- vocoder_name=args.vocoder_name, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path
117
  )
118
 
119
 
@@ -122,11 +127,17 @@ if model == "F5-TTS":
122
  model_cls = DiT
123
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
124
  if ckpt_file == "":
125
- repo_name = "F5-TTS"
126
- exp_name = "F5TTS_Base"
127
- ckpt_step = 1200000
128
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
129
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
 
 
 
 
 
 
130
 
131
  elif model == "E2-TTS":
132
  model_cls = UNetT
@@ -145,10 +156,10 @@ elif model == "E2-TTS":
145
 
146
 
147
  print(f"Using {model}...")
148
- ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
149
 
150
 
151
- def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed):
152
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
153
  if "voices" not in config:
154
  voices = {"main": main_voice}
@@ -183,7 +194,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed
183
  ref_text = voices[voice]["ref_text"]
184
  print(f"Voice: {voice}")
185
  audio, final_sample_rate, spectragram = infer_process(
186
- ref_audio, ref_text, gen_text, model_obj, vocoder, speed=speed
187
  )
188
  generated_audio_segments.append(audio)
189
 
@@ -202,7 +213,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed
202
 
203
 
204
  def main():
205
- main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence, speed)
206
 
207
 
208
  if __name__ == "__main__":
 
10
  import tomli
11
  from cached_path import cached_path
12
 
13
+ from f5_tts.infer.utils_infer import (
14
+ infer_process,
15
+ load_model,
16
+ load_vocoder,
17
+ preprocess_ref_audio_text,
18
+ remove_silence_for_generated_wav,
19
+ )
20
  from f5_tts.model import DiT, UNetT
21
 
22
  parser = argparse.ArgumentParser(
 
112
  wave_path = Path(output_dir) / "infer_cli_out.wav"
113
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114
  if args.vocoder_name == "vocos":
115
+ vocoder_local_path = "../checkpoints/vocos-mel-24khz"
116
  elif args.vocoder_name == "bigvgan":
117
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
118
+ extract_backend = args.vocoder_name
119
 
120
  vocoder = load_vocoder(
121
+ vocoder_name=extract_backend, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path
122
  )
123
 
124
 
 
127
  model_cls = DiT
128
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
129
  if ckpt_file == "":
130
+ if args.vocoder_name == "vocos":
131
+ repo_name = "F5-TTS"
132
+ exp_name = "F5TTS_Base"
133
+ ckpt_step = 1200000
134
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
135
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
136
+ elif args.vocoder_name == "bigvgan":
137
+ repo_name = "F5-TTS"
138
+ exp_name = "F5TTS_Base_bigvgan"
139
+ ckpt_step = 1250000
140
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
141
 
142
  elif model == "E2-TTS":
143
  model_cls = UNetT
 
156
 
157
 
158
  print(f"Using {model}...")
159
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, args.vocoder_name, vocab_file)
160
 
161
 
162
+ def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remove_silence, speed):
163
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
164
  if "voices" not in config:
165
  voices = {"main": main_voice}
 
194
  ref_text = voices[voice]["ref_text"]
195
  print(f"Voice: {voice}")
196
  audio, final_sample_rate, spectragram = infer_process(
197
+ ref_audio, ref_text, gen_text, model_obj, vocoder, extract_backend, speed=speed
198
  )
199
  generated_audio_segments.append(audio)
200
 
 
213
 
214
 
215
  def main():
216
+ main_process(ref_audio, ref_text, gen_text, ema_model, extract_backend, remove_silence, speed)
217
 
218
 
219
  if __name__ == "__main__":
src/f5_tts/infer/speech_edit.py CHANGED
@@ -4,8 +4,7 @@ import torch
4
  import torch.nn.functional as F
5
  import torchaudio
6
 
7
- from f5_tts.infer.utils_infer import (load_checkpoint, load_vocoder,
8
- save_spectrogram)
9
  from f5_tts.model import CFM, DiT, UNetT
10
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
11
 
@@ -173,20 +172,20 @@ with torch.inference_mode():
173
  seed=seed,
174
  edit_mask=edit_mask,
175
  )
176
- print(f"Generated mel: {generated.shape}")
177
-
178
- # Final result
179
- generated = generated.to(torch.float32)
180
- generated = generated[:, ref_audio_len:, :]
181
- gen_mel_spec = generated.permute(0, 2, 1)
182
- if extract_backend == "vocos":
183
- generated_wave = vocoder.decode(gen_mel_spec.cpu())
184
- elif extract_backend == "bigvgan":
185
- generated_wave = vocoder(gen_mel_spec)
186
-
187
- if rms < target_rms:
188
- generated_wave = generated_wave * rms / target_rms
189
-
190
- save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
191
- torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
192
- print(f"Generated wav: {generated_wave.shape}")
 
4
  import torch.nn.functional as F
5
  import torchaudio
6
 
7
+ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
 
8
  from f5_tts.model import CFM, DiT, UNetT
9
  from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
10
 
 
172
  seed=seed,
173
  edit_mask=edit_mask,
174
  )
175
+ print(f"Generated mel: {generated.shape}")
176
+
177
+ # Final result
178
+ generated = generated.to(torch.float32)
179
+ generated = generated[:, ref_audio_len:, :]
180
+ gen_mel_spec = generated.permute(0, 2, 1)
181
+ if extract_backend == "vocos":
182
+ generated_wave = vocoder.decode(gen_mel_spec)
183
+ elif extract_backend == "bigvgan":
184
+ generated_wave = vocoder(gen_mel_spec)
185
+
186
+ if rms < target_rms:
187
+ generated_wave = generated_wave * rms / target_rms
188
+
189
+ save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
190
+ torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
191
+ print(f"Generated wav: {generated_wave.shape}")
src/f5_tts/infer/utils_infer.py CHANGED
@@ -94,7 +94,6 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
94
  vocoder = Vocos.from_hparams(f"{local_path}/config.yaml")
95
  state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu")
96
  vocoder.load_state_dict(state_dict)
97
- vocoder.eval()
98
  vocoder = vocoder.eval().to(device)
99
  else:
100
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
@@ -148,6 +147,11 @@ def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
148
  for k, v in checkpoint["ema_model_state_dict"].items()
149
  if k not in ["initted", "step"]
150
  }
 
 
 
 
 
151
  model.load_state_dict(checkpoint["model_state_dict"])
152
  else:
153
  if ckpt_type == "safetensors":
@@ -160,7 +164,9 @@ def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
160
  # load model for inference
161
 
162
 
163
- def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
 
 
164
  if vocab_file == "":
165
  vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
166
  tokenizer = "custom"
@@ -282,6 +288,7 @@ def infer_process(
282
  gen_text,
283
  model_obj,
284
  vocoder,
 
285
  show_info=print,
286
  progress=tqdm,
287
  target_rms=target_rms,
@@ -307,6 +314,7 @@ def infer_process(
307
  gen_text_batches,
308
  model_obj,
309
  vocoder,
 
310
  progress=progress,
311
  target_rms=target_rms,
312
  cross_fade_duration=cross_fade_duration,
@@ -328,6 +336,7 @@ def infer_batch_process(
328
  gen_text_batches,
329
  model_obj,
330
  vocoder,
 
331
  progress=tqdm,
332
  target_rms=0.1,
333
  cross_fade_duration=0.15,
@@ -384,7 +393,7 @@ def infer_batch_process(
384
  generated = generated[:, ref_audio_len:, :]
385
  generated_mel_spec = generated.permute(0, 2, 1)
386
  if extract_backend == "vocos":
387
- generated_wave = vocoder.decode(generated_mel_spec.cpu())
388
  elif extract_backend == "bigvgan":
389
  generated_wave = vocoder(generated_mel_spec)
390
  if rms < target_rms:
 
94
  vocoder = Vocos.from_hparams(f"{local_path}/config.yaml")
95
  state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu")
96
  vocoder.load_state_dict(state_dict)
 
97
  vocoder = vocoder.eval().to(device)
98
  else:
99
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
 
147
  for k, v in checkpoint["ema_model_state_dict"].items()
148
  if k not in ["initted", "step"]
149
  }
150
+
151
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
152
+ if key in checkpoint["model_state_dict"]:
153
+ del checkpoint["model_state_dict"][key]
154
+
155
  model.load_state_dict(checkpoint["model_state_dict"])
156
  else:
157
  if ckpt_type == "safetensors":
 
164
  # load model for inference
165
 
166
 
167
+ def load_model(
168
+ model_cls, model_cfg, ckpt_path, extract_backend, vocab_file="", ode_method=ode_method, use_ema=True, device=device
169
+ ):
170
  if vocab_file == "":
171
  vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
172
  tokenizer = "custom"
 
288
  gen_text,
289
  model_obj,
290
  vocoder,
291
+ extract_backend,
292
  show_info=print,
293
  progress=tqdm,
294
  target_rms=target_rms,
 
314
  gen_text_batches,
315
  model_obj,
316
  vocoder,
317
+ extract_backend,
318
  progress=progress,
319
  target_rms=target_rms,
320
  cross_fade_duration=cross_fade_duration,
 
336
  gen_text_batches,
337
  model_obj,
338
  vocoder,
339
+ extract_backend,
340
  progress=tqdm,
341
  target_rms=0.1,
342
  cross_fade_duration=0.15,
 
393
  generated = generated[:, ref_audio_len:, :]
394
  generated_mel_spec = generated.permute(0, 2, 1)
395
  if extract_backend == "vocos":
396
+ generated_wave = vocoder.decode(generated_mel_spec)
397
  elif extract_backend == "bigvgan":
398
  generated_wave = vocoder(generated_mel_spec)
399
  if rms < target_rms:
src/f5_tts/model/cfm.py CHANGED
@@ -19,8 +19,14 @@ from torch.nn.utils.rnn import pad_sequence
19
  from torchdiffeq import odeint
20
 
21
  from f5_tts.model.modules import MelSpec
22
- from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
23
- list_str_to_tensor, mask_from_frac_lengths)
 
 
 
 
 
 
24
 
25
 
26
  class CFM(nn.Module):
@@ -92,12 +98,6 @@ class CFM(nn.Module):
92
  edit_mask=None,
93
  ):
94
  self.eval()
95
-
96
- assert next(self.parameters()).dtype == torch.float32 or next(self.parameters()).dtype == torch.float16, print(
97
- "Only support fp16 and fp32 inference currently"
98
- )
99
- cond = cond.to(next(self.parameters()).dtype)
100
-
101
  # raw wave
102
 
103
  if cond.ndim == 2:
@@ -105,6 +105,11 @@ class CFM(nn.Module):
105
  cond = cond.permute(0, 2, 1)
106
  assert cond.shape[-1] == self.num_channels
107
 
 
 
 
 
 
108
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
109
  if not exists(lens):
110
  lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
 
19
  from torchdiffeq import odeint
20
 
21
  from f5_tts.model.modules import MelSpec
22
+ from f5_tts.model.utils import (
23
+ default,
24
+ exists,
25
+ lens_to_mask,
26
+ list_str_to_idx,
27
+ list_str_to_tensor,
28
+ mask_from_frac_lengths,
29
+ )
30
 
31
 
32
  class CFM(nn.Module):
 
98
  edit_mask=None,
99
  ):
100
  self.eval()
 
 
 
 
 
 
101
  # raw wave
102
 
103
  if cond.ndim == 2:
 
105
  cond = cond.permute(0, 2, 1)
106
  assert cond.shape[-1] == self.num_channels
107
 
108
+ assert next(self.parameters()).dtype == torch.float32 or next(self.parameters()).dtype == torch.float16, print(
109
+ "Only support fp16 and fp32 inference currently"
110
+ )
111
+ cond = cond.to(next(self.parameters()).dtype)
112
+
113
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
114
  if not exists(lens):
115
  lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
src/f5_tts/model/modules.py CHANGED
@@ -123,7 +123,7 @@ def get_vocos_mel_spectrogram(
123
  center=True,
124
  normalized=False,
125
  norm=None,
126
- )
127
  if len(waveform.shape) == 3:
128
  waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
129
 
 
123
  center=True,
124
  normalized=False,
125
  norm=None,
126
+ ).to(waveform.device)
127
  if len(waveform.shape) == 3:
128
  waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
129
 
src/f5_tts/model/trainer.py CHANGED
@@ -187,8 +187,7 @@ class Trainer:
187
 
188
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
189
  if self.log_samples:
190
- from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder,
191
- nfe_step, sway_sampling_coef)
192
 
193
  vocoder = load_vocoder(vocoder_name=self.vocoder_name)
194
  target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
@@ -315,7 +314,7 @@ class Trainer:
315
  self.save_checkpoint(global_step)
316
 
317
  if self.log_samples and self.accelerator.is_local_main_process:
318
- ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0).cpu()), mel_lengths[0]
319
  torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
320
  with torch.inference_mode():
321
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
 
187
 
188
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
189
  if self.log_samples:
190
+ from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
 
191
 
192
  vocoder = load_vocoder(vocoder_name=self.vocoder_name)
193
  target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
 
314
  self.save_checkpoint(global_step)
315
 
316
  if self.log_samples and self.accelerator.is_local_main_process:
317
+ ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
318
  torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
319
  with torch.inference_mode():
320
  generated, _ = self.accelerator.unwrap_model(self.model).sample(