Spaces:
Configuration error
Configuration error
minor fix
Browse files- README.md +1 -1
- gradio_app.py +9 -15
- requirements_gradio.txt +3 -2
- test_infer_single_edit.py +1 -1
README.md
CHANGED
@@ -92,7 +92,7 @@ First, make sure you have the dependencies installed (`pip install -r requiremen
|
|
92 |
pip install -r requirements_gradio.txt
|
93 |
```
|
94 |
|
95 |
-
After installing the dependencies, launch the app:
|
96 |
|
97 |
```bash
|
98 |
python gradio_app.py
|
|
|
92 |
pip install -r requirements_gradio.txt
|
93 |
```
|
94 |
|
95 |
+
After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`):
|
96 |
|
97 |
```bash
|
98 |
python gradio_app.py
|
gradio_app.py
CHANGED
@@ -6,12 +6,12 @@ import gradio as gr
|
|
6 |
import numpy as np
|
7 |
import tempfile
|
8 |
from einops import rearrange
|
9 |
-
from ema_pytorch import EMA
|
10 |
from vocos import Vocos
|
11 |
from pydub import AudioSegment
|
12 |
from model import CFM, UNetT, DiT, MMDiT
|
13 |
from cached_path import cached_path
|
14 |
from model.utils import (
|
|
|
15 |
get_tokenizer,
|
16 |
convert_char_to_pinyin,
|
17 |
save_spectrogram,
|
@@ -51,10 +51,8 @@ fix_duration = None
|
|
51 |
|
52 |
|
53 |
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
54 |
-
|
55 |
-
|
56 |
-
map_location=device,
|
57 |
-
)
|
58 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
59 |
model = CFM(
|
60 |
transformer=model_cls(
|
@@ -71,11 +69,9 @@ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
|
71 |
vocab_char_map=vocab_char_map,
|
72 |
).to(device)
|
73 |
|
74 |
-
|
75 |
-
ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
76 |
-
ema_model.copy_params_from_ema_to_model()
|
77 |
|
78 |
-
return
|
79 |
|
80 |
|
81 |
# load models
|
@@ -84,10 +80,10 @@ F5TTS_model_cfg = dict(
|
|
84 |
)
|
85 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
86 |
|
87 |
-
F5TTS_ema_model
|
88 |
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
|
89 |
)
|
90 |
-
E2TTS_ema_model
|
91 |
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
92 |
)
|
93 |
|
@@ -107,10 +103,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
|
|
107 |
ref_audio = f.name
|
108 |
if exp_name == "F5-TTS":
|
109 |
ema_model = F5TTS_ema_model
|
110 |
-
base_model = F5TTS_base_model
|
111 |
elif exp_name == "E2-TTS":
|
112 |
ema_model = E2TTS_ema_model
|
113 |
-
base_model = E2TTS_base_model
|
114 |
|
115 |
if not ref_text.strip():
|
116 |
gr.Info("No reference text provided, transcribing reference audio...")
|
@@ -151,7 +145,7 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
|
|
151 |
# inference
|
152 |
gr.Info(f"Generating audio using {exp_name}")
|
153 |
with torch.inference_mode():
|
154 |
-
generated, _ =
|
155 |
cond=audio,
|
156 |
text=final_text_list,
|
157 |
duration=duration,
|
@@ -243,7 +237,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
|
243 |
|
244 |
|
245 |
@click.command()
|
246 |
-
@click.option("--port", "-p", default=None, help="Port to run the app on")
|
247 |
@click.option("--host", "-H", default=None, help="Host to run the app on")
|
248 |
@click.option(
|
249 |
"--share",
|
|
|
6 |
import numpy as np
|
7 |
import tempfile
|
8 |
from einops import rearrange
|
|
|
9 |
from vocos import Vocos
|
10 |
from pydub import AudioSegment
|
11 |
from model import CFM, UNetT, DiT, MMDiT
|
12 |
from cached_path import cached_path
|
13 |
from model.utils import (
|
14 |
+
load_checkpoint,
|
15 |
get_tokenizer,
|
16 |
convert_char_to_pinyin,
|
17 |
save_spectrogram,
|
|
|
51 |
|
52 |
|
53 |
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
54 |
+
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
|
55 |
+
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
|
|
|
|
56 |
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
57 |
model = CFM(
|
58 |
transformer=model_cls(
|
|
|
69 |
vocab_char_map=vocab_char_map,
|
70 |
).to(device)
|
71 |
|
72 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema = True)
|
|
|
|
|
73 |
|
74 |
+
return model
|
75 |
|
76 |
|
77 |
# load models
|
|
|
80 |
)
|
81 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
82 |
|
83 |
+
F5TTS_ema_model = load_model(
|
84 |
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
|
85 |
)
|
86 |
+
E2TTS_ema_model = load_model(
|
87 |
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
88 |
)
|
89 |
|
|
|
103 |
ref_audio = f.name
|
104 |
if exp_name == "F5-TTS":
|
105 |
ema_model = F5TTS_ema_model
|
|
|
106 |
elif exp_name == "E2-TTS":
|
107 |
ema_model = E2TTS_ema_model
|
|
|
108 |
|
109 |
if not ref_text.strip():
|
110 |
gr.Info("No reference text provided, transcribing reference audio...")
|
|
|
145 |
# inference
|
146 |
gr.Info(f"Generating audio using {exp_name}")
|
147 |
with torch.inference_mode():
|
148 |
+
generated, _ = ema_model.sample(
|
149 |
cond=audio,
|
150 |
text=final_text_list,
|
151 |
duration=duration,
|
|
|
237 |
|
238 |
|
239 |
@click.command()
|
240 |
+
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
241 |
@click.option("--host", "-H", default=None, help="Host to run the app on")
|
242 |
@click.option(
|
243 |
"--share",
|
requirements_gradio.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
cached_path
|
2 |
-
|
3 |
-
|
|
|
|
1 |
cached_path
|
2 |
+
click
|
3 |
+
gradio
|
4 |
+
pydub
|
test_infer_single_edit.py
CHANGED
@@ -14,7 +14,7 @@ from model.utils import (
|
|
14 |
save_spectrogram,
|
15 |
)
|
16 |
|
17 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
|
19 |
|
20 |
# --------------------- Dataset Settings -------------------- #
|
|
|
14 |
save_spectrogram,
|
15 |
)
|
16 |
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
18 |
|
19 |
|
20 |
# --------------------- Dataset Settings -------------------- #
|