# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json import math import os import re from typing import Dict, List, Tuple, Optional, Union import numpy as np import torch from torch.utils.dlpack import from_dlpack, to_dlpack import triton_python_backend_utils as pb_utils from transformers import AutoTokenizer from sparktts.utils.token_parser import TASK_TOKEN_MAP def process_prompt( text: str, prompt_text: Optional[str] = None, global_token_ids: torch.Tensor = None, semantic_token_ids: torch.Tensor = None, ) -> Tuple[str, torch.Tensor]: """ Process input for voice cloning. Args: text: The text input to be converted to speech. prompt_text: Transcript of the prompt audio. global_token_ids: Global token IDs extracted from reference audio. semantic_token_ids: Semantic token IDs extracted from reference audio. Returns: Tuple containing the formatted input prompt and global token IDs. """ # Convert global tokens to string format global_tokens = "".join( [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] ) # Prepare the input tokens for the model if prompt_text is not None: # Include semantic tokens when prompt text is provided semantic_tokens = "".join( [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] ) inputs = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>", "<|start_global_token|>", global_tokens, "<|end_global_token|>", "<|start_semantic_token|>", semantic_tokens, ] else: # Without prompt text, exclude semantic tokens inputs = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>", "<|start_global_token|>", global_tokens, "<|end_global_token|>", ] # Join all input components into a single string inputs = "".join(inputs) return inputs, global_token_ids class TritonPythonModel: """Triton Python model for Spark TTS. This model orchestrates the end-to-end TTS pipeline by coordinating between audio tokenizer, LLM, and vocoder components. """ def initialize(self, args): """Initialize the model. Args: args: Dictionary containing model configuration """ self.logger = pb_utils.Logger # Parse model parameters self.model_config = json.loads(args['model_config']) parameters = self.model_config['parameters'] model_params = {k: v["string_value"] for k, v in parameters.items()} self.logger.log_info(f"model_params:{model_params}") # streaming TTS parameters assert ( float(model_params["audio_chunk_duration"]) >= 0.5 ), f"audio_chunk_duration at least 0.5 seconds" self.audio_chunk_duration = float(model_params["audio_chunk_duration"]) self.max_audio_chunk_duration = float(model_params["max_audio_chunk_duration"]) assert ( float(model_params["audio_chunk_size_scale_factor"]) >= 1.0 ), "audio_chunk_size_scale_factor should be greater than 1, change it according to your actual rtf" self.audio_chunk_size_scale_factor = float(model_params["audio_chunk_size_scale_factor"]) # scale speed self.audio_chunk_overlap_duration = float(model_params["audio_chunk_overlap_duration"]) self.audio_tokenizer_frame_rate = int(model_params["audio_tokenizer_frame_rate"]) # Initialize tokenizer llm_tokenizer_dir = model_params["llm_tokenizer_dir"] self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) self.device = torch.device("cuda") self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) def forward_llm(self, input_ids): """ Prepares the response from the language model based on the provided inputs. Creates a `pb_utils.InferenceRequest` object with passed `llm_request_inputs` to send to a decoupled TensorRTLLM model. For each response from the language model: - Checks for errors and raise an exception if any are found. - Extracts the "output_ids" tensor from the response. - Determines the finish reason based on the presence of the end-of-sequence token or reaching the maximum length. - Appends the generated token IDs to `output_ids`. - If the finish reason is determined, decodes the output IDs to text and prepares the final response. The final response includes the generated text, finish reason, completion tokens, prompt tokens, and total tokens. Parameters ---------- - llm_request_inputs (dict): A dictionary containing the inputs for the language model. Returns ------- - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. """ # convert input_ids to numpy, with shape [1, sequence_length] input_ids = input_ids.cpu().numpy() max_tokens = 512 input_dict = { "request_output_len": np.array([[max_tokens]], dtype=np.int32), "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32), "pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32), "streaming": np.array([[self.decoupled]], dtype=np.bool_), "runtime_top_p": np.array([[0.95]], dtype=np.float32), "runtime_top_k": np.array([[50]], dtype=np.int32), "temperature": np.array([[0.8]], dtype=np.float32), "input_ids": input_ids, "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), } # Convert inputs to Triton tensors input_tensor_list = [ pb_utils.Tensor(k, v) for k, v in input_dict.items() ] # Create and execute inference request llm_request = pb_utils.InferenceRequest( model_name="tensorrt_llm", requested_output_names=["output_ids", "sequence_length"], inputs=input_tensor_list, ) llm_responses = llm_request.exec(decoupled=self.decoupled) if self.decoupled: for llm_response in llm_responses: if llm_response.has_error(): raise pb_utils.TritonModelException(llm_response.error().message()) # Extract and process output output_ids = pb_utils.get_output_tensor_by_name( llm_response, "output_ids").as_numpy() seq_lens = pb_utils.get_output_tensor_by_name( llm_response, "sequence_length").as_numpy() # Get actual output IDs up to the sequence length actual_output_ids = output_ids[0][0][:seq_lens[0][0]] yield actual_output_ids else: llm_response = llm_responses if llm_response.has_error(): raise pb_utils.TritonModelException(llm_response.error().message()) # Extract and process output output_ids = pb_utils.get_output_tensor_by_name( llm_response, "output_ids").as_numpy() seq_lens = pb_utils.get_output_tensor_by_name( llm_response, "sequence_length").as_numpy() # Get actual output IDs up to the sequence length actual_output_ids = output_ids[0][0][:seq_lens[0][0]] yield actual_output_ids def forward_audio_tokenizer(self, wav, wav_len): """Forward pass through the audio tokenizer component. Args: wav: Input waveform tensor wav_len: Waveform length tensor Returns: Tuple of global and semantic tokens """ inference_request = pb_utils.InferenceRequest( model_name='audio_tokenizer', requested_output_names=['global_tokens', 'semantic_tokens'], inputs=[wav, wav_len] ) inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) # Extract and convert output tensors global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens') global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu() semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens') semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu() return global_tokens, semantic_tokens def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor: """Forward pass through the vocoder component. Args: global_token_ids: Global token IDs tensor pred_semantic_ids: Predicted semantic token IDs tensor Returns: Generated waveform tensor """ # Convert tensors to Triton format global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids)) pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids)) # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='vocoder', requested_output_names=['waveform'], inputs=[global_token_ids_tensor, pred_semantic_ids_tensor] ) inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) # Extract and convert output waveform waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() return waveform def token2wav(self, generated_token_ids, global_token_ids): # Decode and extract semantic token IDs from generated text predicted_text = self.tokenizer.batch_decode( [generated_token_ids], skip_special_tokens=True, )[0] pred_semantic_ids = ( torch.tensor( [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)] ) .unsqueeze(0) .to(torch.int32) ) # Generate audio with vocoder audio = self.forward_vocoder( global_token_ids.to(self.device), pred_semantic_ids.to(self.device), ) return audio def execute(self, requests): """Execute inference on the batched requests. Args: requests: List of inference requests Returns: List of inference responses containing generated audio """ responses = [] for request in requests: # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") # Process reference audio through audio tokenizer global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len) # Extract text inputs reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') # Prepare prompt for LLM prompt, global_token_ids = process_prompt( text=target_text, prompt_text=reference_text, global_token_ids=global_tokens, semantic_token_ids=semantic_tokens, ) # Tokenize prompt for LLM model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) input_ids = model_inputs.input_ids.to(torch.int32) # Generate semantic tokens with LLM generated_ids_iter = self.forward_llm(input_ids) if self.decoupled: response_sender = request.get_response_sender() request_id = request.request_id() semantic_token_ids_arr = [] max_chunk_size = math.ceil(self.max_audio_chunk_duration * self.audio_tokenizer_frame_rate) chunk_size = math.ceil(self.audio_chunk_duration * self.audio_tokenizer_frame_rate) overlap_chunk_size = math.ceil(self.audio_chunk_overlap_duration * self.audio_tokenizer_frame_rate) self.logger.log_info( f"[{request_id}] init chunk_size: {chunk_size} max_chunk_size: {max_chunk_size}" ) for generated_ids in generated_ids_iter: if generated_ids is None or len(generated_ids) == 0: break semantic_token_ids_arr.append(generated_ids) if len(semantic_token_ids_arr) >= chunk_size: chunk = semantic_token_ids_arr[:chunk_size] generated_semantic_token_ids = np.hstack(chunk) # Process each chunk sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids) # Prepare response to send audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) semantic_token_ids_arr = semantic_token_ids_arr[chunk_size - overlap_chunk_size:] # increase chunk size for better speech quality chunk_size = min(max_chunk_size, int(chunk_size * self.audio_chunk_size_scale_factor)) self.logger.log_info(f"[{request_id}] increase chunk_size: {chunk_size}") if len(semantic_token_ids_arr) > 0: # end to finalize generated_semantic_token_ids = np.hstack(semantic_token_ids_arr) # Process each chunk sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids) # Prepare response to send audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) self.logger.log_info(f"[{request_id}] last chunk len: {len(semantic_token_ids_arr)}") else: generated_ids = next(generated_ids_iter) if generated_ids is None or len(generated_ids) == 0: raise pb_utils.TritonModelException("Generated IDs is None or empty") audio = self.token2wav(generated_ids, global_token_ids) # Prepare response audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) responses.append(inference_response) if self.decoupled: response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info(f"send tritonserver_response_complete_final to end") if not self.decoupled: return responses