import spaces import subprocess import os subprocess.run( 'pip install flash-attn --no-build-isolation', shell=True, env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"} ) # subprocess.run('FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation', shell=True) # subprocess.run('pip install flash-attn==2.2.0 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # subprocess.run('git clone https://github.com/Dao-AILab/flash-attention.git && cd flash-attention && pip install flash-attn . --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "False"}, shell=True) import tempfile import random import shutil import pickle import gradio as gr import soundfile as sf from pathlib import Path import torch import torchaudio from huggingface_hub import hf_hub_download from infer import load_model, eval_model from spkr import SpeakerEmbedding @spaces.GPU def spkr_model_init(): spkr_model = SpeakerEmbedding(device="cpu") return spkr_model spkr_model = spkr_model_init() spkr_model.model.to("cuda") spkr_model.device = "cuda" model, tokenizer, tokenizer_voila, model_type = load_model("maitrix-org/Voila-chat", "maitrix-org/Voila-Tokenizer") model = model.to("cuda") tokenizer_voila.to("cuda") default_ref_file = "examples/character_ref_emb_demo.pkl" default_ref_name = "Homer Simpson" million_voice_ref_file = hf_hub_download(repo_id="maitrix-org/Voila-million-voice", filename="character_ref_emb_chunk0.pkl", repo_type="dataset") instruction = "You are a smart AI agent created by Maitrix.org." save_path = os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir()) intro = """**Voila** For more demos, please goto [https://voila.maitrix.org](https://voila.maitrix.org).""" default_ref_emb_mask_list = pickle.load(open(default_ref_file, "rb")) million_voice_ref_emb_mask_list = pickle.load(open(million_voice_ref_file, "rb")) @spaces.GPU def get_ref_embs(ref_audio): wav, sr = torchaudio.load(ref_audio) ref_embs = spkr_model(wav, sr).cpu() return ref_embs def delete_directory(request: gr.Request): if not request.session_hash: return user_dir = Path(f"{save_path}/{str(request.session_hash)}") if user_dir.exists(): shutil.rmtree(str(user_dir)) def add_message(history, message): history.append({"role": "user", "content": {"path": message}}) return history, gr.Audio(value=None), gr.Button(interactive=False) @spaces.GPU def call_bot(history, ref_embs, request: gr.Request): formated_history = { "instruction": instruction, "conversations": [{'from': item["role"], 'audio': {"file": item["content"][0]}} for item in history], } formated_history["conversations"].append({"from": "assistant"}) print(formated_history) ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cpu") ref_embs_mask = torch.tensor([1], device="cpu") ref_embs.to("cuda") ref_embs_mask.to("cuda") out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_aiao", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512) if 'audio' in out: wav, sr = out['audio'] user_dir = Path(f"{save_path}/{str(request.session_hash)}") user_dir.mkdir(exist_ok=True) save_name = f"{user_dir}/{len(history)}.wav" sf.write(save_name, wav, sr) history.append({"role": "assistant", "content": {"path": save_name}}) else: history.append({"role": "assistant", "content": {"text": out['text']}}) return history @spaces.GPU def run_tts(text, ref_embs): formated_history = { "instruction": "", "conversations": [{'from': "user", 'text': text}], } formated_history["conversations"].append({"from": "assistant"}) ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cpu") ref_embs_mask = torch.tensor([1], device="cpu") ref_embs.to("cuda") ref_embs_mask.to("cuda") out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_tts", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512) if 'audio' in out: wav, sr = out['audio'] return (sr, wav) else: raise Exception("No audio output") @spaces.GPU def run_asr(audio): formated_history = { "instruction": "", "conversations": [{'from': "user", 'audio': {"file": audio}}], } formated_history["conversations"].append({"from": "assistant"}) out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_asr", formated_history, None, None, max_new_tokens=512) if 'text' in out: return out['text'] else: raise Exception("No text output") def markdown_ref_name(ref_name): return f"### Current voice id: {ref_name}" def random_million_voice(): voice_id = random.choice(list(million_voice_ref_emb_mask_list.keys())) return markdown_ref_name(voice_id), million_voice_ref_emb_mask_list[voice_id] def get_ref_modules(cur_ref_embs): with gr.Row() as ref_row: with gr.Row(): current_ref_name = gr.Markdown(markdown_ref_name(default_ref_name)) with gr.Row() as ref_name_row: with gr.Column(scale=2, min_width=160): ref_name_dropdown = gr.Dropdown( choices=list(default_ref_emb_mask_list.keys()), value=default_ref_name, label="Reference voice", min_width=160, ) with gr.Column(scale=1, min_width=80): random_ref_button = gr.Button( "Random from Million Voice", size="md", ) with gr.Row(visible=False) as ref_audio_row: with gr.Column(scale=2, min_width=80): ref_audio = gr.Audio( sources=["microphone", "upload"], type="filepath", show_label=False, min_width=80, ) with gr.Column(scale=1, min_width=80): change_ref_button = gr.Button( "Change voice", interactive=False, min_width=80, ) ref_name_dropdown.change( lambda x: (markdown_ref_name(x), default_ref_emb_mask_list[x]), ref_name_dropdown, [current_ref_name, cur_ref_embs] ) random_ref_button.click( random_million_voice, None, [current_ref_name, cur_ref_embs], ) ref_audio.input(lambda: gr.Button(interactive=True), None, change_ref_button) # If custom ref voice checkbox is checked, show the Audio component to record or upload a reference voice custom_ref_voice = gr.Checkbox(label="Use custom voice", value=False) # Checked: enable audio and button # Unchecked: disable audio and button def custom_ref_voice_change(x, cur_ref_embs, cur_ref_embs_mask): if not x: cur_ref_embs = default_ref_emb_mask_list[default_ref_name] return [gr.Row(visible=not x), gr.Audio(value=None), gr.Row(visible=x), markdown_ref_name("Custom voice"), cur_ref_embs] custom_ref_voice.change( custom_ref_voice_change, [custom_ref_voice, cur_ref_embs], [ref_name_row, ref_audio, ref_audio_row, current_ref_name, cur_ref_embs] ) # When change ref button is clicked, get the reference voice and update the reference voice state change_ref_button.click( lambda: gr.Button(interactive=False), None, [change_ref_button] ).then( get_ref_embs, ref_audio, cur_ref_embs ) return ref_row def get_chat_tab(): cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name]) with gr.Row() as chat_tab: with gr.Column(scale=1): ref_row = get_ref_modules(cur_ref_embs) # Voice chat input chat_input = gr.Audio( sources=["microphone", "upload"], type="filepath", show_label=False, ) submit = gr.Button("Submit", interactive=False) gr.Markdown(intro) with gr.Column(scale=9): chatbot = gr.Chatbot( elem_id="chatbot", type="messages", bubble_full_width=False, scale=1, show_copy_button=False, avatar_images=( None, # os.path.join("files", "avatar.png"), None, # os.path.join("files", "avatar.png"), ), ) chat_input.input(lambda: gr.Button(interactive=True), None, submit) submit.click( add_message, [chatbot, chat_input], [chatbot, chat_input, submit] ).then( call_bot, [chatbot, cur_ref_embs], chatbot, api_name="bot_response" ) return chat_tab def get_tts_tab(): cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name]) with gr.Row() as tts_tab: with gr.Column(scale=1): ref_row = get_ref_modules(cur_ref_embs) gr.Markdown(intro) with gr.Column(scale=9): tts_output = gr.Audio(label="TTS output", interactive=False) with gr.Row(): text_input = gr.Textbox(label="Text", placeholder="Text to TTS") submit = gr.Button("Submit") submit.click( run_tts, [text_input, cur_ref_embs], tts_output ) return tts_tab def get_asr_tab(): with gr.Row() as asr_tab: with gr.Column(): asr_input = gr.Audio( label="ASR input", sources=["microphone", "upload"], type="filepath", ) submit = gr.Button("Submit") gr.Markdown(intro) with gr.Column(): asr_output = gr.Textbox(label="ASR output", interactive=False) submit.click( run_asr, [asr_input], asr_output ) return asr_tab with gr.Blocks(fill_height=True) as demo: with gr.Tab("Chat"): chat_tab = get_chat_tab() with gr.Tab("TTS"): tts_tab = get_tts_tab() with gr.Tab("ASR"): asr_tab = get_asr_tab() demo.unload(delete_directory) if __name__ == "__main__": demo.launch()