VITA-Audio / vita_audio /tokenizer_sensevoice_glm4voice.py
shenyunhang's picture
Update vita_audio/tokenizer_sensevoice_glm4voice.py
9e0adab verified
import glob
import io
import logging
import math
import os
import tarfile
import uuid
import safetensors
import torch
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
import torchaudio
from transformers import WhisperFeatureExtractor
from speech_tokenizer.modeling_whisper import WhisperVQEncoder
from flow_inference import AudioDecoder
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.sense_voice.model import SenseVoiceSmall
from .constants import (
AUD_CONTEXT_TOKEN,
AUD_END_TOKEN,
AUD_START_TOKEN,
AUD_TAG_TOKEN,
BOX_END_TOKEN,
BOX_START_TOKEN,
IMG_CONTEXT_TOKEN,
IMG_END_TOKEN,
IMG_START_TOKEN,
IMG_TAG_TOKEN,
PATCH_CONTEXT_TOKEN,
PATCH_END_TOKEN,
PATCH_START_TOKEN,
QUAD_END_TOKEN,
QUAD_START_TOKEN,
REF_END_TOKEN,
REF_START_TOKEN,
VID_CONTEXT_TOKEN,
VID_END_TOKEN,
VID_START_TOKEN,
VID_TAG_TOKEN,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def update_tokenizer_for_sensevoice_glm4voice(tokenizer):
token_list = [
IMG_START_TOKEN,
IMG_END_TOKEN,
IMG_CONTEXT_TOKEN,
VID_START_TOKEN,
VID_END_TOKEN,
VID_CONTEXT_TOKEN,
PATCH_START_TOKEN,
PATCH_END_TOKEN,
PATCH_CONTEXT_TOKEN,
AUD_START_TOKEN,
AUD_END_TOKEN,
AUD_CONTEXT_TOKEN,
QUAD_START_TOKEN,
QUAD_END_TOKEN,
REF_START_TOKEN,
REF_END_TOKEN,
BOX_START_TOKEN,
BOX_END_TOKEN,
IMG_TAG_TOKEN,
VID_TAG_TOKEN,
AUD_TAG_TOKEN,
]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
token_list = [f"<|audio_{i}|>" for i in range(16384)]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=False)
# logger.info(f"tokenizer {tokenizer}")
return tokenizer
class SenseVoiceGLM4VoiceTokenizer:
def __init__(self, model_name_or_path, flow_path=None, rank=None):
self.model_name_or_path = model_name_or_path
self.flow_path = flow_path
if rank is None and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
self.rank = rank % 8
else:
self.rank = rank
logger.info(f"{self.rank=}")
self.sample_rate = 16000
self.is_discrete = True
self.is_contiguous = True
# # T A
# text_audio_interval_ratio = [13, 26]
# # T A T A T A
# text_audio_interval_ratio = [1, 4, 3, 8, 4, 10]
# # T A T A
# text_audio_interval_ratio = [1, 10, 4, 10]
# self.text_audio_interval_ratio = text_audio_interval_ratio
def load_model(self):
if hasattr(self, "whisper_model"):
return
import faulthandler
faulthandler.enable()
if self.rank is not None:
self.device = f"cuda:{self.rank}"
#torch.cuda.set_device(self.rank)
else:
self.device = "cpu"
logger.info(f"{self.device=} Loading SenseVoiceSmall")
from huggingface_hub import snapshot_download
model_dir = snapshot_download(repo_id="FunAudioLLM/SenseVoiceSmall")
_, self.kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device=self.device)
logger.info(f"{self.device=} Loading SenseVoiceSmall Done")
logger.info(f"{self.device=} Loading GLM4VoiceTokenizer")
self.whisper_model = (
WhisperVQEncoder.from_pretrained(self.model_name_or_path).eval().to(self.device)
)
self.feature_extractor = WhisperFeatureExtractor.from_pretrained(self.model_name_or_path)
if self.flow_path is not None:
flow_config = os.path.join(self.flow_path, "config.yaml")
flow_checkpoint = os.path.join(self.flow_path, "flow.pt")
hift_checkpoint = os.path.join(self.flow_path, "hift.pt")
# Flow & Hift
self.audio_decoder = AudioDecoder(
config_path=flow_config,
flow_ckpt_path=flow_checkpoint,
hift_ckpt_path=hift_checkpoint,
device=self.device,
)
logger.info(f"{self.device=} Loading GLM4VoiceTokenizer Done")
def encode(self, audio_path, is_discrete=False, is_contiguous=True, **kwargs):
if not hasattr(self, "whisper_model"):
self.load_model()
assert not (is_discrete and is_contiguous)
assert is_discrete or is_contiguous
if is_discrete:
audio_tokens = extract_speech_token(
self.whisper_model, self.feature_extractor, [audio_path], device=self.device
)[0]
return audio_tokens
if is_contiguous:
audio, sample_rate = torchaudio.load(audio_path)
audio = audio.mean(0)
if sample_rate != self.sample_rate:
if sample_rate not in _resample_buffer:
_resample_buffer[sample_rate] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.sample_rate
).to(self.device)
audio = audio.to(self.device)
audio = _resample_buffer[sample_rate](audio[None, :])[0, :]
audio = audio.cpu()
# resampler = torchaudio.transforms.Resample(
# orig_freq=sample_rate, new_freq=self.sample_rate
# )
# audio = resampler(audio[None, :])[0, :]
# audio = audio.to(self.device)
frontend = self.kwargs["frontend"]
speech, speech_lengths = extract_fbank(audio, data_type="sound", frontend=frontend)
speech = speech[0]
# print(f"{speech_lengths=}")
# print(f"{speech.size()=}")
return speech
def decode(self, audio_tokens, option_steps=10, **kwargs):
if not hasattr(self, "whisper_model"):
self.load_model()
this_uuid = str(uuid.uuid4())
this_uuid = "abc"
tts_token = torch.tensor(audio_tokens, device=self.device).unsqueeze(0)
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(self.device)
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
tts_speech, tts_mel = self.audio_decoder.token2wav(
tts_token,
uuid=this_uuid,
prompt_token=flow_prompt_speech_token.to(self.device),
prompt_feat=prompt_speech_feat.to(self.device),
finalize=True,
option_steps=option_steps,
)
tts_speechs = []
tts_speechs.append(tts_speech.squeeze())
tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
return tts_speech
def apply_to_role(self, role, **kwargs):
is_discrete = kwargs.get("is_discrete", False)
if is_discrete and role in ["assistant", "gpt"]:
return True
is_contiguous = kwargs.get("is_contiguous", False)
if is_contiguous and role in ["user", "human"]:
return True
return False
_resample_buffer: dict[int, torchaudio.transforms.Resample] = {}
def extract_speech_token(model, feature_extractor, utts, device="cuda"):
with torch.no_grad():
audios, indices = [], []
for idx, utt in enumerate(utts):
if isinstance(utt, tuple):
audio, sample_rate = utt
else:
audio, sample_rate = torchaudio.load(utt)
audio = audio.to(device)
if sample_rate != 16000:
if sample_rate not in _resample_buffer:
_resample_buffer[sample_rate] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
).to(device)
audio = _resample_buffer[sample_rate](audio)
# if audio.shape[0] > 1:
# audio = audio[:1]
audio = audio[0]
audio = audio.cpu().numpy()
time_step = 0
while time_step * 16000 < audio.shape[0]:
audio_segment = audio[time_step * 16000 : (time_step + 30) * 16000]
audios.append(audio_segment)
indices.append(idx)
time_step += 30
pooling_kernel_size = model.config.pooling_kernel_size or 1
stride = (
model.conv1.stride[0]
* model.conv2.stride[0]
* pooling_kernel_size
* feature_extractor.hop_length
)
all_speech_tokens = [[] for _ in range(len(utts))]
batch_size = 128
for start in range(0, len(audios), batch_size):
features = feature_extractor(
audios[start : start + batch_size],
sampling_rate=16000,
return_attention_mask=True,
return_tensors="pt",
device=device,
padding="longest",
pad_to_multiple_of=stride,
)
features = features.to(device=device)
outputs = model(**features)
speech_tokens = outputs.quantized_token_ids
attention_mask = features.attention_mask[
:, :: model.conv1.stride[0] * model.conv2.stride[0]
]
attention_mask = attention_mask[:, :: model.config.pooling_kernel_size]
assert attention_mask.shape == speech_tokens.shape
for i in range(len(speech_tokens)):
idx = indices[start + i]
speech_token = speech_tokens[i][attention_mask[i].bool()].tolist()
all_speech_tokens[idx].extend(speech_token)
return all_speech_tokens