VITA-Audio / vita_audio /tokenizer.py
shenyunhang's picture
-a
52e4f53
from .constants import (
BOX_END_TOKEN,
BOX_START_TOKEN,
IMG_CONTEXT_TOKEN,
IMG_END_TOKEN,
IMG_START_TOKEN,
IMG_TAG_TOKEN,
PATCH_CONTEXT_TOKEN,
PATCH_END_TOKEN,
PATCH_START_TOKEN,
QUAD_END_TOKEN,
QUAD_START_TOKEN,
REF_END_TOKEN,
REF_START_TOKEN,
VID_CONTEXT_TOKEN,
VID_END_TOKEN,
VID_START_TOKEN,
VID_TAG_TOKEN,
)
def update_tokenizer(tokenizer):
token_list = [
IMG_START_TOKEN,
IMG_END_TOKEN,
IMG_CONTEXT_TOKEN,
VID_START_TOKEN,
VID_END_TOKEN,
VID_CONTEXT_TOKEN,
PATCH_START_TOKEN,
PATCH_END_TOKEN,
PATCH_CONTEXT_TOKEN,
QUAD_START_TOKEN,
QUAD_END_TOKEN,
REF_START_TOKEN,
REF_END_TOKEN,
BOX_START_TOKEN,
BOX_END_TOKEN,
IMG_TAG_TOKEN,
VID_TAG_TOKEN,
]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
# print(f"tokenizer {tokenizer}")
return tokenizer
def update_tokenizer_for_s2s(tokenizer, model_type):
if model_type is None:
return update_tokenizer(tokenizer)
if model_type == "glm4voice":
from .tokenizer_glm4voice import update_tokenizer_for_glm4voice, GLM4VoiceTokenizer
return update_tokenizer_for_glm4voice(tokenizer)
if model_type == "cosyvoice2":
from .tokenizer_cosyvoice2 import update_tokenizer_for_cosyvoice2, CosyVoice2Tokenizer
return update_tokenizer_for_cosyvoice2(tokenizer)
if model_type == "snac24khz":
from .tokenizer_snac import update_tokenizer_for_snac, SNACTokenizer
return update_tokenizer_for_snac(tokenizer)
if model_type == "sensevoice_sparktts":
from .tokenizer_sensevoice_sparktts import (
update_tokenizer_for_sensevoice_sparktts,
SenseVoiceSparkTTSTokenizer,
)
return update_tokenizer_for_sensevoice_sparktts(tokenizer)
if model_type == "sensevoice_glm4voice":
from .tokenizer_sensevoice_glm4voice import (
update_tokenizer_for_sensevoice_glm4voice,
SenseVoiceGLM4VoiceTokenizer,
)
return update_tokenizer_for_sensevoice_glm4voice(tokenizer)
raise NotImplementedError
def get_audio_tokenizer(model_name_or_path, model_type, flow_path=None, rank=None):
if model_type is None:
return None
if model_type == "glm4voice":
from .tokenizer_glm4voice import update_tokenizer_for_glm4voice, GLM4VoiceTokenizer
return GLM4VoiceTokenizer(model_name_or_path, flow_path=flow_path, rank=rank)
if model_type == "cosyvoice2":
from .tokenizer_cosyvoice2 import update_tokenizer_for_cosyvoice2, CosyVoice2Tokenizer
return CosyVoice2Tokenizer(model_name_or_path, rank=rank)
if model_type == "snac24khz":
from .tokenizer_snac import update_tokenizer_for_snac, SNACTokenizer
return SNACTokenizer(model_name_or_path, rank=rank)
if model_type == "sensevoice_sparktts":
from .tokenizer_sensevoice_sparktts import (
update_tokenizer_for_sensevoice_sparktts,
SenseVoiceSparkTTSTokenizer,
)
return SenseVoiceSparkTTSTokenizer(model_name_or_path, rank=rank)
if model_type == "sensevoice_glm4voice":
from .tokenizer_sensevoice_glm4voice import (
update_tokenizer_for_sensevoice_glm4voice,
SenseVoiceGLM4VoiceTokenizer,
)
return SenseVoiceGLM4VoiceTokenizer(model_name_or_path, flow_path=flow_path, rank=rank)
raise NotImplementedError