download model from HF
Browse files
webui.py
CHANGED
@@ -22,10 +22,16 @@ import gradio as gr
|
|
22 |
from datetime import datetime
|
23 |
from cli.SparkTTS import SparkTTS
|
24 |
from sparktts.utils.token_parser import LEVELS_MAP_UI
|
|
|
25 |
|
26 |
|
27 |
-
def initialize_model(model_dir=
|
28 |
"""Load the model once at the beginning."""
|
|
|
|
|
|
|
|
|
|
|
29 |
logging.info(f"Loading model from: {model_dir}")
|
30 |
device = torch.device(f"cuda:{device}")
|
31 |
model = SparkTTS(model_dir, device)
|
@@ -213,7 +219,7 @@ def parse_arguments():
|
|
213 |
parser.add_argument(
|
214 |
"--model_dir",
|
215 |
type=str,
|
216 |
-
default=
|
217 |
help="Path to the model directory."
|
218 |
)
|
219 |
parser.add_argument(
|
@@ -225,13 +231,13 @@ def parse_arguments():
|
|
225 |
parser.add_argument(
|
226 |
"--server_name",
|
227 |
type=str,
|
228 |
-
default=
|
229 |
help="Server host/IP for Gradio app."
|
230 |
)
|
231 |
parser.add_argument(
|
232 |
"--server_port",
|
233 |
type=int,
|
234 |
-
default=
|
235 |
help="Server port for Gradio app."
|
236 |
)
|
237 |
return parser.parse_args()
|
|
|
22 |
from datetime import datetime
|
23 |
from cli.SparkTTS import SparkTTS
|
24 |
from sparktts.utils.token_parser import LEVELS_MAP_UI
|
25 |
+
from huggingface_hub import snapshot_download
|
26 |
|
27 |
|
28 |
+
def initialize_model(model_dir=None, device=0):
|
29 |
"""Load the model once at the beginning."""
|
30 |
+
|
31 |
+
if model_dir is None:
|
32 |
+
logging.info(f"Downloading model to: {model_dir}")
|
33 |
+
model_dir = snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
|
34 |
+
|
35 |
logging.info(f"Loading model from: {model_dir}")
|
36 |
device = torch.device(f"cuda:{device}")
|
37 |
model = SparkTTS(model_dir, device)
|
|
|
219 |
parser.add_argument(
|
220 |
"--model_dir",
|
221 |
type=str,
|
222 |
+
default=None,
|
223 |
help="Path to the model directory."
|
224 |
)
|
225 |
parser.add_argument(
|
|
|
231 |
parser.add_argument(
|
232 |
"--server_name",
|
233 |
type=str,
|
234 |
+
default=None,
|
235 |
help="Server host/IP for Gradio app."
|
236 |
)
|
237 |
parser.add_argument(
|
238 |
"--server_port",
|
239 |
type=int,
|
240 |
+
default=None,
|
241 |
help="Server port for Gradio app."
|
242 |
)
|
243 |
return parser.parse_args()
|