# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) # 2025 (authors: Yuekai Zhang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py """ Example Usage torchrun --nproc_per_node=1 \ benchmark.py --output-dir $log_dir \ --batch-size $batch_size \ --enable-warmup \ --split-name $split_name \ --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \ --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \ --vocoder-trt-engine-path $vocoder_trt_engine_path \ --backend-type $backend_type \ --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1 """ import argparse import json import os import time from typing import List, Dict, Union import torch import torch.distributed as dist import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence import torchaudio import jieba from pypinyin import Style, lazy_pinyin from datasets import load_dataset import datasets from huggingface_hub import hf_hub_download from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm from vocos import Vocos from f5_tts_trtllm import F5TTS import tensorrt as trt from tensorrt_llm.runtime.session import Session, TensorInfo from tensorrt_llm.logger import logger from tensorrt_llm._utils import trt_dtype_to_torch torch.manual_seed(0) def get_args(): parser = argparse.ArgumentParser(description="extract speech code") parser.add_argument( "--split-name", type=str, default="wenetspeech4tts", choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], help="huggingface dataset split name", ) parser.add_argument("--output-dir", required=True, type=str, help="dir to save result") parser.add_argument( "--vocab-file", required=True, type=str, help="vocab file", ) parser.add_argument( "--model-path", required=True, type=str, help="model path, to load text embedding", ) parser.add_argument( "--tllm-model-dir", required=True, type=str, help="tllm model dir", ) parser.add_argument( "--batch-size", required=True, type=int, help="batch size (per-device) for inference", ) parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader") parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader") parser.add_argument( "--vocoder", default="vocos", type=str, help="vocoder name", ) parser.add_argument( "--vocoder-trt-engine-path", default=None, type=str, help="vocoder trt engine path", ) parser.add_argument("--enable-warmup", action="store_true") parser.add_argument("--remove-input-padding", action="store_true") parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance") parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type") args = parser.parse_args() return args def padded_mel_batch(ref_mels, max_seq_len): padded_ref_mels = [] for mel in ref_mels: # pad along the last dimension padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0) padded_ref_mels.append(padded_ref_mel) padded_ref_mels = torch.stack(padded_ref_mels) return padded_ref_mels def data_collator(batch, vocab_char_map, device="cuda", use_perf=False): if use_perf: torch.cuda.nvtx.range_push("data_collator") target_sample_rate = 24000 target_rms = 0.1 ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = ( [], [], [], [], [], ) for i, item in enumerate(batch): item_id, prompt_text, target_text = ( item["id"], item["prompt_text"], item["target_text"], ) ids.append(item_id) reference_target_texts_list.append(prompt_text + target_text) ref_audio_org, ref_sr = ( item["prompt_audio"]["array"], item["prompt_audio"]["sampling_rate"], ) ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) if ref_rms < target_rms: ref_audio_org = ref_audio_org * target_rms / ref_rms if ref_sr != target_sample_rate: resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) ref_audio = resampler(ref_audio_org) else: ref_audio = ref_audio_org if use_perf: torch.cuda.nvtx.range_push(f"mel_spectrogram {i}") ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda") if use_perf: torch.cuda.nvtx.range_pop() ref_mel = ref_mel.squeeze() ref_mel_len = ref_mel.shape[0] assert ref_mel.shape[1] == 100 ref_mel_list.append(ref_mel) ref_mel_len_list.append(ref_mel_len) estimated_reference_target_mel_len.append(int(ref_mel.shape[0] * (1 + len(target_text) / len(prompt_text)))) max_seq_len = max(estimated_reference_target_mel_len) ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len) ref_mel_len_batch = torch.LongTensor(ref_mel_len_list) pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map) for i, item in enumerate(text_pad_sequence): text_pad_sequence[i] = F.pad( item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1 ) text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device) text_pad_sequence = F.pad( text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1 ) if use_perf: torch.cuda.nvtx.range_pop() return { "ids": ids, "ref_mel_batch": ref_mel_batch, "ref_mel_len_batch": ref_mel_len_batch, "text_pad_sequence": text_pad_sequence, "estimated_reference_target_mel_len": estimated_reference_target_mel_len, } def init_distributed(): world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) rank = int(os.environ.get("RANK", 0)) print( "Inference on multiple gpus, this gpu {}".format(local_rank) + ", rank {}, world_size {}".format(rank, world_size) ) torch.cuda.set_device(local_rank) # Initialize process group with explicit device IDs dist.init_process_group( "nccl", ) return world_size, local_rank, rank def get_tokenizer(vocab_file_path: str): """ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - "char" for char-wise tokenizer, need .txt vocab_file - "byte" for utf-8 tokenizer - "custom" if you're directly passing in a path to the vocab.txt you want to use vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - if use "char", derived from unfiltered character & symbol counts of custom dataset - if use "byte", set to 256 (unicode byte range) """ with open(vocab_file_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): vocab_char_map[char[:-1]] = i vocab_size = len(vocab_char_map) return vocab_char_map, vocab_size def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): final_reference_target_texts_list = [] custom_trans = str.maketrans( {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} ) # add custom trans here, to address oov def is_chinese(c): return "\u3100" <= c <= "\u9fff" # common chinese characters for text in reference_target_texts_list: char_list = [] text = text.translate(custom_trans) for seg in jieba.cut(text): seg_byte_len = len(bytes(seg, "UTF-8")) if seg_byte_len == len(seg): # if pure alphabets and symbols if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): char_list.append(" ") char_list.append(seg_[i]) else: # if mixed characters, alphabets and symbols for c in seg: if ord(c) < 256: char_list.extend(c) elif is_chinese(c): char_list.append(" ") char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) else: char_list.append(c) final_reference_target_texts_list.append(char_list) return final_reference_target_texts_list def list_str_to_idx( text: Union[List[str], List[List[str]]], vocab_char_map: Dict[str, int], # {char: idx} padding_value=-1, ): list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return list_idx_tensors def load_vocoder( vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None ): if vocoder_name == "vocos": if vocoder_trt_engine_path is not None: vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path) else: # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) if is_local: print(f"Load vocos from local path {local_path}") config_path = f"{local_path}/config.yaml" model_path = f"{local_path}/pytorch_model.bin" else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") repo_id = "charactr/vocos-mel-24khz" config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) from vocos.feature_extractors import EncodecFeatures if isinstance(vocoder.feature_extractor, EncodecFeatures): encodec_parameters = { "feature_extractor.encodec." + key: value for key, value in vocoder.feature_extractor.encodec.state_dict().items() } state_dict.update(encodec_parameters) vocoder.load_state_dict(state_dict) vocoder = vocoder.eval().to(device) elif vocoder_name == "bigvgan": raise NotImplementedError("BigVGAN is not implemented yet") return vocoder def mel_spectrogram(waveform, vocoder="vocos", device="cuda"): if vocoder == "vocos": mel_stft = torchaudio.transforms.MelSpectrogram( sample_rate=24000, n_fft=1024, win_length=1024, hop_length=256, n_mels=100, power=1, center=True, normalized=False, norm=None, ).to(device) mel = mel_stft(waveform.to(device)) mel = mel.clamp(min=1e-5).log() return mel.transpose(1, 2) class VocosTensorRT: def __init__(self, engine_path="./vocos_vocoder.plan", stream=None): TRT_LOGGER = trt.Logger(trt.Logger.WARNING) trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="") logger.info(f"Loading vae engine from {engine_path}") self.engine_path = engine_path with open(engine_path, "rb") as f: engine_buffer = f.read() self.session = Session.from_serialized_engine(engine_buffer) self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream def decode(self, mels): mels = mels.contiguous() inputs = {"mel": mels} output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)]) outputs = { t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info } ok = self.session.run(inputs, outputs, self.stream) assert ok, "Runtime execution failed for vae session" samples = outputs["waveform"] return samples def main(): args = get_args() os.makedirs(args.output_dir, exist_ok=True) assert torch.cuda.is_available() world_size, local_rank, rank = init_distributed() device = torch.device(f"cuda:{local_rank}") vocab_char_map, vocab_size = get_tokenizer(args.vocab_file) tllm_model_dir = args.tllm_model_dir config_file = os.path.join(tllm_model_dir, "config.json") with open(config_file) as f: config = json.load(f) if args.backend_type == "trt": model = F5TTS( config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size ) elif args.backend_type == "pytorch": import sys sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/") from f5_tts.model import DiT from f5_tts.infer.utils_infer import load_model F5TTS_model_cfg = dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, pe_attn_head=1, text_mask_padding=False, ) model = load_model(DiT, F5TTS_model_cfg, args.model_path) vocoder = load_vocoder( vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path ) dataset = load_dataset( "yuekai/seed_tts", split=args.split_name, trust_remote_code=True, ) def add_estimated_duration(example): prompt_audio_len = example["prompt_audio"]["array"].shape[0] scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"]) estimated_duration = prompt_audio_len * scale_factor example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"] return example dataset = dataset.map(add_estimated_duration) dataset = dataset.sort("estimated_duration", reverse=True) if args.use_perf: # dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000 dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719 # dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002 # dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long) dataset = datasets.concatenate_datasets(dataset_list_short) if world_size > 1: sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) else: # This would disable shuffling sampler = None dataloader = DataLoader( dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False, num_workers=args.num_workers, prefetch_factor=args.prefetch, collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf), ) total_steps = len(dataset) if args.enable_warmup: for batch in dataloader: ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) text_pad_seq = batch["text_pad_sequence"].to(device) total_mel_lens = batch["estimated_reference_target_mel_len"] if args.backend_type == "trt": _ = model.sample( text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding ) elif args.backend_type == "pytorch": with torch.inference_mode(): text_pad_seq -= 1 text_pad_seq[text_pad_seq == -2] = -1 total_mel_lens = torch.tensor(total_mel_lens, device=device) generated, _ = model.sample( cond=ref_mels, text=text_pad_seq, duration=total_mel_lens, steps=16, cfg_strength=2.0, sway_sampling_coef=-1, ) if rank == 0: progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") decoding_time = 0 vocoder_time = 0 total_duration = 0 if args.use_perf: torch.cuda.cudart().cudaProfilerStart() total_decoding_time = time.time() for batch in dataloader: if args.use_perf: torch.cuda.nvtx.range_push("data sample") ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) text_pad_seq = batch["text_pad_sequence"].to(device) total_mel_lens = batch["estimated_reference_target_mel_len"] if args.use_perf: torch.cuda.nvtx.range_pop() if args.backend_type == "trt": generated, cost_time = model.sample( text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding, use_perf=args.use_perf, ) elif args.backend_type == "pytorch": total_mel_lens = torch.tensor(total_mel_lens, device=device) with torch.inference_mode(): start_time = time.time() text_pad_seq -= 1 text_pad_seq[text_pad_seq == -2] = -1 generated, _ = model.sample( cond=ref_mels, text=text_pad_seq, duration=total_mel_lens, lens=ref_mel_lens, steps=16, cfg_strength=2.0, sway_sampling_coef=-1, ) cost_time = time.time() - start_time decoding_time += cost_time vocoder_start_time = time.time() for i, gen in enumerate(generated): gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) if args.vocoder == "vocos": if args.use_perf: torch.cuda.nvtx.range_push("vocoder decode") generated_wave = vocoder.decode(gen_mel_spec).cpu() if args.use_perf: torch.cuda.nvtx.range_pop() else: generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() target_rms = 0.1 target_sample_rate = 24_000 # if ref_rms_list[i] < target_rms: # generated_wave = generated_wave * ref_rms_list[i] / target_rms rms = torch.sqrt(torch.mean(torch.square(generated_wave))) if rms < target_rms: generated_wave = generated_wave * target_rms / rms utt = batch["ids"][i] torchaudio.save( f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate, ) total_duration += generated_wave.shape[1] / target_sample_rate vocoder_time += time.time() - vocoder_start_time if rank == 0: progress_bar.update(world_size * len(batch["ids"])) total_decoding_time = time.time() - total_decoding_time if rank == 0: progress_bar.close() rtf = total_decoding_time / total_duration s = f"RTF: {rtf:.4f}\n" s += f"total_duration: {total_duration:.3f} seconds\n" s += f"({total_duration / 3600:.2f} hours)\n" s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n" s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n" s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n" s += f"batch size: {args.batch_size}\n" print(s) with open(f"{args.output_dir}/rtf.txt", "w") as f: f.write(s) dist.barrier() dist.destroy_process_group() if __name__ == "__main__": main()