Spaces:
Build error
Build error
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import soundfile as sf | |
| from xcodec2.modeling_xcodec2 import XCodec2Model | |
| import numpy as np | |
| import ChatTTS | |
| import re | |
| DEFAULT_TTS_MODEL_NAME = "HKUSTAudio/LLasa-1B" | |
| DEMO_EXAMPLES = [ | |
| ["太乙真人.wav", "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"], | |
| ["邓紫棋.wav", "特别大的不同,因为以前在香港是过年的时候,我们可能见到的亲戚都是爸爸那边的亲戚"], | |
| ["雷军.wav", "这是个好问题,我把来龙去脉给你简单讲,就是这个社会对小米有很多的误解,有很多的误解,呃,也能理解啊,就是小米这个模式呢"], | |
| ["周杰伦.wav", "但如果你这兴趣可以得到很大的回响,那会更开心"], | |
| ["Taylor Swift.wav", "It's actually uh, it's a concept record, but it's my first directly autobiographical album in a while because the last album that I put out was, uh, a rework."] | |
| ] | |
| class TTSapi: | |
| def __init__(self, | |
| model_name=DEFAULT_TTS_MODEL_NAME, | |
| codec_model_name="HKUST-Audio/xcodec2", | |
| device=torch.device("cuda:0")): | |
| self.reload(model_name, codec_model_name, device) | |
| def reload(self, | |
| model_name=DEFAULT_TTS_MODEL_NAME, | |
| codec_model_name="HKUST-Audio/xcodec2", | |
| device=torch.device("cuda:0")): | |
| if 'llasa' in model_name.lower(): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained(model_name) | |
| self.model.eval().to(device) | |
| self.codec_model = XCodec2Model.from_pretrained(codec_model_name) | |
| self.codec_model.eval().to(device) | |
| self.device = device | |
| self.codec_model_name = codec_model_name | |
| self.sr = 16000 | |
| elif 'chattts' in model_name.lower(): | |
| self.model = ChatTTS.Chat() | |
| self.model.load(compile=False) # Set to True for better performance but would l significantly reduce the loading speed | |
| self.sr = 24000 | |
| self.punctuation = r'[,,.。??!!~~;;]' | |
| else: | |
| raise ValueError(f'不支持的TTS模型:{model_name}') | |
| self.model_name = model_name | |
| def ids_to_speech_tokens(self, speech_ids): | |
| speech_tokens_str = [] | |
| for speech_id in speech_ids: | |
| speech_tokens_str.append(f"<|s_{speech_id}|>") | |
| return speech_tokens_str | |
| def extract_speech_ids(self, speech_tokens_str): | |
| speech_ids = [] | |
| for token_str in speech_tokens_str: | |
| if token_str.startswith('<|s_') and token_str.endswith('|>'): | |
| num_str = token_str[4:-2] | |
| num = int(num_str) | |
| speech_ids.append(num) | |
| else: | |
| print(f"Unexpected token: {token_str}") | |
| return speech_ids | |
| def forward(self, input_text, speech_prompt=None, save_path='wavs/generated/gen.wav'): | |
| #TTS start! | |
| with torch.no_grad(): | |
| if 'chattts' in self.model_name.lower(): | |
| # rand_spk = chat.sample_random_speaker() | |
| # print(rand_spk) # save it for later timbre recovery | |
| # params_infer_code = ChatTTS.Chat.InferCodeParams( | |
| # spk_emb = rand_spk, # add sampled speaker | |
| # temperature = .3, # using custom temperature | |
| # top_P = 0.7, # top P decode | |
| # top_K = 20, # top K decode | |
| # ) | |
| break_num = max(min(len(re.split(self.punctuation, input_text)), 7), 2) | |
| params_refine_text = ChatTTS.Chat.RefineTextParams( | |
| prompt=f'[oral_2][laugh_0][break_{break_num}]', | |
| ) | |
| wavs = self.model.infer([input_text], | |
| params_refine_text=params_refine_text, | |
| ) | |
| gen_wav_save = wavs[0] | |
| sf.write(save_path, gen_wav_save, 24000) | |
| else: | |
| if speech_prompt: | |
| # only 16khz speech support! | |
| prompt_wav, sr = sf.read(speech_prompt) # you can find wav in Files | |
| prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0) | |
| # Encode the prompt wav | |
| vq_code_prompt = self.codec_model.encode_code(input_waveform=prompt_wav) | |
| print("Prompt Vq Code Shape:", vq_code_prompt.shape ) | |
| vq_code_prompt = vq_code_prompt[0,0,:] | |
| # Convert int 12345 to token <|s_12345|> | |
| speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt) | |
| else: | |
| speech_ids_prefix = '' | |
| formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" | |
| # Tokenize the text ( and the speech prefix) | |
| chat = [ | |
| {"role": "user", "content": "Convert the text to speech:" + formatted_text}, | |
| {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)} | |
| ] | |
| input_ids = self.tokenizer.apply_chat_template( | |
| chat, | |
| tokenize=True, | |
| return_tensors='pt', | |
| continue_final_message=True | |
| ) | |
| input_ids = input_ids.to(self.device) | |
| speech_end_id = self.tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>') | |
| # Generate the speech autoregressively | |
| outputs = self.model.generate( | |
| input_ids, | |
| max_length=2048, # We trained our model with a max length of 2048 | |
| eos_token_id= speech_end_id , | |
| do_sample=True, | |
| top_p=1, # Adjusts the diversity of generated content | |
| temperature=1, # Controls randomness in output | |
| ) | |
| # Extract the speech tokens | |
| generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1] | |
| speech_tokens = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
| # Convert token <|s_23456|> to int 23456 | |
| speech_tokens = self.extract_speech_ids(speech_tokens) | |
| speech_tokens = torch.tensor(speech_tokens).to(self.device).unsqueeze(0).unsqueeze(0) | |
| # Decode the speech tokens to speech waveform | |
| gen_wav = self.codec_model.decode_code(speech_tokens) | |
| # if only need the generated part | |
| if speech_prompt: | |
| gen_wav = gen_wav[:,:,prompt_wav.shape[1]:] | |
| gen_wav_save = gen_wav[0, 0, :].cpu().numpy() | |
| sf.write(save_path, gen_wav_save, 16000) | |
| # gen_wav_save = np.clip(gen_wav_save, -1, 1) | |
| # gen_wav_save = (gen_wav_save * 32767).astype(np.int16) | |
| return gen_wav_save | |
| if __name__ == '__main__': | |
| # Llasa-8B shows better text understanding ability. | |
| # input_text = " He shouted, 'Everyone, please gather 'round! Here's the plan: 1) Set-up at 9:15 a.m.; 2) Lunch at 12:00 p.m. (please RSVP!); 3) Playing — e.g., games, music, etc. — from 1:15 to 4:45; and 4) Clean-up at 5 p.m.'" | |
| # prompt_text ="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" | |
| # input_text = prompt_text + '嘻嘻,臭宝儿你真可爱,我好喜欢你呀。' | |
| # save_root = 'wavs/generated/' | |
| # save_path = save_root + 'test.wav' | |
| # speech_ref = 'wavs/ref/太乙真人.wav' | |
| # # speech_ref = None | |
| # # 帘外雨潺潺,春意阑珊。罗衾不耐五更寒。梦里不知身是客,一晌贪欢。独自莫凭栏,无限江山。别时容易见时难。流水落花春去也,天上人间。 | |
| # llasa_tts = TTSapi() | |
| # gen = llasa_tts.forward(input_text, speech_prompt=speech_ref, save_path=save_path) | |
| # print(gen.shape) | |
| import gradio as gr | |
| synthesiser = TTSapi() | |
| TTS_LOADED = True | |
| def predict(config): | |
| global TTS_LOADED, synthesiser | |
| print(f"待合成文本:{config['msg']}") | |
| print(f"选中TTS模型:{config['tts_model']}") | |
| print(f"参考音频路径:{config['ref_audio']}") | |
| print(f"参考音频文本:{config['ref_audio_transcribe']}") | |
| text = config['msg'] | |
| try: | |
| if len(text) == 0: | |
| audio_output = np.array([0], dtype=np.int16) | |
| print("输入为空,无法合成语音") | |
| else: | |
| if not TTS_LOADED: | |
| print('TTS模型首次加载...') | |
| gr.Info("初次加载TTS模型,请稍候..", duration=63) | |
| synthesiser = TTSapi(model_name=config['tts_model'])#, device="cuda:2") | |
| TTS_LOADED = True | |
| print('加载完毕...') | |
| # 检查当前模型是否是所选 | |
| if config['tts_model'] != synthesiser.model_name: | |
| print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载') | |
| synthesiser.reload(model_name=config['tts_model']) | |
| # 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀 | |
| if config['ref_audio']: | |
| prompt_text = config['ref_audio_transcribe'] | |
| if prompt_text is None: | |
| # prompt_text = ... | |
| raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型 | |
| text = prompt_text + text | |
| audio_output = synthesiser.forward(text, speech_prompt=config['ref_audio']) | |
| except Exception as e: | |
| print('!!!!!!!!') | |
| print(e) | |
| print('!!!!!!!!') | |
| return (synthesiser.sr if synthesiser else 16000, audio_output) | |
| with gr.Blocks(title="TTS Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo: | |
| gr.Markdown(""" | |
| # Personalized TTS Demo | |
| ## 使用步骤 | |
| * 上传你想要合成的目标说话人的语音,10s左右即可,并在下面填入对应的文本。或直接点击下方示例 | |
| * 输入你想要合成的文字,点击合成语音按钮,稍等片刻即可 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # TTS模型选择 | |
| tts_model = gr.Dropdown( | |
| label="选择TTS模型", | |
| choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"], | |
| value=DEFAULT_TTS_MODEL_NAME, | |
| interactive=True, | |
| visible=False # 给产品演示,暂时不展示模型选择 | |
| ) | |
| # 参考音频上传 | |
| ref_audio = gr.Audio( | |
| label="上传参考音频", | |
| type="filepath", | |
| interactive=True | |
| ) | |
| ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True) | |
| # 创建示例数据 | |
| examples = gr.Examples( | |
| examples=DEMO_EXAMPLES, | |
| inputs=[ref_audio, ref_audio_transcribe], | |
| fn=predict | |
| ) | |
| with gr.Column(): | |
| audio_player = gr.Audio( | |
| label="听听我声音~", | |
| type="numpy", | |
| interactive=False | |
| ) | |
| msg = gr.Textbox(label="输入文本", placeholder="请输入想要合成的文本") | |
| submit_btn = gr.Button("合成语音", variant="primary") | |
| current_config = gr.State({ | |
| "msg": None, | |
| "tts_model": DEFAULT_TTS_MODEL_NAME, | |
| "ref_audio": None, | |
| "ref_audio_transcribe": None | |
| }) | |
| gr.on(triggers=[msg.change, tts_model.change, ref_audio.change, | |
| ref_audio_transcribe.change], | |
| fn=lambda text, model, audio, ref_text: {"msg": text, "tts_model": model, "ref_audio": audio, | |
| "ref_audio_transcribe": ref_text}, | |
| inputs=[msg, tts_model, ref_audio, ref_audio_transcribe], | |
| outputs=current_config | |
| ) | |
| submit_btn.click( | |
| predict, | |
| [current_config], | |
| [audio_player], | |
| queue=False | |
| ) | |
| demo.launch(share=False, server_name='0.0.0.0', server_port=7863, inbrowser=True) | |