SWivid commited on
Commit
073092d
·
2 Parent(s): 28b46d3 765a2ae

Merge branch 'main' of github.com:SWivid/F5-TTS into main

Browse files
Files changed (1) hide show
  1. inference-cli.py +33 -32
inference-cli.py CHANGED
@@ -174,6 +174,32 @@ F5TTS_model_cfg = dict(
174
  )
175
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def chunk_text(text, max_chars=135):
178
  """
179
  Splits the input text into chunks, each with a maximum number of characters.
@@ -205,26 +231,7 @@ def chunk_text(text, max_chars=135):
205
  #if not Path(ckpt_path).exists():
206
  #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
207
 
208
- def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
209
- if model == "F5-TTS":
210
-
211
- if ckpt_file == "":
212
- repo_name= "F5-TTS"
213
- exp_name = "F5TTS_Base"
214
- ckpt_step= 1200000
215
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
216
-
217
- ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
218
-
219
- elif model == "E2-TTS":
220
- if ckpt_file == "":
221
- repo_name= "E2-TTS"
222
- exp_name = "E2TTS_Base"
223
- ckpt_step= 1200000
224
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
225
-
226
- ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
227
-
228
  audio, sr = ref_audio
229
  if audio.shape[0] > 1:
230
  audio = torch.mean(audio, dim=0, keepdim=True)
@@ -341,13 +348,7 @@ def process_voice(ref_audio_orig, ref_text):
341
 
342
  if not ref_text.strip():
343
  print("No reference text provided, transcribing reference audio...")
344
- pipe = pipeline(
345
- "automatic-speech-recognition",
346
- model="openai/whisper-large-v3-turbo",
347
- torch_dtype=torch.float16,
348
- device=device,
349
- )
350
- ref_text = pipe(
351
  ref_audio,
352
  chunk_length_s=30,
353
  batch_size=128,
@@ -359,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text):
359
  print("Using custom reference text...")
360
  return ref_audio, ref_text
361
 
362
- def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
363
  # Add the functionality to ensure it ends with ". "
364
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
365
  if ref_text.endswith("."):
@@ -375,10 +376,10 @@ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_sile
375
  print(f'gen_text {i}', gen_text)
376
 
377
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
378
- return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
379
 
380
 
381
- def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
382
  main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
383
  if "voices" not in config:
384
  voices = {"main": main_voice}
@@ -406,7 +407,7 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
406
  ref_audio = voices[voice]['ref_audio']
407
  ref_text = voices[voice]['ref_text']
408
  print(f"Voice: {voice}")
409
- audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
410
  generated_audio_segments.append(audio)
411
 
412
  if generated_audio_segments:
@@ -425,4 +426,4 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
425
  print(f.name)
426
 
427
 
428
- process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)
 
174
  )
175
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
176
 
177
+ if model == "F5-TTS":
178
+
179
+ if ckpt_file == "":
180
+ repo_name= "F5-TTS"
181
+ exp_name = "F5TTS_Base"
182
+ ckpt_step= 1200000
183
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
184
+
185
+ ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file)
186
+
187
+ elif model == "E2-TTS":
188
+ if ckpt_file == "":
189
+ repo_name= "E2-TTS"
190
+ exp_name = "E2TTS_Base"
191
+ ckpt_step= 1200000
192
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
193
+
194
+ ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file)
195
+
196
+ asr_pipe = pipeline(
197
+ "automatic-speech-recognition",
198
+ model="openai/whisper-large-v3-turbo",
199
+ torch_dtype=torch.float16,
200
+ device=device,
201
+ )
202
+
203
  def chunk_text(text, max_chars=135):
204
  """
205
  Splits the input text into chunks, each with a maximum number of characters.
 
231
  #if not Path(ckpt_path).exists():
232
  #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
233
 
234
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  audio, sr = ref_audio
236
  if audio.shape[0] > 1:
237
  audio = torch.mean(audio, dim=0, keepdim=True)
 
348
 
349
  if not ref_text.strip():
350
  print("No reference text provided, transcribing reference audio...")
351
+ ref_text = asr_pipe(
 
 
 
 
 
 
352
  ref_audio,
353
  chunk_length_s=30,
354
  batch_size=128,
 
360
  print("Using custom reference text...")
361
  return ref_audio, ref_text
362
 
363
+ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
364
  # Add the functionality to ensure it ends with ". "
365
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
366
  if ref_text.endswith("."):
 
376
  print(f'gen_text {i}', gen_text)
377
 
378
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
379
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
380
 
381
 
382
+ def process(ref_audio, ref_text, text_gen, model, remove_silence):
383
  main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
384
  if "voices" not in config:
385
  voices = {"main": main_voice}
 
407
  ref_audio = voices[voice]['ref_audio']
408
  ref_text = voices[voice]['ref_text']
409
  print(f"Voice: {voice}")
410
+ audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence)
411
  generated_audio_segments.append(audio)
412
 
413
  if generated_audio_segments:
 
426
  print(f.name)
427
 
428
 
429
+ process(ref_audio, ref_text, gen_text, model, remove_silence)