Spaces:
Configuration error
Configuration error
run pre-commit
Browse files- src/f5_tts/api.py +9 -3
- src/f5_tts/infer/utils_infer.py +6 -2
src/f5_tts/api.py
CHANGED
@@ -56,14 +56,20 @@ class F5TTS:
|
|
56 |
if model_type == "F5-TTS":
|
57 |
if not ckpt_file:
|
58 |
if mel_spec_type == "vocos":
|
59 |
-
ckpt_file = str(
|
|
|
|
|
60 |
elif mel_spec_type == "bigvgan":
|
61 |
-
ckpt_file = str(
|
|
|
|
|
62 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
63 |
model_cls = DiT
|
64 |
elif model_type == "E2-TTS":
|
65 |
if not ckpt_file:
|
66 |
-
ckpt_file = str(
|
|
|
|
|
67 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
68 |
model_cls = UNetT
|
69 |
else:
|
|
|
56 |
if model_type == "F5-TTS":
|
57 |
if not ckpt_file:
|
58 |
if mel_spec_type == "vocos":
|
59 |
+
ckpt_file = str(
|
60 |
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)
|
61 |
+
)
|
62 |
elif mel_spec_type == "bigvgan":
|
63 |
+
ckpt_file = str(
|
64 |
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)
|
65 |
+
)
|
66 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
67 |
model_cls = DiT
|
68 |
elif model_type == "E2-TTS":
|
69 |
if not ckpt_file:
|
70 |
+
ckpt_file = str(
|
71 |
+
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)
|
72 |
+
)
|
73 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
74 |
model_cls = UNetT
|
75 |
else:
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -96,8 +96,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
96 |
print(f"Load vocos from local path {local_path}")
|
97 |
repo_id = "charactr/vocos-mel-24khz"
|
98 |
revision = None
|
99 |
-
config_path = hf_hub_download(
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
vocoder = Vocos.from_hparams(config_path=config_path)
|
102 |
state_dict = torch.load(model_path, map_location="cpu")
|
103 |
vocoder.load_state_dict(state_dict)
|
|
|
96 |
print(f"Load vocos from local path {local_path}")
|
97 |
repo_id = "charactr/vocos-mel-24khz"
|
98 |
revision = None
|
99 |
+
config_path = hf_hub_download(
|
100 |
+
repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision
|
101 |
+
)
|
102 |
+
model_path = hf_hub_download(
|
103 |
+
repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision
|
104 |
+
)
|
105 |
vocoder = Vocos.from_hparams(config_path=config_path)
|
106 |
state_dict = torch.load(model_path, map_location="cpu")
|
107 |
vocoder.load_state_dict(state_dict)
|