Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,599 Bytes
52e4f53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import logging
import os
import uuid
import torch
from .constants import (
AUD_END_TOKEN,
AUD_START_TOKEN,
AUD_TAG_TOKEN,
BOX_END_TOKEN,
BOX_START_TOKEN,
IMG_CONTEXT_TOKEN,
IMG_END_TOKEN,
IMG_START_TOKEN,
IMG_TAG_TOKEN,
PATCH_CONTEXT_TOKEN,
PATCH_END_TOKEN,
PATCH_START_TOKEN,
QUAD_END_TOKEN,
QUAD_START_TOKEN,
REF_END_TOKEN,
REF_START_TOKEN,
VID_CONTEXT_TOKEN,
VID_END_TOKEN,
VID_START_TOKEN,
VID_TAG_TOKEN,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def update_tokenizer_for_cosyvoice2(tokenizer):
token_list = [
IMG_START_TOKEN,
IMG_END_TOKEN,
IMG_CONTEXT_TOKEN,
VID_START_TOKEN,
VID_END_TOKEN,
VID_CONTEXT_TOKEN,
PATCH_START_TOKEN,
PATCH_END_TOKEN,
PATCH_CONTEXT_TOKEN,
AUD_START_TOKEN,
AUD_END_TOKEN,
QUAD_START_TOKEN,
QUAD_END_TOKEN,
REF_START_TOKEN,
REF_END_TOKEN,
BOX_START_TOKEN,
BOX_END_TOKEN,
IMG_TAG_TOKEN,
VID_TAG_TOKEN,
AUD_TAG_TOKEN,
]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
token_list = [f"<|audio_{i}|>" for i in range(6561)]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=False)
# logger.info(f"tokenizer {tokenizer}")
return tokenizer
class CosyVoice2Tokenizer:
def __init__(self, model_name_or_path, rank=None):
self.model_name_or_path = model_name_or_path
if rank is None and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
self.rank = rank % 8
else:
self.rank = rank
logger.info(f"{self.rank=}")
self.is_discrete = True
self.is_contiguous = False
# T A
text_audio_interval_ratio = [13, 26]
self.text_audio_interval_ratio = text_audio_interval_ratio
def load_model(self):
if hasattr(self, "cosyvoice"):
return
logger.info("Loading CosyVoice2Tokenizer")
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
if self.rank is not None:
torch.cuda.set_device(self.rank)
else:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
print(f"{self.rank}")
self.cosyvoice = CosyVoice2(
self.model_name_or_path, load_jit=False, load_trt=False, fp16=True
)
del self.cosyvoice.model.llm
self.load_wav = load_wav
def encode(self, audio_path, **kwargs):
if not hasattr(self, "cosyvoice"):
self.load_model()
speech_16k = self.load_wav(audio_path, 16000)
try:
speech_token, speech_token_len = self.cosyvoice.frontend._extract_speech_token(
speech_16k
)
speech_token = speech_token[0].cpu().tolist()
except Exception as error:
# logger.info("error", error)
speech_token = []
# logger.info(f"speech_token {speech_token}")
return speech_token
def decode(self, prompt_speech_token, source_speech_16k=None):
if not hasattr(self, "cosyvoice"):
self.load_model()
prompt_speech_token = torch.tensor(prompt_speech_token).unsqueeze(0)
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32)
prompt_speech_feat = torch.zeros(1, 0, 80)
if source_speech_16k is None:
flow_embedding = torch.zeros(1, 192)
else:
flow_embedding = self.cosyvoice.frontend._extract_spk_embedding(source_speech_16k)
this_uuid = str(uuid.uuid1())
this_uuid = "abc"
self.cosyvoice.model.hift_cache_dict[this_uuid] = None
token_offset = 0
tts_speech = self.cosyvoice.model.token2wav(
token=prompt_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=token_offset,
finalize=True,
)
tts_speech = tts_speech.squeeze().cpu()
return tts_speech
def apply_to_role(self, role, **kwargs):
is_discrete = kwargs.get("is_discrete", False)
if is_discrete:
return True
is_contiguous = kwargs.get("is_contiguous", False)
if is_contiguous:
return False
return True
|