File size: 6,602 Bytes
998a350 2c1a7ab 998a350 2c1a7ab 998a350 2c1a7ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import gradio as gr
import torch
import numpy as np
from dia.model import Dia
import warnings
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# Global model variable
model = None
def load_model_once():
"""Load the Dia model once and cache it globally"""
global model
if model is None:
print("Loading Dia model... This may take a few minutes.")
try:
# Load model with correct parameters for Dia
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
else:
print("Model loaded on CPU")
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
return model
def generate_audio(text, seed=42):
"""Generate audio from text input with error handling"""
try:
# Clear GPU cache before generation
if torch.cuda.is_available():
torch.cuda.empty_cache()
current_model = load_model_once()
# Validate input
if not text or not text.strip():
return None, "β Please enter some text"
# Clean and format text
text = text.strip()
if not text.startswith('[S1]') and not text.startswith('[S2]'):
text = '[S1] ' + text
# Set seed for reproducibility
if seed:
torch.manual_seed(int(seed))
if torch.cuda.is_available():
torch.cuda.manual_seed(int(seed))
print(f"Generating speech for: {text[:100]}...")
# Generate audio - disable torch compile for stability
with torch.no_grad():
audio_output = current_model.generate(
text,
use_torch_compile=False, # Disabled for T4 compatibility
verbose=False
)
# Ensure audio_output is numpy array
if isinstance(audio_output, torch.Tensor):
audio_output = audio_output.cpu().numpy()
# Normalize audio to prevent clipping
if len(audio_output) > 0:
max_val = np.max(np.abs(audio_output))
if max_val > 1.0:
audio_output = audio_output / max_val * 0.95
print("β
Audio generated successfully!")
return (44100, audio_output), "β
Audio generated successfully!"
except torch.cuda.OutOfMemoryError:
# Handle GPU memory issues
if torch.cuda.is_available():
torch.cuda.empty_cache()
error_msg = "β GPU memory error. Try shorter text or restart the space."
print(error_msg)
return None, error_msg
except Exception as e:
error_msg = f"β Error: {str(e)}"
print(error_msg)
return None, error_msg
# Create the Gradio interface - simplified to avoid OAuth triggers
demo = gr.Blocks(title="Dia TTS Demo")
with demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>ποΈ Dia TTS - Ultra-Realistic Text-to-Speech</h1>
<p style="font-size: 18px; color: #666;">
Generate multi-speaker, emotion-aware dialogue using the Dia 1.6B model
</p>
</div>
""")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="π Text to Speech",
placeholder="[S1] Hello there! How are you today? [S2] I'm doing great, thanks for asking! (laughs)",
lines=6,
value="[S1] Welcome to the Dia TTS demo! [S2] This is amazing technology!",
info="Use [S1] and [S2] for different speakers. Add emotions like (laughs), (sighs), etc."
)
seed_input = gr.Number(
label="π² Random Seed",
value=42,
precision=0,
info="Same seed = consistent voices"
)
generate_btn = gr.Button("π΅ Generate Speech", variant="primary")
with gr.Column():
audio_output = gr.Audio(
label="π Generated Audio",
type="numpy"
)
status_text = gr.Textbox(
label="π Status",
interactive=False,
lines=2
)
# Connect the button to the function
generate_btn.click(
fn=generate_audio,
inputs=[text_input, seed_input],
outputs=[audio_output, status_text]
)
# Add example buttons
with gr.Row():
example_btn1 = gr.Button("π» Podcast", size="sm")
example_btn2 = gr.Button("π Chat", size="sm")
example_btn3 = gr.Button("π Drama", size="sm")
# Example button functions
example_btn1.click(
lambda: "[S1] Welcome to our podcast! [S2] Thanks for having me on the show!",
outputs=text_input
)
example_btn2.click(
lambda: "[S1] Did you see the game? [S2] Yes! (laughs) It was incredible!",
outputs=text_input
)
example_btn3.click(
lambda: "[S1] I can't believe you're leaving. (sighs) [S2] I know, it's hard. (sad)",
outputs=text_input
)
# Usage instructions
gr.HTML("""
<div style="margin-top: 20px; padding: 15px; background: #f0f8ff; border-radius: 8px;">
<h3>π‘ Usage Tips:</h3>
<ul>
<li><strong>Speaker Tags:</strong> Use [S1] and [S2] to switch between speakers</li>
<li><strong>Emotions:</strong> Add (laughs), (sighs), (excited), (whispers), (sad), etc.</li>
<li><strong>Length:</strong> Keep text moderate length (5-20 seconds of speech works best)</li>
<li><strong>Seeds:</strong> Use the same seed number for consistent voice characteristics</li>
</ul>
<p><strong>Supported Emotions:</strong> (laughs), (sighs), (gasps), (excited), (sad), (angry),
(surprised), (whispers), (shouts), (coughs), (clears throat), (sniffs), (chuckles), (groans)</p>
</div>
""")
# Launch with basic configuration
if __name__ == "__main__":
demo.launch()
|