unknown commited on
Commit
2ca1fb7
·
1 Parent(s): 3409192
src/f5_tts/model/trainer.py CHANGED
@@ -19,7 +19,7 @@ from f5_tts.model import CFM
19
  from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
  from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
22
- from f5_tts.model.utils import gen_sample
23
 
24
  # trainer
25
 
@@ -315,7 +315,7 @@ class Trainer:
315
  and self.export_samples
316
  and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
317
  ):
318
- wave_org, wave_gen, mel_org, mel_gen = gen_sample(
319
  vocos,
320
  self.model,
321
  self.file_path_samples,
 
19
  from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
  from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
22
+ from f5_tts.model.utils import get_sample
23
 
24
  # trainer
25
 
 
315
  and self.export_samples
316
  and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
317
  ):
318
+ wave_org, wave_gen, mel_org, mel_gen = get_sample(
319
  vocos,
320
  self.model,
321
  self.file_path_samples,
src/f5_tts/model/utils.py CHANGED
@@ -205,7 +205,7 @@ def export_mel(mel_colored_hwc, file_out):
205
  plt.imsave(file_out, mel_colored_hwc)
206
 
207
 
208
- def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
209
  audio, sr = torchaudio.load(file_wav_org)
210
  audio = audio.to("cuda")
211
  ref_audio_len = audio.shape[-1] // hop_length
@@ -228,7 +228,7 @@ def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cf
228
  return generated_wave_gen, generated_mel_spec_gen
229
 
230
 
231
- def gen_sample(
232
  vocos,
233
  model,
234
  file_path_samples,
@@ -245,7 +245,7 @@ def gen_sample(
245
  generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
246
  file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
247
  export_audio(file_wav_org, generated_wave_org, target_sample_rate)
248
- generated_wave_gen, generated_mel_spec_gen = get_sample(
249
  model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
250
  )
251
  file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")
 
205
  plt.imsave(file_out, mel_colored_hwc)
206
 
207
 
208
+ def gen_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
209
  audio, sr = torchaudio.load(file_wav_org)
210
  audio = audio.to("cuda")
211
  ref_audio_len = audio.shape[-1] // hop_length
 
228
  return generated_wave_gen, generated_mel_spec_gen
229
 
230
 
231
+ def get_sample(
232
  vocos,
233
  model,
234
  file_path_samples,
 
245
  generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
246
  file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
247
  export_audio(file_wav_org, generated_wave_org, target_sample_rate)
248
+ generated_wave_gen, generated_mel_spec_gen = gen_sample(
249
  model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
250
  )
251
  file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")