Spaces:
Configuration error
Configuration error
Merge branch 'SWivid:main' into main
Browse files
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -119,12 +119,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
119 |
asr_pipe = None
|
120 |
|
121 |
|
122 |
-
def initialize_asr_pipeline(device=device):
|
|
|
|
|
|
|
|
|
123 |
global asr_pipe
|
124 |
asr_pipe = pipeline(
|
125 |
"automatic-speech-recognition",
|
126 |
model="openai/whisper-large-v3-turbo",
|
127 |
-
torch_dtype=
|
128 |
device=device,
|
129 |
)
|
130 |
|
|
|
119 |
asr_pipe = None
|
120 |
|
121 |
|
122 |
+
def initialize_asr_pipeline(device=device, dtype=None):
|
123 |
+
if dtype is None:
|
124 |
+
dtype = (
|
125 |
+
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
126 |
+
)
|
127 |
global asr_pipe
|
128 |
asr_pipe = pipeline(
|
129 |
"automatic-speech-recognition",
|
130 |
model="openai/whisper-large-v3-turbo",
|
131 |
+
torch_dtype=dtype,
|
132 |
device=device,
|
133 |
)
|
134 |
|
src/f5_tts/model/trainer.py
CHANGED
@@ -325,7 +325,9 @@ class Trainer:
|
|
325 |
|
326 |
if self.log_samples and self.accelerator.is_local_main_process:
|
327 |
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
|
328 |
-
torchaudio.save(
|
|
|
|
|
329 |
with torch.inference_mode():
|
330 |
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
331 |
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
@@ -336,8 +338,12 @@ class Trainer:
|
|
336 |
sway_sampling_coef=sway_sampling_coef,
|
337 |
)
|
338 |
generated = generated.to(torch.float32)
|
339 |
-
gen_audio = vocoder.decode(
|
340 |
-
|
|
|
|
|
|
|
|
|
341 |
|
342 |
if global_step % self.last_per_steps == 0:
|
343 |
self.save_checkpoint(global_step, last=True)
|
|
|
325 |
|
326 |
if self.log_samples and self.accelerator.is_local_main_process:
|
327 |
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
|
328 |
+
torchaudio.save(
|
329 |
+
f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
|
330 |
+
)
|
331 |
with torch.inference_mode():
|
332 |
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
333 |
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
|
|
338 |
sway_sampling_coef=sway_sampling_coef,
|
339 |
)
|
340 |
generated = generated.to(torch.float32)
|
341 |
+
gen_audio = vocoder.decode(
|
342 |
+
generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
|
343 |
+
)
|
344 |
+
torchaudio.save(
|
345 |
+
f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
|
346 |
+
)
|
347 |
|
348 |
if global_step % self.last_per_steps == 0:
|
349 |
self.save_checkpoint(global_step, last=True)
|