SWivid commited on
Commit
87c4f9f
·
1 Parent(s): 381ea0c
Files changed (1) hide show
  1. src/f5_tts/model/trainer.py +4 -3
src/f5_tts/model/trainer.py CHANGED
@@ -188,8 +188,9 @@ class Trainer:
188
 
189
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
190
  if self.log_samples:
191
- from f5_tts.infer.utils_infer import vocos, nfe_step, cfg_strength, sway_sampling_coef
192
 
 
193
  target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
194
  log_samples_path = f"{self.checkpoint_path}/samples"
195
  os.makedirs(log_samples_path, exist_ok=True)
@@ -314,7 +315,7 @@ class Trainer:
314
  self.save_checkpoint(global_step)
315
 
316
  if self.log_samples:
317
- ref_audio, ref_audio_len = vocos.decode([batch["mel"][0]].cpu()), mel_lengths[0]
318
  torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
319
  with torch.inference_mode():
320
  generated, _ = self.model.sample(
@@ -326,7 +327,7 @@ class Trainer:
326
  sway_sampling_coef=sway_sampling_coef,
327
  )
328
  generated = generated.to(torch.float32)
329
- gen_audio = vocos.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
330
  torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
331
 
332
  if global_step % self.last_per_steps == 0:
 
188
 
189
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
190
  if self.log_samples:
191
+ from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
192
 
193
+ vocoder = load_vocoder()
194
  target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
195
  log_samples_path = f"{self.checkpoint_path}/samples"
196
  os.makedirs(log_samples_path, exist_ok=True)
 
315
  self.save_checkpoint(global_step)
316
 
317
  if self.log_samples:
318
+ ref_audio, ref_audio_len = vocoder.decode([batch["mel"][0]].cpu()), mel_lengths[0]
319
  torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
320
  with torch.inference_mode():
321
  generated, _ = self.model.sample(
 
327
  sway_sampling_coef=sway_sampling_coef,
328
  )
329
  generated = generated.to(torch.float32)
330
+ gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
331
  torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
332
 
333
  if global_step % self.last_per_steps == 0: