thunnai commited on
Commit
d6d4199
·
1 Parent(s): a025623

download model from HF

Browse files
Files changed (1) hide show
  1. webui.py +10 -4
webui.py CHANGED
@@ -22,10 +22,16 @@ import gradio as gr
22
  from datetime import datetime
23
  from cli.SparkTTS import SparkTTS
24
  from sparktts.utils.token_parser import LEVELS_MAP_UI
 
25
 
26
 
27
- def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
28
  """Load the model once at the beginning."""
 
 
 
 
 
29
  logging.info(f"Loading model from: {model_dir}")
30
  device = torch.device(f"cuda:{device}")
31
  model = SparkTTS(model_dir, device)
@@ -213,7 +219,7 @@ def parse_arguments():
213
  parser.add_argument(
214
  "--model_dir",
215
  type=str,
216
- default="pretrained_models/Spark-TTS-0.5B",
217
  help="Path to the model directory."
218
  )
219
  parser.add_argument(
@@ -225,13 +231,13 @@ def parse_arguments():
225
  parser.add_argument(
226
  "--server_name",
227
  type=str,
228
- default="0.0.0.0",
229
  help="Server host/IP for Gradio app."
230
  )
231
  parser.add_argument(
232
  "--server_port",
233
  type=int,
234
- default=7860,
235
  help="Server port for Gradio app."
236
  )
237
  return parser.parse_args()
 
22
  from datetime import datetime
23
  from cli.SparkTTS import SparkTTS
24
  from sparktts.utils.token_parser import LEVELS_MAP_UI
25
+ from huggingface_hub import snapshot_download
26
 
27
 
28
+ def initialize_model(model_dir=None, device=0):
29
  """Load the model once at the beginning."""
30
+
31
+ if model_dir is None:
32
+ logging.info(f"Downloading model to: {model_dir}")
33
+ model_dir = snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
34
+
35
  logging.info(f"Loading model from: {model_dir}")
36
  device = torch.device(f"cuda:{device}")
37
  model = SparkTTS(model_dir, device)
 
219
  parser.add_argument(
220
  "--model_dir",
221
  type=str,
222
+ default=None,
223
  help="Path to the model directory."
224
  )
225
  parser.add_argument(
 
231
  parser.add_argument(
232
  "--server_name",
233
  type=str,
234
+ default=None,
235
  help="Server host/IP for Gradio app."
236
  )
237
  parser.add_argument(
238
  "--server_port",
239
  type=int,
240
+ default=None,
241
  help="Server port for Gradio app."
242
  )
243
  return parser.parse_args()