Yushen CHEN commited on
Commit
dee0420
·
unverified ·
2 Parent(s): 7dd3eec 6cbb548

Merge pull request #334 from lpscr/main

Browse files
src/f5_tts/api.py CHANGED
@@ -15,6 +15,9 @@ from f5_tts.infer.utils_infer import (
15
  infer_process,
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
 
 
 
18
  )
19
 
20
 
@@ -31,10 +34,8 @@ class F5TTS:
31
  ):
32
  # Initialize parameters
33
  self.final_wave = None
34
- self.target_sample_rate = 24000
35
- self.n_mel_channels = 100
36
- self.hop_length = 256
37
- self.target_rms = 0.1
38
  self.seed = -1
39
 
40
  # Set device
@@ -97,6 +98,9 @@ class F5TTS:
97
  seed = random.randint(0, sys.maxsize)
98
  seed_everything(seed)
99
  self.seed = seed
 
 
 
100
  wav, sr, spect = infer_process(
101
  ref_file,
102
  ref_text,
 
15
  infer_process,
16
  remove_silence_for_generated_wav,
17
  save_spectrogram,
18
+ preprocess_ref_audio_text,
19
+ target_sample_rate,
20
+ hop_length,
21
  )
22
 
23
 
 
34
  ):
35
  # Initialize parameters
36
  self.final_wave = None
37
+ self.target_sample_rate = target_sample_rate
38
+ self.hop_length = hop_length
 
 
39
  self.seed = -1
40
 
41
  # Set device
 
98
  seed = random.randint(0, sys.maxsize)
99
  seed_everything(seed)
100
  self.seed = seed
101
+
102
+ ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
103
+
104
  wav, sr, spect = infer_process(
105
  ref_file,
106
  ref_text,
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1216,7 +1216,7 @@ def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe
1216
  else:
1217
  device_test = None
1218
 
1219
- if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema:
1220
  if last_checkpoint != file_checkpoint:
1221
  last_checkpoint = file_checkpoint
1222
 
 
1216
  else:
1217
  device_test = None
1218
 
1219
+ if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema or tts_api is None:
1220
  if last_checkpoint != file_checkpoint:
1221
  last_checkpoint = file_checkpoint
1222