Spaces:
Configuration error
Configuration error
redirect to split hf ckpt repos
Browse files- gradio_app.py +4 -4
- inference-cli.py +12 -12
gradio_app.py
CHANGED
@@ -62,8 +62,8 @@ speed = 1.0
|
|
62 |
fix_duration = None
|
63 |
|
64 |
|
65 |
-
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
66 |
-
ckpt_path = str(cached_path(f"hf://SWivid/
|
67 |
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
68 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
69 |
model = CFM(
|
@@ -93,10 +93,10 @@ F5TTS_model_cfg = dict(
|
|
93 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
94 |
|
95 |
F5TTS_ema_model = load_model(
|
96 |
-
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
|
97 |
)
|
98 |
E2TTS_ema_model = load_model(
|
99 |
-
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
100 |
)
|
101 |
|
102 |
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
|
|
62 |
fix_duration = None
|
63 |
|
64 |
|
65 |
+
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
66 |
+
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
67 |
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
68 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
69 |
model = CFM(
|
|
|
93 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
94 |
|
95 |
F5TTS_ema_model = load_model(
|
96 |
+
"F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
|
97 |
)
|
98 |
E2TTS_ema_model = load_model(
|
99 |
+
"E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
100 |
)
|
101 |
|
102 |
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
inference-cli.py
CHANGED
@@ -74,7 +74,7 @@ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
|
|
74 |
ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
|
75 |
gen_text = args.gen_text if args.gen_text else config["gen_text"]
|
76 |
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
77 |
-
|
78 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
79 |
wave_path = Path(output_dir)/"out.wav"
|
80 |
spectrogram_path = Path(output_dir)/"out.png"
|
@@ -112,8 +112,8 @@ speed = 1.0
|
|
112 |
# fix_duration = 27 # None or float (duration in seconds)
|
113 |
fix_duration = None
|
114 |
|
115 |
-
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
116 |
-
ckpt_path = str(cached_path(f"hf://SWivid/
|
117 |
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
118 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
119 |
model = CFM(
|
@@ -238,11 +238,11 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
|
238 |
|
239 |
return batches
|
240 |
|
241 |
-
def infer_batch(ref_audio, ref_text, gen_text_batches,
|
242 |
-
if
|
243 |
-
ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
244 |
-
elif
|
245 |
-
ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
|
246 |
|
247 |
audio, sr = ref_audio
|
248 |
if audio.shape[0] > 1:
|
@@ -320,7 +320,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
|
|
320 |
print(spectrogram_path)
|
321 |
|
322 |
|
323 |
-
def infer(ref_audio_orig, ref_text, gen_text,
|
324 |
if not custom_split_words.strip():
|
325 |
custom_words = [word.strip() for word in custom_split_words.split(',')]
|
326 |
global SPLIT_WORDS
|
@@ -372,8 +372,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
|
|
372 |
for i, gen_text in enumerate(gen_text_batches):
|
373 |
print(f'gen_text {i}', gen_text)
|
374 |
|
375 |
-
print(f"Generating audio using {
|
376 |
-
return infer_batch((audio, sr), ref_text, gen_text_batches,
|
377 |
|
378 |
|
379 |
-
infer(ref_audio, ref_text, gen_text,
|
|
|
74 |
ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
|
75 |
gen_text = args.gen_text if args.gen_text else config["gen_text"]
|
76 |
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
77 |
+
model = args.model if args.model else config["model"]
|
78 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
79 |
wave_path = Path(output_dir)/"out.wav"
|
80 |
spectrogram_path = Path(output_dir)/"out.png"
|
|
|
112 |
# fix_duration = 27 # None or float (duration in seconds)
|
113 |
fix_duration = None
|
114 |
|
115 |
+
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
116 |
+
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
117 |
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
118 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
119 |
model = CFM(
|
|
|
238 |
|
239 |
return batches
|
240 |
|
241 |
+
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
242 |
+
if model == "F5-TTS":
|
243 |
+
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
244 |
+
elif model == "E2-TTS":
|
245 |
+
ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
|
246 |
|
247 |
audio, sr = ref_audio
|
248 |
if audio.shape[0] > 1:
|
|
|
320 |
print(spectrogram_path)
|
321 |
|
322 |
|
323 |
+
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
|
324 |
if not custom_split_words.strip():
|
325 |
custom_words = [word.strip() for word in custom_split_words.split(',')]
|
326 |
global SPLIT_WORDS
|
|
|
372 |
for i, gen_text in enumerate(gen_text_batches):
|
373 |
print(f'gen_text {i}', gen_text)
|
374 |
|
375 |
+
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
376 |
+
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
377 |
|
378 |
|
379 |
+
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
|