|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import math |
|
import os |
|
import re |
|
from typing import Dict, List, Tuple, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
from torch.utils.dlpack import from_dlpack, to_dlpack |
|
import triton_python_backend_utils as pb_utils |
|
from transformers import AutoTokenizer |
|
|
|
from sparktts.utils.token_parser import TASK_TOKEN_MAP |
|
|
|
def process_prompt( |
|
text: str, |
|
prompt_text: Optional[str] = None, |
|
global_token_ids: torch.Tensor = None, |
|
semantic_token_ids: torch.Tensor = None, |
|
) -> Tuple[str, torch.Tensor]: |
|
""" |
|
Process input for voice cloning. |
|
|
|
Args: |
|
text: The text input to be converted to speech. |
|
prompt_text: Transcript of the prompt audio. |
|
global_token_ids: Global token IDs extracted from reference audio. |
|
semantic_token_ids: Semantic token IDs extracted from reference audio. |
|
|
|
Returns: |
|
Tuple containing the formatted input prompt and global token IDs. |
|
""" |
|
|
|
global_tokens = "".join( |
|
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] |
|
) |
|
|
|
|
|
|
|
if prompt_text is not None: |
|
|
|
semantic_tokens = "".join( |
|
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] |
|
) |
|
|
|
inputs = [ |
|
TASK_TOKEN_MAP["tts"], |
|
"<|start_content|>", |
|
prompt_text, |
|
text, |
|
"<|end_content|>", |
|
"<|start_global_token|>", |
|
global_tokens, |
|
"<|end_global_token|>", |
|
"<|start_semantic_token|>", |
|
semantic_tokens, |
|
] |
|
else: |
|
|
|
inputs = [ |
|
TASK_TOKEN_MAP["tts"], |
|
"<|start_content|>", |
|
text, |
|
"<|end_content|>", |
|
"<|start_global_token|>", |
|
global_tokens, |
|
"<|end_global_token|>", |
|
] |
|
|
|
|
|
inputs = "".join(inputs) |
|
return inputs, global_token_ids |
|
|
|
|
|
class TritonPythonModel: |
|
"""Triton Python model for Spark TTS. |
|
|
|
This model orchestrates the end-to-end TTS pipeline by coordinating |
|
between audio tokenizer, LLM, and vocoder components. |
|
""" |
|
|
|
def initialize(self, args): |
|
"""Initialize the model. |
|
|
|
Args: |
|
args: Dictionary containing model configuration |
|
""" |
|
self.logger = pb_utils.Logger |
|
|
|
self.model_config = json.loads(args['model_config']) |
|
parameters = self.model_config['parameters'] |
|
model_params = {k: v["string_value"] for k, v in parameters.items()} |
|
self.logger.log_info(f"model_params:{model_params}") |
|
|
|
assert ( |
|
float(model_params["audio_chunk_duration"]) >= 0.5 |
|
), f"audio_chunk_duration at least 0.5 seconds" |
|
self.audio_chunk_duration = float(model_params["audio_chunk_duration"]) |
|
self.max_audio_chunk_duration = float(model_params["max_audio_chunk_duration"]) |
|
assert ( |
|
float(model_params["audio_chunk_size_scale_factor"]) >= 1.0 |
|
), "audio_chunk_size_scale_factor should be greater than 1, change it according to your actual rtf" |
|
self.audio_chunk_size_scale_factor = float(model_params["audio_chunk_size_scale_factor"]) |
|
self.audio_chunk_overlap_duration = float(model_params["audio_chunk_overlap_duration"]) |
|
self.audio_tokenizer_frame_rate = int(model_params["audio_tokenizer_frame_rate"]) |
|
|
|
|
|
llm_tokenizer_dir = model_params["llm_tokenizer_dir"] |
|
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) |
|
self.device = torch.device("cuda") |
|
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) |
|
|
|
def forward_llm(self, input_ids): |
|
""" |
|
Prepares the response from the language model based on the provided |
|
inputs. Creates a `pb_utils.InferenceRequest` object with passed |
|
`llm_request_inputs` to send to a decoupled TensorRTLLM model. |
|
For each response from the language model: |
|
- Checks for errors and raise an exception if any are found. |
|
- Extracts the "output_ids" tensor from the response. |
|
- Determines the finish reason based on the presence of the |
|
end-of-sequence token or reaching the maximum length. |
|
- Appends the generated token IDs to `output_ids`. |
|
- If the finish reason is determined, decodes the output IDs to text |
|
and prepares the final response. |
|
|
|
The final response includes the generated text, finish reason, |
|
completion tokens, prompt tokens, and total tokens. |
|
|
|
Parameters |
|
---------- |
|
- llm_request_inputs (dict): A dictionary containing the inputs for the language model. |
|
|
|
Returns |
|
------- |
|
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. |
|
""" |
|
|
|
input_ids = input_ids.cpu().numpy() |
|
max_tokens = 512 |
|
input_dict = { |
|
"request_output_len": np.array([[max_tokens]], dtype=np.int32), |
|
"end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32), |
|
"pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32), |
|
"streaming": np.array([[self.decoupled]], dtype=np.bool_), |
|
"runtime_top_p": np.array([[0.95]], dtype=np.float32), |
|
"runtime_top_k": np.array([[50]], dtype=np.int32), |
|
"temperature": np.array([[0.8]], dtype=np.float32), |
|
"input_ids": input_ids, |
|
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), |
|
} |
|
|
|
|
|
input_tensor_list = [ |
|
pb_utils.Tensor(k, v) for k, v in input_dict.items() |
|
] |
|
|
|
|
|
llm_request = pb_utils.InferenceRequest( |
|
model_name="tensorrt_llm", |
|
requested_output_names=["output_ids", "sequence_length"], |
|
inputs=input_tensor_list, |
|
) |
|
|
|
llm_responses = llm_request.exec(decoupled=self.decoupled) |
|
if self.decoupled: |
|
for llm_response in llm_responses: |
|
if llm_response.has_error(): |
|
raise pb_utils.TritonModelException(llm_response.error().message()) |
|
|
|
|
|
output_ids = pb_utils.get_output_tensor_by_name( |
|
llm_response, "output_ids").as_numpy() |
|
seq_lens = pb_utils.get_output_tensor_by_name( |
|
llm_response, "sequence_length").as_numpy() |
|
|
|
|
|
actual_output_ids = output_ids[0][0][:seq_lens[0][0]] |
|
|
|
yield actual_output_ids |
|
else: |
|
llm_response = llm_responses |
|
if llm_response.has_error(): |
|
raise pb_utils.TritonModelException(llm_response.error().message()) |
|
|
|
|
|
output_ids = pb_utils.get_output_tensor_by_name( |
|
llm_response, "output_ids").as_numpy() |
|
seq_lens = pb_utils.get_output_tensor_by_name( |
|
llm_response, "sequence_length").as_numpy() |
|
|
|
|
|
actual_output_ids = output_ids[0][0][:seq_lens[0][0]] |
|
|
|
yield actual_output_ids |
|
|
|
def forward_audio_tokenizer(self, wav, wav_len): |
|
"""Forward pass through the audio tokenizer component. |
|
|
|
Args: |
|
wav: Input waveform tensor |
|
wav_len: Waveform length tensor |
|
|
|
Returns: |
|
Tuple of global and semantic tokens |
|
""" |
|
inference_request = pb_utils.InferenceRequest( |
|
model_name='audio_tokenizer', |
|
requested_output_names=['global_tokens', 'semantic_tokens'], |
|
inputs=[wav, wav_len] |
|
) |
|
|
|
inference_response = inference_request.exec() |
|
if inference_response.has_error(): |
|
raise pb_utils.TritonModelException(inference_response.error().message()) |
|
|
|
|
|
global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens') |
|
global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu() |
|
|
|
semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens') |
|
semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu() |
|
|
|
return global_tokens, semantic_tokens |
|
|
|
def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass through the vocoder component. |
|
|
|
Args: |
|
global_token_ids: Global token IDs tensor |
|
pred_semantic_ids: Predicted semantic token IDs tensor |
|
|
|
Returns: |
|
Generated waveform tensor |
|
""" |
|
|
|
global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids)) |
|
pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids)) |
|
|
|
|
|
inference_request = pb_utils.InferenceRequest( |
|
model_name='vocoder', |
|
requested_output_names=['waveform'], |
|
inputs=[global_token_ids_tensor, pred_semantic_ids_tensor] |
|
) |
|
|
|
inference_response = inference_request.exec() |
|
if inference_response.has_error(): |
|
raise pb_utils.TritonModelException(inference_response.error().message()) |
|
|
|
|
|
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') |
|
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() |
|
|
|
return waveform |
|
|
|
def token2wav(self, generated_token_ids, global_token_ids): |
|
|
|
predicted_text = self.tokenizer.batch_decode( |
|
[generated_token_ids], |
|
skip_special_tokens=True, |
|
)[0] |
|
pred_semantic_ids = ( |
|
torch.tensor( |
|
[int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)] |
|
) |
|
.unsqueeze(0) |
|
.to(torch.int32) |
|
) |
|
|
|
|
|
audio = self.forward_vocoder( |
|
global_token_ids.to(self.device), |
|
pred_semantic_ids.to(self.device), |
|
) |
|
|
|
return audio |
|
|
|
def execute(self, requests): |
|
"""Execute inference on the batched requests. |
|
|
|
Args: |
|
requests: List of inference requests |
|
|
|
Returns: |
|
List of inference responses containing generated audio |
|
""" |
|
responses = [] |
|
|
|
for request in requests: |
|
|
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") |
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") |
|
|
|
|
|
global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len) |
|
|
|
|
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() |
|
reference_text = reference_text[0][0].decode('utf-8') |
|
|
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() |
|
target_text = target_text[0][0].decode('utf-8') |
|
|
|
|
|
prompt, global_token_ids = process_prompt( |
|
text=target_text, |
|
prompt_text=reference_text, |
|
global_token_ids=global_tokens, |
|
semantic_token_ids=semantic_tokens, |
|
) |
|
|
|
|
|
|
|
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) |
|
input_ids = model_inputs.input_ids.to(torch.int32) |
|
|
|
|
|
generated_ids_iter = self.forward_llm(input_ids) |
|
|
|
if self.decoupled: |
|
response_sender = request.get_response_sender() |
|
request_id = request.request_id() |
|
semantic_token_ids_arr = [] |
|
max_chunk_size = math.ceil(self.max_audio_chunk_duration * self.audio_tokenizer_frame_rate) |
|
chunk_size = math.ceil(self.audio_chunk_duration * self.audio_tokenizer_frame_rate) |
|
overlap_chunk_size = math.ceil(self.audio_chunk_overlap_duration * self.audio_tokenizer_frame_rate) |
|
self.logger.log_info( |
|
f"[{request_id}] init chunk_size: {chunk_size} max_chunk_size: {max_chunk_size}" |
|
) |
|
for generated_ids in generated_ids_iter: |
|
if generated_ids is None or len(generated_ids) == 0: |
|
break |
|
|
|
semantic_token_ids_arr.append(generated_ids) |
|
if len(semantic_token_ids_arr) >= chunk_size: |
|
chunk = semantic_token_ids_arr[:chunk_size] |
|
generated_semantic_token_ids = np.hstack(chunk) |
|
|
|
sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids) |
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) |
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) |
|
response_sender.send(inference_response) |
|
|
|
semantic_token_ids_arr = semantic_token_ids_arr[chunk_size - overlap_chunk_size:] |
|
|
|
chunk_size = min(max_chunk_size, int(chunk_size * self.audio_chunk_size_scale_factor)) |
|
self.logger.log_info(f"[{request_id}] increase chunk_size: {chunk_size}") |
|
|
|
if len(semantic_token_ids_arr) > 0: |
|
generated_semantic_token_ids = np.hstack(semantic_token_ids_arr) |
|
|
|
sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids) |
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) |
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) |
|
response_sender.send(inference_response) |
|
self.logger.log_info(f"[{request_id}] last chunk len: {len(semantic_token_ids_arr)}") |
|
else: |
|
generated_ids = next(generated_ids_iter) |
|
if generated_ids is None or len(generated_ids) == 0: |
|
raise pb_utils.TritonModelException("Generated IDs is None or empty") |
|
|
|
audio = self.token2wav(generated_ids, global_token_ids) |
|
|
|
|
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) |
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) |
|
responses.append(inference_response) |
|
|
|
if self.decoupled: |
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) |
|
self.logger.log_info(f"send tritonserver_response_complete_final to end") |
|
|
|
if not self.decoupled: |
|
return responses |
|
|
|
|