Spaces:
Configuration error
Configuration error
finish trainer modification
Browse files- src/f5_tts/model/trainer.py +5 -5
- src/f5_tts/train/train.py +1 -0
src/f5_tts/model/trainer.py
CHANGED
@@ -191,7 +191,7 @@ class Trainer:
|
|
191 |
from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
|
192 |
|
193 |
vocoder = load_vocoder()
|
194 |
-
target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
|
195 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
196 |
os.makedirs(log_samples_path, exist_ok=True)
|
197 |
|
@@ -314,12 +314,12 @@ class Trainer:
|
|
314 |
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
315 |
self.save_checkpoint(global_step)
|
316 |
|
317 |
-
if self.log_samples:
|
318 |
-
ref_audio, ref_audio_len = vocoder.decode(
|
319 |
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
320 |
with torch.inference_mode():
|
321 |
-
generated, _ = self.model.sample(
|
322 |
-
cond=
|
323 |
text=[text_inputs[0] + [" "] + text_inputs[0]],
|
324 |
duration=ref_audio_len * 2,
|
325 |
steps=nfe_step,
|
|
|
191 |
from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
|
192 |
|
193 |
vocoder = load_vocoder()
|
194 |
+
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
|
195 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
196 |
os.makedirs(log_samples_path, exist_ok=True)
|
197 |
|
|
|
314 |
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
315 |
self.save_checkpoint(global_step)
|
316 |
|
317 |
+
if self.log_samples and self.accelerator.is_local_main_process:
|
318 |
+
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0).cpu()), mel_lengths[0]
|
319 |
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
320 |
with torch.inference_mode():
|
321 |
+
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
322 |
+
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
323 |
text=[text_inputs[0] + [" "] + text_inputs[0]],
|
324 |
duration=ref_audio_len * 2,
|
325 |
steps=nfe_step,
|
src/f5_tts/train/train.py
CHANGED
@@ -83,6 +83,7 @@ def main():
|
|
83 |
wandb_run_name=exp_name,
|
84 |
wandb_resume_id=wandb_resume_id,
|
85 |
last_per_steps=last_per_steps,
|
|
|
86 |
)
|
87 |
|
88 |
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
|
|
83 |
wandb_run_name=exp_name,
|
84 |
wandb_resume_id=wandb_resume_id,
|
85 |
last_per_steps=last_per_steps,
|
86 |
+
log_samples=True,
|
87 |
)
|
88 |
|
89 |
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|