spark-tts commited on
Commit
6f15685
·
1 Parent(s): 832ac1a

support voice creation

Browse files
Files changed (2) hide show
  1. cli/SparkTTS.py +121 -20
  2. cli/inference.py +38 -10
cli/SparkTTS.py CHANGED
@@ -15,12 +15,13 @@
15
 
16
  import re
17
  import torch
 
18
  from pathlib import Path
19
  from transformers import AutoTokenizer, AutoModelForCausalLM
20
 
21
  from sparktts.utils.file import load_config
22
  from sparktts.models.audio_tokenizer import BiCodecTokenizer
23
- from sparktts.utils.token_parser import TASK_TOKEN_MAP
24
 
25
 
26
  class SparkTTS:
@@ -49,36 +50,36 @@ class SparkTTS:
49
  self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
50
  self.model.to(self.device)
51
 
52
- @torch.no_grad()
53
- def inference(
54
  self,
55
  text: str,
56
  prompt_speech_path: Path,
57
  prompt_text: str = None,
58
- temperature: float = 0.8,
59
- top_k: float = 50,
60
- top_p: float = 0.95,
61
- ) -> torch.Tensor:
62
  """
63
- Performs inference to generate speech from text, incorporating prompt audio and/or text.
64
 
65
  Args:
66
  text (str): The text input to be converted to speech.
67
  prompt_speech_path (Path): Path to the audio file used as a prompt.
68
  prompt_text (str, optional): Transcript of the prompt audio.
69
- temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
70
- top_k (float, optional): Top-k sampling parameter. Default is 50.
71
- top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
72
 
73
- Returns:
74
- torch.Tensor: Generated waveform as a tensor.
75
  """
76
- global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(prompt_speech_path)
77
- global_tokens = "".join([f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()])
 
 
 
 
 
78
 
79
  # Prepare the input tokens for the model
80
  if prompt_text is not None:
81
- semantic_tokens = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()])
 
 
82
  inputs = [
83
  TASK_TOKEN_MAP["tts"],
84
  "<|start_content|>",
@@ -103,7 +104,94 @@ class SparkTTS:
103
  ]
104
 
105
  inputs = "".join(inputs)
106
- model_inputs = self.tokenizer([inputs], return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # Generate speech using the model
109
  generated_ids = self.model.generate(
@@ -117,14 +205,27 @@ class SparkTTS:
117
 
118
  # Trim the output tokens to remove the input tokens
119
  generated_ids = [
120
- output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
 
121
  ]
122
 
123
  # Decode the generated tokens into text
124
  predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
125
 
126
  # Extract semantic token IDs from the generated text
127
- pred_semantic_ids = torch.tensor([int(token) for token in re.findall(r"\d+", predicts)]).long().unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # Convert semantic tokens back to waveform
130
  wav = self.audio_tokenizer.detokenize(
@@ -132,4 +233,4 @@ class SparkTTS:
132
  pred_semantic_ids.to(self.device),
133
  )
134
 
135
- return wav
 
15
 
16
  import re
17
  import torch
18
+ from typing import Tuple
19
  from pathlib import Path
20
  from transformers import AutoTokenizer, AutoModelForCausalLM
21
 
22
  from sparktts.utils.file import load_config
23
  from sparktts.models.audio_tokenizer import BiCodecTokenizer
24
+ from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
25
 
26
 
27
  class SparkTTS:
 
50
  self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
51
  self.model.to(self.device)
52
 
53
+ def process_prompt(
 
54
  self,
55
  text: str,
56
  prompt_speech_path: Path,
57
  prompt_text: str = None,
58
+ ) -> Tuple[str, torch.Tensor]:
 
 
 
59
  """
60
+ Process input for voice cloning.
61
 
62
  Args:
63
  text (str): The text input to be converted to speech.
64
  prompt_speech_path (Path): Path to the audio file used as a prompt.
65
  prompt_text (str, optional): Transcript of the prompt audio.
 
 
 
66
 
67
+ Return:
68
+ Tuple[str, torch.Tensor]: Input prompt; global tokens
69
  """
70
+
71
+ global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
72
+ prompt_speech_path
73
+ )
74
+ global_tokens = "".join(
75
+ [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
76
+ )
77
 
78
  # Prepare the input tokens for the model
79
  if prompt_text is not None:
80
+ semantic_tokens = "".join(
81
+ [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
82
+ )
83
  inputs = [
84
  TASK_TOKEN_MAP["tts"],
85
  "<|start_content|>",
 
104
  ]
105
 
106
  inputs = "".join(inputs)
107
+
108
+ return inputs, global_token_ids
109
+
110
+ def process_prompt_control(
111
+ self,
112
+ gender: str,
113
+ pitch: str,
114
+ speed: str,
115
+ text: str,
116
+ ):
117
+ """
118
+ Process input for voice creation.
119
+
120
+ Args:
121
+ gender (str): female | male.
122
+ pitch (str): very_low | low | moderate | high | very_high
123
+ speed (str): very_low | low | moderate | high | very_high
124
+ text (str): The text input to be converted to speech.
125
+
126
+ Return:
127
+ str: Input prompt
128
+ """
129
+ assert gender in GENDER_MAP.keys()
130
+ assert pitch in LEVELS_MAP.keys()
131
+ assert speed in LEVELS_MAP.keys()
132
+
133
+ gender_id = GENDER_MAP[gender]
134
+ pitch_level_id = LEVELS_MAP[pitch]
135
+ speed_level_id = LEVELS_MAP[speed]
136
+
137
+ pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
138
+ speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
139
+ gender_tokens = f"<|gender_{gender_id}|>"
140
+
141
+ attribte_tokens = "".join(
142
+ [gender_tokens, pitch_label_tokens, speed_label_tokens]
143
+ )
144
+
145
+ control_tts_inputs = [
146
+ TASK_TOKEN_MAP["controllable_tts"],
147
+ "<|start_content|>",
148
+ text,
149
+ "<|end_content|>",
150
+ "<|start_style_label|>",
151
+ attribte_tokens,
152
+ "<|end_style_label|>",
153
+ ]
154
+
155
+ return "".join(control_tts_inputs)
156
+
157
+ @torch.no_grad()
158
+ def inference(
159
+ self,
160
+ text: str,
161
+ prompt_speech_path: Path = None,
162
+ prompt_text: str = None,
163
+ gender: str = None,
164
+ pitch: str = None,
165
+ speed: str = None,
166
+ temperature: float = 0.8,
167
+ top_k: float = 50,
168
+ top_p: float = 0.95,
169
+ ) -> torch.Tensor:
170
+ """
171
+ Performs inference to generate speech from text, incorporating prompt audio and/or text.
172
+
173
+ Args:
174
+ text (str): The text input to be converted to speech.
175
+ prompt_speech_path (Path): Path to the audio file used as a prompt.
176
+ prompt_text (str, optional): Transcript of the prompt audio.
177
+ gender (str): female | male.
178
+ pitch (str): very_low | low | moderate | high | very_high
179
+ speed (str): very_low | low | moderate | high | very_high
180
+ temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
181
+ top_k (float, optional): Top-k sampling parameter. Default is 50.
182
+ top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
183
+
184
+ Returns:
185
+ torch.Tensor: Generated waveform as a tensor.
186
+ """
187
+ if gender is not None:
188
+ prompt = self.process_prompt_control(gender, pitch, speed, text)
189
+
190
+ else:
191
+ prompt, global_token_ids = self.process_prompt(
192
+ text, prompt_speech_path, prompt_text
193
+ )
194
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
195
 
196
  # Generate speech using the model
197
  generated_ids = self.model.generate(
 
205
 
206
  # Trim the output tokens to remove the input tokens
207
  generated_ids = [
208
+ output_ids[len(input_ids) :]
209
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
210
  ]
211
 
212
  # Decode the generated tokens into text
213
  predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
214
 
215
  # Extract semantic token IDs from the generated text
216
+ pred_semantic_ids = (
217
+ torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
218
+ .long()
219
+ .unsqueeze(0)
220
+ )
221
+
222
+ if gender is not None:
223
+ global_token_ids = (
224
+ torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
225
+ .long()
226
+ .unsqueeze(0)
227
+ .unsqueeze(0)
228
+ )
229
 
230
  # Convert semantic tokens back to waveform
231
  wav = self.audio_tokenizer.detokenize(
 
233
  pred_semantic_ids.to(self.device),
234
  )
235
 
236
+ return wav
cli/inference.py CHANGED
@@ -12,16 +12,35 @@ def parse_args():
12
  """Parse command-line arguments."""
13
  parser = argparse.ArgumentParser(description="Run TTS inference.")
14
 
15
- parser.add_argument("--model_dir", type=str, default="pretrained_models/Spark-TTS-0.5B",
16
- help="Path to the model directory")
17
- parser.add_argument("--save_dir", type=str, default="example/results",
18
- help="Directory to save generated audio files")
 
 
 
 
 
 
 
 
19
  parser.add_argument("--device", type=int, default=0, help="CUDA device number")
20
- parser.add_argument("--text", type=str, required=True, help="Text for TTS generation")
 
 
21
  parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
22
- parser.add_argument("--prompt_speech_path", type=str, required=True,
23
- help="Path to the prompt audio file")
24
-
 
 
 
 
 
 
 
 
 
25
  return parser.parse_args()
26
 
27
 
@@ -47,14 +66,23 @@ def run_tts(args):
47
 
48
  # Perform inference and save the output audio
49
  with torch.no_grad():
50
- wav = model.inference(args.text, args.prompt_speech_path, prompt_text=args.prompt_text)
 
 
 
 
 
 
 
51
  sf.write(save_path, wav, samplerate=16000)
52
 
53
  logging.info(f"Audio saved at: {save_path}")
54
 
55
 
56
  if __name__ == "__main__":
57
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
58
 
59
  args = parse_args()
60
  run_tts(args)
 
12
  """Parse command-line arguments."""
13
  parser = argparse.ArgumentParser(description="Run TTS inference.")
14
 
15
+ parser.add_argument(
16
+ "--model_dir",
17
+ type=str,
18
+ default="pretrained_models/Spark-TTS-0.5B",
19
+ help="Path to the model directory",
20
+ )
21
+ parser.add_argument(
22
+ "--save_dir",
23
+ type=str,
24
+ default="example/results",
25
+ help="Directory to save generated audio files",
26
+ )
27
  parser.add_argument("--device", type=int, default=0, help="CUDA device number")
28
+ parser.add_argument(
29
+ "--text", type=str, required=True, help="Text for TTS generation"
30
+ )
31
  parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
32
+ parser.add_argument(
33
+ "--prompt_speech_path",
34
+ type=str,
35
+ help="Path to the prompt audio file",
36
+ )
37
+ parser.add_argument("--gender", choices=["male", "pitch"])
38
+ parser.add_argument(
39
+ "--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
40
+ )
41
+ parser.add_argument(
42
+ "--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
43
+ )
44
  return parser.parse_args()
45
 
46
 
 
66
 
67
  # Perform inference and save the output audio
68
  with torch.no_grad():
69
+ wav = model.inference(
70
+ args.text,
71
+ args.prompt_speech_path,
72
+ prompt_text=args.prompt_text,
73
+ gender=args.gender,
74
+ pitch=args.pitch,
75
+ speed=args.speed,
76
+ )
77
  sf.write(save_path, wav, samplerate=16000)
78
 
79
  logging.info(f"Audio saved at: {save_path}")
80
 
81
 
82
  if __name__ == "__main__":
83
+ logging.basicConfig(
84
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
85
+ )
86
 
87
  args = parse_args()
88
  run_tts(args)