import gradio as gr import torch import numpy as np from dia.model import Dia import warnings # Suppress warnings for cleaner output warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) # Global model variable model = None def load_model_once(): """Load the Dia model once and cache it globally""" global model if model is None: print("Loading Dia model... This may take a few minutes.") try: # Load model with correct parameters for Dia model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32") # Move model to GPU if available if torch.cuda.is_available(): model = model.cuda() print(f"Model loaded on GPU: {torch.cuda.get_device_name()}") else: print("Model loaded on CPU") print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise e return model def generate_audio(text, seed=42): """Generate audio from text input with error handling""" try: # Clear GPU cache before generation if torch.cuda.is_available(): torch.cuda.empty_cache() current_model = load_model_once() # Validate input if not text or not text.strip(): return None, "❌ Please enter some text" # Clean and format text text = text.strip() if not text.startswith('[S1]') and not text.startswith('[S2]'): text = '[S1] ' + text # Set seed for reproducibility if seed: torch.manual_seed(int(seed)) if torch.cuda.is_available(): torch.cuda.manual_seed(int(seed)) print(f"Generating speech for: {text[:100]}...") # Generate audio - disable torch compile for stability with torch.no_grad(): audio_output = current_model.generate( text, use_torch_compile=False, # Disabled for T4 compatibility verbose=False ) # Ensure audio_output is numpy array if isinstance(audio_output, torch.Tensor): audio_output = audio_output.cpu().numpy() # Normalize audio to prevent clipping if len(audio_output) > 0: max_val = np.max(np.abs(audio_output)) if max_val > 1.0: audio_output = audio_output / max_val * 0.95 print("✅ Audio generated successfully!") return (44100, audio_output), "✅ Audio generated successfully!" except torch.cuda.OutOfMemoryError: # Handle GPU memory issues if torch.cuda.is_available(): torch.cuda.empty_cache() error_msg = "❌ GPU memory error. Try shorter text or restart the space." print(error_msg) return None, error_msg except Exception as e: error_msg = f"❌ Error: {str(e)}" print(error_msg) return None, error_msg # Create the Gradio interface - simplified to avoid OAuth triggers demo = gr.Blocks(title="Dia TTS Demo") with demo: gr.HTML("""
Generate multi-speaker, emotion-aware dialogue using the Dia 1.6B model
Supported Emotions: (laughs), (sighs), (gasps), (excited), (sad), (angry), (surprised), (whispers), (shouts), (coughs), (clears throat), (sniffs), (chuckles), (groans)