yuekaiz commited on
Commit
82f7b02
·
1 Parent(s): be03ceb

update http client; launch script

Browse files
runtime/triton_trtllm/README.md CHANGED
@@ -1,38 +1,38 @@
1
- ## Triton Inference Serving Best Practice for F5 TTS
2
-
3
- ### Model Training
4
- See [official F5-TTS](https://github.com/SWivid/F5-TTS) or [Icefall F5-TTS](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech4tts/TTS#f5-tts).
5
 
6
  ### Quick Start
7
  Directly launch the service using docker compose.
8
  ```sh
9
- # VOCODER vocos or bigvgan
10
- VOCODER=vocos docker compose up
11
  ```
12
 
13
  ### Build Image
14
  Build the docker image from scratch.
15
  ```sh
16
- docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
17
  ```
18
 
19
  ### Create Docker Container
20
  ```sh
21
  your_mount_dir=/mnt:/mnt
22
- docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
23
  ```
24
 
25
  ### Export Models to TensorRT-LLM and Launch Server
26
- Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
27
 
28
  ```sh
29
- bash build_server.sh
 
 
 
 
30
  ```
31
 
32
  ### Benchmark using Dataset
33
  ```sh
34
  num_task=2
35
- python3 client.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
36
  ```
37
 
38
  ### Benchmark Results
@@ -40,8 +40,4 @@ Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
40
 
41
  | Model | Note | Concurrency | Avg Latency | RTF |
42
  |-------|-----------|-----------------------|---------|--|
43
- | F5-TTS Base (Vocos) | [Code Commit](https://github.com/yuekaizhang/sherpa/tree/329ab3c573252e835844bea38505c6b43e994cf4/triton/f5_tts) | 1 | 253 ms | 0.0394|
44
-
45
- ### Credits
46
- 1. [F5-TTS](https://github.com/SWivid/F5-TTS)
47
- 2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
 
1
+ ## Nvidia Triton Inference Serving Best Practice for Spark TTS
 
 
 
2
 
3
  ### Quick Start
4
  Directly launch the service using docker compose.
5
  ```sh
6
+ docker compose up
 
7
  ```
8
 
9
  ### Build Image
10
  Build the docker image from scratch.
11
  ```sh
12
+ docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02
13
  ```
14
 
15
  ### Create Docker Container
16
  ```sh
17
  your_mount_dir=/mnt:/mnt
18
+ docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02
19
  ```
20
 
21
  ### Export Models to TensorRT-LLM and Launch Server
22
+ Inside docker container, we would follow the official guide of TensorRT-LLM to build TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen).
23
 
24
  ```sh
25
+ bash run.sh 0 3
26
+ ```
27
+ ### Simple HTTP client
28
+ ```sh
29
+ python3 client_http.py
30
  ```
31
 
32
  ### Benchmark using Dataset
33
  ```sh
34
  num_task=2
35
+ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
36
  ```
37
 
38
  ### Benchmark Results
 
40
 
41
  | Model | Note | Concurrency | Avg Latency | RTF |
42
  |-------|-----------|-----------------------|---------|--|
43
+ | Spark-TTS-0.5B | [Code Commit]() | 4 | 253 ms | 0.0394|
 
 
 
 
runtime/triton_trtllm/build.sh CHANGED
@@ -1,28 +1,73 @@
1
 
2
- # pip install -r /workspace_yuekai/spark-tts/Spark-TTS/requirements.txt
3
- export PYTHONPATH=/workspace_yuekai/spark-tts/Spark-TTS/
 
 
 
 
 
 
 
 
 
4
 
5
  model_repo=./model_repo_test
6
- rm -rf $model_repo
7
- cp -r ./model_repo $model_repo
8
 
9
- ENGINE_PATH=/workspace_yuekai/spark-tts/TensorRT-LLM/examples/qwen/Spark-TTS-0.5B_trt_engines_1gpu_bfloat16
10
- MAX_QUEUE_DELAY_MICROSECONDS=0
11
- MODEL_DIR=/workspace_yuekai/spark-tts/Spark-TTS/pretrained_models/Spark-TTS-0.5B
12
- LLM_TOKENIZER_DIR=/workspace_yuekai/spark-tts/Spark-TTS/pretrained_models/Spark-TTS-0.5B/LLM
13
- BLS_INSTANCE_NUM=4
14
- TRITON_MAX_BATCH_SIZE=16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
17
- python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
18
- python3 scripts/fill_template.py -i ${model_repo}/spark_tts/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
19
- python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
 
 
20
 
 
 
 
 
21
 
22
- CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository ${model_repo}
23
 
 
 
 
 
24
 
25
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
 
 
1
 
2
+
3
+ export PYTHONPATH=../../../Spark-TTS/
4
+ export CUDA_VISIBLE_DEVICES=0
5
+ stage=$1
6
+ stop_stage=$2
7
+ echo "Start stage: $stage, Stop stage: $stop_stage"
8
+
9
+ huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
10
+ trt_dtype=bfloat16
11
+ trt_weights_dir=./tllm_checkpoint_${trt_dtype}
12
+ trt_engines_dir=./trt_engines_${trt_dtype}
13
 
14
  model_repo=./model_repo_test
 
 
15
 
16
+ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
17
+ echo "Downloading Spark-TTS-0.5B from HuggingFace"
18
+ hugginface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
19
+ # pip install -r /workspace_yuekai/spark-tts/Spark-TTS/requirements.txt
20
+ fi
21
+
22
+
23
+ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
24
+ echo "Converting checkpoint to TensorRT weights"
25
+ python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
26
+ --output_dir $trt_weights_dir \
27
+ --dtype $trt_dtype || exit 1
28
+
29
+ echo "Building TensorRT engines"
30
+ trtllm-build --checkpoint_dir $trt_weights_dir \
31
+ --output_dir $trt_engines_dir \
32
+ --max_batch_size 16 \
33
+ --max_num_tokens 32768 \
34
+ --gemm_plugin $trt_dtype || exit 1
35
+ fi
36
+
37
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
38
+ echo "Creating model repository"
39
+ rm -rf $model_repo
40
+ cp -r ./model_repo $model_repo
41
 
42
+ ENGINE_PATH=$trt_engines_dir
43
+ MAX_QUEUE_DELAY_MICROSECONDS=0
44
+ MODEL_DIR=$huggingface_model_local_dir
45
+ LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
46
+ BLS_INSTANCE_NUM=4
47
+ TRITON_MAX_BATCH_SIZE=16
48
 
49
+ python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
50
+ python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
51
+ python3 scripts/fill_template.py -i ${model_repo}/spark_tts/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
52
+ python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
53
 
54
+ fi
55
 
56
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
57
+ echo "Starting Triton server"
58
+ tritonserver --model-repository ${model_repo}
59
+ fi
60
 
61
 
62
+ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
63
+ echo "Running client"
64
+ num_task=4
65
+ python3 client_grpc.py \
66
+ --server-addr localhost \
67
+ --model-name spark_tts \
68
+ --num-tasks $num_task \
69
+ --log-dir ./log_${num_task}
70
+ fi
71
 
72
 
73
 
runtime/triton_trtllm/client_http.py CHANGED
@@ -2,45 +2,139 @@ import requests
2
  import soundfile as sf
3
  import json
4
  import numpy as np
 
5
 
6
- url = "http://localhost:8000/v2/models/infer_pipeline/infer"
7
- wav_path = "*********"
8
- waveform, sr = sf.read(wav_path)
9
- lang_id = 54
10
- samples = np.array([waveform], dtype=np.float32)
11
- lengths = np.array([[len(waveform)]], dtype=np.int32)
12
- lang_id = np.array([[lang_id]], dtype=np.int8)
13
-
14
- data = {
15
- "inputs":[
16
- {
17
- "name": "WAV",
18
- "shape": samples.shape,
19
- "datatype": "FP32",
20
- "data": samples.tolist()
21
- },
22
- {
23
- "name": "WAV_LENS",
24
- "shape": lengths.shape,
25
- "datatype": "INT32",
26
- "data": lengths.tolist(),
27
- },
28
- {
29
- "name": "LANG_ID",
30
- "shape": lang_id.shape,
31
- "datatype": "INT8",
32
- "data": lang_id.tolist()
33
- }
34
- ]
35
- }
36
- rsp = requests.post(
37
- url,
38
- headers={"Content-Type": "application/json"},
39
- json=data,
40
- verify=False,
41
- params={"request_id": '0'}
42
- )
43
- result = rsp.json()
44
- print(result)
45
- transcripts = result["outputs"][0]["data"][0]
46
- print(transcripts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import soundfile as sf
3
  import json
4
  import numpy as np
5
+ import argparse
6
 
7
+ def get_args():
8
+ parser = argparse.ArgumentParser(
9
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
10
+ )
11
+
12
+ parser.add_argument(
13
+ "--server-url",
14
+ type=str,
15
+ default="localhost:8000",
16
+ help="Address of the server",
17
+ )
18
+
19
+ parser.add_argument(
20
+ "--reference-audio",
21
+ type=str,
22
+ default="../../example/prompt_audio.wav",
23
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
24
+ )
25
+
26
+ parser.add_argument(
27
+ "--reference-text",
28
+ type=str,
29
+ default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
30
+ help="",
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--target-text",
35
+ type=str,
36
+ default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
37
+ help="",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--model-name",
42
+ type=str,
43
+ default="spark_tts",
44
+ choices=[
45
+ "f5_tts", "spark_tts"
46
+ ],
47
+ help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
48
+ )
49
+
50
+ parser.add_argument(
51
+ "--output-audio",
52
+ type=str,
53
+ default="output.wav",
54
+ help="Path to save the output audio",
55
+ )
56
+ return parser.parse_args()
57
+
58
+ def prepare_request(
59
+ waveform,
60
+ reference_text,
61
+ target_text,
62
+ sample_rate=16000,
63
+ padding_duration: int = None,
64
+ audio_save_dir: str = "./",
65
+ ):
66
+ assert len(waveform.shape) == 1, "waveform should be 1D"
67
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
68
+ if padding_duration:
69
+ # padding to nearset 10 seconds
70
+ samples = np.zeros(
71
+ (
72
+ 1,
73
+ padding_duration
74
+ * sample_rate
75
+ * ((int(duration) // padding_duration) + 1),
76
+ ),
77
+ dtype=np.float32,
78
+ )
79
+
80
+ samples[0, : len(waveform)] = waveform
81
+ else:
82
+ samples = waveform
83
+
84
+ samples = samples.reshape(1, -1).astype(np.float32)
85
+
86
+ data = {
87
+ "inputs":[
88
+ {
89
+ "name": "reference_wav",
90
+ "shape": samples.shape,
91
+ "datatype": "FP32",
92
+ "data": samples.tolist()
93
+ },
94
+ {
95
+ "name": "reference_wav_len",
96
+ "shape": lengths.shape,
97
+ "datatype": "INT32",
98
+ "data": lengths.tolist(),
99
+ },
100
+ {
101
+ "name": "reference_text",
102
+ "shape": [1, 1],
103
+ "datatype": "BYTES",
104
+ "data": [reference_text]
105
+ },
106
+ {
107
+ "name": "target_text",
108
+ "shape": [1, 1],
109
+ "datatype": "BYTES",
110
+ "data": [target_text]
111
+ }
112
+ ]
113
+ }
114
+
115
+ return data
116
+
117
+ if __name__ == "__main__":
118
+ args = get_args()
119
+ server_url = args.server_url
120
+ if not server_url.startswith(("http://", "https://")):
121
+ server_url = f"http://{server_url}"
122
+
123
+ url = f"{server_url}/v2/models/{args.model_name}/infer"
124
+ waveform, sr = sf.read(args.reference_audio)
125
+ assert sr == 16000, "sample rate hardcoded in server"
126
+
127
+ samples = np.array(waveform, dtype=np.float32)
128
+ data = prepare_request(samples, args.reference_text, args.target_text)
129
+
130
+ rsp = requests.post(
131
+ url,
132
+ headers={"Content-Type": "application/json"},
133
+ json=data,
134
+ verify=False,
135
+ params={"request_id": '0'}
136
+ )
137
+ result = rsp.json()
138
+ audio = result["outputs"][0]["data"]
139
+ audio = np.array(audio, dtype=np.float32)
140
+ sf.write(args.output_audio, audio, 16000, "PCM_16")
runtime/triton_trtllm/docker-compose.yml CHANGED
@@ -1,6 +1,6 @@
1
  services:
2
  tts:
3
- image: soar97/triton-f5-tts:24.12
4
  shm_size: '1gb'
5
  ports:
6
  - "8000:8000"
 
1
  services:
2
  tts:
3
+ image: soar97/triton-spark-tts:25.02
4
  shm_size: '1gb'
5
  ports:
6
  - "8000:8000"
runtime/triton_trtllm/run.sh ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ export PYTHONPATH=../../../Spark-TTS/
4
+ export CUDA_VISIBLE_DEVICES=0
5
+ stage=$1
6
+ stop_stage=$2
7
+ echo "Start stage: $stage, Stop stage: $stop_stage"
8
+
9
+ huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
10
+ trt_dtype=bfloat16
11
+ trt_weights_dir=./tllm_checkpoint_${trt_dtype}
12
+ trt_engines_dir=./trt_engines_${trt_dtype}
13
+
14
+ model_repo=./model_repo_test
15
+
16
+ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
17
+ echo "Downloading Spark-TTS-0.5B from HuggingFace"
18
+ hugginface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
19
+ # pip install -r /workspace_yuekai/spark-tts/Spark-TTS/requirements.txt
20
+ fi
21
+
22
+
23
+ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
24
+ echo "Converting checkpoint to TensorRT weights"
25
+ python convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
26
+ --output_dir $trt_weights_dir \
27
+ --dtype $trt_dtype || exit 1
28
+
29
+ echo "Building TensorRT engines"
30
+ trtllm-build --checkpoint_dir $trt_weights_dir \
31
+ --output_dir $trt_engines_dir \
32
+ --max_batch_size 16 \
33
+ --max_num_tokens 32768 \
34
+ --gemm_plugin $trt_dtype || exit 1
35
+ fi
36
+
37
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
38
+ echo "Creating model repository"
39
+ rm -rf $model_repo
40
+ cp -r ./model_repo $model_repo
41
+
42
+ ENGINE_PATH=$trt_engines_dir
43
+ MAX_QUEUE_DELAY_MICROSECONDS=0
44
+ MODEL_DIR=$huggingface_model_local_dir
45
+ LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
46
+ BLS_INSTANCE_NUM=4
47
+ TRITON_MAX_BATCH_SIZE=16
48
+
49
+ python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
50
+ python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
51
+ python3 scripts/fill_template.py -i ${model_repo}/spark_tts/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
52
+ python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
53
+
54
+ fi
55
+
56
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
57
+ echo "Starting Triton server"
58
+ tritonserver --model-repository ${model_repo}
59
+ fi
60
+
61
+
62
+ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
63
+ echo "Running client"
64
+ num_task=4
65
+ python3 client_grpc.py \
66
+ --server-addr localhost \
67
+ --model-name spark_tts \
68
+ --num-tasks $num_task \
69
+ --log-dir ./log_${num_task}
70
+ fi
71
+
72
+
73
+
74
+
75
+
76
+
runtime/triton_trtllm/scripts/build_engine.sh DELETED
@@ -1,46 +0,0 @@
1
-
2
-
3
-
4
- # model_dir=./Qwen2.5-0.5B-Instruct/
5
- # output_dir=./tllm_checkpoint_1gpu_fp16
6
- # trt_engines_dir=./trt_engines
7
-
8
- model_dir=/workspace_yuekai/spark-tts/Spark-TTS/pretrained_models/Spark-TTS-0.5B/LLM
9
- base_name=Spark-TTS-0.5B
10
- dtype=bfloat16
11
- output_dir=./${base_name}_tllm_checkpoint_1gpu_${dtype}
12
- trt_engines_dir=./${base_name}_trt_engines_1gpu_${dtype}
13
-
14
-
15
- # python convert_checkpoint.py --model_dir $model_dir \
16
- # --output_dir $output_dir \
17
- # --dtype $dtype || exit 1
18
-
19
- trtllm-build --checkpoint_dir $output_dir \
20
- --output_dir $trt_engines_dir \
21
- --max_batch_size 16 \
22
- --max_num_tokens 32768 \
23
- --gemm_plugin $dtype || exit 1
24
- # trtllm-build --checkpoint_dir $output_dir \
25
- # --output_dir $trt_engines_dir \
26
- # --max_batch_size 16 \
27
- # --max_num_tokens 32768 \
28
- # --gemm_plugin $dtype || exit 1
29
-
30
- python3 ../run.py --input_file /workspace_yuekai/spark-tts/Spark-TTS/model_inputs.npy \
31
- --max_output_len=1500 \
32
- --tokenizer_dir $model_dir \
33
- --top_k 50 \
34
- --top_p 0.95 \
35
- --temperature 0.8 \
36
- --output_npy ./output.npy \
37
- --engine_dir=$trt_engines_dir || exit 1
38
-
39
-
40
- # python3 ../run.py --input_file /workspace_yuekai/spark-tts/Spark-TTS/model_inputs.npy \
41
- # --max_output_len=1500 \
42
- # --tokenizer_dir $model_dir \
43
- # --top_k 50 \
44
- # --top_p 0.95 \
45
- # --temperature 0.8 \
46
- # --engine_dir=$trt_engines_dir || exit 1