import argparse import itertools import json import os import random import sys import uuid from datetime import timedelta from functools import partial from pathlib import Path import torch import tqdm from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig import torchaudio from vita_audio.tokenizer import get_audio_tokenizer def collate_fn(batches): input_ids = [sample["input_ids"] for sample in batches] refs = [sample["ref"] for sample in batches] filenames = [sample["filename"] for sample in batches] prompt_audio_path = [sample["prompt_audio_path"] for sample in batches] return input_ids, refs, filenames, prompt_audio_path class SeedTTSDataset(torch.utils.data.Dataset): def __init__( self, data_path, tokenizer, audio_tokenizer, default_system_message=None, speaker_prompt=False, add_generation_prompt=True, ): self.data = [] meta_path = os.path.join(data_path, f"seedtts_testset/zh/meta.lst") with open(meta_path, "r") as f: lines = f.readlines() for line in lines: line = line.strip().split("|") filename = line[0] prompt_text = line[1] prompt_audio = line[2] text = line[3] self.data.append(["zh", filename, prompt_text, prompt_audio, text]) meta_path = os.path.join(data_path, f"seedtts_testset/zh/hardcase.lst") with open(meta_path, "r") as f: lines = f.readlines() for line in lines: line = line.strip().split("|") filename = line[0] prompt_text = line[1] prompt_audio = line[2] text = line[3] self.data.append(["hardcase", filename, prompt_text, prompt_audio, text]) meta_path = os.path.join(data_path, f"seedtts_testset/en/meta.lst") with open(meta_path, "r") as f: lines = f.readlines() for line in lines: line = line.strip().split("|") filename = line[0] prompt_text = line[1] prompt_audio = line[2] text = line[3] self.data.append(["en", filename, prompt_text, prompt_audio, text]) self.tokenizer = tokenizer self.audio_tokenizer = audio_tokenizer self.default_system_message = default_system_message self.add_generation_prompt = add_generation_prompt self.data_path = data_path self.speaker_prompt = speaker_prompt def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] split, filename, prompt_text, prompt_audio, text = sample messages = [] if self.default_system_message is not None: messages = self.default_system_message + messages if self.speaker_prompt: if split == "hardcase": prompt_audio_path = os.path.join( self.data_path, "seedtts_testset", "zh", prompt_audio ) else: prompt_audio_path = os.path.join( self.data_path, "seedtts_testset", split, prompt_audio ) if self.audio_tokenizer.apply_to_role("system", is_discrete=True): # discrete codec prompt_audio_tokens = self.audio_tokenizer.encode(prompt_audio_path) prompt_audio_tokens = "".join(f"<|audio_{i}|>" for i in prompt_audio_tokens) prompt_text = f"Speaker Metadata:\nAudio: <|begin_of_audio|>{prompt_audio_tokens}<|end_of_audio|>\n" if len(messages) > 0 and messages[0]["role"] == "system": messages[0]["content"] += prompt_text else: messages.append( { "role": "system", "content": prompt_text, } ) else: prompt_audio_path = None role = "user" content = "Convert the text to speech.\n" + text messages.append( { "role": role, "content": content, } ) input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=self.add_generation_prompt, return_tensors="pt", ) ref = text return { "input_ids": input_ids, "ref": ref, "filename": split + "/" + filename, "prompt_audio_path": prompt_audio_path, } class InferenceSampler(torch.utils.data.sampler.Sampler): def __init__(self, size): self._size = int(size) assert size > 0 self._rank = torch.distributed.get_rank() self._world_size = torch.distributed.get_world_size() self._local_indices = self._get_local_indices(size, self._world_size, self._rank) @staticmethod def _get_local_indices(total_size, world_size, rank): shard_size = total_size // world_size left = total_size % world_size shard_sizes = [shard_size + int(r < left) for r in range(world_size)] begin = sum(shard_sizes[:rank]) end = min(sum(shard_sizes[: rank + 1]), total_size) return range(begin, end) def __iter__(self): yield from self._local_indices def __len__(self): return len(self._local_indices) def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir): audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>") outputs = [] for _, ( batched_input_ids, batched_ref, batched_filename, batched_prompt_audio_path, ) in enumerate(tqdm.tqdm(dataloader)): for input_ids, ref, filename, prompt_audio_path in zip( batched_input_ids, batched_ref, batched_filename, batched_prompt_audio_path ): responses = model.generate( input_ids=input_ids.cuda(), # temperature=0.2, # top_p=0.8, # do_sample=False, # temperature=1.0, max_new_tokens=1024, min_new_tokens=1, ) response = responses[0][len(input_ids[0]) :] text_tokens = [] audio_tokens = [] for token_id in response: if token_id >= audio_offset: audio_tokens.append(token_id - audio_offset) else: text_tokens.append(token_id) if len(audio_tokens) == 0: continue tts_speech = audio_tokenizer.decode(audio_tokens, source_speech_16k=prompt_audio_path) wav_path = os.path.join(output_dir, filename + ".wav") os.makedirs(os.path.dirname(wav_path), exist_ok=True) torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") outputs.append((wav_path, filename)) print("") print("=" * 100) # print(f"{len(input_id)=}") # print(f"{len(response)=}") print(f"{tokenizer.decode(response, skip_special_tokens=False)}") print(f"{filename=}") return outputs if __name__ == "__main__": parser = argparse.ArgumentParser( description="", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path") parser.add_argument( "--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path" ) parser.add_argument( "--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type" ) parser.add_argument("--flow_path", type=str, required=True, help="flow_path") parser.add_argument("--data_path", type=str, required=True, help="data_path") parser.add_argument("--output_dir", type=str, required=True, help="output_dir") parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--speaker_prompt", action=argparse.BooleanOptionalAction, default=False) args = parser.parse_args() print(f"{args=}") torch.distributed.init_process_group( backend="nccl", world_size=int(os.getenv("WORLD_SIZE", "1")), rank=int(os.getenv("RANK", "0")), timeout=timedelta(seconds=7200), ) torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0))) random.seed(42) torch.manual_seed(42) config = AutoConfig.from_pretrained( args.model_name_or_path, trust_remote_code=True, ) # ================================================================ if "glm" in config.model_type.lower(): from get_chat_template import glm4_chat_template as chat_template add_generation_prompt = True default_system_message = [ { "role": "system", "content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.", } ] if "qwen2" in config.model_type.lower(): from get_chat_template import qwen2_chat_template as chat_template add_generation_prompt = True default_system_message = [] if "hunyuan" in config.model_type.lower(): from get_chat_template import hunyuan_chat_template as chat_template add_generation_prompt = False default_system_message = [ { "role": "system", "content": "You are a helpful AI assistant.", } ] # ================================================================ print("Loading model") device = "cuda" # device_map = "auto" device_map = "cuda" # torch_dtype=torch.float16 torch_dtype = torch.bfloat16 rank = torch.distributed.get_rank() audio_tokenizer = get_audio_tokenizer( args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank ) tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True, chat_template=chat_template, ) # print("tokenizer", tokenizer) model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, trust_remote_code=True, device_map=device_map, torch_dtype=torch_dtype, attn_implementation="flash_attention_2", ).eval() # print("model", model) model.generation_config = GenerationConfig.from_pretrained( args.model_name_or_path, trust_remote_code=True ) model.generation_config.max_new_tokens = 4096 model.generation_config.chat_format = "chatml" model.generation_config.max_window_size = 8192 model.generation_config.use_cache = True model.generation_config.do_sample = True model.generation_config.pad_token_id = tokenizer.pad_token_id if model.config.model_type == "hunyuan": model.generation_config.eos_token_id = tokenizer.eos_id # ================================================================ print("Loading data") dataset = SeedTTSDataset( data_path=args.data_path, tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, default_system_message=default_system_message, speaker_prompt=args.speaker_prompt, add_generation_prompt=add_generation_prompt, ) dataloader = torch.utils.data.DataLoader( dataset=dataset, sampler=InferenceSampler(len(dataset)), batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False, collate_fn=partial( collate_fn, ), ) # ================================================================ outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir) torch.distributed.barrier() world_size = torch.distributed.get_world_size() merged_outputs = [None for _ in range(world_size)] torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) merged_outputs = [json.loads(_) for _ in merged_outputs] merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] torch.distributed.barrier() print("Done.")