Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import itertools | |
import json | |
import os | |
import random | |
import sys | |
import uuid | |
from datetime import timedelta | |
from functools import partial | |
from pathlib import Path | |
import torch | |
import tqdm | |
from datasets import load_dataset | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation import GenerationConfig | |
import torchaudio | |
from vita_audio.tokenizer import get_audio_tokenizer | |
def collate_fn(batches): | |
input_ids = [sample["input_ids"] for sample in batches] | |
refs = [sample["ref"] for sample in batches] | |
filenames = [sample["filename"] for sample in batches] | |
prompt_audio_path = [sample["prompt_audio_path"] for sample in batches] | |
return input_ids, refs, filenames, prompt_audio_path | |
class SeedTTSDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
data_path, | |
tokenizer, | |
audio_tokenizer, | |
default_system_message=None, | |
speaker_prompt=False, | |
add_generation_prompt=True, | |
): | |
self.data = [] | |
meta_path = os.path.join(data_path, f"seedtts_testset/zh/meta.lst") | |
with open(meta_path, "r") as f: | |
lines = f.readlines() | |
for line in lines: | |
line = line.strip().split("|") | |
filename = line[0] | |
prompt_text = line[1] | |
prompt_audio = line[2] | |
text = line[3] | |
self.data.append(["zh", filename, prompt_text, prompt_audio, text]) | |
meta_path = os.path.join(data_path, f"seedtts_testset/zh/hardcase.lst") | |
with open(meta_path, "r") as f: | |
lines = f.readlines() | |
for line in lines: | |
line = line.strip().split("|") | |
filename = line[0] | |
prompt_text = line[1] | |
prompt_audio = line[2] | |
text = line[3] | |
self.data.append(["hardcase", filename, prompt_text, prompt_audio, text]) | |
meta_path = os.path.join(data_path, f"seedtts_testset/en/meta.lst") | |
with open(meta_path, "r") as f: | |
lines = f.readlines() | |
for line in lines: | |
line = line.strip().split("|") | |
filename = line[0] | |
prompt_text = line[1] | |
prompt_audio = line[2] | |
text = line[3] | |
self.data.append(["en", filename, prompt_text, prompt_audio, text]) | |
self.tokenizer = tokenizer | |
self.audio_tokenizer = audio_tokenizer | |
self.default_system_message = default_system_message | |
self.add_generation_prompt = add_generation_prompt | |
self.data_path = data_path | |
self.speaker_prompt = speaker_prompt | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
sample = self.data[idx] | |
split, filename, prompt_text, prompt_audio, text = sample | |
messages = [] | |
if self.default_system_message is not None: | |
messages = self.default_system_message + messages | |
if self.speaker_prompt: | |
if split == "hardcase": | |
prompt_audio_path = os.path.join( | |
self.data_path, "seedtts_testset", "zh", prompt_audio | |
) | |
else: | |
prompt_audio_path = os.path.join( | |
self.data_path, "seedtts_testset", split, prompt_audio | |
) | |
if self.audio_tokenizer.apply_to_role("system", is_discrete=True): | |
# discrete codec | |
prompt_audio_tokens = self.audio_tokenizer.encode(prompt_audio_path) | |
prompt_audio_tokens = "".join(f"<|audio_{i}|>" for i in prompt_audio_tokens) | |
prompt_text = f"Speaker Metadata:\nAudio: <|begin_of_audio|>{prompt_audio_tokens}<|end_of_audio|>\n" | |
if len(messages) > 0 and messages[0]["role"] == "system": | |
messages[0]["content"] += prompt_text | |
else: | |
messages.append( | |
{ | |
"role": "system", | |
"content": prompt_text, | |
} | |
) | |
else: | |
prompt_audio_path = None | |
role = "user" | |
content = "Convert the text to speech.\n" + text | |
messages.append( | |
{ | |
"role": role, | |
"content": content, | |
} | |
) | |
input_ids = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=self.add_generation_prompt, | |
return_tensors="pt", | |
) | |
ref = text | |
return { | |
"input_ids": input_ids, | |
"ref": ref, | |
"filename": split + "/" + filename, | |
"prompt_audio_path": prompt_audio_path, | |
} | |
class InferenceSampler(torch.utils.data.sampler.Sampler): | |
def __init__(self, size): | |
self._size = int(size) | |
assert size > 0 | |
self._rank = torch.distributed.get_rank() | |
self._world_size = torch.distributed.get_world_size() | |
self._local_indices = self._get_local_indices(size, self._world_size, self._rank) | |
def _get_local_indices(total_size, world_size, rank): | |
shard_size = total_size // world_size | |
left = total_size % world_size | |
shard_sizes = [shard_size + int(r < left) for r in range(world_size)] | |
begin = sum(shard_sizes[:rank]) | |
end = min(sum(shard_sizes[: rank + 1]), total_size) | |
return range(begin, end) | |
def __iter__(self): | |
yield from self._local_indices | |
def __len__(self): | |
return len(self._local_indices) | |
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir): | |
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>") | |
outputs = [] | |
for _, ( | |
batched_input_ids, | |
batched_ref, | |
batched_filename, | |
batched_prompt_audio_path, | |
) in enumerate(tqdm.tqdm(dataloader)): | |
for input_ids, ref, filename, prompt_audio_path in zip( | |
batched_input_ids, batched_ref, batched_filename, batched_prompt_audio_path | |
): | |
responses = model.generate( | |
input_ids=input_ids.cuda(), | |
# temperature=0.2, | |
# top_p=0.8, | |
# do_sample=False, | |
# temperature=1.0, | |
max_new_tokens=1024, | |
min_new_tokens=1, | |
) | |
response = responses[0][len(input_ids[0]) :] | |
text_tokens = [] | |
audio_tokens = [] | |
for token_id in response: | |
if token_id >= audio_offset: | |
audio_tokens.append(token_id - audio_offset) | |
else: | |
text_tokens.append(token_id) | |
if len(audio_tokens) == 0: | |
continue | |
tts_speech = audio_tokenizer.decode(audio_tokens, source_speech_16k=prompt_audio_path) | |
wav_path = os.path.join(output_dir, filename + ".wav") | |
os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
outputs.append((wav_path, filename)) | |
print("") | |
print("=" * 100) | |
# print(f"{len(input_id)=}") | |
# print(f"{len(response)=}") | |
print(f"{tokenizer.decode(response, skip_special_tokens=False)}") | |
print(f"{filename=}") | |
return outputs | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="", | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
) | |
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path") | |
parser.add_argument( | |
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path" | |
) | |
parser.add_argument( | |
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type" | |
) | |
parser.add_argument("--flow_path", type=str, required=True, help="flow_path") | |
parser.add_argument("--data_path", type=str, required=True, help="data_path") | |
parser.add_argument("--output_dir", type=str, required=True, help="output_dir") | |
parser.add_argument("--batch_size", type=int, default=1) | |
parser.add_argument("--num_workers", type=int, default=0) | |
parser.add_argument("--speaker_prompt", action=argparse.BooleanOptionalAction, default=False) | |
args = parser.parse_args() | |
print(f"{args=}") | |
torch.distributed.init_process_group( | |
backend="nccl", | |
world_size=int(os.getenv("WORLD_SIZE", "1")), | |
rank=int(os.getenv("RANK", "0")), | |
timeout=timedelta(seconds=7200), | |
) | |
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0))) | |
random.seed(42) | |
torch.manual_seed(42) | |
config = AutoConfig.from_pretrained( | |
args.model_name_or_path, | |
trust_remote_code=True, | |
) | |
# ================================================================ | |
if "glm" in config.model_type.lower(): | |
from get_chat_template import glm4_chat_template as chat_template | |
add_generation_prompt = True | |
default_system_message = [ | |
{ | |
"role": "system", | |
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.", | |
} | |
] | |
if "qwen2" in config.model_type.lower(): | |
from get_chat_template import qwen2_chat_template as chat_template | |
add_generation_prompt = True | |
default_system_message = [] | |
if "hunyuan" in config.model_type.lower(): | |
from get_chat_template import hunyuan_chat_template as chat_template | |
add_generation_prompt = False | |
default_system_message = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful AI assistant.", | |
} | |
] | |
# ================================================================ | |
print("Loading model") | |
device = "cuda" | |
# device_map = "auto" | |
device_map = "cuda" | |
# torch_dtype=torch.float16 | |
torch_dtype = torch.bfloat16 | |
rank = torch.distributed.get_rank() | |
audio_tokenizer = get_audio_tokenizer( | |
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.model_name_or_path, | |
trust_remote_code=True, | |
chat_template=chat_template, | |
) | |
# print("tokenizer", tokenizer) | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, | |
trust_remote_code=True, | |
device_map=device_map, | |
torch_dtype=torch_dtype, | |
attn_implementation="flash_attention_2", | |
).eval() | |
# print("model", model) | |
model.generation_config = GenerationConfig.from_pretrained( | |
args.model_name_or_path, trust_remote_code=True | |
) | |
model.generation_config.max_new_tokens = 4096 | |
model.generation_config.chat_format = "chatml" | |
model.generation_config.max_window_size = 8192 | |
model.generation_config.use_cache = True | |
model.generation_config.do_sample = True | |
model.generation_config.pad_token_id = tokenizer.pad_token_id | |
if model.config.model_type == "hunyuan": | |
model.generation_config.eos_token_id = tokenizer.eos_id | |
# ================================================================ | |
print("Loading data") | |
dataset = SeedTTSDataset( | |
data_path=args.data_path, | |
tokenizer=tokenizer, | |
audio_tokenizer=audio_tokenizer, | |
default_system_message=default_system_message, | |
speaker_prompt=args.speaker_prompt, | |
add_generation_prompt=add_generation_prompt, | |
) | |
dataloader = torch.utils.data.DataLoader( | |
dataset=dataset, | |
sampler=InferenceSampler(len(dataset)), | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
collate_fn=partial( | |
collate_fn, | |
), | |
) | |
# ================================================================ | |
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir) | |
torch.distributed.barrier() | |
world_size = torch.distributed.get_world_size() | |
merged_outputs = [None for _ in range(world_size)] | |
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) | |
merged_outputs = [json.loads(_) for _ in merged_outputs] | |
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] | |
torch.distributed.barrier() | |
print("Done.") | |