Spaces:
Configuration error
Configuration error
minor fix
Browse files
src/f5_tts/model/trainer.py
CHANGED
@@ -188,8 +188,9 @@ class Trainer:
|
|
188 |
|
189 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
190 |
if self.log_samples:
|
191 |
-
from f5_tts.infer.utils_infer import
|
192 |
|
|
|
193 |
target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
|
194 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
195 |
os.makedirs(log_samples_path, exist_ok=True)
|
@@ -314,7 +315,7 @@ class Trainer:
|
|
314 |
self.save_checkpoint(global_step)
|
315 |
|
316 |
if self.log_samples:
|
317 |
-
ref_audio, ref_audio_len =
|
318 |
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
319 |
with torch.inference_mode():
|
320 |
generated, _ = self.model.sample(
|
@@ -326,7 +327,7 @@ class Trainer:
|
|
326 |
sway_sampling_coef=sway_sampling_coef,
|
327 |
)
|
328 |
generated = generated.to(torch.float32)
|
329 |
-
gen_audio =
|
330 |
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
|
331 |
|
332 |
if global_step % self.last_per_steps == 0:
|
|
|
188 |
|
189 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
190 |
if self.log_samples:
|
191 |
+
from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
|
192 |
|
193 |
+
vocoder = load_vocoder()
|
194 |
target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
|
195 |
log_samples_path = f"{self.checkpoint_path}/samples"
|
196 |
os.makedirs(log_samples_path, exist_ok=True)
|
|
|
315 |
self.save_checkpoint(global_step)
|
316 |
|
317 |
if self.log_samples:
|
318 |
+
ref_audio, ref_audio_len = vocoder.decode([batch["mel"][0]].cpu()), mel_lengths[0]
|
319 |
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
320 |
with torch.inference_mode():
|
321 |
generated, _ = self.model.sample(
|
|
|
327 |
sway_sampling_coef=sway_sampling_coef,
|
328 |
)
|
329 |
generated = generated.to(torch.float32)
|
330 |
+
gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
|
331 |
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
|
332 |
|
333 |
if global_step % self.last_per_steps == 0:
|