thunnai commited on
Commit
9f28641
·
1 Parent(s): 66f7114

Test zero gpu

Browse files
Files changed (1) hide show
  1. webui.py +31 -7
webui.py CHANGED
@@ -23,7 +23,7 @@ from datetime import datetime
23
  from cli.SparkTTS import SparkTTS
24
  from sparktts.utils.token_parser import LEVELS_MAP_UI
25
  from huggingface_hub import snapshot_download
26
-
27
 
28
  def initialize_model(model_dir=None, device="cpu"):
29
  """Load the model once at the beginning."""
@@ -37,6 +37,32 @@ def initialize_model(model_dir=None, device="cpu"):
37
  model = SparkTTS(model_dir, device)
38
  return model
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def run_tts(
42
  text,
@@ -64,17 +90,15 @@ def run_tts(
64
  logging.info("Starting inference...")
65
 
66
  # Perform inference and save the output audio
67
- with torch.no_grad():
68
- wav = model.inference(
69
- text,
70
  prompt_speech,
71
  prompt_text,
72
  gender,
73
  pitch,
74
- speed,
75
- )
76
 
77
- sf.write(save_path, wav, samplerate=16000)
78
 
79
  logging.info(f"Audio saved at: {save_path}")
80
 
 
23
  from cli.SparkTTS import SparkTTS
24
  from sparktts.utils.token_parser import LEVELS_MAP_UI
25
  from huggingface_hub import snapshot_download
26
+ import spaces
27
 
28
  def initialize_model(model_dir=None, device="cpu"):
29
  """Load the model once at the beginning."""
 
37
  model = SparkTTS(model_dir, device)
38
  return model
39
 
40
+ @spaces.gpu
41
+ def generate(model,
42
+ text,
43
+ prompt_speech,
44
+ prompt_text,
45
+ gender,
46
+ pitch,
47
+ speed,
48
+ ):
49
+ """Generate audio from text."""
50
+ # if gpu available, move model to gpu
51
+ if torch.cuda.is_available():
52
+ model = model.to("cuda")
53
+
54
+ with torch.no_grad():
55
+ wav = model.inference(
56
+ text,
57
+ prompt_speech,
58
+ prompt_text,
59
+ gender,
60
+ pitch,
61
+ speed,
62
+ )
63
+
64
+ return wav
65
+
66
 
67
  def run_tts(
68
  text,
 
90
  logging.info("Starting inference...")
91
 
92
  # Perform inference and save the output audio
93
+ wav = generate(model, text,
 
 
94
  prompt_speech,
95
  prompt_text,
96
  gender,
97
  pitch,
98
+ speed,)
99
+
100
 
101
+ sf.write(save_path, wav, samplerate=16000)
102
 
103
  logging.info(f"Audio saved at: {save_path}")
104