thunnai commited on
Commit
66f7114
·
1 Parent(s): d6d4199

default to cpu for testing

Browse files
Files changed (1) hide show
  1. webui.py +5 -5
webui.py CHANGED
@@ -25,7 +25,7 @@ from sparktts.utils.token_parser import LEVELS_MAP_UI
25
  from huggingface_hub import snapshot_download
26
 
27
 
28
- def initialize_model(model_dir=None, device=0):
29
  """Load the model once at the beginning."""
30
 
31
  if model_dir is None:
@@ -33,7 +33,7 @@ def initialize_model(model_dir=None, device=0):
33
  model_dir = snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
34
 
35
  logging.info(f"Loading model from: {model_dir}")
36
- device = torch.device(f"cuda:{device}")
37
  model = SparkTTS(model_dir, device)
38
  return model
39
 
@@ -224,9 +224,9 @@ def parse_arguments():
224
  )
225
  parser.add_argument(
226
  "--device",
227
- type=int,
228
- default=0,
229
- help="ID of the GPU device to use (e.g., 0 for cuda:0)."
230
  )
231
  parser.add_argument(
232
  "--server_name",
 
25
  from huggingface_hub import snapshot_download
26
 
27
 
28
+ def initialize_model(model_dir=None, device="cpu"):
29
  """Load the model once at the beginning."""
30
 
31
  if model_dir is None:
 
33
  model_dir = snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
34
 
35
  logging.info(f"Loading model from: {model_dir}")
36
+ device = torch.device(device)
37
  model = SparkTTS(model_dir, device)
38
  return model
39
 
 
224
  )
225
  parser.add_argument(
226
  "--device",
227
+ type=str,
228
+ default="cpu",
229
+ help="Device to use (e.g., 'cpu' or 'cuda:0')."
230
  )
231
  parser.add_argument(
232
  "--server_name",