VITA-Audio / vita_audio /tokenizer_glm4voice.py
shenyunhang's picture
-a
52e4f53
import glob
import io
import logging
import math
import os
import tarfile
import uuid
import safetensors
import torch
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
import torchaudio
from transformers import WhisperFeatureExtractor
from speech_tokenizer.modeling_whisper import WhisperVQEncoder
from flow_inference import AudioDecoder
from .constants import (
AUD_END_TOKEN,
AUD_START_TOKEN,
AUD_TAG_TOKEN,
BOX_END_TOKEN,
BOX_START_TOKEN,
IMG_CONTEXT_TOKEN,
IMG_END_TOKEN,
IMG_START_TOKEN,
IMG_TAG_TOKEN,
PATCH_CONTEXT_TOKEN,
PATCH_END_TOKEN,
PATCH_START_TOKEN,
QUAD_END_TOKEN,
QUAD_START_TOKEN,
REF_END_TOKEN,
REF_START_TOKEN,
VID_CONTEXT_TOKEN,
VID_END_TOKEN,
VID_START_TOKEN,
VID_TAG_TOKEN,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def update_tokenizer_for_glm4voice(tokenizer):
token_list = [
IMG_START_TOKEN,
IMG_END_TOKEN,
IMG_CONTEXT_TOKEN,
VID_START_TOKEN,
VID_END_TOKEN,
VID_CONTEXT_TOKEN,
PATCH_START_TOKEN,
PATCH_END_TOKEN,
PATCH_CONTEXT_TOKEN,
AUD_START_TOKEN,
AUD_END_TOKEN,
QUAD_START_TOKEN,
QUAD_END_TOKEN,
REF_START_TOKEN,
REF_END_TOKEN,
BOX_START_TOKEN,
BOX_END_TOKEN,
IMG_TAG_TOKEN,
VID_TAG_TOKEN,
AUD_TAG_TOKEN,
]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
token_list = [f"<|audio_{i}|>" for i in range(16384)]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=False)
# logger.info(f"tokenizer {tokenizer}")
return tokenizer
class GLM4VoiceTokenizer:
def __init__(self, model_name_or_path, flow_path=None, rank=None):
self.model_name_or_path = model_name_or_path
self.flow_path = flow_path
if rank is None and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
self.rank = rank % 8
# elif rank > 0:
# self.rank = None
else:
self.rank = rank
logger.info(f"{self.rank=}")
# print(f"{self.rank=}")
self.is_discrete = True
self.is_contiguous = False
# # T A
# text_audio_interval_ratio = [13, 26]
# # T A T A T A
# text_audio_interval_ratio = [1, 4, 3, 8, 4, 10]
# # T A T A
# text_audio_interval_ratio = [1, 10, 4, 10]
# self.text_audio_interval_ratio = text_audio_interval_ratio
def load_model(self):
if hasattr(self, "whisper_model"):
return
if self.rank is not None:
self.device = f"cuda:{self.rank}"
torch.cuda.set_device(self.rank)
else:
self.device = "cpu"
logger.info(f"{self.device=} Loading GLM4VoiceTokenizer")
self.whisper_model = (
WhisperVQEncoder.from_pretrained(self.model_name_or_path).eval().to(self.device)
)
self.feature_extractor = WhisperFeatureExtractor.from_pretrained(self.model_name_or_path)
if self.flow_path is not None:
flow_config = os.path.join(self.flow_path, "config.yaml")
flow_checkpoint = os.path.join(self.flow_path, "flow.pt")
hift_checkpoint = os.path.join(self.flow_path, "hift.pt")
# Flow & Hift
self.audio_decoder = AudioDecoder(
config_path=flow_config,
flow_ckpt_path=flow_checkpoint,
hift_ckpt_path=hift_checkpoint,
device=self.device,
)
logger.info(f"{self.device=} Loading GLM4VoiceTokenizer Done")
def encode(self, audio_path, **kwargs):
if not hasattr(self, "whisper_model"):
self.load_model()
audio_tokens = extract_speech_token(
self.whisper_model, self.feature_extractor, [audio_path], device=self.device
)[0]
return audio_tokens
def decode(self, audio_tokens, option_steps=10, **kwargs):
if not hasattr(self, "whisper_model"):
self.load_model()
this_uuid = str(uuid.uuid4())
this_uuid = "abc"
tts_token = torch.tensor(audio_tokens, device=self.device).unsqueeze(0)
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(self.device)
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
tts_speech, tts_mel = self.audio_decoder.token2wav(
tts_token,
uuid=this_uuid,
prompt_token=flow_prompt_speech_token.to(self.device),
prompt_feat=prompt_speech_feat.to(self.device),
finalize=True,
option_steps=option_steps,
)
tts_speechs = []
tts_speechs.append(tts_speech.squeeze())
tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
return tts_speech
def apply_to_role(self, role, **kwargs):
is_discrete = kwargs.get("is_discrete", False)
if is_discrete:
return True
is_contiguous = kwargs.get("is_contiguous", False)
if is_contiguous:
return False
return True
_resample_buffer: dict[int, torchaudio.transforms.Resample] = {}
def extract_speech_token(model, feature_extractor, utts, device="cuda"):
with torch.no_grad():
audios, indices = [], []
for idx, utt in enumerate(utts):
if isinstance(utt, tuple):
audio, sample_rate = utt
else:
audio, sample_rate = torchaudio.load(utt)
audio = audio.to(device)
if sample_rate != 16000:
if sample_rate not in _resample_buffer:
_resample_buffer[sample_rate] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
).to(device)
audio = _resample_buffer[sample_rate](audio)
# if audio.shape[0] > 1:
# audio = audio[:1]
audio = audio[0]
audio = audio.cpu().numpy()
time_step = 0
while time_step * 16000 < audio.shape[0]:
audio_segment = audio[time_step * 16000 : (time_step + 30) * 16000]
audios.append(audio_segment)
indices.append(idx)
time_step += 30
pooling_kernel_size = model.config.pooling_kernel_size or 1
stride = (
model.conv1.stride[0]
* model.conv2.stride[0]
* pooling_kernel_size
* feature_extractor.hop_length
)
all_speech_tokens = [[] for _ in range(len(utts))]
batch_size = 128
for start in range(0, len(audios), batch_size):
features = feature_extractor(
audios[start : start + batch_size],
sampling_rate=16000,
return_attention_mask=True,
return_tensors="pt",
device=device,
padding="longest",
pad_to_multiple_of=stride,
)
features = features.to(device=device)
outputs = model(**features)
speech_tokens = outputs.quantized_token_ids
attention_mask = features.attention_mask[
:, :: model.conv1.stride[0] * model.conv2.stride[0]
]
attention_mask = attention_mask[:, :: model.config.pooling_kernel_size]
assert attention_mask.shape == speech_tokens.shape
for i in range(len(speech_tokens)):
idx = indices[start + i]
speech_token = speech_tokens[i][attention_mask[i].bool()].tolist()
all_speech_tokens[idx].extend(speech_token)
return all_speech_tokens