kunci115 commited on
Commit
6e24f1e
·
unverified ·
2 Parent(s): 0fe34a8 ea90244

Merge branch 'SWivid:main' into main

Browse files
src/f5_tts/infer/utils_infer.py CHANGED
@@ -119,12 +119,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
119
  asr_pipe = None
120
 
121
 
122
- def initialize_asr_pipeline(device=device):
 
 
 
 
123
  global asr_pipe
124
  asr_pipe = pipeline(
125
  "automatic-speech-recognition",
126
  model="openai/whisper-large-v3-turbo",
127
- torch_dtype=torch.float16,
128
  device=device,
129
  )
130
 
 
119
  asr_pipe = None
120
 
121
 
122
+ def initialize_asr_pipeline(device=device, dtype=None):
123
+ if dtype is None:
124
+ dtype = (
125
+ torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
126
+ )
127
  global asr_pipe
128
  asr_pipe = pipeline(
129
  "automatic-speech-recognition",
130
  model="openai/whisper-large-v3-turbo",
131
+ torch_dtype=dtype,
132
  device=device,
133
  )
134
 
src/f5_tts/model/trainer.py CHANGED
@@ -325,7 +325,9 @@ class Trainer:
325
 
326
  if self.log_samples and self.accelerator.is_local_main_process:
327
  ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328
- torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
 
 
329
  with torch.inference_mode():
330
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
331
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
@@ -336,8 +338,12 @@ class Trainer:
336
  sway_sampling_coef=sway_sampling_coef,
337
  )
338
  generated = generated.to(torch.float32)
339
- gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
340
- torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
 
 
 
 
341
 
342
  if global_step % self.last_per_steps == 0:
343
  self.save_checkpoint(global_step, last=True)
 
325
 
326
  if self.log_samples and self.accelerator.is_local_main_process:
327
  ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328
+ torchaudio.save(
329
+ f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
330
+ )
331
  with torch.inference_mode():
332
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
333
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
 
338
  sway_sampling_coef=sway_sampling_coef,
339
  )
340
  generated = generated.to(torch.float32)
341
+ gen_audio = vocoder.decode(
342
+ generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
343
+ )
344
+ torchaudio.save(
345
+ f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
346
+ )
347
 
348
  if global_step % self.last_per_steps == 0:
349
  self.save_checkpoint(global_step, last=True)