import gradio as gr import torch import numpy as np from dia.model import Dia import warnings # Suppress warnings for cleaner output warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) # Global model variable model = None def load_model_once(): """Load the Dia model once and cache it globally""" global model if model is None: print("Loading Dia model... This may take a few minutes.") try: # Load model with correct parameters for Dia model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32") # Move model to GPU if available if torch.cuda.is_available(): model = model.cuda() print(f"Model loaded on GPU: {torch.cuda.get_device_name()}") else: print("Model loaded on CPU") print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise e return model def generate_audio(text, seed=42): """Generate audio from text input with error handling""" try: # Clear GPU cache before generation if torch.cuda.is_available(): torch.cuda.empty_cache() current_model = load_model_once() # Validate input if not text or not text.strip(): return None, "❌ Please enter some text" # Clean and format text text = text.strip() if not text.startswith('[S1]') and not text.startswith('[S2]'): text = '[S1] ' + text # Set seed for reproducibility if seed: torch.manual_seed(int(seed)) if torch.cuda.is_available(): torch.cuda.manual_seed(int(seed)) print(f"Generating speech for: {text[:100]}...") # Generate audio - disable torch compile for stability with torch.no_grad(): audio_output = current_model.generate( text, use_torch_compile=False, # Disabled for T4 compatibility verbose=False ) # Ensure audio_output is numpy array if isinstance(audio_output, torch.Tensor): audio_output = audio_output.cpu().numpy() # Normalize audio to prevent clipping if len(audio_output) > 0: max_val = np.max(np.abs(audio_output)) if max_val > 1.0: audio_output = audio_output / max_val * 0.95 print("✅ Audio generated successfully!") return (44100, audio_output), "✅ Audio generated successfully!" except torch.cuda.OutOfMemoryError: # Handle GPU memory issues if torch.cuda.is_available(): torch.cuda.empty_cache() error_msg = "❌ GPU memory error. Try shorter text or restart the space." print(error_msg) return None, error_msg except Exception as e: error_msg = f"❌ Error: {str(e)}" print(error_msg) return None, error_msg # Create the Gradio interface - simplified to avoid OAuth triggers demo = gr.Blocks(title="Dia TTS Demo") with demo: gr.HTML("""

🎙️ Dia TTS - Ultra-Realistic Text-to-Speech

Generate multi-speaker, emotion-aware dialogue using the Dia 1.6B model

""") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="📝 Text to Speech", placeholder="[S1] Hello there! How are you today? [S2] I'm doing great, thanks for asking! (laughs)", lines=6, value="[S1] Welcome to the Dia TTS demo! [S2] This is amazing technology!", info="Use [S1] and [S2] for different speakers. Add emotions like (laughs), (sighs), etc." ) seed_input = gr.Number( label="🎲 Random Seed", value=42, precision=0, info="Same seed = consistent voices" ) generate_btn = gr.Button("🎵 Generate Speech", variant="primary") with gr.Column(): audio_output = gr.Audio( label="🔊 Generated Audio", type="numpy" ) status_text = gr.Textbox( label="📊 Status", interactive=False, lines=2 ) # Connect the button to the function generate_btn.click( fn=generate_audio, inputs=[text_input, seed_input], outputs=[audio_output, status_text] ) # Add example buttons with gr.Row(): example_btn1 = gr.Button("📻 Podcast", size="sm") example_btn2 = gr.Button("😄 Chat", size="sm") example_btn3 = gr.Button("🎭 Drama", size="sm") # Example button functions example_btn1.click( lambda: "[S1] Welcome to our podcast! [S2] Thanks for having me on the show!", outputs=text_input ) example_btn2.click( lambda: "[S1] Did you see the game? [S2] Yes! (laughs) It was incredible!", outputs=text_input ) example_btn3.click( lambda: "[S1] I can't believe you're leaving. (sighs) [S2] I know, it's hard. (sad)", outputs=text_input ) # Usage instructions gr.HTML("""

💡 Usage Tips:

Supported Emotions: (laughs), (sighs), (gasps), (excited), (sad), (angry), (surprised), (whispers), (shouts), (coughs), (clears throat), (sniffs), (chuckles), (groans)

""") # Launch with basic configuration if __name__ == "__main__": demo.launch()