yuekaiz
commited on
Commit
Β·
be03ceb
1
Parent(s):
0955e96
clean codes
Browse files- runtime/triton_trtllm/Dockerfile.server +9 -0
- runtime/triton_trtllm/README.md +47 -0
- runtime/triton_trtllm/build.sh +23 -7
- runtime/triton_trtllm/{client.py β client_grpc.py} +0 -0
- runtime/triton_trtllm/client_http.py +46 -0
- runtime/triton_trtllm/docker-compose.yml +20 -0
- runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +69 -36
- runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt +3 -3
- runtime/triton_trtllm/model_repo/spark_tts/1/model.py +128 -104
- runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt +5 -5
- runtime/triton_trtllm/model_repo/vocoder/1/model.py +50 -18
- runtime/triton_trtllm/model_repo/vocoder/config.pbtxt +3 -3
- runtime/triton_trtllm/{build_engine.sh β scripts/build_engine.sh} +0 -0
- runtime/triton_trtllm/{convert_checkpoint.py β scripts/convert_checkpoint.py} +0 -0
- runtime/triton_trtllm/{fill_template.py β scripts/fill_template.py} +0 -0
runtime/triton_trtllm/Dockerfile.server
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3
|
2 |
+
RUN pip install tritonclient[grpc] librosa
|
3 |
+
WORKDIR /workspace
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
runtime/triton_trtllm/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Triton Inference Serving Best Practice for F5 TTS
|
2 |
+
|
3 |
+
### Model Training
|
4 |
+
See [official F5-TTS](https://github.com/SWivid/F5-TTS) or [Icefall F5-TTS](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech4tts/TTS#f5-tts).
|
5 |
+
|
6 |
+
### Quick Start
|
7 |
+
Directly launch the service using docker compose.
|
8 |
+
```sh
|
9 |
+
# VOCODER vocos or bigvgan
|
10 |
+
VOCODER=vocos docker compose up
|
11 |
+
```
|
12 |
+
|
13 |
+
### Build Image
|
14 |
+
Build the docker image from scratch.
|
15 |
+
```sh
|
16 |
+
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
|
17 |
+
```
|
18 |
+
|
19 |
+
### Create Docker Container
|
20 |
+
```sh
|
21 |
+
your_mount_dir=/mnt:/mnt
|
22 |
+
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
|
23 |
+
```
|
24 |
+
|
25 |
+
### Export Models to TensorRT-LLM and Launch Server
|
26 |
+
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
|
27 |
+
|
28 |
+
```sh
|
29 |
+
bash build_server.sh
|
30 |
+
```
|
31 |
+
|
32 |
+
### Benchmark using Dataset
|
33 |
+
```sh
|
34 |
+
num_task=2
|
35 |
+
python3 client.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
36 |
+
```
|
37 |
+
|
38 |
+
### Benchmark Results
|
39 |
+
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
|
40 |
+
|
41 |
+
| Model | Note | Concurrency | Avg Latency | RTF |
|
42 |
+
|-------|-----------|-----------------------|---------|--|
|
43 |
+
| F5-TTS Base (Vocos) | [Code Commit](https://github.com/yuekaizhang/sherpa/tree/329ab3c573252e835844bea38505c6b43e994cf4/triton/f5_tts) | 1 | 253 ms | 0.0394|
|
44 |
+
|
45 |
+
### Credits
|
46 |
+
1. [F5-TTS](https://github.com/SWivid/F5-TTS)
|
47 |
+
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
runtime/triton_trtllm/build.sh
CHANGED
@@ -1,15 +1,31 @@
|
|
1 |
|
2 |
-
pip install -r /workspace_yuekai/spark-tts/Spark-TTS/requirements.txt
|
|
|
|
|
3 |
model_repo=./model_repo_test
|
4 |
rm -rf $model_repo
|
5 |
-
|
6 |
cp -r ./model_repo $model_repo
|
7 |
|
8 |
ENGINE_PATH=/workspace_yuekai/spark-tts/TensorRT-LLM/examples/qwen/Spark-TTS-0.5B_trt_engines_1gpu_bfloat16
|
9 |
MAX_QUEUE_DELAY_MICROSECONDS=0
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
# enable_context_fmha_fp32_acc:${ENABLE_CONTEXT_FMHA_FP32_ACC}
|
14 |
-
export PYTHONPATH=/workspace_yuekai/spark-tts/Spark-TTS/
|
15 |
-
CUDA_VISIBLE_DEVICES=${gpu_device_ids} tritonserver --model-repository ${model_repo}
|
|
|
1 |
|
2 |
+
# pip install -r /workspace_yuekai/spark-tts/Spark-TTS/requirements.txt
|
3 |
+
export PYTHONPATH=/workspace_yuekai/spark-tts/Spark-TTS/
|
4 |
+
|
5 |
model_repo=./model_repo_test
|
6 |
rm -rf $model_repo
|
|
|
7 |
cp -r ./model_repo $model_repo
|
8 |
|
9 |
ENGINE_PATH=/workspace_yuekai/spark-tts/TensorRT-LLM/examples/qwen/Spark-TTS-0.5B_trt_engines_1gpu_bfloat16
|
10 |
MAX_QUEUE_DELAY_MICROSECONDS=0
|
11 |
+
MODEL_DIR=/workspace_yuekai/spark-tts/Spark-TTS/pretrained_models/Spark-TTS-0.5B
|
12 |
+
LLM_TOKENIZER_DIR=/workspace_yuekai/spark-tts/Spark-TTS/pretrained_models/Spark-TTS-0.5B/LLM
|
13 |
+
BLS_INSTANCE_NUM=4
|
14 |
+
TRITON_MAX_BATCH_SIZE=16
|
15 |
+
|
16 |
+
python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
17 |
+
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
18 |
+
python3 scripts/fill_template.py -i ${model_repo}/spark_tts/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
19 |
+
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
20 |
+
|
21 |
+
|
22 |
+
CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository ${model_repo}
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
|
|
|
|
|
|
runtime/triton_trtllm/{client.py β client_grpc.py}
RENAMED
File without changes
|
runtime/triton_trtllm/client_http.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import soundfile as sf
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
url = "http://localhost:8000/v2/models/infer_pipeline/infer"
|
7 |
+
wav_path = "*********"
|
8 |
+
waveform, sr = sf.read(wav_path)
|
9 |
+
lang_id = 54
|
10 |
+
samples = np.array([waveform], dtype=np.float32)
|
11 |
+
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
12 |
+
lang_id = np.array([[lang_id]], dtype=np.int8)
|
13 |
+
|
14 |
+
data = {
|
15 |
+
"inputs":[
|
16 |
+
{
|
17 |
+
"name": "WAV",
|
18 |
+
"shape": samples.shape,
|
19 |
+
"datatype": "FP32",
|
20 |
+
"data": samples.tolist()
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"name": "WAV_LENS",
|
24 |
+
"shape": lengths.shape,
|
25 |
+
"datatype": "INT32",
|
26 |
+
"data": lengths.tolist(),
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"name": "LANG_ID",
|
30 |
+
"shape": lang_id.shape,
|
31 |
+
"datatype": "INT8",
|
32 |
+
"data": lang_id.tolist()
|
33 |
+
}
|
34 |
+
]
|
35 |
+
}
|
36 |
+
rsp = requests.post(
|
37 |
+
url,
|
38 |
+
headers={"Content-Type": "application/json"},
|
39 |
+
json=data,
|
40 |
+
verify=False,
|
41 |
+
params={"request_id": '0'}
|
42 |
+
)
|
43 |
+
result = rsp.json()
|
44 |
+
print(result)
|
45 |
+
transcripts = result["outputs"][0]["data"][0]
|
46 |
+
print(transcripts)
|
runtime/triton_trtllm/docker-compose.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
tts:
|
3 |
+
image: soar97/triton-f5-tts:24.12
|
4 |
+
shm_size: '1gb'
|
5 |
+
ports:
|
6 |
+
- "8000:8000"
|
7 |
+
- "8001:8001"
|
8 |
+
- "8002:8002"
|
9 |
+
environment:
|
10 |
+
- PYTHONIOENCODING=utf-8
|
11 |
+
- MODEL_ID=${MODEL_ID}
|
12 |
+
deploy:
|
13 |
+
resources:
|
14 |
+
reservations:
|
15 |
+
devices:
|
16 |
+
- driver: nvidia
|
17 |
+
device_ids: ['0']
|
18 |
+
capabilities: [gpu]
|
19 |
+
command: >
|
20 |
+
/bin/bash -c "rm -rf sherpa && git clone https://github.com/yuekaizhang/sherpa.git -b f5 && cd sherpa/triton/f5_tts/ && bash build_server.sh $VOCODER"
|
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
CHANGED
@@ -25,80 +25,113 @@
|
|
25 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
import json
|
27 |
import torch
|
28 |
-
from torch import
|
29 |
-
from torch.nn.utils.rnn import pad_sequence
|
30 |
-
import torch.nn.functional as F
|
31 |
-
from torch.utils.dlpack import from_dlpack, to_dlpack
|
32 |
|
33 |
import triton_python_backend_utils as pb_utils
|
34 |
|
35 |
-
import math
|
36 |
import os
|
37 |
-
from functools import wraps
|
38 |
import numpy as np
|
39 |
|
40 |
-
|
41 |
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
42 |
|
43 |
class TritonPythonModel:
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def initialize(self, args):
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
parameters = json.loads(args['model_config'])['parameters']
|
46 |
-
for
|
47 |
-
|
48 |
-
|
49 |
self.device = torch.device("cuda")
|
50 |
-
self.audio_tokenizer = BiCodecTokenizer(model_dir,
|
|
|
51 |
|
52 |
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
53 |
-
"""
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
ref_segment_length = (
|
59 |
-
int(
|
60 |
-
//
|
61 |
-
*
|
62 |
)
|
63 |
wav_length = len(wav)
|
64 |
|
65 |
if ref_segment_length > wav_length:
|
66 |
-
# Repeat and truncate
|
67 |
-
|
|
|
68 |
|
69 |
return wav[:ref_segment_length]
|
70 |
|
71 |
def execute(self, requests):
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
|
|
74 |
for request in requests:
|
75 |
-
|
|
|
|
|
76 |
wav_len = pb_utils.get_input_tensor_by_name(
|
77 |
request, "reference_wav_len").as_numpy().item()
|
78 |
-
|
79 |
-
|
80 |
-
# squeeze the first dimension, for the numpy array
|
81 |
wav = wav_array[:, :wav_len].squeeze(0)
|
82 |
reference_wav_list.append(wav)
|
83 |
-
|
84 |
wav_ref_clip = self.get_ref_clip(wav)
|
85 |
-
print(wav_ref_clip.shape, 2333333333455)
|
86 |
reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))
|
87 |
-
|
|
|
88 |
ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
|
89 |
-
wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
|
90 |
-
|
91 |
-
|
|
|
|
|
92 |
"feat": wav2vec2_features.to(self.device),
|
93 |
}
|
94 |
-
semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
|
95 |
-
|
96 |
|
|
|
97 |
responses = []
|
98 |
for i in range(len(requests)):
|
99 |
-
global_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
102 |
responses.append(inference_response)
|
103 |
|
104 |
return responses
|
|
|
25 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
import json
|
27 |
import torch
|
28 |
+
from torch.utils.dlpack import to_dlpack
|
|
|
|
|
|
|
29 |
|
30 |
import triton_python_backend_utils as pb_utils
|
31 |
|
|
|
32 |
import os
|
|
|
33 |
import numpy as np
|
34 |
|
|
|
35 |
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
36 |
|
37 |
class TritonPythonModel:
|
38 |
+
"""Triton Python model for audio tokenization.
|
39 |
+
|
40 |
+
This model takes reference audio input and extracts semantic and global tokens
|
41 |
+
using BiCodec tokenizer.
|
42 |
+
"""
|
43 |
+
|
44 |
def initialize(self, args):
|
45 |
+
"""Initialize the model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
args: Dictionary containing model configuration
|
49 |
+
"""
|
50 |
+
# Parse model parameters
|
51 |
parameters = json.loads(args['model_config'])['parameters']
|
52 |
+
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
53 |
+
|
54 |
+
# Initialize tokenizer
|
55 |
self.device = torch.device("cuda")
|
56 |
+
self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"],
|
57 |
+
device=self.device)
|
58 |
|
59 |
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
60 |
+
"""Extract reference audio clip for speaker embedding.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
wav: Input waveform array
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Reference clip of fixed duration
|
67 |
+
"""
|
68 |
+
SAMPLE_RATE = 16000
|
69 |
+
REF_SEGMENT_DURATION = 6 # seconds
|
70 |
+
LATENT_HOP_LENGTH = 320
|
71 |
|
72 |
ref_segment_length = (
|
73 |
+
int(SAMPLE_RATE * REF_SEGMENT_DURATION)
|
74 |
+
// LATENT_HOP_LENGTH
|
75 |
+
* LATENT_HOP_LENGTH
|
76 |
)
|
77 |
wav_length = len(wav)
|
78 |
|
79 |
if ref_segment_length > wav_length:
|
80 |
+
# Repeat and truncate if input is too short
|
81 |
+
repeat_times = ref_segment_length // wav_length + 1
|
82 |
+
wav = np.tile(wav, repeat_times)
|
83 |
|
84 |
return wav[:ref_segment_length]
|
85 |
|
86 |
def execute(self, requests):
|
87 |
+
"""Execute inference on the batched requests.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
requests: List of inference requests
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
List of inference responses containing tokenized outputs
|
94 |
+
"""
|
95 |
+
reference_wav_list = []
|
96 |
+
reference_wav_ref_clip_list = []
|
97 |
|
98 |
+
# Process each request in batch
|
99 |
for request in requests:
|
100 |
+
# Extract input tensors
|
101 |
+
wav_array = pb_utils.get_input_tensor_by_name(
|
102 |
+
request, "reference_wav").as_numpy()
|
103 |
wav_len = pb_utils.get_input_tensor_by_name(
|
104 |
request, "reference_wav_len").as_numpy().item()
|
105 |
+
|
106 |
+
# Prepare inputs
|
|
|
107 |
wav = wav_array[:, :wav_len].squeeze(0)
|
108 |
reference_wav_list.append(wav)
|
109 |
+
|
110 |
wav_ref_clip = self.get_ref_clip(wav)
|
|
|
111 |
reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))
|
112 |
+
|
113 |
+
# Batch process through tokenizer
|
114 |
ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
|
115 |
+
wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
|
116 |
+
reference_wav_list)
|
117 |
+
|
118 |
+
audio_tokenizer_input = {
|
119 |
+
"ref_wav": ref_wav_clip_tensor.to(self.device),
|
120 |
"feat": wav2vec2_features.to(self.device),
|
121 |
}
|
122 |
+
semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
|
123 |
+
audio_tokenizer_input)
|
124 |
|
125 |
+
# Prepare responses
|
126 |
responses = []
|
127 |
for i in range(len(requests)):
|
128 |
+
global_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
129 |
+
"global_tokens", to_dlpack(global_tokens[i]))
|
130 |
+
semantic_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
131 |
+
"semantic_tokens", to_dlpack(semantic_tokens[i]))
|
132 |
+
|
133 |
+
inference_response = pb_utils.InferenceResponse(
|
134 |
+
output_tensors=[global_tokens_tensor, semantic_tokens_tensor])
|
135 |
responses.append(inference_response)
|
136 |
|
137 |
return responses
|
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt
CHANGED
@@ -14,14 +14,14 @@
|
|
14 |
|
15 |
name: "audio_tokenizer"
|
16 |
backend: "python"
|
17 |
-
max_batch_size:
|
18 |
dynamic_batching {
|
19 |
-
max_queue_delay_microseconds:
|
20 |
}
|
21 |
parameters [
|
22 |
{
|
23 |
key: "model_dir",
|
24 |
-
value: {string_value:"
|
25 |
}
|
26 |
]
|
27 |
|
|
|
14 |
|
15 |
name: "audio_tokenizer"
|
16 |
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
}
|
21 |
parameters [
|
22 |
{
|
23 |
key: "model_dir",
|
24 |
+
value: {string_value:"${model_dir}"}
|
25 |
}
|
26 |
]
|
27 |
|
runtime/triton_trtllm/model_repo/spark_tts/1/model.py
CHANGED
@@ -23,31 +23,23 @@
|
|
23 |
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
26 |
import json
|
|
|
|
|
|
|
|
|
|
|
27 |
import torch
|
28 |
-
from torch import nn
|
29 |
-
from torch.nn.utils.rnn import pad_sequence
|
30 |
-
import torch.nn.functional as F
|
31 |
from torch.utils.dlpack import from_dlpack, to_dlpack
|
32 |
-
|
33 |
import triton_python_backend_utils as pb_utils
|
34 |
-
|
35 |
-
import math
|
36 |
-
import os
|
37 |
-
from functools import wraps
|
38 |
-
|
39 |
from transformers import AutoTokenizer
|
40 |
|
41 |
-
|
42 |
-
import re
|
43 |
-
from typing import Tuple
|
44 |
-
|
45 |
-
from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
|
46 |
-
|
47 |
|
48 |
def process_prompt(
|
49 |
text: str,
|
50 |
-
prompt_text: str = None,
|
51 |
global_token_ids: torch.Tensor = None,
|
52 |
semantic_token_ids: torch.Tensor = None,
|
53 |
) -> Tuple[str, torch.Tensor]:
|
@@ -55,27 +47,27 @@ def process_prompt(
|
|
55 |
Process input for voice cloning.
|
56 |
|
57 |
Args:
|
58 |
-
text
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
-
|
63 |
-
Tuple
|
64 |
"""
|
65 |
-
|
66 |
-
# global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
|
67 |
-
# prompt_speech_path
|
68 |
-
# )
|
69 |
global_tokens = "".join(
|
70 |
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
71 |
)
|
72 |
-
|
|
|
73 |
# Prepare the input tokens for the model
|
74 |
if prompt_text is not None:
|
|
|
75 |
semantic_tokens = "".join(
|
76 |
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
77 |
)
|
78 |
-
|
79 |
inputs = [
|
80 |
TASK_TOKEN_MAP["tts"],
|
81 |
"<|start_content|>",
|
@@ -89,6 +81,7 @@ def process_prompt(
|
|
89 |
semantic_tokens,
|
90 |
]
|
91 |
else:
|
|
|
92 |
inputs = [
|
93 |
TASK_TOKEN_MAP["tts"],
|
94 |
"<|start_content|>",
|
@@ -99,17 +92,31 @@ def process_prompt(
|
|
99 |
"<|end_global_token|>",
|
100 |
]
|
101 |
|
|
|
102 |
inputs = "".join(inputs)
|
103 |
-
|
104 |
return inputs, global_token_ids
|
105 |
|
|
|
106 |
class TritonPythonModel:
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def initialize(self, args):
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
parameters = json.loads(args['model_config'])['parameters']
|
109 |
-
for
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
113 |
self.device = torch.device("cuda")
|
114 |
self.decoupled = False
|
115 |
|
@@ -140,7 +147,6 @@ class TritonPythonModel:
|
|
140 |
"""
|
141 |
# convert input_ids to numpy, with shape [1, sequence_length]
|
142 |
input_ids = input_ids.cpu().numpy()
|
143 |
-
print(input_ids.shape, 233333333333, "input_ids")
|
144 |
max_tokens = 512
|
145 |
input_dict = {
|
146 |
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
@@ -153,135 +159,153 @@ class TritonPythonModel:
|
|
153 |
"input_ids": input_ids,
|
154 |
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
155 |
}
|
156 |
-
|
157 |
-
|
158 |
-
# exit()
|
159 |
input_tensor_list = [
|
160 |
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
161 |
]
|
162 |
-
|
163 |
-
#
|
164 |
-
# ))
|
165 |
llm_request = pb_utils.InferenceRequest(
|
166 |
model_name="tensorrt_llm",
|
167 |
requested_output_names=["output_ids", "sequence_length"],
|
168 |
inputs=input_tensor_list,
|
169 |
)
|
170 |
-
|
171 |
llm_response = llm_request.exec(decoupled=self.decoupled)
|
172 |
if llm_response.has_error():
|
173 |
-
raise pb_utils.TritonModelException(
|
174 |
-
|
|
|
175 |
output_ids = pb_utils.get_output_tensor_by_name(
|
176 |
llm_response, "output_ids").as_numpy()
|
177 |
seq_lens = pb_utils.get_output_tensor_by_name(
|
178 |
llm_response, "sequence_length").as_numpy()
|
179 |
-
|
180 |
-
|
181 |
-
actual_output_ids =
|
182 |
-
|
183 |
return actual_output_ids
|
184 |
|
185 |
def forward_audio_tokenizer(self, wav, wav_len):
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
inference_request = pb_utils.InferenceRequest(
|
190 |
model_name='audio_tokenizer',
|
191 |
requested_output_names=['global_tokens', 'semantic_tokens'],
|
192 |
inputs=[wav, wav_len]
|
193 |
)
|
|
|
194 |
inference_response = inference_request.exec()
|
195 |
if inference_response.has_error():
|
196 |
raise pb_utils.TritonModelException(inference_response.error().message())
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
205 |
|
206 |
-
def forward_vocoder(self, global_token_ids, pred_semantic_ids):
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
inference_request = pb_utils.InferenceRequest(
|
210 |
model_name='vocoder',
|
211 |
requested_output_names=['waveform'],
|
212 |
-
inputs=[
|
213 |
)
|
|
|
214 |
inference_response = inference_request.exec()
|
215 |
if inference_response.has_error():
|
216 |
raise pb_utils.TritonModelException(inference_response.error().message())
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
222 |
|
223 |
def execute(self, requests):
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
responses = []
|
|
|
226 |
for request in requests:
|
|
|
227 |
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
228 |
-
wav_len = pb_utils.get_input_tensor_by_name(
|
229 |
-
|
|
|
230 |
global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
231 |
-
|
232 |
-
#
|
233 |
-
|
234 |
-
# reference_wav_ref_clip_list.append(wav_ref_clip)
|
235 |
-
|
236 |
-
|
237 |
-
reference_text = pb_utils.get_input_tensor_by_name(
|
238 |
-
request, "reference_text").as_numpy()
|
239 |
reference_text = reference_text[0][0].decode('utf-8')
|
240 |
-
# reference_text_list.append(reference_text)
|
241 |
-
|
242 |
-
target_text = pb_utils.get_input_tensor_by_name(
|
243 |
-
request, "target_text").as_numpy()
|
244 |
-
target_text = target_text[0][0].decode('utf-8')
|
245 |
-
# target_text_list.append(target_text)
|
246 |
|
247 |
-
|
248 |
-
|
249 |
-
# audio_tokenizer_input_dict = {
|
250 |
-
# "ref_wav": ref_wav_clip_tensor, # no padding, spaker encoder
|
251 |
-
# "feat": wav2vec2_features,
|
252 |
-
# }
|
253 |
|
|
|
254 |
prompt, global_token_ids = process_prompt(
|
255 |
text=target_text,
|
256 |
prompt_text=reference_text,
|
257 |
global_token_ids=global_tokens,
|
258 |
semantic_token_ids=semantic_tokens,
|
259 |
)
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
264 |
-
print(model_inputs, "model_inputs")
|
265 |
input_ids = model_inputs.input_ids.to(torch.int32)
|
266 |
-
|
267 |
-
|
268 |
generated_ids = self.forward_llm(input_ids)
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
pred_semantic_ids = (
|
273 |
-
torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)",
|
274 |
.unsqueeze(0).to(torch.int32)
|
275 |
)
|
276 |
-
|
277 |
-
|
|
|
278 |
audio = self.forward_vocoder(
|
279 |
global_token_ids.to(self.device),
|
280 |
pred_semantic_ids.to(self.device),
|
281 |
)
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
285 |
responses.append(inference_response)
|
286 |
|
287 |
return responses
|
|
|
23 |
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
import json
|
28 |
+
import os
|
29 |
+
import re
|
30 |
+
from typing import Dict, List, Tuple, Optional, Union
|
31 |
+
|
32 |
+
import numpy as np
|
33 |
import torch
|
|
|
|
|
|
|
34 |
from torch.utils.dlpack import from_dlpack, to_dlpack
|
|
|
35 |
import triton_python_backend_utils as pb_utils
|
|
|
|
|
|
|
|
|
|
|
36 |
from transformers import AutoTokenizer
|
37 |
|
38 |
+
from sparktts.utils.token_parser import TASK_TOKEN_MAP
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def process_prompt(
|
41 |
text: str,
|
42 |
+
prompt_text: Optional[str] = None,
|
43 |
global_token_ids: torch.Tensor = None,
|
44 |
semantic_token_ids: torch.Tensor = None,
|
45 |
) -> Tuple[str, torch.Tensor]:
|
|
|
47 |
Process input for voice cloning.
|
48 |
|
49 |
Args:
|
50 |
+
text: The text input to be converted to speech.
|
51 |
+
prompt_text: Transcript of the prompt audio.
|
52 |
+
global_token_ids: Global token IDs extracted from reference audio.
|
53 |
+
semantic_token_ids: Semantic token IDs extracted from reference audio.
|
54 |
|
55 |
+
Returns:
|
56 |
+
Tuple containing the formatted input prompt and global token IDs.
|
57 |
"""
|
58 |
+
# Convert global tokens to string format
|
|
|
|
|
|
|
59 |
global_tokens = "".join(
|
60 |
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
61 |
)
|
62 |
+
|
63 |
+
|
64 |
# Prepare the input tokens for the model
|
65 |
if prompt_text is not None:
|
66 |
+
# Include semantic tokens when prompt text is provided
|
67 |
semantic_tokens = "".join(
|
68 |
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
69 |
)
|
70 |
+
|
71 |
inputs = [
|
72 |
TASK_TOKEN_MAP["tts"],
|
73 |
"<|start_content|>",
|
|
|
81 |
semantic_tokens,
|
82 |
]
|
83 |
else:
|
84 |
+
# Without prompt text, exclude semantic tokens
|
85 |
inputs = [
|
86 |
TASK_TOKEN_MAP["tts"],
|
87 |
"<|start_content|>",
|
|
|
92 |
"<|end_global_token|>",
|
93 |
]
|
94 |
|
95 |
+
# Join all input components into a single string
|
96 |
inputs = "".join(inputs)
|
|
|
97 |
return inputs, global_token_ids
|
98 |
|
99 |
+
|
100 |
class TritonPythonModel:
|
101 |
+
"""Triton Python model for Spark TTS.
|
102 |
+
|
103 |
+
This model orchestrates the end-to-end TTS pipeline by coordinating
|
104 |
+
between audio tokenizer, LLM, and vocoder components.
|
105 |
+
"""
|
106 |
+
|
107 |
def initialize(self, args):
|
108 |
+
"""Initialize the model.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
args: Dictionary containing model configuration
|
112 |
+
"""
|
113 |
+
# Parse model parameters
|
114 |
parameters = json.loads(args['model_config'])['parameters']
|
115 |
+
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
116 |
+
|
117 |
+
# Initialize tokenizer
|
118 |
+
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
119 |
+
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
120 |
self.device = torch.device("cuda")
|
121 |
self.decoupled = False
|
122 |
|
|
|
147 |
"""
|
148 |
# convert input_ids to numpy, with shape [1, sequence_length]
|
149 |
input_ids = input_ids.cpu().numpy()
|
|
|
150 |
max_tokens = 512
|
151 |
input_dict = {
|
152 |
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
|
|
159 |
"input_ids": input_ids,
|
160 |
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
161 |
}
|
162 |
+
|
163 |
+
# Convert inputs to Triton tensors
|
|
|
164 |
input_tensor_list = [
|
165 |
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
166 |
]
|
167 |
+
|
168 |
+
# Create and execute inference request
|
|
|
169 |
llm_request = pb_utils.InferenceRequest(
|
170 |
model_name="tensorrt_llm",
|
171 |
requested_output_names=["output_ids", "sequence_length"],
|
172 |
inputs=input_tensor_list,
|
173 |
)
|
174 |
+
|
175 |
llm_response = llm_request.exec(decoupled=self.decoupled)
|
176 |
if llm_response.has_error():
|
177 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
178 |
+
|
179 |
+
# Extract and process output
|
180 |
output_ids = pb_utils.get_output_tensor_by_name(
|
181 |
llm_response, "output_ids").as_numpy()
|
182 |
seq_lens = pb_utils.get_output_tensor_by_name(
|
183 |
llm_response, "sequence_length").as_numpy()
|
184 |
+
|
185 |
+
# Get actual output IDs up to the sequence length
|
186 |
+
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
187 |
+
|
188 |
return actual_output_ids
|
189 |
|
190 |
def forward_audio_tokenizer(self, wav, wav_len):
|
191 |
+
"""Forward pass through the audio tokenizer component.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
wav: Input waveform tensor
|
195 |
+
wav_len: Waveform length tensor
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
Tuple of global and semantic tokens
|
199 |
+
"""
|
200 |
inference_request = pb_utils.InferenceRequest(
|
201 |
model_name='audio_tokenizer',
|
202 |
requested_output_names=['global_tokens', 'semantic_tokens'],
|
203 |
inputs=[wav, wav_len]
|
204 |
)
|
205 |
+
|
206 |
inference_response = inference_request.exec()
|
207 |
if inference_response.has_error():
|
208 |
raise pb_utils.TritonModelException(inference_response.error().message())
|
209 |
+
|
210 |
+
# Extract and convert output tensors
|
211 |
+
global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens')
|
212 |
+
global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu()
|
213 |
+
|
214 |
+
semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens')
|
215 |
+
semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu()
|
216 |
+
|
217 |
+
return global_tokens, semantic_tokens
|
218 |
|
219 |
+
def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor:
|
220 |
+
"""Forward pass through the vocoder component.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
global_token_ids: Global token IDs tensor
|
224 |
+
pred_semantic_ids: Predicted semantic token IDs tensor
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
Generated waveform tensor
|
228 |
+
"""
|
229 |
+
# Convert tensors to Triton format
|
230 |
+
global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids))
|
231 |
+
pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids))
|
232 |
+
|
233 |
+
# Create and execute inference request
|
234 |
inference_request = pb_utils.InferenceRequest(
|
235 |
model_name='vocoder',
|
236 |
requested_output_names=['waveform'],
|
237 |
+
inputs=[global_token_ids_tensor, pred_semantic_ids_tensor]
|
238 |
)
|
239 |
+
|
240 |
inference_response = inference_request.exec()
|
241 |
if inference_response.has_error():
|
242 |
raise pb_utils.TritonModelException(inference_response.error().message())
|
243 |
+
|
244 |
+
# Extract and convert output waveform
|
245 |
+
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
246 |
+
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
247 |
+
|
248 |
+
return waveform
|
249 |
|
250 |
def execute(self, requests):
|
251 |
+
"""Execute inference on the batched requests.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
requests: List of inference requests
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
List of inference responses containing generated audio
|
258 |
+
"""
|
259 |
responses = []
|
260 |
+
|
261 |
for request in requests:
|
262 |
+
# Extract input tensors
|
263 |
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
264 |
+
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
265 |
+
|
266 |
+
# Process reference audio through audio tokenizer
|
267 |
global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
268 |
+
|
269 |
+
# Extract text inputs
|
270 |
+
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
|
|
|
|
|
|
|
|
|
271 |
reference_text = reference_text[0][0].decode('utf-8')
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
+
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
274 |
+
target_text = target_text[0][0].decode('utf-8')
|
|
|
|
|
|
|
|
|
275 |
|
276 |
+
# Prepare prompt for LLM
|
277 |
prompt, global_token_ids = process_prompt(
|
278 |
text=target_text,
|
279 |
prompt_text=reference_text,
|
280 |
global_token_ids=global_tokens,
|
281 |
semantic_token_ids=semantic_tokens,
|
282 |
)
|
283 |
+
|
284 |
+
|
285 |
+
# Tokenize prompt for LLM
|
286 |
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
|
|
287 |
input_ids = model_inputs.input_ids.to(torch.int32)
|
288 |
+
|
289 |
+
# Generate semantic tokens with LLM
|
290 |
generated_ids = self.forward_llm(input_ids)
|
291 |
+
|
292 |
+
# Decode and extract semantic token IDs from generated text
|
293 |
+
predicted_text = self.tokenizer.batch_decode([generated_ids], skip_special_tokens=True)[0]
|
294 |
pred_semantic_ids = (
|
295 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)])
|
296 |
.unsqueeze(0).to(torch.int32)
|
297 |
)
|
298 |
+
|
299 |
+
|
300 |
+
# Generate audio with vocoder
|
301 |
audio = self.forward_vocoder(
|
302 |
global_token_ids.to(self.device),
|
303 |
pred_semantic_ids.to(self.device),
|
304 |
)
|
305 |
+
|
306 |
+
# Prepare response
|
307 |
+
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
308 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
309 |
responses.append(inference_response)
|
310 |
|
311 |
return responses
|
runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt
CHANGED
@@ -14,14 +14,14 @@
|
|
14 |
|
15 |
name: "spark_tts"
|
16 |
backend: "python"
|
17 |
-
max_batch_size:
|
18 |
dynamic_batching {
|
19 |
-
max_queue_delay_microseconds:
|
20 |
}
|
21 |
parameters [
|
22 |
{
|
23 |
-
key: "
|
24 |
-
value: {string_value:"
|
25 |
}
|
26 |
]
|
27 |
|
@@ -59,7 +59,7 @@ output [
|
|
59 |
|
60 |
instance_group [
|
61 |
{
|
62 |
-
count:
|
63 |
kind: KIND_CPU
|
64 |
}
|
65 |
]
|
|
|
14 |
|
15 |
name: "spark_tts"
|
16 |
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
}
|
21 |
parameters [
|
22 |
{
|
23 |
+
key: "llm_tokenizer_dir",
|
24 |
+
value: {string_value:"${llm_tokenizer_dir}"}
|
25 |
}
|
26 |
]
|
27 |
|
|
|
59 |
|
60 |
instance_group [
|
61 |
{
|
62 |
+
count: ${bls_instance_num}
|
63 |
kind: KIND_CPU
|
64 |
}
|
65 |
]
|
runtime/triton_trtllm/model_repo/vocoder/1/model.py
CHANGED
@@ -23,48 +23,80 @@
|
|
23 |
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
26 |
import json
|
|
|
|
|
|
|
|
|
27 |
import torch
|
28 |
-
from torch import
|
29 |
-
from torch.nn.utils.rnn import pad_sequence
|
30 |
-
import torch.nn.functional as F
|
31 |
-
from torch.utils.dlpack import from_dlpack, to_dlpack
|
32 |
|
33 |
import triton_python_backend_utils as pb_utils
|
34 |
|
35 |
-
import math
|
36 |
-
import os
|
37 |
-
from functools import wraps
|
38 |
-
|
39 |
from sparktts.models.bicodec import BiCodec
|
40 |
|
|
|
|
|
|
|
|
|
41 |
class TritonPythonModel:
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def initialize(self, args):
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
parameters = json.loads(args['model_config'])['parameters']
|
44 |
-
for key, value in parameters.items()
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
self.
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def execute(self, requests):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
global_tokens_list, semantic_tokens_list = [], []
|
54 |
|
|
|
55 |
for request in requests:
|
56 |
global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
|
57 |
semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
|
58 |
-
# check shape
|
59 |
global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
|
60 |
semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
|
61 |
|
|
|
62 |
global_tokens = torch.cat(global_tokens_list, dim=0)
|
63 |
semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
|
64 |
-
|
65 |
|
66 |
-
|
|
|
|
|
67 |
|
|
|
68 |
responses = []
|
69 |
for i in range(len(requests)):
|
70 |
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
|
|
|
23 |
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
import json
|
28 |
+
import os
|
29 |
+
import logging
|
30 |
+
from typing import List, Dict
|
31 |
+
|
32 |
import torch
|
33 |
+
from torch.utils.dlpack import to_dlpack
|
|
|
|
|
|
|
34 |
|
35 |
import triton_python_backend_utils as pb_utils
|
36 |
|
|
|
|
|
|
|
|
|
37 |
from sparktts.models.bicodec import BiCodec
|
38 |
|
39 |
+
# Configure logging
|
40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
class TritonPythonModel:
|
44 |
+
"""Triton Python model for vocoder.
|
45 |
+
|
46 |
+
This model takes global and semantic tokens as input and generates audio waveforms
|
47 |
+
using the BiCodec vocoder.
|
48 |
+
"""
|
49 |
+
|
50 |
def initialize(self, args):
|
51 |
+
"""Initialize the model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
args: Dictionary containing model configuration
|
55 |
+
"""
|
56 |
+
# Parse model parameters
|
57 |
parameters = json.loads(args['model_config'])['parameters']
|
58 |
+
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
59 |
+
model_dir = model_params["model_dir"]
|
60 |
+
|
61 |
+
# Initialize device and vocoder
|
62 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
+
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
64 |
+
|
65 |
+
self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
|
66 |
+
del self.vocoder.encoder, self.vocoder.postnet
|
67 |
+
self.vocoder.eval().to(self.device) # Set model to evaluation mode
|
68 |
+
|
69 |
+
logger.info("Vocoder initialized successfully")
|
70 |
+
|
71 |
|
72 |
def execute(self, requests):
|
73 |
+
"""Execute inference on the batched requests.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
requests: List of inference requests
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
List of inference responses containing generated waveforms
|
80 |
+
"""
|
81 |
global_tokens_list, semantic_tokens_list = [], []
|
82 |
|
83 |
+
# Process each request in batch
|
84 |
for request in requests:
|
85 |
global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
|
86 |
semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
|
|
|
87 |
global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
|
88 |
semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
|
89 |
|
90 |
+
# Concatenate tokens for batch processing
|
91 |
global_tokens = torch.cat(global_tokens_list, dim=0)
|
92 |
semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
|
93 |
+
|
94 |
|
95 |
+
# Generate waveforms
|
96 |
+
with torch.no_grad():
|
97 |
+
wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))
|
98 |
|
99 |
+
# Prepare responses
|
100 |
responses = []
|
101 |
for i in range(len(requests)):
|
102 |
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
|
runtime/triton_trtllm/model_repo/vocoder/config.pbtxt
CHANGED
@@ -14,14 +14,14 @@
|
|
14 |
|
15 |
name: "vocoder"
|
16 |
backend: "python"
|
17 |
-
max_batch_size:
|
18 |
dynamic_batching {
|
19 |
-
max_queue_delay_microseconds:
|
20 |
}
|
21 |
parameters [
|
22 |
{
|
23 |
key: "model_dir",
|
24 |
-
value: {string_value:"
|
25 |
}
|
26 |
]
|
27 |
|
|
|
14 |
|
15 |
name: "vocoder"
|
16 |
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
}
|
21 |
parameters [
|
22 |
{
|
23 |
key: "model_dir",
|
24 |
+
value: {string_value:"${model_dir}"}
|
25 |
}
|
26 |
]
|
27 |
|
runtime/triton_trtllm/{build_engine.sh β scripts/build_engine.sh}
RENAMED
File without changes
|
runtime/triton_trtllm/{convert_checkpoint.py β scripts/convert_checkpoint.py}
RENAMED
File without changes
|
runtime/triton_trtllm/{fill_template.py β scripts/fill_template.py}
RENAMED
File without changes
|