spark-tts
commited on
Commit
·
6f15685
1
Parent(s):
832ac1a
support voice creation
Browse files- cli/SparkTTS.py +121 -20
- cli/inference.py +38 -10
cli/SparkTTS.py
CHANGED
@@ -15,12 +15,13 @@
|
|
15 |
|
16 |
import re
|
17 |
import torch
|
|
|
18 |
from pathlib import Path
|
19 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
20 |
|
21 |
from sparktts.utils.file import load_config
|
22 |
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
23 |
-
from sparktts.utils.token_parser import TASK_TOKEN_MAP
|
24 |
|
25 |
|
26 |
class SparkTTS:
|
@@ -49,36 +50,36 @@ class SparkTTS:
|
|
49 |
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
50 |
self.model.to(self.device)
|
51 |
|
52 |
-
|
53 |
-
def inference(
|
54 |
self,
|
55 |
text: str,
|
56 |
prompt_speech_path: Path,
|
57 |
prompt_text: str = None,
|
58 |
-
|
59 |
-
top_k: float = 50,
|
60 |
-
top_p: float = 0.95,
|
61 |
-
) -> torch.Tensor:
|
62 |
"""
|
63 |
-
|
64 |
|
65 |
Args:
|
66 |
text (str): The text input to be converted to speech.
|
67 |
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
68 |
prompt_text (str, optional): Transcript of the prompt audio.
|
69 |
-
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
70 |
-
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
71 |
-
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
72 |
|
73 |
-
|
74 |
-
torch.Tensor:
|
75 |
"""
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# Prepare the input tokens for the model
|
80 |
if prompt_text is not None:
|
81 |
-
semantic_tokens = "".join(
|
|
|
|
|
82 |
inputs = [
|
83 |
TASK_TOKEN_MAP["tts"],
|
84 |
"<|start_content|>",
|
@@ -103,7 +104,94 @@ class SparkTTS:
|
|
103 |
]
|
104 |
|
105 |
inputs = "".join(inputs)
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# Generate speech using the model
|
109 |
generated_ids = self.model.generate(
|
@@ -117,14 +205,27 @@ class SparkTTS:
|
|
117 |
|
118 |
# Trim the output tokens to remove the input tokens
|
119 |
generated_ids = [
|
120 |
-
output_ids[len(input_ids) :]
|
|
|
121 |
]
|
122 |
|
123 |
# Decode the generated tokens into text
|
124 |
predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
125 |
|
126 |
# Extract semantic token IDs from the generated text
|
127 |
-
pred_semantic_ids =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Convert semantic tokens back to waveform
|
130 |
wav = self.audio_tokenizer.detokenize(
|
@@ -132,4 +233,4 @@ class SparkTTS:
|
|
132 |
pred_semantic_ids.to(self.device),
|
133 |
)
|
134 |
|
135 |
-
return wav
|
|
|
15 |
|
16 |
import re
|
17 |
import torch
|
18 |
+
from typing import Tuple
|
19 |
from pathlib import Path
|
20 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
21 |
|
22 |
from sparktts.utils.file import load_config
|
23 |
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
24 |
+
from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
|
25 |
|
26 |
|
27 |
class SparkTTS:
|
|
|
50 |
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
51 |
self.model.to(self.device)
|
52 |
|
53 |
+
def process_prompt(
|
|
|
54 |
self,
|
55 |
text: str,
|
56 |
prompt_speech_path: Path,
|
57 |
prompt_text: str = None,
|
58 |
+
) -> Tuple[str, torch.Tensor]:
|
|
|
|
|
|
|
59 |
"""
|
60 |
+
Process input for voice cloning.
|
61 |
|
62 |
Args:
|
63 |
text (str): The text input to be converted to speech.
|
64 |
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
65 |
prompt_text (str, optional): Transcript of the prompt audio.
|
|
|
|
|
|
|
66 |
|
67 |
+
Return:
|
68 |
+
Tuple[str, torch.Tensor]: Input prompt; global tokens
|
69 |
"""
|
70 |
+
|
71 |
+
global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
|
72 |
+
prompt_speech_path
|
73 |
+
)
|
74 |
+
global_tokens = "".join(
|
75 |
+
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
76 |
+
)
|
77 |
|
78 |
# Prepare the input tokens for the model
|
79 |
if prompt_text is not None:
|
80 |
+
semantic_tokens = "".join(
|
81 |
+
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
82 |
+
)
|
83 |
inputs = [
|
84 |
TASK_TOKEN_MAP["tts"],
|
85 |
"<|start_content|>",
|
|
|
104 |
]
|
105 |
|
106 |
inputs = "".join(inputs)
|
107 |
+
|
108 |
+
return inputs, global_token_ids
|
109 |
+
|
110 |
+
def process_prompt_control(
|
111 |
+
self,
|
112 |
+
gender: str,
|
113 |
+
pitch: str,
|
114 |
+
speed: str,
|
115 |
+
text: str,
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
Process input for voice creation.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
gender (str): female | male.
|
122 |
+
pitch (str): very_low | low | moderate | high | very_high
|
123 |
+
speed (str): very_low | low | moderate | high | very_high
|
124 |
+
text (str): The text input to be converted to speech.
|
125 |
+
|
126 |
+
Return:
|
127 |
+
str: Input prompt
|
128 |
+
"""
|
129 |
+
assert gender in GENDER_MAP.keys()
|
130 |
+
assert pitch in LEVELS_MAP.keys()
|
131 |
+
assert speed in LEVELS_MAP.keys()
|
132 |
+
|
133 |
+
gender_id = GENDER_MAP[gender]
|
134 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
135 |
+
speed_level_id = LEVELS_MAP[speed]
|
136 |
+
|
137 |
+
pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
|
138 |
+
speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
|
139 |
+
gender_tokens = f"<|gender_{gender_id}|>"
|
140 |
+
|
141 |
+
attribte_tokens = "".join(
|
142 |
+
[gender_tokens, pitch_label_tokens, speed_label_tokens]
|
143 |
+
)
|
144 |
+
|
145 |
+
control_tts_inputs = [
|
146 |
+
TASK_TOKEN_MAP["controllable_tts"],
|
147 |
+
"<|start_content|>",
|
148 |
+
text,
|
149 |
+
"<|end_content|>",
|
150 |
+
"<|start_style_label|>",
|
151 |
+
attribte_tokens,
|
152 |
+
"<|end_style_label|>",
|
153 |
+
]
|
154 |
+
|
155 |
+
return "".join(control_tts_inputs)
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def inference(
|
159 |
+
self,
|
160 |
+
text: str,
|
161 |
+
prompt_speech_path: Path = None,
|
162 |
+
prompt_text: str = None,
|
163 |
+
gender: str = None,
|
164 |
+
pitch: str = None,
|
165 |
+
speed: str = None,
|
166 |
+
temperature: float = 0.8,
|
167 |
+
top_k: float = 50,
|
168 |
+
top_p: float = 0.95,
|
169 |
+
) -> torch.Tensor:
|
170 |
+
"""
|
171 |
+
Performs inference to generate speech from text, incorporating prompt audio and/or text.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
text (str): The text input to be converted to speech.
|
175 |
+
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
176 |
+
prompt_text (str, optional): Transcript of the prompt audio.
|
177 |
+
gender (str): female | male.
|
178 |
+
pitch (str): very_low | low | moderate | high | very_high
|
179 |
+
speed (str): very_low | low | moderate | high | very_high
|
180 |
+
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
181 |
+
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
182 |
+
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
torch.Tensor: Generated waveform as a tensor.
|
186 |
+
"""
|
187 |
+
if gender is not None:
|
188 |
+
prompt = self.process_prompt_control(gender, pitch, speed, text)
|
189 |
+
|
190 |
+
else:
|
191 |
+
prompt, global_token_ids = self.process_prompt(
|
192 |
+
text, prompt_speech_path, prompt_text
|
193 |
+
)
|
194 |
+
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
195 |
|
196 |
# Generate speech using the model
|
197 |
generated_ids = self.model.generate(
|
|
|
205 |
|
206 |
# Trim the output tokens to remove the input tokens
|
207 |
generated_ids = [
|
208 |
+
output_ids[len(input_ids) :]
|
209 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
210 |
]
|
211 |
|
212 |
# Decode the generated tokens into text
|
213 |
predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
214 |
|
215 |
# Extract semantic token IDs from the generated text
|
216 |
+
pred_semantic_ids = (
|
217 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
|
218 |
+
.long()
|
219 |
+
.unsqueeze(0)
|
220 |
+
)
|
221 |
+
|
222 |
+
if gender is not None:
|
223 |
+
global_token_ids = (
|
224 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
|
225 |
+
.long()
|
226 |
+
.unsqueeze(0)
|
227 |
+
.unsqueeze(0)
|
228 |
+
)
|
229 |
|
230 |
# Convert semantic tokens back to waveform
|
231 |
wav = self.audio_tokenizer.detokenize(
|
|
|
233 |
pred_semantic_ids.to(self.device),
|
234 |
)
|
235 |
|
236 |
+
return wav
|
cli/inference.py
CHANGED
@@ -12,16 +12,35 @@ def parse_args():
|
|
12 |
"""Parse command-line arguments."""
|
13 |
parser = argparse.ArgumentParser(description="Run TTS inference.")
|
14 |
|
15 |
-
parser.add_argument(
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
|
20 |
-
parser.add_argument(
|
|
|
|
|
21 |
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
|
22 |
-
parser.add_argument(
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
return parser.parse_args()
|
26 |
|
27 |
|
@@ -47,14 +66,23 @@ def run_tts(args):
|
|
47 |
|
48 |
# Perform inference and save the output audio
|
49 |
with torch.no_grad():
|
50 |
-
wav = model.inference(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
sf.write(save_path, wav, samplerate=16000)
|
52 |
|
53 |
logging.info(f"Audio saved at: {save_path}")
|
54 |
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
-
logging.basicConfig(
|
|
|
|
|
58 |
|
59 |
args = parse_args()
|
60 |
run_tts(args)
|
|
|
12 |
"""Parse command-line arguments."""
|
13 |
parser = argparse.ArgumentParser(description="Run TTS inference.")
|
14 |
|
15 |
+
parser.add_argument(
|
16 |
+
"--model_dir",
|
17 |
+
type=str,
|
18 |
+
default="pretrained_models/Spark-TTS-0.5B",
|
19 |
+
help="Path to the model directory",
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--save_dir",
|
23 |
+
type=str,
|
24 |
+
default="example/results",
|
25 |
+
help="Directory to save generated audio files",
|
26 |
+
)
|
27 |
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
|
28 |
+
parser.add_argument(
|
29 |
+
"--text", type=str, required=True, help="Text for TTS generation"
|
30 |
+
)
|
31 |
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
|
32 |
+
parser.add_argument(
|
33 |
+
"--prompt_speech_path",
|
34 |
+
type=str,
|
35 |
+
help="Path to the prompt audio file",
|
36 |
+
)
|
37 |
+
parser.add_argument("--gender", choices=["male", "pitch"])
|
38 |
+
parser.add_argument(
|
39 |
+
"--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
|
43 |
+
)
|
44 |
return parser.parse_args()
|
45 |
|
46 |
|
|
|
66 |
|
67 |
# Perform inference and save the output audio
|
68 |
with torch.no_grad():
|
69 |
+
wav = model.inference(
|
70 |
+
args.text,
|
71 |
+
args.prompt_speech_path,
|
72 |
+
prompt_text=args.prompt_text,
|
73 |
+
gender=args.gender,
|
74 |
+
pitch=args.pitch,
|
75 |
+
speed=args.speed,
|
76 |
+
)
|
77 |
sf.write(save_path, wav, samplerate=16000)
|
78 |
|
79 |
logging.info(f"Audio saved at: {save_path}")
|
80 |
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
+
logging.basicConfig(
|
84 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
85 |
+
)
|
86 |
|
87 |
args = parse_args()
|
88 |
run_tts(args)
|