Spaces:
Configuration error
Configuration error
unknown
commited on
Commit
·
2ca1fb7
1
Parent(s):
3409192
update
Browse files- src/f5_tts/model/trainer.py +2 -2
- src/f5_tts/model/utils.py +3 -3
src/f5_tts/model/trainer.py
CHANGED
@@ -19,7 +19,7 @@ from f5_tts.model import CFM
|
|
19 |
from f5_tts.model.utils import exists, default
|
20 |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
21 |
from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
|
22 |
-
from f5_tts.model.utils import
|
23 |
|
24 |
# trainer
|
25 |
|
@@ -315,7 +315,7 @@ class Trainer:
|
|
315 |
and self.export_samples
|
316 |
and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
|
317 |
):
|
318 |
-
wave_org, wave_gen, mel_org, mel_gen =
|
319 |
vocos,
|
320 |
self.model,
|
321 |
self.file_path_samples,
|
|
|
19 |
from f5_tts.model.utils import exists, default
|
20 |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
21 |
from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
|
22 |
+
from f5_tts.model.utils import get_sample
|
23 |
|
24 |
# trainer
|
25 |
|
|
|
315 |
and self.export_samples
|
316 |
and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
|
317 |
):
|
318 |
+
wave_org, wave_gen, mel_org, mel_gen = get_sample(
|
319 |
vocos,
|
320 |
self.model,
|
321 |
self.file_path_samples,
|
src/f5_tts/model/utils.py
CHANGED
@@ -205,7 +205,7 @@ def export_mel(mel_colored_hwc, file_out):
|
|
205 |
plt.imsave(file_out, mel_colored_hwc)
|
206 |
|
207 |
|
208 |
-
def
|
209 |
audio, sr = torchaudio.load(file_wav_org)
|
210 |
audio = audio.to("cuda")
|
211 |
ref_audio_len = audio.shape[-1] // hop_length
|
@@ -228,7 +228,7 @@ def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cf
|
|
228 |
return generated_wave_gen, generated_mel_spec_gen
|
229 |
|
230 |
|
231 |
-
def
|
232 |
vocos,
|
233 |
model,
|
234 |
file_path_samples,
|
@@ -245,7 +245,7 @@ def gen_sample(
|
|
245 |
generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
|
246 |
file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
|
247 |
export_audio(file_wav_org, generated_wave_org, target_sample_rate)
|
248 |
-
generated_wave_gen, generated_mel_spec_gen =
|
249 |
model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
|
250 |
)
|
251 |
file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")
|
|
|
205 |
plt.imsave(file_out, mel_colored_hwc)
|
206 |
|
207 |
|
208 |
+
def gen_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
|
209 |
audio, sr = torchaudio.load(file_wav_org)
|
210 |
audio = audio.to("cuda")
|
211 |
ref_audio_len = audio.shape[-1] // hop_length
|
|
|
228 |
return generated_wave_gen, generated_mel_spec_gen
|
229 |
|
230 |
|
231 |
+
def get_sample(
|
232 |
vocos,
|
233 |
model,
|
234 |
file_path_samples,
|
|
|
245 |
generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
|
246 |
file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
|
247 |
export_audio(file_wav_org, generated_wave_org, target_sample_rate)
|
248 |
+
generated_wave_gen, generated_mel_spec_gen = gen_sample(
|
249 |
model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
|
250 |
)
|
251 |
file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")
|