|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from dia.model import Dia |
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
model = None |
|
|
|
def load_model_once(): |
|
"""Load the Dia model once and cache it globally""" |
|
global model |
|
if model is None: |
|
print("Loading Dia model... This may take a few minutes.") |
|
try: |
|
|
|
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
print(f"Model loaded on GPU: {torch.cuda.get_device_name()}") |
|
else: |
|
print("Model loaded on CPU") |
|
|
|
print("Model loaded successfully!") |
|
|
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise e |
|
|
|
return model |
|
|
|
def generate_audio(text, seed=42): |
|
"""Generate audio from text input with error handling""" |
|
try: |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
current_model = load_model_once() |
|
|
|
|
|
if not text or not text.strip(): |
|
return None, "β Please enter some text" |
|
|
|
|
|
text = text.strip() |
|
if not text.startswith('[S1]') and not text.startswith('[S2]'): |
|
text = '[S1] ' + text |
|
|
|
|
|
if seed: |
|
torch.manual_seed(int(seed)) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(int(seed)) |
|
|
|
print(f"Generating speech for: {text[:100]}...") |
|
|
|
|
|
with torch.no_grad(): |
|
audio_output = current_model.generate( |
|
text, |
|
use_torch_compile=False, |
|
verbose=False |
|
) |
|
|
|
|
|
if isinstance(audio_output, torch.Tensor): |
|
audio_output = audio_output.cpu().numpy() |
|
|
|
|
|
if len(audio_output) > 0: |
|
max_val = np.max(np.abs(audio_output)) |
|
if max_val > 1.0: |
|
audio_output = audio_output / max_val * 0.95 |
|
|
|
print("β
Audio generated successfully!") |
|
return (44100, audio_output), "β
Audio generated successfully!" |
|
|
|
except torch.cuda.OutOfMemoryError: |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
error_msg = "β GPU memory error. Try shorter text or restart the space." |
|
print(error_msg) |
|
return None, error_msg |
|
|
|
except Exception as e: |
|
error_msg = f"β Error: {str(e)}" |
|
print(error_msg) |
|
return None, error_msg |
|
|
|
|
|
demo = gr.Blocks(title="Dia TTS Demo") |
|
|
|
with demo: |
|
gr.HTML(""" |
|
<div style="text-align: center; padding: 20px;"> |
|
<h1>ποΈ Dia TTS - Ultra-Realistic Text-to-Speech</h1> |
|
<p style="font-size: 18px; color: #666;"> |
|
Generate multi-speaker, emotion-aware dialogue using the Dia 1.6B model |
|
</p> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox( |
|
label="π Text to Speech", |
|
placeholder="[S1] Hello there! How are you today? [S2] I'm doing great, thanks for asking! (laughs)", |
|
lines=6, |
|
value="[S1] Welcome to the Dia TTS demo! [S2] This is amazing technology!", |
|
info="Use [S1] and [S2] for different speakers. Add emotions like (laughs), (sighs), etc." |
|
) |
|
|
|
seed_input = gr.Number( |
|
label="π² Random Seed", |
|
value=42, |
|
precision=0, |
|
info="Same seed = consistent voices" |
|
) |
|
|
|
generate_btn = gr.Button("π΅ Generate Speech", variant="primary") |
|
|
|
with gr.Column(): |
|
audio_output = gr.Audio( |
|
label="π Generated Audio", |
|
type="numpy" |
|
) |
|
|
|
status_text = gr.Textbox( |
|
label="π Status", |
|
interactive=False, |
|
lines=2 |
|
) |
|
|
|
|
|
generate_btn.click( |
|
fn=generate_audio, |
|
inputs=[text_input, seed_input], |
|
outputs=[audio_output, status_text] |
|
) |
|
|
|
|
|
with gr.Row(): |
|
example_btn1 = gr.Button("π» Podcast", size="sm") |
|
example_btn2 = gr.Button("π Chat", size="sm") |
|
example_btn3 = gr.Button("π Drama", size="sm") |
|
|
|
|
|
example_btn1.click( |
|
lambda: "[S1] Welcome to our podcast! [S2] Thanks for having me on the show!", |
|
outputs=text_input |
|
) |
|
|
|
example_btn2.click( |
|
lambda: "[S1] Did you see the game? [S2] Yes! (laughs) It was incredible!", |
|
outputs=text_input |
|
) |
|
|
|
example_btn3.click( |
|
lambda: "[S1] I can't believe you're leaving. (sighs) [S2] I know, it's hard. (sad)", |
|
outputs=text_input |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<div style="margin-top: 20px; padding: 15px; background: #f0f8ff; border-radius: 8px;"> |
|
<h3>π‘ Usage Tips:</h3> |
|
<ul> |
|
<li><strong>Speaker Tags:</strong> Use [S1] and [S2] to switch between speakers</li> |
|
<li><strong>Emotions:</strong> Add (laughs), (sighs), (excited), (whispers), (sad), etc.</li> |
|
<li><strong>Length:</strong> Keep text moderate length (5-20 seconds of speech works best)</li> |
|
<li><strong>Seeds:</strong> Use the same seed number for consistent voice characteristics</li> |
|
</ul> |
|
|
|
<p><strong>Supported Emotions:</strong> (laughs), (sighs), (gasps), (excited), (sad), (angry), |
|
(surprised), (whispers), (shouts), (coughs), (clears throat), (sniffs), (chuckles), (groans)</p> |
|
</div> |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|