Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- inference-cli.py +33 -32
inference-cli.py
CHANGED
|
@@ -175,6 +175,32 @@ F5TTS_model_cfg = dict(
|
|
| 175 |
)
|
| 176 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
def chunk_text(text, max_chars=135):
|
| 179 |
"""
|
| 180 |
Splits the input text into chunks, each with a maximum number of characters.
|
|
@@ -206,26 +232,7 @@ def chunk_text(text, max_chars=135):
|
|
| 206 |
#if not Path(ckpt_path).exists():
|
| 207 |
#ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 208 |
|
| 209 |
-
def infer_batch(ref_audio, ref_text, gen_text_batches, model,
|
| 210 |
-
if model == "F5-TTS":
|
| 211 |
-
|
| 212 |
-
if ckpt_file == "":
|
| 213 |
-
repo_name= "F5-TTS"
|
| 214 |
-
exp_name = "F5TTS_Base"
|
| 215 |
-
ckpt_step= 1200000
|
| 216 |
-
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 217 |
-
|
| 218 |
-
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
|
| 219 |
-
|
| 220 |
-
elif model == "E2-TTS":
|
| 221 |
-
if ckpt_file == "":
|
| 222 |
-
repo_name= "E2-TTS"
|
| 223 |
-
exp_name = "E2TTS_Base"
|
| 224 |
-
ckpt_step= 1200000
|
| 225 |
-
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 226 |
-
|
| 227 |
-
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
|
| 228 |
-
|
| 229 |
audio, sr = ref_audio
|
| 230 |
if audio.shape[0] > 1:
|
| 231 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
@@ -342,13 +349,7 @@ def process_voice(ref_audio_orig, ref_text):
|
|
| 342 |
|
| 343 |
if not ref_text.strip():
|
| 344 |
print("No reference text provided, transcribing reference audio...")
|
| 345 |
-
|
| 346 |
-
"automatic-speech-recognition",
|
| 347 |
-
model="openai/whisper-large-v3-turbo",
|
| 348 |
-
torch_dtype=torch.float16,
|
| 349 |
-
device=device,
|
| 350 |
-
)
|
| 351 |
-
ref_text = pipe(
|
| 352 |
ref_audio,
|
| 353 |
chunk_length_s=30,
|
| 354 |
batch_size=128,
|
|
@@ -360,7 +361,7 @@ def process_voice(ref_audio_orig, ref_text):
|
|
| 360 |
print("Using custom reference text...")
|
| 361 |
return ref_audio, ref_text
|
| 362 |
|
| 363 |
-
def infer(ref_audio, ref_text, gen_text, model,
|
| 364 |
# Add the functionality to ensure it ends with ". "
|
| 365 |
if not ref_text.endswith(". ") and not ref_text.endswith("γ"):
|
| 366 |
if ref_text.endswith("."):
|
|
@@ -376,10 +377,10 @@ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_sile
|
|
| 376 |
print(f'gen_text {i}', gen_text)
|
| 377 |
|
| 378 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
| 379 |
-
return infer_batch((audio, sr), ref_text, gen_text_batches, model,
|
| 380 |
|
| 381 |
|
| 382 |
-
def process(ref_audio, ref_text, text_gen, model,
|
| 383 |
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
| 384 |
if "voices" not in config:
|
| 385 |
voices = {"main": main_voice}
|
|
@@ -407,7 +408,7 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
|
|
| 407 |
ref_audio = voices[voice]['ref_audio']
|
| 408 |
ref_text = voices[voice]['ref_text']
|
| 409 |
print(f"Voice: {voice}")
|
| 410 |
-
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,
|
| 411 |
generated_audio_segments.append(audio)
|
| 412 |
|
| 413 |
if generated_audio_segments:
|
|
@@ -426,4 +427,4 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
|
|
| 426 |
print(f.name)
|
| 427 |
|
| 428 |
|
| 429 |
-
process(ref_audio, ref_text, gen_text, model,
|
|
|
|
| 175 |
)
|
| 176 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 177 |
|
| 178 |
+
if model == "F5-TTS":
|
| 179 |
+
|
| 180 |
+
if ckpt_file == "":
|
| 181 |
+
repo_name= "F5-TTS"
|
| 182 |
+
exp_name = "F5TTS_Base"
|
| 183 |
+
ckpt_step= 1200000
|
| 184 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 185 |
+
|
| 186 |
+
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file)
|
| 187 |
+
|
| 188 |
+
elif model == "E2-TTS":
|
| 189 |
+
if ckpt_file == "":
|
| 190 |
+
repo_name= "E2-TTS"
|
| 191 |
+
exp_name = "E2TTS_Base"
|
| 192 |
+
ckpt_step= 1200000
|
| 193 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 194 |
+
|
| 195 |
+
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file)
|
| 196 |
+
|
| 197 |
+
asr_pipe = pipeline(
|
| 198 |
+
"automatic-speech-recognition",
|
| 199 |
+
model="openai/whisper-large-v3-turbo",
|
| 200 |
+
torch_dtype=torch.float16,
|
| 201 |
+
device=device,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
def chunk_text(text, max_chars=135):
|
| 205 |
"""
|
| 206 |
Splits the input text into chunks, each with a maximum number of characters.
|
|
|
|
| 232 |
#if not Path(ckpt_path).exists():
|
| 233 |
#ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 234 |
|
| 235 |
+
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
audio, sr = ref_audio
|
| 237 |
if audio.shape[0] > 1:
|
| 238 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
|
| 349 |
|
| 350 |
if not ref_text.strip():
|
| 351 |
print("No reference text provided, transcribing reference audio...")
|
| 352 |
+
ref_text = asr_pipe(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
ref_audio,
|
| 354 |
chunk_length_s=30,
|
| 355 |
batch_size=128,
|
|
|
|
| 361 |
print("Using custom reference text...")
|
| 362 |
return ref_audio, ref_text
|
| 363 |
|
| 364 |
+
def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
|
| 365 |
# Add the functionality to ensure it ends with ". "
|
| 366 |
if not ref_text.endswith(". ") and not ref_text.endswith("γ"):
|
| 367 |
if ref_text.endswith("."):
|
|
|
|
| 377 |
print(f'gen_text {i}', gen_text)
|
| 378 |
|
| 379 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
| 380 |
+
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
|
| 381 |
|
| 382 |
|
| 383 |
+
def process(ref_audio, ref_text, text_gen, model, remove_silence):
|
| 384 |
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
| 385 |
if "voices" not in config:
|
| 386 |
voices = {"main": main_voice}
|
|
|
|
| 408 |
ref_audio = voices[voice]['ref_audio']
|
| 409 |
ref_text = voices[voice]['ref_text']
|
| 410 |
print(f"Voice: {voice}")
|
| 411 |
+
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence)
|
| 412 |
generated_audio_segments.append(audio)
|
| 413 |
|
| 414 |
if generated_audio_segments:
|
|
|
|
| 427 |
print(f.name)
|
| 428 |
|
| 429 |
|
| 430 |
+
process(ref_audio, ref_text, gen_text, model, remove_silence)
|