|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
import logging |
|
from typing import List, Dict |
|
|
|
import torch |
|
from torch.utils.dlpack import to_dlpack |
|
|
|
import triton_python_backend_utils as pb_utils |
|
|
|
from sparktts.models.bicodec import BiCodec |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
class TritonPythonModel: |
|
"""Triton Python model for vocoder. |
|
|
|
This model takes global and semantic tokens as input and generates audio waveforms |
|
using the BiCodec vocoder. |
|
""" |
|
|
|
def initialize(self, args): |
|
"""Initialize the model. |
|
|
|
Args: |
|
args: Dictionary containing model configuration |
|
""" |
|
|
|
parameters = json.loads(args['model_config'])['parameters'] |
|
model_params = {key: value["string_value"] for key, value in parameters.items()} |
|
model_dir = model_params["model_dir"] |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Initializing vocoder from {model_dir} on {self.device}") |
|
|
|
self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec") |
|
del self.vocoder.encoder, self.vocoder.postnet |
|
self.vocoder.eval().to(self.device) |
|
|
|
logger.info("Vocoder initialized successfully") |
|
|
|
|
|
def execute(self, requests): |
|
"""Execute inference on the batched requests. |
|
|
|
Args: |
|
requests: List of inference requests |
|
|
|
Returns: |
|
List of inference responses containing generated waveforms |
|
""" |
|
global_tokens_list, semantic_tokens_list = [], [] |
|
|
|
|
|
for request in requests: |
|
global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy() |
|
semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy() |
|
global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device)) |
|
semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device)) |
|
|
|
|
|
global_tokens = torch.cat(global_tokens_list, dim=0) |
|
semantic_tokens = torch.cat(semantic_tokens_list, dim=0) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1)) |
|
|
|
|
|
responses = [] |
|
for i in range(len(requests)): |
|
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i])) |
|
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) |
|
responses.append(inference_response) |
|
|
|
return responses |
|
|