Voff / app.py
TDN-M's picture
Update app.py
b085276 verified
raw
history blame
6.16 kB
import csv
import datetime
import os
import re
import time
import uuid
from io import StringIO
import gradio as gr
import torch
import torchaudio
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from vinorm import TTSnorm
# Initialize Hugging Face API
HF_TOKEN = os.environ.get("HF_TOKEN")
api = HfApi(token=HF_TOKEN)
# Download model files if not already downloaded
print("Downloading viXTTS model files if not already present...")
checkpoint_dir = "model/"
repo_id = "capleaf/viXTTS"
use_deepspeed = False
os.makedirs(checkpoint_dir, exist_ok=True)
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
files_in_dir = os.listdir(checkpoint_dir)
if not all(file in files_in_dir for file in required_files):
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=checkpoint_dir,
)
hf_hub_download(
repo_id="coqui/XTTS-v2",
filename="speakers_xtts.pth",
local_dir=checkpoint_dir,
)
# Load model configuration and initialize model
xtts_config = os.path.join(checkpoint_dir, "config.json")
config = XttsConfig()
config.load_json(xtts_config)
MODEL = Xtts.init_from_config(config)
MODEL.load_checkpoint(
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
)
if torch.cuda.is_available():
MODEL.cuda()
# Supported languages
supported_languages = config.languages
if "vi" not in supported_languages:
supported_languages.append("vi")
def normalize_vietnamese_text(text):
text = (
TTSnorm(text, unknown=False, lower=False, rule=True)
.replace("..", ".")
.replace("!.", "!")
.replace("?.", "?")
.replace(" .", ".")
.replace(" ,", ",")
.replace('"', "")
.replace("'", "")
.replace("AI", "Ây Ai")
.replace("A.I", "Ây Ai")
)
return text
def calculate_keep_len(text, lang):
if lang in ["ja", "zh-cn"]:
return -1
word_count = len(text.split())
num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
if word_count < 5:
return 15000 * word_count + 2000 * num_punct
elif word_count < 10:
return 13000 * word_count + 2000 * num_punct
return -1
def predict(prompt, language, audio_file_pth, normalize_text=True):
if language not in supported_languages:
metrics_text = gr.Warning(
f"Language {language} is not supported. Please choose from the dropdown."
)
return None, metrics_text
if len(prompt) < 2:
metrics_text = gr.Warning("Please provide a longer prompt text.")
return None, metrics_text
try:
metrics_text = ""
t_latent = time.time()
try:
gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
audio_path=audio_file_pth,
gpt_cond_len=30,
gpt_cond_chunk_len=4,
max_ref_length=60,
)
except Exception as e:
print("Speaker encoding error:", str(e))
metrics_text = gr.Warning("Error with reference audio.")
return None, metrics_text
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2", prompt)
if normalize_text and language == "vi":
prompt = normalize_vietnamese_text(prompt)
print("Generating new audio...")
t0 = time.time()
out = MODEL.inference(
prompt,
language,
gpt_cond_latent,
speaker_embedding,
repetition_penalty=5.0,
temperature=0.75,
enable_text_splitting=True,
)
inference_time = time.time() - t0
metrics_text += f"Time to generate audio: {round(inference_time * 1000)} ms\n"
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
keep_len = calculate_keep_len(prompt, language)
out["wav"] = out["wav"][:keep_len]
torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
except RuntimeError as e:
print("RuntimeError:", str(e))
metrics_text = gr.Warning("An error occurred during processing.")
return None, metrics_text
return "output.wav", metrics_text
title = "viXTTS Demo"
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## viXTTS Demo")
with gr.Column():
pass
with gr.Row():
with gr.Column():
input_text_gr = gr.Textbox(
label="Text Prompt",
info="One or two sentences at a time is better. Up to 200 text characters.",
value="Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt",
)
language_gr = gr.Dropdown(
label="Language",
info="Select an output language for the synthesised speech",
choices=supported_languages,
value="vi",
)
normalize_text = gr.Checkbox(
label="Normalize Vietnamese Text",
info="Normalize Vietnamese Text",
value=True,
)
ref_gr = gr.Audio(
label="Reference Audio",
info="Click on the ✎ button to upload your own target speaker audio",
type="filepath",
value="model/samples/nu-luu-loat.wav",
)
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
with gr.Column():
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
out_text_gr = gr.Textbox(label="Metrics")
tts_button.click(
predict,
[input_text_gr, language_gr, ref_gr, normalize_text],
outputs=[audio_gr, out_text_gr],
api_name="predict",
)
demo.queue()
demo.launch(debug=True, show_api=True)