|
import csv |
|
import datetime |
|
import os |
|
import re |
|
import time |
|
import uuid |
|
from io import StringIO |
|
|
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
from huggingface_hub import HfApi, hf_hub_download, snapshot_download |
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
from vinorm import TTSnorm |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
api = HfApi(token=HF_TOKEN) |
|
|
|
|
|
print("Downloading viXTTS model files if not already present...") |
|
checkpoint_dir = "model/" |
|
repo_id = "capleaf/viXTTS" |
|
use_deepspeed = False |
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] |
|
files_in_dir = os.listdir(checkpoint_dir) |
|
if not all(file in files_in_dir for file in required_files): |
|
snapshot_download( |
|
repo_id=repo_id, |
|
repo_type="model", |
|
local_dir=checkpoint_dir, |
|
) |
|
hf_hub_download( |
|
repo_id="coqui/XTTS-v2", |
|
filename="speakers_xtts.pth", |
|
local_dir=checkpoint_dir, |
|
) |
|
|
|
|
|
xtts_config = os.path.join(checkpoint_dir, "config.json") |
|
config = XttsConfig() |
|
config.load_json(xtts_config) |
|
MODEL = Xtts.init_from_config(config) |
|
MODEL.load_checkpoint( |
|
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed |
|
) |
|
if torch.cuda.is_available(): |
|
MODEL.cuda() |
|
|
|
|
|
supported_languages = config.languages |
|
if "vi" not in supported_languages: |
|
supported_languages.append("vi") |
|
|
|
|
|
def normalize_vietnamese_text(text): |
|
text = ( |
|
TTSnorm(text, unknown=False, lower=False, rule=True) |
|
.replace("..", ".") |
|
.replace("!.", "!") |
|
.replace("?.", "?") |
|
.replace(" .", ".") |
|
.replace(" ,", ",") |
|
.replace('"', "") |
|
.replace("'", "") |
|
.replace("AI", "Ây Ai") |
|
.replace("A.I", "Ây Ai") |
|
) |
|
return text |
|
|
|
|
|
def calculate_keep_len(text, lang): |
|
if lang in ["ja", "zh-cn"]: |
|
return -1 |
|
|
|
word_count = len(text.split()) |
|
num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",") |
|
|
|
if word_count < 5: |
|
return 15000 * word_count + 2000 * num_punct |
|
elif word_count < 10: |
|
return 13000 * word_count + 2000 * num_punct |
|
return -1 |
|
|
|
|
|
def predict(prompt, language, audio_file_pth, normalize_text=True): |
|
if language not in supported_languages: |
|
metrics_text = gr.Warning( |
|
f"Language {language} is not supported. Please choose from the dropdown." |
|
) |
|
return None, metrics_text |
|
|
|
if len(prompt) < 2: |
|
metrics_text = gr.Warning("Please provide a longer prompt text.") |
|
return None, metrics_text |
|
|
|
try: |
|
metrics_text = "" |
|
t_latent = time.time() |
|
|
|
try: |
|
gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents( |
|
audio_path=audio_file_pth, |
|
gpt_cond_len=30, |
|
gpt_cond_chunk_len=4, |
|
max_ref_length=60, |
|
) |
|
except Exception as e: |
|
print("Speaker encoding error:", str(e)) |
|
metrics_text = gr.Warning("Error with reference audio.") |
|
return None, metrics_text |
|
|
|
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2", prompt) |
|
|
|
if normalize_text and language == "vi": |
|
prompt = normalize_vietnamese_text(prompt) |
|
|
|
print("Generating new audio...") |
|
t0 = time.time() |
|
out = MODEL.inference( |
|
prompt, |
|
language, |
|
gpt_cond_latent, |
|
speaker_embedding, |
|
repetition_penalty=5.0, |
|
temperature=0.75, |
|
enable_text_splitting=True, |
|
) |
|
inference_time = time.time() - t0 |
|
metrics_text += f"Time to generate audio: {round(inference_time * 1000)} ms\n" |
|
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000 |
|
metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n" |
|
|
|
keep_len = calculate_keep_len(prompt, language) |
|
out["wav"] = out["wav"][:keep_len] |
|
|
|
torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000) |
|
|
|
except RuntimeError as e: |
|
print("RuntimeError:", str(e)) |
|
metrics_text = gr.Warning("An error occurred during processing.") |
|
return None, metrics_text |
|
|
|
return "output.wav", metrics_text |
|
|
|
|
|
title = "viXTTS Demo" |
|
|
|
with gr.Blocks(analytics_enabled=False) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## viXTTS Demo") |
|
with gr.Column(): |
|
pass |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text_gr = gr.Textbox( |
|
label="Text Prompt", |
|
info="One or two sentences at a time is better. Up to 200 text characters.", |
|
value="Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt", |
|
) |
|
language_gr = gr.Dropdown( |
|
label="Language", |
|
info="Select an output language for the synthesised speech", |
|
choices=supported_languages, |
|
value="vi", |
|
) |
|
normalize_text = gr.Checkbox( |
|
label="Normalize Vietnamese Text", |
|
info="Normalize Vietnamese Text", |
|
value=True, |
|
) |
|
ref_gr = gr.Audio( |
|
label="Reference Audio", |
|
info="Click on the ✎ button to upload your own target speaker audio", |
|
type="filepath", |
|
value="model/samples/nu-luu-loat.wav", |
|
) |
|
tts_button = gr.Button("Send", elem_id="send-btn", visible=True) |
|
|
|
with gr.Column(): |
|
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) |
|
out_text_gr = gr.Textbox(label="Metrics") |
|
|
|
tts_button.click( |
|
predict, |
|
[input_text_gr, language_gr, ref_gr, normalize_text], |
|
outputs=[audio_gr, out_text_gr], |
|
api_name="predict", |
|
) |
|
|
|
demo.queue() |
|
demo.launch(debug=True, show_api=True) |