xcv58 commited on
Commit
105c10a
·
unverified ·
1 Parent(s): 7d742ba

Update webui.py as well

Browse files
Files changed (1) hide show
  1. webui.py +17 -2
webui.py CHANGED
@@ -19,6 +19,7 @@ import soundfile as sf
19
  import logging
20
  import argparse
21
  import gradio as gr
 
22
  from datetime import datetime
23
  from cli.SparkTTS import SparkTTS
24
  from sparktts.utils.token_parser import LEVELS_MAP_UI
@@ -27,7 +28,21 @@ from sparktts.utils.token_parser import LEVELS_MAP_UI
27
  def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
28
  """Load the model once at the beginning."""
29
  logging.info(f"Loading model from: {model_dir}")
30
- device = torch.device(f"cuda:{device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  model = SparkTTS(model_dir, device)
32
  return model
33
 
@@ -76,7 +91,7 @@ def run_tts(
76
 
77
 
78
  def build_ui(model_dir, device=0):
79
-
80
  # Initialize model
81
  model = initialize_model(model_dir, device=device)
82
 
 
19
  import logging
20
  import argparse
21
  import gradio as gr
22
+ import platform
23
  from datetime import datetime
24
  from cli.SparkTTS import SparkTTS
25
  from sparktts.utils.token_parser import LEVELS_MAP_UI
 
28
  def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
29
  """Load the model once at the beginning."""
30
  logging.info(f"Loading model from: {model_dir}")
31
+
32
+ # Determine appropriate device based on platform and availability
33
+ if platform.system() == "Darwin":
34
+ # macOS with MPS support (Apple Silicon)
35
+ device = torch.device(f"mps:{device}")
36
+ logging.info(f"Using MPS device: {device}")
37
+ elif torch.cuda.is_available():
38
+ # System with CUDA support
39
+ device = torch.device(f"cuda:{device}")
40
+ logging.info(f"Using CUDA device: {device}")
41
+ else:
42
+ # Fall back to CPU
43
+ device = torch.device("cpu")
44
+ logging.info("GPU acceleration not available, using CPU")
45
+
46
  model = SparkTTS(model_dir, device)
47
  return model
48
 
 
91
 
92
 
93
  def build_ui(model_dir, device=0):
94
+
95
  # Initialize model
96
  model = initialize_model(model_dir, device=device)
97