Yuekai Zhang commited on
Commit
40280a8
·
1 Parent(s): 1f70431

add benchmark code

Browse files
src/f5_tts/runtime/triton_trtllm/README.md CHANGED
@@ -30,18 +30,40 @@ bash run.sh 0 4 F5TTS_Base
30
  python3 client_http.py
31
  ```
32
 
33
- ### Benchmark using Dataset
34
  ```sh
35
  num_task=2
36
  python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
37
  ```
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ### Benchmark Results
40
  Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
41
 
42
- | Model | Concurrency | Avg Latency | RTF |
43
- |-------|-------------|----------------|-------|
44
- | F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
 
 
45
 
46
  ### Credits
47
  1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
 
30
  python3 client_http.py
31
  ```
32
 
33
+ ### Benchmark using Client-Server Mode
34
  ```sh
35
  num_task=2
36
  python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
37
  ```
38
 
39
+ ### Benchmark using Offline TRT-LLM Mode
40
+ ```sh
41
+ batch_size=1
42
+ split_name=wenetspeech4tts
43
+ backend_type=trt
44
+ log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
45
+ rm -r $log_dir
46
+ ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
47
+ torchrun --nproc_per_node=1 \
48
+ benchmark.py --output-dir $log_dir \
49
+ --batch-size $batch_size \
50
+ --enable-warmup \
51
+ --split-name $split_name \
52
+ --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
53
+ --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
54
+ --vocoder-trt-engine-path $vocoder_trt_engine_path \
55
+ --backend-type $backend_type \
56
+ --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
57
+ ```
58
+
59
  ### Benchmark Results
60
  Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
61
 
62
+ | Model | Concurrency | Avg Latency | RTF | Mode |
63
+ |-------|-------------|----------------|-------|------|
64
+ | F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394|Client-Server|
65
+ | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402|Offline TRT-LLM|
66
+ | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467|Offline Pytorch|
67
 
68
  ### Credits
69
  1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
src/f5_tts/runtime/triton_trtllm/benchmark.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
2
+ # 2025 (authors: Yuekai Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
16
+ """ Example Usage
17
+ split=test_zh
18
+ llm_path=f5-tts/exp_zh/checkpoint-805000
19
+ huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
20
+ model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
21
+ huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
22
+ vocoder=./bigvgan_v2_24khz_100band_256x
23
+ torchrun --nproc_per_node=2 \
24
+ f5-tts/infer_dist.py \
25
+ --output_dir $output_dir \
26
+ --batch_size 1 \
27
+ --num_workers 2 \
28
+ --llm-model-name-or-path $llm_path \
29
+ --flow-matching-model-path $model_path \
30
+ --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
31
+ --use-cosyvoice-semantic-token True \
32
+ --vocoder-dir $vocoder \
33
+ --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
34
+ --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
35
+ """
36
+
37
+ import argparse
38
+ import json
39
+ import os
40
+ import time
41
+ from typing import List, Dict, Union
42
+
43
+ import torch
44
+ import torch.distributed as dist
45
+ import torch.nn.functional as F
46
+ from torch.nn.utils.rnn import pad_sequence
47
+ import torchaudio
48
+ import jieba
49
+ from pypinyin import Style, lazy_pinyin
50
+ from datasets import load_dataset
51
+ import datasets
52
+ from huggingface_hub import hf_hub_download
53
+ from torch.utils.data import DataLoader, DistributedSampler
54
+ from tqdm import tqdm
55
+ from vocos import Vocos
56
+ from f5_tts_trtllm import F5TTS
57
+ import tensorrt as trt
58
+ from tensorrt_llm.runtime.session import Session, TensorInfo
59
+ from tensorrt_llm.logger import logger
60
+ from tensorrt_llm._utils import trt_dtype_to_torch
61
+
62
+ torch.manual_seed(0)
63
+
64
+
65
+ def get_args():
66
+ parser = argparse.ArgumentParser(description="extract speech code")
67
+ parser.add_argument(
68
+ "--split-name",
69
+ type=str,
70
+ default="wenetspeech4tts",
71
+ choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
72
+ help="huggingface dataset split name",
73
+ )
74
+ parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
75
+ parser.add_argument(
76
+ "--vocab-file",
77
+ required=True,
78
+ type=str,
79
+ help="vocab file",
80
+ )
81
+ parser.add_argument(
82
+ "--model-path",
83
+ required=True,
84
+ type=str,
85
+ help="model path, to load text embedding",
86
+ )
87
+ parser.add_argument(
88
+ "--tllm-model-dir",
89
+ required=True,
90
+ type=str,
91
+ help="tllm model dir",
92
+ )
93
+ parser.add_argument(
94
+ "--batch-size",
95
+ required=True,
96
+ type=int,
97
+ help="batch size (per-device) for inference",
98
+ )
99
+ parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
100
+ parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
101
+ parser.add_argument(
102
+ "--vocoder",
103
+ default="vocos",
104
+ type=str,
105
+ help="vocoder name",
106
+ )
107
+ parser.add_argument(
108
+ "--vocoder-trt-engine-path",
109
+ default=None,
110
+ type=str,
111
+ help="vocoder trt engine path",
112
+ )
113
+ parser.add_argument("--enable-warmup", action="store_true")
114
+ parser.add_argument("--remove-input-padding", action="store_true")
115
+ parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
116
+ parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
117
+ args = parser.parse_args()
118
+ return args
119
+
120
+
121
+ def padded_mel_batch(ref_mels, max_seq_len):
122
+ padded_ref_mels = []
123
+ for mel in ref_mels:
124
+ # pad along the last dimension
125
+ padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
126
+ padded_ref_mels.append(padded_ref_mel)
127
+ padded_ref_mels = torch.stack(padded_ref_mels)
128
+ return padded_ref_mels
129
+
130
+
131
+ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
132
+ if use_perf:
133
+ torch.cuda.nvtx.range_push("data_collator")
134
+ target_sample_rate = 24000
135
+ target_rms = 0.1
136
+ ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
137
+ [],
138
+ [],
139
+ [],
140
+ [],
141
+ [],
142
+ )
143
+ for i, item in enumerate(batch):
144
+ item_id, prompt_text, target_text = (
145
+ item["id"],
146
+ item["prompt_text"],
147
+ item["target_text"],
148
+ )
149
+ ids.append(item_id)
150
+ reference_target_texts_list.append(prompt_text + target_text)
151
+
152
+ ref_audio_org, ref_sr = (
153
+ item["prompt_audio"]["array"],
154
+ item["prompt_audio"]["sampling_rate"],
155
+ )
156
+ ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
157
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
158
+ if ref_rms < target_rms:
159
+ ref_audio_org = ref_audio_org * target_rms / ref_rms
160
+
161
+ if ref_sr != target_sample_rate:
162
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
163
+ ref_audio = resampler(ref_audio_org)
164
+ else:
165
+ ref_audio = ref_audio_org
166
+
167
+ if use_perf:
168
+ torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
169
+ ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
170
+ if use_perf:
171
+ torch.cuda.nvtx.range_pop()
172
+ ref_mel = ref_mel.squeeze()
173
+ ref_mel_len = ref_mel.shape[0]
174
+ assert ref_mel.shape[1] == 100
175
+
176
+ ref_mel_list.append(ref_mel)
177
+ ref_mel_len_list.append(ref_mel_len)
178
+
179
+ estimated_reference_target_mel_len.append(int(ref_mel.shape[0] * (1 + len(target_text) / len(prompt_text))))
180
+
181
+ max_seq_len = max(estimated_reference_target_mel_len)
182
+ ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
183
+ ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
184
+
185
+ pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
186
+ text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
187
+
188
+ for i, item in enumerate(text_pad_sequence):
189
+ text_pad_sequence[i] = F.pad(
190
+ item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
191
+ )
192
+ text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
193
+ text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
194
+ text_pad_sequence = F.pad(
195
+ text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
196
+ )
197
+ if use_perf:
198
+ torch.cuda.nvtx.range_pop()
199
+ return {
200
+ "ids": ids,
201
+ "ref_mel_batch": ref_mel_batch,
202
+ "ref_mel_len_batch": ref_mel_len_batch,
203
+ "text_pad_sequence": text_pad_sequence,
204
+ "estimated_reference_target_mel_len": estimated_reference_target_mel_len,
205
+ }
206
+
207
+
208
+ def init_distributed():
209
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
210
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
211
+ rank = int(os.environ.get("RANK", 0))
212
+ print(
213
+ "Inference on multiple gpus, this gpu {}".format(local_rank)
214
+ + ", rank {}, world_size {}".format(rank, world_size)
215
+ )
216
+ torch.cuda.set_device(local_rank)
217
+ # Initialize process group with explicit device IDs
218
+ dist.init_process_group(
219
+ "nccl",
220
+ )
221
+ return world_size, local_rank, rank
222
+
223
+
224
+ def get_tokenizer(vocab_file_path: str):
225
+ """
226
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
227
+ - "char" for char-wise tokenizer, need .txt vocab_file
228
+ - "byte" for utf-8 tokenizer
229
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
230
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
231
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
232
+ - if use "byte", set to 256 (unicode byte range)
233
+ """
234
+ with open(vocab_file_path, "r", encoding="utf-8") as f:
235
+ vocab_char_map = {}
236
+ for i, char in enumerate(f):
237
+ vocab_char_map[char[:-1]] = i
238
+ vocab_size = len(vocab_char_map)
239
+ return vocab_char_map, vocab_size
240
+
241
+
242
+ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
243
+ final_reference_target_texts_list = []
244
+ custom_trans = str.maketrans(
245
+ {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
246
+ ) # add custom trans here, to address oov
247
+
248
+ def is_chinese(c):
249
+ return "\u3100" <= c <= "\u9fff" # common chinese characters
250
+
251
+ for text in reference_target_texts_list:
252
+ char_list = []
253
+ text = text.translate(custom_trans)
254
+ for seg in jieba.cut(text):
255
+ seg_byte_len = len(bytes(seg, "UTF-8"))
256
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
257
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
258
+ char_list.append(" ")
259
+ char_list.extend(seg)
260
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
261
+ seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
262
+ for i, c in enumerate(seg):
263
+ if is_chinese(c):
264
+ char_list.append(" ")
265
+ char_list.append(seg_[i])
266
+ else: # if mixed characters, alphabets and symbols
267
+ for c in seg:
268
+ if ord(c) < 256:
269
+ char_list.extend(c)
270
+ elif is_chinese(c):
271
+ char_list.append(" ")
272
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
273
+ else:
274
+ char_list.append(c)
275
+ final_reference_target_texts_list.append(char_list)
276
+
277
+ return final_reference_target_texts_list
278
+
279
+
280
+ def list_str_to_idx(
281
+ text: Union[List[str], List[List[str]]],
282
+ vocab_char_map: Dict[str, int], # {char: idx}
283
+ padding_value=-1,
284
+ ):
285
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
286
+ # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
287
+ return list_idx_tensors
288
+
289
+
290
+ def load_vocoder(
291
+ vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
292
+ ):
293
+ if vocoder_name == "vocos":
294
+ if vocoder_trt_engine_path is not None:
295
+ vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
296
+ else:
297
+ # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
298
+ if is_local:
299
+ print(f"Load vocos from local path {local_path}")
300
+ config_path = f"{local_path}/config.yaml"
301
+ model_path = f"{local_path}/pytorch_model.bin"
302
+ else:
303
+ print("Download Vocos from huggingface charactr/vocos-mel-24khz")
304
+ repo_id = "charactr/vocos-mel-24khz"
305
+ config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
306
+ model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
307
+ vocoder = Vocos.from_hparams(config_path)
308
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
309
+ from vocos.feature_extractors import EncodecFeatures
310
+
311
+ if isinstance(vocoder.feature_extractor, EncodecFeatures):
312
+ encodec_parameters = {
313
+ "feature_extractor.encodec." + key: value
314
+ for key, value in vocoder.feature_extractor.encodec.state_dict().items()
315
+ }
316
+ state_dict.update(encodec_parameters)
317
+ vocoder.load_state_dict(state_dict)
318
+ vocoder = vocoder.eval().to(device)
319
+ elif vocoder_name == "bigvgan":
320
+ raise NotImplementedError("BigVGAN is not implemented yet")
321
+ return vocoder
322
+
323
+
324
+ def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
325
+ if vocoder == "vocos":
326
+ mel_stft = torchaudio.transforms.MelSpectrogram(
327
+ sample_rate=24000,
328
+ n_fft=1024,
329
+ win_length=1024,
330
+ hop_length=256,
331
+ n_mels=100,
332
+ power=1,
333
+ center=True,
334
+ normalized=False,
335
+ norm=None,
336
+ ).to(device)
337
+ mel = mel_stft(waveform.to(device))
338
+ mel = mel.clamp(min=1e-5).log()
339
+ return mel.transpose(1, 2)
340
+
341
+
342
+ class VocosTensorRT:
343
+ def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
344
+ TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
345
+ trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
346
+ logger.info(f"Loading vae engine from {engine_path}")
347
+ self.engine_path = engine_path
348
+ with open(engine_path, "rb") as f:
349
+ engine_buffer = f.read()
350
+ self.session = Session.from_serialized_engine(engine_buffer)
351
+ self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
352
+
353
+ def decode(self, mels):
354
+ mels = mels.contiguous()
355
+ inputs = {"mel": mels}
356
+ output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
357
+ outputs = {
358
+ t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
359
+ }
360
+ ok = self.session.run(inputs, outputs, self.stream)
361
+
362
+ assert ok, "Runtime execution failed for vae session"
363
+
364
+ samples = outputs["waveform"]
365
+ return samples
366
+
367
+
368
+ def main():
369
+ args = get_args()
370
+ os.makedirs(args.output_dir, exist_ok=True)
371
+
372
+ assert torch.cuda.is_available()
373
+ world_size, local_rank, rank = init_distributed()
374
+ device = torch.device(f"cuda:{local_rank}")
375
+
376
+ vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
377
+
378
+ tllm_model_dir = args.tllm_model_dir
379
+ config_file = os.path.join(tllm_model_dir, "config.json")
380
+ with open(config_file) as f:
381
+ config = json.load(f)
382
+ if args.backend_type == "trt":
383
+ model = F5TTS(
384
+ config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
385
+ )
386
+ elif args.backend_type == "pytorch":
387
+ import sys
388
+
389
+ sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
390
+ from f5_tts.model import DiT
391
+ from f5_tts.infer.utils_infer import load_model
392
+
393
+ F5TTS_model_cfg = dict(
394
+ dim=1024,
395
+ depth=22,
396
+ heads=16,
397
+ ff_mult=2,
398
+ text_dim=512,
399
+ conv_layers=4,
400
+ pe_attn_head=1,
401
+ text_mask_padding=False,
402
+ )
403
+ model = load_model(DiT, F5TTS_model_cfg, args.model_path)
404
+
405
+ vocoder = load_vocoder(
406
+ vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
407
+ )
408
+
409
+ dataset = load_dataset(
410
+ "yuekai/seed_tts",
411
+ split=args.split_name,
412
+ trust_remote_code=True,
413
+ )
414
+
415
+ def add_estimated_duration(example):
416
+ prompt_audio_len = example["prompt_audio"]["array"].shape[0]
417
+ scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
418
+ estimated_duration = prompt_audio_len * scale_factor
419
+ example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
420
+ return example
421
+
422
+ dataset = dataset.map(add_estimated_duration)
423
+ dataset = dataset.sort("estimated_duration", reverse=True)
424
+ if args.use_perf:
425
+ # dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
426
+ dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
427
+ # dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
428
+ # dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
429
+ dataset = datasets.concatenate_datasets(dataset_list_short)
430
+ if world_size > 1:
431
+ sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
432
+ else:
433
+ # This would disable shuffling
434
+ sampler = None
435
+
436
+ dataloader = DataLoader(
437
+ dataset,
438
+ batch_size=args.batch_size,
439
+ sampler=sampler,
440
+ shuffle=False,
441
+ num_workers=args.num_workers,
442
+ prefetch_factor=args.prefetch,
443
+ collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
444
+ )
445
+
446
+ total_steps = len(dataset)
447
+
448
+ if args.enable_warmup:
449
+ for batch in dataloader:
450
+ ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
451
+ text_pad_seq = batch["text_pad_sequence"].to(device)
452
+ total_mel_lens = batch["estimated_reference_target_mel_len"]
453
+ if args.backend_type == "trt":
454
+ _ = model.sample(
455
+ text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
456
+ )
457
+ elif args.backend_type == "pytorch":
458
+ with torch.inference_mode():
459
+ text_pad_seq -= 1
460
+ text_pad_seq[text_pad_seq == -2] = -1
461
+ total_mel_lens = torch.tensor(total_mel_lens, device=device)
462
+ generated, _ = model.sample(
463
+ cond=ref_mels,
464
+ text=text_pad_seq,
465
+ duration=total_mel_lens,
466
+ steps=16,
467
+ cfg_strength=2.0,
468
+ sway_sampling_coef=-1,
469
+ )
470
+
471
+ if rank == 0:
472
+ progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
473
+
474
+ decoding_time = 0
475
+ vocoder_time = 0
476
+ total_duration = 0
477
+ if args.use_perf:
478
+ torch.cuda.cudart().cudaProfilerStart()
479
+ total_decoding_time = time.time()
480
+ for batch in dataloader:
481
+ if args.use_perf:
482
+ torch.cuda.nvtx.range_push("data sample")
483
+ ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
484
+ text_pad_seq = batch["text_pad_sequence"].to(device)
485
+ total_mel_lens = batch["estimated_reference_target_mel_len"]
486
+
487
+ if args.use_perf:
488
+ torch.cuda.nvtx.range_pop()
489
+ if args.backend_type == "trt":
490
+ generated, cost_time = model.sample(
491
+ text_pad_seq,
492
+ ref_mels,
493
+ ref_mel_lens,
494
+ total_mel_lens,
495
+ remove_input_padding=args.remove_input_padding,
496
+ use_perf=args.use_perf,
497
+ )
498
+ elif args.backend_type == "pytorch":
499
+ total_mel_lens = torch.tensor(total_mel_lens, device=device)
500
+ with torch.inference_mode():
501
+ start_time = time.time()
502
+ text_pad_seq -= 1
503
+ text_pad_seq[text_pad_seq == -2] = -1
504
+ generated, _ = model.sample(
505
+ cond=ref_mels,
506
+ text=text_pad_seq,
507
+ duration=total_mel_lens,
508
+ lens=ref_mel_lens,
509
+ steps=16,
510
+ cfg_strength=2.0,
511
+ sway_sampling_coef=-1,
512
+ )
513
+ cost_time = time.time() - start_time
514
+ decoding_time += cost_time
515
+ vocoder_start_time = time.time()
516
+ for i, gen in enumerate(generated):
517
+ gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
518
+ gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
519
+ if args.vocoder == "vocos":
520
+ if args.use_perf:
521
+ torch.cuda.nvtx.range_push("vocoder decode")
522
+ generated_wave = vocoder.decode(gen_mel_spec).cpu()
523
+ if args.use_perf:
524
+ torch.cuda.nvtx.range_pop()
525
+ else:
526
+ generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
527
+ target_rms = 0.1
528
+ target_sample_rate = 24_000
529
+ # if ref_rms_list[i] < target_rms:
530
+ # generated_wave = generated_wave * ref_rms_list[i] / target_rms
531
+ rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
532
+ if rms < target_rms:
533
+ generated_wave = generated_wave * target_rms / rms
534
+ utt = batch["ids"][i]
535
+ torchaudio.save(
536
+ f"{args.output_dir}/{utt}.wav",
537
+ generated_wave,
538
+ target_sample_rate,
539
+ )
540
+ total_duration += generated_wave.shape[1] / target_sample_rate
541
+ vocoder_time += time.time() - vocoder_start_time
542
+ if rank == 0:
543
+ progress_bar.update(world_size * len(batch["ids"]))
544
+ total_decoding_time = time.time() - total_decoding_time
545
+ if rank == 0:
546
+ progress_bar.close()
547
+ rtf = total_decoding_time / total_duration
548
+ s = f"RTF: {rtf:.4f}\n"
549
+ s += f"total_duration: {total_duration:.3f} seconds\n"
550
+ s += f"({total_duration / 3600:.2f} hours)\n"
551
+ s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
552
+ s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
553
+ s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
554
+ s += f"batch size: {args.batch_size}\n"
555
+ print(s)
556
+
557
+ with open(f"{args.output_dir}/rtf.txt", "w") as f:
558
+ f.write(s)
559
+
560
+ dist.barrier()
561
+ dist.destroy_process_group()
562
+
563
+
564
+ if __name__ == "__main__":
565
+ main()
src/f5_tts/runtime/triton_trtllm/requirements-pytorch.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.33.0
2
+ bitsandbytes>0.37.0
3
+ cached_path
4
+ click
5
+ datasets
6
+ ema_pytorch>=0.5.2
7
+ gradio>=3.45.2
8
+ hydra-core>=1.3.0
9
+ jieba
10
+ librosa
11
+ matplotlib
12
+ numpy<=1.26.4
13
+ pydub
14
+ pypinyin
15
+ safetensors
16
+ soundfile
17
+ tomli
18
+ torch>=2.0.0
19
+ # torchaudio>=2.0.0
20
+ torchdiffeq
21
+ tqdm>=4.65.0
22
+ transformers
23
+ x_transformers>=1.31.14
24
+ packaging>=24.2
src/f5_tts/runtime/triton_trtllm/run.sh CHANGED
@@ -2,8 +2,8 @@ stage=$1
2
  stop_stage=$2
3
  model=$3 # F5TTS_Base
4
  if [ -z "$model" ]; then
5
- echo "Model is none"
6
- exit 1
7
  fi
8
  echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
9
  export CUDA_VISIBLE_DEVICES=0
@@ -68,3 +68,43 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
68
  target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
69
  python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
70
  fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  stop_stage=$2
3
  model=$3 # F5TTS_Base
4
  if [ -z "$model" ]; then
5
+ echo "Model is none, using default model F5TTS_Base"
6
+ model=F5TTS_Base
7
  fi
8
  echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
9
  export CUDA_VISIBLE_DEVICES=0
 
68
  target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
69
  python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
70
  fi
71
+
72
+ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
73
+ echo "TRT-LLM: offline decoding benchmark test"
74
+ batch_size=1
75
+ split_name=wenetspeech4tts
76
+ backend_type=trt
77
+ log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
78
+ rm -r $log_dir
79
+ ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
80
+ torchrun --nproc_per_node=1 \
81
+ benchmark.py --output-dir $log_dir \
82
+ --batch-size $batch_size \
83
+ --enable-warmup \
84
+ --split-name $split_name \
85
+ --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
86
+ --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
87
+ --vocoder-trt-engine-path $vocoder_trt_engine_path \
88
+ --backend-type $backend_type \
89
+ --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
90
+ fi
91
+
92
+ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
93
+ echo "Native Pytorch: offline decoding benchmark test"
94
+ pip install -r requirements-pytorch.txt
95
+ batch_size=1
96
+ split_name=wenetspeech4tts
97
+ backend_type=pytorch
98
+ log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
99
+ rm -r $log_dir
100
+ ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
101
+ torchrun --nproc_per_node=1 \
102
+ benchmark.py --output-dir $log_dir \
103
+ --batch-size $batch_size \
104
+ --split-name $split_name \
105
+ --enable-warmup \
106
+ --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
107
+ --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
108
+ --backend-type $backend_type \
109
+ --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
110
+ fi