Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchaudio | |
import gradio as gr | |
from zonos.model import Zonos | |
from zonos.conditioning import make_cond_dict, supported_language_codes | |
# Global cache to hold the loaded model | |
MODEL = None | |
device = "cuda" | |
def load_model(): | |
""" | |
Loads the Zonos model once and caches it globally. | |
Adjust the model name if you want to switch from hybrid to transformer, etc. | |
""" | |
global MODEL | |
if MODEL is None: | |
model_name = "Zyphra/Zonos-v0.1-hybrid" | |
print(f"Loading model: {model_name}") | |
MODEL = Zonos.from_pretrained(model_name, device="cuda") | |
MODEL = MODEL.requires_grad_(False).eval() | |
MODEL.bfloat16() # optional if your GPU supports bfloat16 | |
print("Model loaded successfully!") | |
return MODEL | |
def tts(text, speaker_audio, selected_language): | |
""" | |
text: str | |
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy" | |
selected_language: str (e.g., "en-us", "es-es", etc.) | |
Returns (sample_rate, waveform) for Gradio audio output. | |
""" | |
model = load_model() | |
# If no text, return None | |
if not text: | |
return None | |
# If no reference audio, return None | |
if speaker_audio is None: | |
return None | |
# Gradio provides audio in (sample_rate, numpy_array) | |
sr, wav_np = speaker_audio | |
# Convert to Torch tensor: shape (1, num_samples) | |
wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float() | |
if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]: | |
# If shape is transposed, fix it | |
wav_tensor = wav_tensor.T | |
# Get speaker embedding | |
with torch.no_grad(): | |
spk_embedding = model.make_speaker_embedding(wav_tensor, sr) | |
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16) | |
# Prepare conditioning dictionary | |
cond_dict = make_cond_dict( | |
text=text, # The text prompt | |
speaker=spk_embedding, # Speaker embedding | |
language=selected_language, # Language from the Dropdown | |
device=device, | |
) | |
conditioning = model.prepare_conditioning(cond_dict) | |
# Generate codes | |
with torch.no_grad(): | |
codes = model.generate(conditioning) | |
# Decode the codes into raw audio | |
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze() | |
sr_out = model.autoencoder.sampling_rate | |
return (sr_out, wav_out.numpy()) | |
def build_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio + Language)") | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Text Prompt", | |
value="Hello from Zonos!", | |
lines=3 | |
) | |
ref_audio_input = gr.Audio( | |
label="Reference Audio (Speaker Cloning)", | |
type="numpy" | |
) | |
# Add a dropdown for language selection | |
language_dropdown = gr.Dropdown( | |
label="Language", | |
choices=supported_language_codes, | |
value="en-us", | |
interactive=True | |
) | |
generate_button = gr.Button("Generate") | |
# The output is an audio widget that Gradio will play | |
audio_output = gr.Audio(label="Synthesized Output", type="numpy") | |
# Bind the generate button: pass text, reference audio, and selected language | |
generate_button.click( | |
fn=tts, | |
inputs=[text_input, ref_audio_input, language_dropdown], | |
outputs=audio_output, | |
) | |
return demo | |
if __name__ == "__main__": | |
demo_app = build_demo() | |
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True) | |