zkniu commited on
Commit
18e1ab5
·
1 Parent(s): b180961

refactor: del global params and set vocos as default vocoder, add dtype check

Browse files
README.md CHANGED
@@ -45,7 +45,7 @@ git clone https://github.com/SWivid/F5-TTS.git
45
  cd F5-TTS
46
  pip install -e .
47
 
48
- # Init submodule(optional, if you want to change the vocoder from vocos to bigvgan)
49
  # git submodule update --init --recursive
50
  # pip install -e .
51
  ```
 
45
  cd F5-TTS
46
  pip install -e .
47
 
48
+ # Init submodule (optional, if you want to change the vocoder from vocos to bigvgan)
49
  # git submodule update --init --recursive
50
  # pip install -e .
51
  ```
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -32,7 +32,6 @@ n_mel_channels = 100
32
  hop_length = 256
33
  win_length = 1024
34
  n_fft = 1024
35
- mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
36
  target_rms = 0.1
37
 
38
 
@@ -49,6 +48,7 @@ def main():
49
  parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
50
  parser.add_argument("-n", "--expname", required=True)
51
  parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
 
52
 
53
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
54
  parser.add_argument("-o", "--odemethod", default="euler")
@@ -63,6 +63,7 @@ def main():
63
  exp_name = args.expname
64
  ckpt_step = args.ckptstep
65
  ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
 
66
 
67
  nfe_step = args.nfestep
68
  ode_method = args.odemethod
@@ -101,7 +102,7 @@ def main():
101
  output_dir = (
102
  f"{rel_path}/"
103
  f"results/{exp_name}_{ckpt_step}/{testset}/"
104
- f"seed{seed}_{ode_method}_nfe{nfe_step}"
105
  f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
106
  f"_cfg{cfg_strength}_speed{speed}"
107
  f"{'_gt-dur' if use_truth_duration else ''}"
@@ -155,10 +156,10 @@ def main():
155
  supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
156
  if supports_fp16 and mel_spec_type == "vocos":
157
  dtype = torch.float16
158
- else:
159
  dtype = torch.float32
160
 
161
- model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
162
 
163
  if not os.path.exists(output_dir) and accelerator.is_main_process:
164
  os.makedirs(output_dir)
 
32
  hop_length = 256
33
  win_length = 1024
34
  n_fft = 1024
 
35
  target_rms = 0.1
36
 
37
 
 
48
  parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
49
  parser.add_argument("-n", "--expname", required=True)
50
  parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
51
+ parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
52
 
53
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
54
  parser.add_argument("-o", "--odemethod", default="euler")
 
63
  exp_name = args.expname
64
  ckpt_step = args.ckptstep
65
  ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
66
+ mel_spec_type = args.mel_spec_type
67
 
68
  nfe_step = args.nfestep
69
  ode_method = args.odemethod
 
102
  output_dir = (
103
  f"{rel_path}/"
104
  f"results/{exp_name}_{ckpt_step}/{testset}/"
105
+ f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
106
  f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
107
  f"_cfg{cfg_strength}_speed{speed}"
108
  f"{'_gt-dur' if use_truth_duration else ''}"
 
156
  supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
157
  if supports_fp16 and mel_spec_type == "vocos":
158
  dtype = torch.float16
159
+ elif mel_spec_type == "bigvgan":
160
  dtype = torch.float32
161
 
162
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
163
 
164
  if not os.path.exists(output_dir) and accelerator.is_main_process:
165
  os.makedirs(output_dir)
src/f5_tts/infer/infer_cli.py CHANGED
@@ -154,7 +154,7 @@ elif model == "E2-TTS":
154
 
155
 
156
  print(f"Using {model}...")
157
- ema_model = load_model(model_cls, model_cfg, ckpt_file, args.vocoder_name, vocab_file)
158
 
159
 
160
  def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
@@ -192,7 +192,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
192
  ref_text = voices[voice]["ref_text"]
193
  print(f"Voice: {voice}")
194
  audio, final_sample_rate, spectragram = infer_process(
195
- ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type, speed=speed
196
  )
197
  generated_audio_segments.append(audio)
198
 
 
154
 
155
 
156
  print(f"Using {model}...")
157
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
158
 
159
 
160
  def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
 
192
  ref_text = voices[voice]["ref_text"]
193
  print(f"Voice: {voice}")
194
  audio, final_sample_rate, spectragram = infer_process(
195
+ ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
196
  )
197
  generated_audio_segments.append(audio)
198
 
src/f5_tts/infer/speech_edit.py CHANGED
@@ -18,7 +18,7 @@ n_mel_channels = 100
18
  hop_length = 256
19
  win_length = 1024
20
  n_fft = 1024
21
- mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
22
  target_rms = 0.1
23
 
24
  tokenizer = "pinyin"
@@ -114,10 +114,10 @@ model = CFM(
114
  supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
115
  if supports_fp16 and mel_spec_type == "vocos":
116
  dtype = torch.float16
117
- else:
118
  dtype = torch.float32
119
 
120
- model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
121
 
122
  # Audio
123
  audio, sr = torchaudio.load(audio_to_edit)
 
18
  hop_length = 256
19
  win_length = 1024
20
  n_fft = 1024
21
+ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
22
  target_rms = 0.1
23
 
24
  tokenizer = "pinyin"
 
114
  supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
115
  if supports_fp16 and mel_spec_type == "vocos":
116
  dtype = torch.float16
117
+ elif mel_spec_type == "bigvgan":
118
  dtype = torch.float32
119
 
120
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
121
 
122
  # Audio
123
  audio, sr = torchaudio.load(audio_to_edit)
src/f5_tts/infer/utils_infer.py CHANGED
@@ -40,7 +40,6 @@ n_mel_channels = 100
40
  hop_length = 256
41
  win_length = 1024
42
  n_fft = 1024
43
- mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
44
  target_rms = 0.1
45
  cross_fade_duration = 0.15
46
  ode_method = "euler"
@@ -133,6 +132,10 @@ def initialize_asr_pipeline(device=device):
133
 
134
 
135
  def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
 
 
 
 
136
  model = model.to(dtype)
137
 
138
  ckpt_type = ckpt_path.split(".")[-1]
@@ -169,7 +172,14 @@ def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
169
 
170
 
171
  def load_model(
172
- model_cls, model_cfg, ckpt_path, mel_spec_type, vocab_file="", ode_method=ode_method, use_ema=True, device=device
 
 
 
 
 
 
 
173
  ):
174
  if vocab_file == "":
175
  vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
@@ -199,10 +209,10 @@ def load_model(
199
  supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
200
  if supports_fp16 and mel_spec_type == "vocos":
201
  dtype = torch.float16
202
- else:
203
  dtype = torch.float32
204
 
205
- model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
206
 
207
  return model
208
 
@@ -297,7 +307,7 @@ def infer_process(
297
  gen_text,
298
  model_obj,
299
  vocoder,
300
- mel_spec_type,
301
  show_info=print,
302
  progress=tqdm,
303
  target_rms=target_rms,
@@ -323,7 +333,7 @@ def infer_process(
323
  gen_text_batches,
324
  model_obj,
325
  vocoder,
326
- mel_spec_type,
327
  progress=progress,
328
  target_rms=target_rms,
329
  cross_fade_duration=cross_fade_duration,
@@ -345,7 +355,7 @@ def infer_batch_process(
345
  gen_text_batches,
346
  model_obj,
347
  vocoder,
348
- mel_spec_type,
349
  progress=tqdm,
350
  target_rms=0.1,
351
  cross_fade_duration=0.15,
 
40
  hop_length = 256
41
  win_length = 1024
42
  n_fft = 1024
 
43
  target_rms = 0.1
44
  cross_fade_duration = 0.15
45
  ode_method = "euler"
 
132
 
133
 
134
  def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
135
+ if dtype is None:
136
+ dtype = (
137
+ torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
138
+ )
139
  model = model.to(dtype)
140
 
141
  ckpt_type = ckpt_path.split(".")[-1]
 
172
 
173
 
174
  def load_model(
175
+ model_cls,
176
+ model_cfg,
177
+ ckpt_path,
178
+ mel_spec_type="vocos",
179
+ vocab_file="",
180
+ ode_method=ode_method,
181
+ use_ema=True,
182
+ device=device,
183
  ):
184
  if vocab_file == "":
185
  vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
 
209
  supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
210
  if supports_fp16 and mel_spec_type == "vocos":
211
  dtype = torch.float16
212
+ elif mel_spec_type == "bigvgan":
213
  dtype = torch.float32
214
 
215
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
216
 
217
  return model
218
 
 
307
  gen_text,
308
  model_obj,
309
  vocoder,
310
+ mel_spec_type="vocos",
311
  show_info=print,
312
  progress=tqdm,
313
  target_rms=target_rms,
 
333
  gen_text_batches,
334
  model_obj,
335
  vocoder,
336
+ mel_spec_type=mel_spec_type,
337
  progress=progress,
338
  target_rms=target_rms,
339
  cross_fade_duration=cross_fade_duration,
 
355
  gen_text_batches,
356
  model_obj,
357
  vocoder,
358
+ mel_spec_type="vocos",
359
  progress=tqdm,
360
  target_rms=0.1,
361
  cross_fade_duration=0.15,