|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import torch |
|
from torch.utils.dlpack import to_dlpack |
|
|
|
import triton_python_backend_utils as pb_utils |
|
|
|
import os |
|
import numpy as np |
|
|
|
from sparktts.models.audio_tokenizer import BiCodecTokenizer |
|
|
|
class TritonPythonModel: |
|
"""Triton Python model for audio tokenization. |
|
|
|
This model takes reference audio input and extracts semantic and global tokens |
|
using BiCodec tokenizer. |
|
""" |
|
|
|
def initialize(self, args): |
|
"""Initialize the model. |
|
|
|
Args: |
|
args: Dictionary containing model configuration |
|
""" |
|
|
|
parameters = json.loads(args['model_config'])['parameters'] |
|
model_params = {k: v["string_value"] for k, v in parameters.items()} |
|
|
|
|
|
self.device = torch.device("cuda") |
|
self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"], |
|
device=self.device) |
|
|
|
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray: |
|
"""Extract reference audio clip for speaker embedding. |
|
|
|
Args: |
|
wav: Input waveform array |
|
|
|
Returns: |
|
Reference clip of fixed duration |
|
""" |
|
SAMPLE_RATE = 16000 |
|
REF_SEGMENT_DURATION = 6 |
|
LATENT_HOP_LENGTH = 320 |
|
|
|
ref_segment_length = ( |
|
int(SAMPLE_RATE * REF_SEGMENT_DURATION) |
|
// LATENT_HOP_LENGTH |
|
* LATENT_HOP_LENGTH |
|
) |
|
wav_length = len(wav) |
|
|
|
if ref_segment_length > wav_length: |
|
|
|
repeat_times = ref_segment_length // wav_length + 1 |
|
wav = np.tile(wav, repeat_times) |
|
|
|
return wav[:ref_segment_length] |
|
|
|
def execute(self, requests): |
|
"""Execute inference on the batched requests. |
|
|
|
Args: |
|
requests: List of inference requests |
|
|
|
Returns: |
|
List of inference responses containing tokenized outputs |
|
""" |
|
reference_wav_list = [] |
|
reference_wav_ref_clip_list = [] |
|
|
|
|
|
for request in requests: |
|
|
|
wav_array = pb_utils.get_input_tensor_by_name( |
|
request, "reference_wav").as_numpy() |
|
wav_len = pb_utils.get_input_tensor_by_name( |
|
request, "reference_wav_len").as_numpy().item() |
|
|
|
|
|
wav = wav_array[:, :wav_len].squeeze(0) |
|
reference_wav_list.append(wav) |
|
|
|
wav_ref_clip = self.get_ref_clip(wav) |
|
reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip)) |
|
|
|
|
|
ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0) |
|
wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features( |
|
reference_wav_list) |
|
|
|
audio_tokenizer_input = { |
|
"ref_wav": ref_wav_clip_tensor.to(self.device), |
|
"feat": wav2vec2_features.to(self.device), |
|
} |
|
semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize( |
|
audio_tokenizer_input) |
|
|
|
|
|
responses = [] |
|
for i in range(len(requests)): |
|
global_tokens_tensor = pb_utils.Tensor.from_dlpack( |
|
"global_tokens", to_dlpack(global_tokens[i])) |
|
semantic_tokens_tensor = pb_utils.Tensor.from_dlpack( |
|
"semantic_tokens", to_dlpack(semantic_tokens[i])) |
|
|
|
inference_response = pb_utils.InferenceResponse( |
|
output_tensors=[global_tokens_tensor, semantic_tokens_tensor]) |
|
responses.append(inference_response) |
|
|
|
return responses |
|
|