import csv import datetime import os import re import time import uuid from io import StringIO import gradio as gr import torch import torchaudio from huggingface_hub import HfApi, hf_hub_download, snapshot_download from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from vinorm import TTSnorm # Initialize Hugging Face API HF_TOKEN = os.environ.get("HF_TOKEN") api = HfApi(token=HF_TOKEN) PASSWORD = os.environ.get("KEY") # Download model files if not already downloaded print("Downloading viXTTS model files if not already present...") checkpoint_dir = "model/" repo_id = "capleaf/viXTTS" use_deepspeed = False os.makedirs(checkpoint_dir, exist_ok=True) required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] files_in_dir = os.listdir(checkpoint_dir) if not all(file in files_in_dir for file in required_files): snapshot_download( repo_id=repo_id, repo_type="model", local_dir=checkpoint_dir, ) hf_hub_download( repo_id="coqui/XTTS-v2", filename="speakers_xtts.pth", local_dir=checkpoint_dir, ) # Load model configuration and initialize model xtts_config = os.path.join(checkpoint_dir, "config.json") config = XttsConfig() config.load_json(xtts_config) MODEL = Xtts.init_from_config(config) MODEL.load_checkpoint( config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed ) if torch.cuda.is_available(): MODEL.cuda() def authenticate(password): if password == PASSWORD: return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True, value="Invalid password"), gr.update(visible=True) # Supported languages supported_languages = config.languages if "vi" not in supported_languages: supported_languages.append("vi") def normalize_vietnamese_text(text): text = ( TTSnorm(text, unknown=False, lower=False, rule=True) .replace("..", ".") .replace("!.", "!") .replace("?.", "?") .replace(" .", ".") .replace(" ,", ",") .replace('"', "") .replace("'", "") .replace("AI", "Ây Ai") .replace("A.I", "Ây Ai") ) return text def calculate_keep_len(text, lang): if lang in ["ja", "zh-cn"]: return -1 word_count = len(text.split()) num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",") if word_count < 5: return 15000 * word_count + 2000 * num_punct elif word_count < 10: return 13000 * word_count + 2000 * num_punct return -1 def predict(prompt, language, audio_file_pth, normalize_text=True): if language not in supported_languages: metrics_text = gr.Warning( f"Language {language} is not supported. Please choose from the dropdown." ) return None, metrics_text if len(prompt) < 2: metrics_text = gr.Warning("Please provide a longer prompt text.") return None, metrics_text try: metrics_text = "" t_latent = time.time() try: gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents( audio_path=audio_file_pth, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60, ) except Exception as e: print("Speaker encoding error:", str(e)) metrics_text = gr.Warning("Error with reference audio.") return None, metrics_text prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2", prompt) if normalize_text and language == "vi": prompt = normalize_vietnamese_text(prompt) print("Generating new audio...") t0 = time.time() out = MODEL.inference( prompt, language, gpt_cond_latent, speaker_embedding, repetition_penalty=5.0, temperature=0.75, enable_text_splitting=True, ) inference_time = time.time() - t0 metrics_text += f"Time to generate audio: {round(inference_time * 1000)} ms\n" real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000 metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n" keep_len = calculate_keep_len(prompt, language) out["wav"] = out["wav"][:keep_len] torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000) except RuntimeError as e: print("RuntimeError:", str(e)) metrics_text = gr.Warning("An error occurred during processing.") return None, metrics_text return "output.wav", metrics_text title = "Phòng Thu VMC" with gr.Blocks(analytics_enabled=False) as demo: with gr.Row(): with gr.Column(): gr.Markdown("## VMC LAB") with gr.Column(): pass with gr.Row(): with gr.Column(): input_text_gr = gr.Textbox( label="Text Prompt", info="One or two sentences at a time is better. Up to 200 text characters.", value="Xin chào, hãy nhập nội dung cần thu âm vào đây", ) language_gr = gr.Dropdown( label="Language", info="Select an output language for the synthesised speech", choices=supported_languages, value="vi", ) normalize_text = gr.Checkbox( label="Normalize Vietnamese Text", info="Normalize Vietnamese Text", value=True, ) ref_gr = gr.Audio( label="Reference Audio", info="Click on the ✎ button to upload your own target speaker audio", type="filepath", value="model/samples/nu-luu-loat.wav", ) tts_button = gr.Button("Send", elem_id="send-btn", visible=True) with gr.Column(): audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) out_text_gr = gr.Textbox(label="Metrics") tts_button.click( predict, [input_text_gr, language_gr, ref_gr, normalize_text], outputs=[audio_gr, out_text_gr], api_name="predict", ) login_btn.click( fn=authenticate, inputs=password_input, outputs=[main_column, error_message, login_column] ) submit_btn.click( fn=pipe, inputs=[text, voice, image_in], outputs=[video_o], concurrency_limit=3 ) demo.queue(max_size=10).launch(show_error=True, show_api=False) demo.queue() demo.launch(debug=True, show_api=True)