import argparse import itertools import json import os import random import sys import uuid from datetime import timedelta from functools import partial from pathlib import Path import torch import tqdm from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig import torchaudio from vita_audio.data.processor.audio_processor import add_audio_input_contiguous from vita_audio.tokenizer import get_audio_tokenizer def collate_fn(batches): input_ids = [sample["input_ids"] for sample in batches] audios = [sample["audios"] for sample in batches] audio_indices = [sample["audio_indices"] for sample in batches] refs = [sample["ref"] for sample in batches] filenames = [sample["filename"] for sample in batches] return input_ids, audios, audio_indices, refs, filenames class STSDataset(torch.utils.data.Dataset): def __init__(self, json_path, tokenizer, audio_tokenizer, default_system_message=None, add_generation_prompt=True): data = load_dataset("json", data_files=json_path, keep_in_memory=False) self.data = data["train"] self.tokenizer = tokenizer self.add_generation_prompt = add_generation_prompt self.audio_tokenizer = audio_tokenizer self.default_system_message = default_system_message def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] assert len(sample["audios"]) == 1 audio_path = sample["audios"][0] if self.audio_tokenizer.apply_to_role("user", is_discrete=True): # discrete codec audio_tokens = self.audio_tokenizer.encode(audio_path) audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) else: audio_tokens = None messages = [] if len(sample["messages"]) == 2: assert len(sample["messages"]) == 2 assert sample["messages"][0]["role"] == "user" assert sample["messages"][1]["role"] == "assistant" if self.default_system_message is not None: messages = self.default_system_message + messages elif len(sample["messages"]) == 3: assert len(sample["messages"]) == 3 assert sample["messages"][0]["role"] == "system" assert sample["messages"][1]["role"] == "user" assert sample["messages"][2]["role"] == "assistant" else: raise NotImplementedError for conv in sample["messages"][:-1]: new_conv = {} new_conv["role"] = conv["role"] content = conv["content"] if isinstance(content, list): assert len(content) == 1 content = content[0] if audio_tokens is not None: content = content.replace( "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" ) new_conv["content"] = content messages.append(new_conv) input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=self.add_generation_prompt, # return_tensors="pt", ) ref = sample["messages"][-1]["content"] if self.audio_tokenizer.apply_to_role("user", is_contiguous=True): # contiguous codec input_ids, audios, audio_indices = add_audio_input_contiguous( input_ids, [audio_path], self.tokenizer, self.audio_tokenizer ) else: audios = None audio_indices = None input_ids = torch.tensor([input_ids], dtype=torch.long) filename = os.path.basename(audio_path) filename = os.path.splitext(filename)[0] return { "input_ids": input_ids, "audios": audios, "audio_indices": audio_indices, "ref": ref, "filename": filename, } class InferenceSampler(torch.utils.data.sampler.Sampler): def __init__(self, size): self._size = int(size) assert size > 0 self._rank = torch.distributed.get_rank() self._world_size = torch.distributed.get_world_size() self._local_indices = self._get_local_indices(size, self._world_size, self._rank) @staticmethod def _get_local_indices(total_size, world_size, rank): shard_size = total_size // world_size left = total_size % world_size shard_sizes = [shard_size + int(r < left) for r in range(world_size)] begin = sum(shard_sizes[:rank]) end = min(sum(shard_sizes[: rank + 1]), total_size) return range(begin, end) def __iter__(self): yield from self._local_indices def __len__(self): return len(self._local_indices) def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir, asr_model): audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>") outputs = [] for _, (batched_input_ids, batched_audios, batched_audio_indices, batched_ref, batched_filename) in enumerate( tqdm.tqdm(dataloader) ): for input_ids, audios, audio_indices, ref, filename in zip( batched_input_ids, batched_audios, batched_audio_indices, batched_ref, batched_filename ): responses = model.generate( input_ids=input_ids.cuda(), audios=audios, audio_indices=audio_indices, # temperature=0.2, # top_p=0.8, # do_sample=False, # temperature=1.0, max_new_tokens=1024, min_new_tokens=1, ) response = responses[0][len(input_ids[0]) :] text_tokens = [] audio_tokens = [] for token_id in response: if token_id >= audio_offset: audio_tokens.append(token_id - audio_offset) else: text_tokens.append(token_id) hyp_text = tokenizer.decode(text_tokens, skip_special_tokens=True) if len(audio_tokens) == 0: continue tts_speech = audio_tokenizer.decode(audio_tokens) wav_dir = os.path.join(output_dir, "audio") wav_path = os.path.join(wav_dir, filename + ".wav") os.makedirs(os.path.dirname(wav_path), exist_ok=True) torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") # hyp_speech = asr_model.transcribe(wav_path)["text"].strip() hyp_speech = asr_model(wav_path, return_timestamps=True)["text"].strip() # hyp_speech = "" outputs.append((hyp_text, hyp_speech, ref)) print("") print("=" * 100) print(f"{tokenizer.decode(response, skip_special_tokens=False)}") print(f" {hyp_text=}") print(f"{hyp_speech=}") print(f" {ref=}") print(f"{filename=}") return outputs def load_asr_model(): import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline rank = torch.distributed.get_rank() device = f"cuda:{rank}" torch_dtype = torch.float16 model_id = "/data/models/openai/whisper-large-v3" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) return pipe if __name__ == "__main__": parser = argparse.ArgumentParser( description="", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path") parser.add_argument( "--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path" ) parser.add_argument( "--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type" ) parser.add_argument("--flow_path", type=str, required=True, help="flow_path") parser.add_argument("--json_path", type=str, required=True, help="json_path") parser.add_argument("--output_dir", type=str, required=True, help="output_dir") parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_workers", type=int, default=0) args = parser.parse_args() print(f"{args=}") torch.distributed.init_process_group( backend="nccl", world_size=int(os.getenv("WORLD_SIZE", "1")), rank=int(os.getenv("RANK", "0")), timeout=timedelta(seconds=7200), ) torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0))) random.seed(42) torch.manual_seed(42) config = AutoConfig.from_pretrained( args.model_name_or_path, trust_remote_code=True, ) # ================================================================ if "glm" in config.model_type.lower(): from get_chat_template import glm4_chat_template as chat_template add_generation_prompt = True default_system_message = [ { "role": "system", "content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.", } ] if "qwen2" in config.model_type.lower(): from get_chat_template import qwen2_chat_template as chat_template add_generation_prompt = True default_system_message = [] if "hunyuan" in config.model_type.lower(): from get_chat_template import hunyuan_chat_template as chat_template add_generation_prompt = False default_system_message = [ { "role": "system", "content": "You are a helpful AI assistant.", } ] default_system_message = [ { "role": "system", # "content": "Your Name: Luke\nYour Gender: male\nRespond in a text-audio interleaved manner.", # "content": "Your Name: Lucy\nYour Gender: female\nRespond in a text-audio interleaved manner.", "content": "Your Name: Omni\nYour Gender: female\nRespond in a text-audio interleaved manner.", }, ] # ================================================================ print("Loading model") device = "cuda" # device_map = "auto" device_map = "cuda" # torch_dtype=torch.float16 torch_dtype = torch.bfloat16 rank = torch.distributed.get_rank() audio_tokenizer = get_audio_tokenizer( args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank ) tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True, chat_template=chat_template, ) # print("tokenizer", tokenizer) model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, trust_remote_code=True, device_map=device_map, torch_dtype=torch_dtype, attn_implementation="flash_attention_2", ).eval() # print("model", model) model.generation_config = GenerationConfig.from_pretrained( args.model_name_or_path, trust_remote_code=True ) model.generation_config.max_new_tokens = 4096 model.generation_config.chat_format = "chatml" model.generation_config.max_window_size = 8192 model.generation_config.use_cache = True model.generation_config.do_sample = False model.generation_config.temperature = None model.generation_config.top_p = None model.generation_config.top_k = None model.generation_config.pad_token_id = tokenizer.pad_token_id if model.config.model_type == "hunyuan": model.generation_config.eos_token_id = tokenizer.eos_id asr_model = load_asr_model() # ================================================================ print("Loading data") dataset = STSDataset( json_path=args.json_path, tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, default_system_message=default_system_message, add_generation_prompt=add_generation_prompt, ) dataloader = torch.utils.data.DataLoader( dataset=dataset, sampler=InferenceSampler(len(dataset)), batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False, collate_fn=partial( collate_fn, ), ) # ================================================================ outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir, asr_model) torch.distributed.barrier() world_size = torch.distributed.get_world_size() merged_outputs = [None for _ in range(world_size)] torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) merged_outputs = [json.loads(_) for _ in merged_outputs] merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] if torch.distributed.get_rank() == 0: # json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem hyp_text_path = os.path.join(args.output_dir, f"{json_name}_hyp_text.txt") hyp_speech_path = os.path.join(args.output_dir, f"{json_name}_hyp_speech.txt") ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt") os.makedirs(os.path.dirname(ref_path), exist_ok=True) os.makedirs(os.path.dirname(hyp_text_path), exist_ok=True) os.makedirs(os.path.dirname(hyp_speech_path), exist_ok=True) hyp_text_file = open(hyp_text_path, "w") hyp_speech_file = open(hyp_speech_path, "w") ref_file = open(ref_path, "w") for sample_idx, (hyp_text, hyp_speech, ref) in enumerate(merged_outputs): hyp_text_file.write(f"{sample_idx} {hyp_text}" + "\n") hyp_speech_file.write(f"{sample_idx} {hyp_speech}" + "\n") ref_file.write(f"{sample_idx} {ref}" + "\n") hyp_text_file.close() hyp_speech_file.close() ref_file.close() outputs_speech = [[x[1], x[2]] for x in merged_outputs] outputs_text = [[x[0], x[2]] for x in merged_outputs] hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref_text.json") hyp_ref_file = open(hyp_ref_path, "w") json.dump(outputs_text, hyp_ref_file, indent=4) hyp_ref_file.close() hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref_speech.json") hyp_ref_file = open(hyp_ref_path, "w") json.dump(outputs_speech, hyp_ref_file, indent=4) hyp_ref_file.close() torch.distributed.barrier() print("Done.")