SWivid commited on
Commit
aaa92f6
·
1 Parent(s): 87c4f9f

finish trainer modification

Browse files
src/f5_tts/model/trainer.py CHANGED
@@ -191,7 +191,7 @@ class Trainer:
191
  from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
192
 
193
  vocoder = load_vocoder()
194
- target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
195
  log_samples_path = f"{self.checkpoint_path}/samples"
196
  os.makedirs(log_samples_path, exist_ok=True)
197
 
@@ -314,12 +314,12 @@ class Trainer:
314
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
315
  self.save_checkpoint(global_step)
316
 
317
- if self.log_samples:
318
- ref_audio, ref_audio_len = vocoder.decode([batch["mel"][0]].cpu()), mel_lengths[0]
319
  torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
320
  with torch.inference_mode():
321
- generated, _ = self.model.sample(
322
- cond=[mel_spec[0][:ref_audio_len]],
323
  text=[text_inputs[0] + [" "] + text_inputs[0]],
324
  duration=ref_audio_len * 2,
325
  steps=nfe_step,
 
191
  from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
192
 
193
  vocoder = load_vocoder()
194
+ target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
195
  log_samples_path = f"{self.checkpoint_path}/samples"
196
  os.makedirs(log_samples_path, exist_ok=True)
197
 
 
314
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
315
  self.save_checkpoint(global_step)
316
 
317
+ if self.log_samples and self.accelerator.is_local_main_process:
318
+ ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0).cpu()), mel_lengths[0]
319
  torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
320
  with torch.inference_mode():
321
+ generated, _ = self.accelerator.unwrap_model(self.model).sample(
322
+ cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
323
  text=[text_inputs[0] + [" "] + text_inputs[0]],
324
  duration=ref_audio_len * 2,
325
  steps=nfe_step,
src/f5_tts/train/train.py CHANGED
@@ -83,6 +83,7 @@ def main():
83
  wandb_run_name=exp_name,
84
  wandb_resume_id=wandb_resume_id,
85
  last_per_steps=last_per_steps,
 
86
  )
87
 
88
  train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
 
83
  wandb_run_name=exp_name,
84
  wandb_resume_id=wandb_resume_id,
85
  last_per_steps=last_per_steps,
86
+ log_samples=True,
87
  )
88
 
89
  train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)