Devakumar868's picture
Update app.py
2c1a7ab verified
raw
history blame
6.6 kB
import gradio as gr
import torch
import numpy as np
from dia.model import Dia
import warnings
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# Global model variable
model = None
def load_model_once():
"""Load the Dia model once and cache it globally"""
global model
if model is None:
print("Loading Dia model... This may take a few minutes.")
try:
# Load model with correct parameters for Dia
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
else:
print("Model loaded on CPU")
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
return model
def generate_audio(text, seed=42):
"""Generate audio from text input with error handling"""
try:
# Clear GPU cache before generation
if torch.cuda.is_available():
torch.cuda.empty_cache()
current_model = load_model_once()
# Validate input
if not text or not text.strip():
return None, "❌ Please enter some text"
# Clean and format text
text = text.strip()
if not text.startswith('[S1]') and not text.startswith('[S2]'):
text = '[S1] ' + text
# Set seed for reproducibility
if seed:
torch.manual_seed(int(seed))
if torch.cuda.is_available():
torch.cuda.manual_seed(int(seed))
print(f"Generating speech for: {text[:100]}...")
# Generate audio - disable torch compile for stability
with torch.no_grad():
audio_output = current_model.generate(
text,
use_torch_compile=False, # Disabled for T4 compatibility
verbose=False
)
# Ensure audio_output is numpy array
if isinstance(audio_output, torch.Tensor):
audio_output = audio_output.cpu().numpy()
# Normalize audio to prevent clipping
if len(audio_output) > 0:
max_val = np.max(np.abs(audio_output))
if max_val > 1.0:
audio_output = audio_output / max_val * 0.95
print("βœ… Audio generated successfully!")
return (44100, audio_output), "βœ… Audio generated successfully!"
except torch.cuda.OutOfMemoryError:
# Handle GPU memory issues
if torch.cuda.is_available():
torch.cuda.empty_cache()
error_msg = "❌ GPU memory error. Try shorter text or restart the space."
print(error_msg)
return None, error_msg
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
return None, error_msg
# Create the Gradio interface - simplified to avoid OAuth triggers
demo = gr.Blocks(title="Dia TTS Demo")
with demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>πŸŽ™οΈ Dia TTS - Ultra-Realistic Text-to-Speech</h1>
<p style="font-size: 18px; color: #666;">
Generate multi-speaker, emotion-aware dialogue using the Dia 1.6B model
</p>
</div>
""")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="πŸ“ Text to Speech",
placeholder="[S1] Hello there! How are you today? [S2] I'm doing great, thanks for asking! (laughs)",
lines=6,
value="[S1] Welcome to the Dia TTS demo! [S2] This is amazing technology!",
info="Use [S1] and [S2] for different speakers. Add emotions like (laughs), (sighs), etc."
)
seed_input = gr.Number(
label="🎲 Random Seed",
value=42,
precision=0,
info="Same seed = consistent voices"
)
generate_btn = gr.Button("🎡 Generate Speech", variant="primary")
with gr.Column():
audio_output = gr.Audio(
label="πŸ”Š Generated Audio",
type="numpy"
)
status_text = gr.Textbox(
label="πŸ“Š Status",
interactive=False,
lines=2
)
# Connect the button to the function
generate_btn.click(
fn=generate_audio,
inputs=[text_input, seed_input],
outputs=[audio_output, status_text]
)
# Add example buttons
with gr.Row():
example_btn1 = gr.Button("πŸ“» Podcast", size="sm")
example_btn2 = gr.Button("πŸ˜„ Chat", size="sm")
example_btn3 = gr.Button("🎭 Drama", size="sm")
# Example button functions
example_btn1.click(
lambda: "[S1] Welcome to our podcast! [S2] Thanks for having me on the show!",
outputs=text_input
)
example_btn2.click(
lambda: "[S1] Did you see the game? [S2] Yes! (laughs) It was incredible!",
outputs=text_input
)
example_btn3.click(
lambda: "[S1] I can't believe you're leaving. (sighs) [S2] I know, it's hard. (sad)",
outputs=text_input
)
# Usage instructions
gr.HTML("""
<div style="margin-top: 20px; padding: 15px; background: #f0f8ff; border-radius: 8px;">
<h3>πŸ’‘ Usage Tips:</h3>
<ul>
<li><strong>Speaker Tags:</strong> Use [S1] and [S2] to switch between speakers</li>
<li><strong>Emotions:</strong> Add (laughs), (sighs), (excited), (whispers), (sad), etc.</li>
<li><strong>Length:</strong> Keep text moderate length (5-20 seconds of speech works best)</li>
<li><strong>Seeds:</strong> Use the same seed number for consistent voice characteristics</li>
</ul>
<p><strong>Supported Emotions:</strong> (laughs), (sighs), (gasps), (excited), (sad), (angry),
(surprised), (whispers), (shouts), (coughs), (clears throat), (sniffs), (chuckles), (groans)</p>
</div>
""")
# Launch with basic configuration
if __name__ == "__main__":
demo.launch()