voxtream / app.py
herimor's picture
Add demo
03e05c4
raw
history blame
3.82 kB
import json
from pathlib import Path
import nltk
import torch
import spaces
import gradio as gr
import numpy as np
from voxtream.generator import SpeechGenerator, SpeechGeneratorConfig
with open("configs/generator.json") as f:
config = SpeechGeneratorConfig(**json.load(f))
# Loading speaker encoder
torch.hub.load(
config.spk_enc_repo,
config.spk_enc_model,
model_name=config.spk_enc_model_name,
train_type=config.spk_enc_train_type,
dataset=config.spk_enc_dataset,
trust_repo=True,
verbose=False,
)
# Loading NLTK packages
nltk.download("averaged_perceptron_tagger_eng", quiet=True, raise_on_error=True)
nltk.download("punkt", quiet=True, raise_on_error=True)
# Initialize speech generator
speech_generator = SpeechGenerator(config)
CUSTOM_CSS = """
/* overall width */
.gradio-container {max-width: 1100px !important}
/* stack labels tighter and even heights */
#cols .wrap > .form {gap: 10px}
#left-col, #right-col {gap: 14px}
/* make submit centered + bigger */
#submit {width: 260px; margin: 10px auto 0 auto;}
/* make clear align left and look secondary */
#clear {width: 120px;}
/* give audio a little breathing room */
audio {outline: none;}
"""
@spaces.GPU
def synthesize_fn(prompt_audio_path, prompt_text, target_text):
if speech_generator.model.device == "cpu":
speech_generator.model.to("cuda")
speech_generator.mimi.to("cuda")
speech_generator.spk_enc.to("cuda")
speech_generator.aligner.to("cuda")
if not prompt_audio_path or not target_text:
return None
stream = speech_generator.generate_stream(
prompt_text=prompt_text,
prompt_audio_path=Path(prompt_audio_path),
text=target_text,
)
frames = [frame for frame, _ in stream]
if not frames:
return None
waveform = np.concatenate(frames).astype(np.float32)
# Fade out
fade_len_sec = 0.1
fade_out = np.linspace(1.0, 0.0, int(config.mimi_sr * fade_len_sec))
waveform[-int(config.mimi_sr * fade_len_sec) :] *= fade_out
return (config.mimi_sr, waveform)
def main():
with gr.Blocks(css=CUSTOM_CSS, title="VoXtream") as demo:
gr.Markdown("# VoXtream TTS demo")
with gr.Row(equal_height=True, elem_id="cols"):
with gr.Column(scale=1, elem_id="left-col"):
prompt_audio = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Prompt audio (3-5 sec of target voice)",
)
prompt_text = gr.Textbox(
lines=3,
label="Prompt transcript",
placeholder="Text that matches the prompt audio (Required)",
)
with gr.Column(scale=1, elem_id="right-col"):
target_text = gr.Textbox(
lines=3,
label="Target text",
placeholder="What you want the model to say",
)
output_audio = gr.Audio(
type="numpy",
label="Synthesized audio",
interactive=False,
)
with gr.Row():
clear_btn = gr.Button("Clear", elem_id="clear", variant="secondary")
submit_btn = gr.Button("Submit", elem_id="submit", variant="primary")
# wire up actions
submit_btn.click(
fn=synthesize_fn,
inputs=[prompt_audio, prompt_text, target_text],
outputs=output_audio,
)
# reset everything
clear_btn.click(
fn=lambda: (None, "", "", None),
inputs=[],
outputs=[prompt_audio, prompt_text, target_text, output_audio],
)
demo.launch()
if __name__ == "__main__":
main()