VITA-Audio / vita_audio /tokenizer_snac.py
shenyunhang's picture
-a
52e4f53
import logging
import os
import uuid
import torch
import torchaudio
from .constants import (
AUD_END_TOKEN,
AUD_START_TOKEN,
AUD_TAG_TOKEN,
BOX_END_TOKEN,
BOX_START_TOKEN,
IMG_CONTEXT_TOKEN,
IMG_END_TOKEN,
IMG_START_TOKEN,
IMG_TAG_TOKEN,
PATCH_CONTEXT_TOKEN,
PATCH_END_TOKEN,
PATCH_START_TOKEN,
QUAD_END_TOKEN,
QUAD_START_TOKEN,
REF_END_TOKEN,
REF_START_TOKEN,
VID_CONTEXT_TOKEN,
VID_END_TOKEN,
VID_START_TOKEN,
VID_TAG_TOKEN,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def update_tokenizer_for_snac(tokenizer):
token_list = [
IMG_START_TOKEN,
IMG_END_TOKEN,
IMG_CONTEXT_TOKEN,
VID_START_TOKEN,
VID_END_TOKEN,
VID_CONTEXT_TOKEN,
PATCH_START_TOKEN,
PATCH_END_TOKEN,
PATCH_CONTEXT_TOKEN,
AUD_START_TOKEN,
AUD_END_TOKEN,
QUAD_START_TOKEN,
QUAD_END_TOKEN,
REF_START_TOKEN,
REF_END_TOKEN,
BOX_START_TOKEN,
BOX_END_TOKEN,
IMG_TAG_TOKEN,
VID_TAG_TOKEN,
AUD_TAG_TOKEN,
]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
token_list = [f"<|audio_{i}|>" for i in range(4 * 4096)]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=False)
# logger.info(f"tokenizer {tokenizer}")
return tokenizer
class SNACTokenizer:
def __init__(self, model_name_or_path, rank=None):
self.model_name_or_path = model_name_or_path
if rank is None and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
self.rank = rank % 8
else:
self.rank = rank
logger.info(f"{self.rank=}")
self.is_discrete = True
self.is_contiguous = False
# T A
text_audio_interval_ratio = [13, 26]
self.text_audio_interval_ratio = text_audio_interval_ratio
def load_model(self):
if hasattr(self, "model"):
return
logger.info("Loading SNACTokenizer")
from snac import SNAC
self.device = f"cuda:{self.rank}"
torch.cuda.set_device(self.rank)
self.model = SNAC.from_pretrained(self.model_name_or_path).eval().to(self.device)
def encode(self, audio_path, **kwargs):
if not hasattr(self, "model"):
self.load_model()
audio, sampling_rate = torchaudio.load(audio_path)
audio = torchaudio.transforms.Resample(
orig_freq=sampling_rate, new_freq=self.model.sampling_rate
)(audio)
audio = audio.unsqueeze(0)
audio = audio.to(self.device)
with torch.inference_mode():
codes = self.model.encode(audio)
codes = shift_code(codes, self.model.codebook_size, self.model.vq_strides)
audio_tokens = codes.cpu().tolist()
return audio_tokens
def decode(self, audio_tokens, **kwargs):
if not hasattr(self, "model"):
self.load_model()
while len(audio_tokens) % sum(self.model.vq_strides):
audio_tokens += [
audio_tokens[-1] + 4096,
]
codes = torch.tensor(audio_tokens, device=self.device)
codes = inverse_shift_code(codes, self.model.codebook_size, self.model.vq_strides)
codes = [torch.clamp(x, min=0, max=self.model.codebook_size - 1) for x in codes]
# logger.info(f"codes {codes} {[x.size() for x in codes]}")
with torch.inference_mode():
audio_hat = self.model.decode(codes)
# logger.info(f"audio_hat {audio_hat.size()}")
audio_hat = audio_hat.squeeze(0).squeeze(0).cpu()
return audio_hat
def apply_to_role(self, role, **kwargs):
is_discrete = kwargs.get("is_discrete", False)
if is_discrete:
return True
is_contiguous = kwargs.get("is_contiguous", False)
if is_contiguous:
return False
return True
def shift_code(codes, codebook_size, vq_strides):
# codes: [torch.Size([1, 43]), torch.Size([1, 86]), torch.Size([1, 172])]
# 3 * 4096 new vocabularies
# codes = torch.cat([x.reshape(1, -1, vq_strides[-i-1]) + i * codebook_size for i, x in enumerate(codes)], dim=-1).reshape(-1)
# 7 * 4096 new vocabularies
codes = [x.reshape(1, -1, s) for s, x in zip(vq_strides[::-1], codes)]
codes = torch.cat(
[
x + i * codebook_size
for i, x in enumerate(torch.cat(codes, dim=-1).chunk(sum(vq_strides), dim=-1))
],
dim=-1,
).reshape(-1)
return codes
def inverse_shift_code(codes, codebook_size, vq_strides):
# codes: torch.Size([301])
# 3 * 4096 new vocabularies
# codes = [x.reshape(1, -1) - i * codebook_size for i, x in enumerate(codes.reshape(1, -1, sum(vq_strides)).split(vq_strides[::-1], dim=-1))]
# 7 * 4096 new vocabularies
codes = torch.cat(
[
x - i * codebook_size
for i, x in enumerate(
codes.reshape(1, -1, sum(vq_strides)).chunk(sum(vq_strides), dim=-1)
)
],
dim=-1,
).split(vq_strides[::-1], dim=-1)
codes = [x.reshape(1, -1) for x in codes]
return codes