Spaces:
Running
Running
File size: 8,772 Bytes
030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 91223c9 030c851 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
import time
import logging
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional
from dia.model import Dia
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
# Global model instance (lazy loaded)
_model = None
def _get_model() -> Dia:
"""Lazy-load the Dia model to avoid loading it until needed"""
global _model
if _model is None:
logger.info("Loading Dia model...")
try:
# Check if torch is available with correct version
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
# Check if model path exists
logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
# Load the model with detailed logging
logger.info("Initializing Dia model...")
_model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
# Log model details
logger.info(f"Dia model loaded successfully")
logger.info(f"Model type: {type(_model).__name__}")
logger.info(f"Model device: {next(_model.parameters()).device}")
except ImportError as import_err:
logger.error(f"Import error loading Dia model: {import_err}")
logger.error(f"This may indicate missing dependencies")
raise
except FileNotFoundError as file_err:
logger.error(f"File not found error loading Dia model: {file_err}")
logger.error(f"Model path may be incorrect or inaccessible")
raise
except Exception as e:
logger.error(f"Error loading Dia model: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"This may indicate incompatible versions or missing CUDA support")
raise
return _model
def generate_speech(text: str, language: str = "zh") -> str:
"""Public interface for TTS generation using Dia model
Args:
text (str): Input text to synthesize
language (str): Language code (not used in Dia model, kept for API compatibility)
Returns:
str: Path to the generated audio file
"""
logger.info(f"Generating speech for text length: {len(text)}")
logger.info(f"Text content (first 50 chars): {text[:50]}...")
try:
# Create output directory if it doesn't exist
output_dir = "temp/outputs"
logger.info(f"Ensuring output directory exists: {output_dir}")
try:
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Output directory ready: {output_dir}")
except PermissionError as perm_err:
logger.error(f"Permission error creating output directory: {perm_err}")
raise
except Exception as dir_err:
logger.error(f"Error creating output directory: {dir_err}")
raise
# Generate unique output path
timestamp = int(time.time())
output_path = f"{output_dir}/output_{timestamp}.wav"
logger.info(f"Output will be saved to: {output_path}")
# Get the model
logger.info("Retrieving Dia model instance")
try:
model = _get_model()
logger.info("Successfully retrieved Dia model instance")
except Exception as model_err:
logger.error(f"Failed to get Dia model: {model_err}")
logger.error(f"Error type: {type(model_err).__name__}")
raise
# Generate audio
logger.info("Starting audio generation with Dia model")
start_time = time.time()
try:
with torch.inference_mode():
logger.info("Calling model.generate() with inference_mode")
output_audio_np = model.generate(
text,
max_tokens=None, # Use default from model config
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False, # Keep False for stability
verbose=False
)
logger.info("Model.generate() completed")
except RuntimeError as rt_err:
logger.error(f"Runtime error during generation: {rt_err}")
if "CUDA out of memory" in str(rt_err):
logger.error("CUDA out of memory error - consider reducing batch size or model size")
raise
except Exception as gen_err:
logger.error(f"Error during audio generation: {gen_err}")
logger.error(f"Error type: {type(gen_err).__name__}")
raise
end_time = time.time()
generation_time = end_time - start_time
logger.info(f"Generation finished in {generation_time:.2f} seconds")
# Process the output
if output_audio_np is not None:
logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
# Apply a slight slowdown for better quality (0.94x speed)
speed_factor = 0.94
original_len = len(output_audio_np)
target_len = int(original_len / speed_factor)
logger.info(f"Applying speed adjustment factor: {speed_factor}")
if target_len != original_len and target_len > 0:
try:
x_original = np.arange(original_len)
x_resampled = np.linspace(0, original_len - 1, target_len)
output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
except Exception as resample_err:
logger.error(f"Error during audio resampling: {resample_err}")
logger.warning("Using original audio without resampling")
# Save the audio file
logger.info(f"Saving audio to file: {output_path}")
try:
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
logger.info(f"Audio successfully saved to {output_path}")
except Exception as save_err:
logger.error(f"Error saving audio file: {save_err}")
logger.error(f"Error type: {type(save_err).__name__}")
raise
return output_path
else:
logger.warning("Generation produced no output (None returned from model)")
logger.warning("This may indicate a model configuration issue or empty input text")
dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
logger.warning(f"Returning dummy audio path: {dummy_path}")
return dummy_path
except Exception as e:
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
# Log additional diagnostic information based on error type
if isinstance(e, ImportError):
logger.error(f"Import error - missing dependency: {e.__class__.__module__}.{e.__class__.__name__}")
logger.error("Check if all required packages are installed correctly")
elif isinstance(e, RuntimeError) and "CUDA" in str(e):
logger.error("CUDA-related runtime error - check GPU compatibility and memory")
elif isinstance(e, AttributeError):
logger.error(f"Attribute error - likely API incompatibility or incorrect module version")
if hasattr(e, '__traceback__'):
tb = e.__traceback__
while tb.tb_next:
tb = tb.tb_next
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
elif isinstance(e, FileNotFoundError):
logger.error(f"File not found - check if model files exist and are accessible")
# Return dummy path in case of error
return "temp/outputs/dummy.wav" |