Spaces:
Sleeping
Sleeping
Michael Hu
commited on
Commit
·
91223c9
1
Parent(s):
419e343
enhance logging
Browse files- utils/tts.py +141 -18
- utils/tts_dia.py +119 -24
utils/tts.py
CHANGED
|
@@ -28,12 +28,43 @@ except ImportError:
|
|
| 28 |
# Try to import Dia as fallback
|
| 29 |
if not KOKORO_AVAILABLE:
|
| 30 |
try:
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
except ImportError as e:
|
| 35 |
-
logger.
|
|
|
|
| 36 |
logger.warning("Will use dummy TTS implementation as fallback")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
class TTSEngine:
|
| 39 |
def __init__(self, lang_code='z'):
|
|
@@ -45,20 +76,34 @@ class TTSEngine:
|
|
| 45 |
Note: lang_code is only used for Kokoro, not for Dia
|
| 46 |
"""
|
| 47 |
logger.info("Initializing TTS Engine")
|
|
|
|
| 48 |
self.engine_type = None
|
| 49 |
|
| 50 |
if KOKORO_AVAILABLE:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# For Dia, we don't need to initialize anything here
|
| 56 |
# The model will be lazy-loaded when needed
|
| 57 |
self.pipeline = None
|
| 58 |
self.engine_type = "dia"
|
| 59 |
logger.info("TTS engine initialized with Dia (lazy loading)")
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
logger.warning("Using dummy TTS implementation as no TTS engines are available")
|
|
|
|
| 62 |
self.pipeline = None
|
| 63 |
self.engine_type = "dummy"
|
| 64 |
|
|
@@ -95,13 +140,29 @@ class TTSEngine:
|
|
| 95 |
elif self.engine_type == "dia":
|
| 96 |
# Use Dia for TTS generation
|
| 97 |
try:
|
|
|
|
| 98 |
# Import here to avoid circular imports
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# Call Dia's generate_speech function
|
|
|
|
| 101 |
output_path = dia_generate_speech(text)
|
| 102 |
logger.info(f"Generated audio with Dia: {output_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
except Exception as dia_error:
|
| 104 |
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
|
|
|
|
|
|
|
| 105 |
# Fall back to dummy audio if Dia fails
|
| 106 |
return self._generate_dummy_audio(output_path)
|
| 107 |
else:
|
|
@@ -157,14 +218,36 @@ class TTSEngine:
|
|
| 157 |
# Dia doesn't support streaming natively, so we generate the full audio
|
| 158 |
# and then yield it as a single chunk
|
| 159 |
try:
|
|
|
|
| 160 |
# Import here to avoid circular imports
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
# Get the Dia model
|
| 165 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# Generate audio
|
|
|
|
| 168 |
with torch.inference_mode():
|
| 169 |
output_audio_np = model.generate(
|
| 170 |
text,
|
|
@@ -178,12 +261,22 @@ class TTSEngine:
|
|
| 178 |
)
|
| 179 |
|
| 180 |
if output_audio_np is not None:
|
|
|
|
| 181 |
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
| 182 |
else:
|
|
|
|
|
|
|
| 183 |
# Fall back to dummy audio if Dia fails
|
| 184 |
yield from self._generate_dummy_audio_stream()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
except Exception as dia_error:
|
| 186 |
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
|
|
|
|
|
|
| 187 |
# Fall back to dummy audio if Dia fails
|
| 188 |
yield from self._generate_dummy_audio_stream()
|
| 189 |
else:
|
|
@@ -221,14 +314,25 @@ def get_tts_engine(lang_code='a'):
|
|
| 221 |
Returns:
|
| 222 |
TTSEngine: Initialized TTS engine instance
|
| 223 |
"""
|
|
|
|
| 224 |
try:
|
| 225 |
import streamlit as st
|
|
|
|
| 226 |
@st.cache_resource
|
| 227 |
def _get_engine():
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
except ImportError:
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 234 |
"""Public interface for TTS generation
|
|
@@ -241,5 +345,24 @@ def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> s
|
|
| 241 |
Returns:
|
| 242 |
str: Path to generated audio file
|
| 243 |
"""
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# Try to import Dia as fallback
|
| 29 |
if not KOKORO_AVAILABLE:
|
| 30 |
try:
|
| 31 |
+
logger.info("Attempting to import Dia TTS engine as fallback")
|
| 32 |
+
try:
|
| 33 |
+
# Check if required dependencies for Dia are available
|
| 34 |
+
import torch
|
| 35 |
+
logger.info("PyTorch is available for Dia TTS")
|
| 36 |
+
except ImportError as torch_err:
|
| 37 |
+
logger.error(f"PyTorch dependency for Dia TTS is missing: {str(torch_err)}")
|
| 38 |
+
raise ImportError(f"PyTorch dependency required for Dia TTS: {str(torch_err)}") from torch_err
|
| 39 |
+
|
| 40 |
+
# Try to import the Dia module
|
| 41 |
+
try:
|
| 42 |
+
from utils.tts_dia import _get_model as get_dia_model
|
| 43 |
+
logger.info("Successfully imported Dia TTS module")
|
| 44 |
+
|
| 45 |
+
# Verify the model can be accessed
|
| 46 |
+
logger.info("Verifying Dia model can be accessed")
|
| 47 |
+
model_info = get_dia_model.__module__
|
| 48 |
+
logger.info(f"Dia model module: {model_info}")
|
| 49 |
+
|
| 50 |
+
DIA_AVAILABLE = True
|
| 51 |
+
logger.info("Dia TTS engine is available as fallback")
|
| 52 |
+
except ImportError as module_err:
|
| 53 |
+
logger.error(f"Failed to import Dia TTS module: {str(module_err)}")
|
| 54 |
+
logger.error(f"Module path: {module_err.__traceback__.tb_frame.f_globals.get('__file__', 'unknown')}")
|
| 55 |
+
raise
|
| 56 |
+
except AttributeError as attr_err:
|
| 57 |
+
logger.error(f"Dia TTS module attribute error: {str(attr_err)}")
|
| 58 |
+
logger.error(f"This may indicate the module exists but has incorrect structure")
|
| 59 |
+
raise
|
| 60 |
except ImportError as e:
|
| 61 |
+
logger.error(f"Dia TTS engine is not available due to import error: {str(e)}")
|
| 62 |
+
logger.error(f"Import path attempted: {e.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
| 63 |
logger.warning("Will use dummy TTS implementation as fallback")
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"Unexpected error initializing Dia TTS: {str(e)}")
|
| 66 |
+
logger.error(f"Error type: {type(e).__name__}")
|
| 67 |
+
logger.error("Will use dummy TTS implementation as fallback")
|
| 68 |
|
| 69 |
class TTSEngine:
|
| 70 |
def __init__(self, lang_code='z'):
|
|
|
|
| 76 |
Note: lang_code is only used for Kokoro, not for Dia
|
| 77 |
"""
|
| 78 |
logger.info("Initializing TTS Engine")
|
| 79 |
+
logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
|
| 80 |
self.engine_type = None
|
| 81 |
|
| 82 |
if KOKORO_AVAILABLE:
|
| 83 |
+
logger.info(f"Using Kokoro as primary TTS engine with language code: {lang_code}")
|
| 84 |
+
try:
|
| 85 |
+
self.pipeline = KPipeline(lang_code=lang_code)
|
| 86 |
+
self.engine_type = "kokoro"
|
| 87 |
+
logger.info("TTS engine successfully initialized with Kokoro")
|
| 88 |
+
except Exception as kokoro_err:
|
| 89 |
+
logger.error(f"Failed to initialize Kokoro pipeline: {str(kokoro_err)}")
|
| 90 |
+
logger.error(f"Error type: {type(kokoro_err).__name__}")
|
| 91 |
+
logger.info("Will try to fall back to Dia TTS engine")
|
| 92 |
+
# Fall through to try Dia
|
| 93 |
+
|
| 94 |
+
# Try Dia if Kokoro is not available or failed to initialize
|
| 95 |
+
if self.engine_type is None and DIA_AVAILABLE:
|
| 96 |
+
logger.info("Using Dia as fallback TTS engine")
|
| 97 |
# For Dia, we don't need to initialize anything here
|
| 98 |
# The model will be lazy-loaded when needed
|
| 99 |
self.pipeline = None
|
| 100 |
self.engine_type = "dia"
|
| 101 |
logger.info("TTS engine initialized with Dia (lazy loading)")
|
| 102 |
+
|
| 103 |
+
# Use dummy if no TTS engines are available
|
| 104 |
+
if self.engine_type is None:
|
| 105 |
logger.warning("Using dummy TTS implementation as no TTS engines are available")
|
| 106 |
+
logger.warning("Check logs above for specific errors that prevented Kokoro or Dia initialization")
|
| 107 |
self.pipeline = None
|
| 108 |
self.engine_type = "dummy"
|
| 109 |
|
|
|
|
| 140 |
elif self.engine_type == "dia":
|
| 141 |
# Use Dia for TTS generation
|
| 142 |
try:
|
| 143 |
+
logger.info("Attempting to use Dia TTS for speech generation")
|
| 144 |
# Import here to avoid circular imports
|
| 145 |
+
try:
|
| 146 |
+
logger.info("Importing Dia speech generation module")
|
| 147 |
+
from utils.tts_dia import generate_speech as dia_generate_speech
|
| 148 |
+
logger.info("Successfully imported Dia speech generation function")
|
| 149 |
+
except ImportError as import_err:
|
| 150 |
+
logger.error(f"Failed to import Dia speech generation function: {str(import_err)}")
|
| 151 |
+
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
| 152 |
+
raise
|
| 153 |
+
|
| 154 |
# Call Dia's generate_speech function
|
| 155 |
+
logger.info("Calling Dia's generate_speech function")
|
| 156 |
output_path = dia_generate_speech(text)
|
| 157 |
logger.info(f"Generated audio with Dia: {output_path}")
|
| 158 |
+
except ImportError as import_err:
|
| 159 |
+
logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
|
| 160 |
+
logger.error("Falling back to dummy audio generation")
|
| 161 |
+
return self._generate_dummy_audio(output_path)
|
| 162 |
except Exception as dia_error:
|
| 163 |
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
|
| 164 |
+
logger.error(f"Error type: {type(dia_error).__name__}")
|
| 165 |
+
logger.error("Falling back to dummy audio generation")
|
| 166 |
# Fall back to dummy audio if Dia fails
|
| 167 |
return self._generate_dummy_audio(output_path)
|
| 168 |
else:
|
|
|
|
| 218 |
# Dia doesn't support streaming natively, so we generate the full audio
|
| 219 |
# and then yield it as a single chunk
|
| 220 |
try:
|
| 221 |
+
logger.info("Attempting to use Dia TTS for speech streaming")
|
| 222 |
# Import here to avoid circular imports
|
| 223 |
+
try:
|
| 224 |
+
logger.info("Importing required modules for Dia streaming")
|
| 225 |
+
import torch
|
| 226 |
+
logger.info("PyTorch successfully imported for Dia streaming")
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
|
| 230 |
+
logger.info("Successfully imported Dia model and sample rate")
|
| 231 |
+
except ImportError as import_err:
|
| 232 |
+
logger.error(f"Failed to import Dia model for streaming: {str(import_err)}")
|
| 233 |
+
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
| 234 |
+
raise
|
| 235 |
+
except ImportError as torch_err:
|
| 236 |
+
logger.error(f"PyTorch import failed for Dia streaming: {str(torch_err)}")
|
| 237 |
+
raise
|
| 238 |
|
| 239 |
# Get the Dia model
|
| 240 |
+
logger.info("Getting Dia model instance")
|
| 241 |
+
try:
|
| 242 |
+
model = _get_model()
|
| 243 |
+
logger.info("Successfully obtained Dia model instance")
|
| 244 |
+
except Exception as model_err:
|
| 245 |
+
logger.error(f"Failed to get Dia model instance: {str(model_err)}")
|
| 246 |
+
logger.error(f"Error type: {type(model_err).__name__}")
|
| 247 |
+
raise
|
| 248 |
|
| 249 |
# Generate audio
|
| 250 |
+
logger.info("Generating audio with Dia model")
|
| 251 |
with torch.inference_mode():
|
| 252 |
output_audio_np = model.generate(
|
| 253 |
text,
|
|
|
|
| 261 |
)
|
| 262 |
|
| 263 |
if output_audio_np is not None:
|
| 264 |
+
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
| 265 |
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
| 266 |
else:
|
| 267 |
+
logger.warning("Dia model returned None for audio output")
|
| 268 |
+
logger.warning("Falling back to dummy audio stream")
|
| 269 |
# Fall back to dummy audio if Dia fails
|
| 270 |
yield from self._generate_dummy_audio_stream()
|
| 271 |
+
except ImportError as import_err:
|
| 272 |
+
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
|
| 273 |
+
logger.error("Falling back to dummy audio stream")
|
| 274 |
+
# Fall back to dummy audio if Dia fails
|
| 275 |
+
yield from self._generate_dummy_audio_stream()
|
| 276 |
except Exception as dia_error:
|
| 277 |
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
| 278 |
+
logger.error(f"Error type: {type(dia_error).__name__}")
|
| 279 |
+
logger.error("Falling back to dummy audio stream")
|
| 280 |
# Fall back to dummy audio if Dia fails
|
| 281 |
yield from self._generate_dummy_audio_stream()
|
| 282 |
else:
|
|
|
|
| 314 |
Returns:
|
| 315 |
TTSEngine: Initialized TTS engine instance
|
| 316 |
"""
|
| 317 |
+
logger.info(f"Requesting TTS engine with language code: {lang_code}")
|
| 318 |
try:
|
| 319 |
import streamlit as st
|
| 320 |
+
logger.info("Streamlit detected, using cached TTS engine")
|
| 321 |
@st.cache_resource
|
| 322 |
def _get_engine():
|
| 323 |
+
logger.info("Creating cached TTS engine instance")
|
| 324 |
+
engine = TTSEngine(lang_code)
|
| 325 |
+
logger.info(f"Cached TTS engine created with type: {engine.engine_type}")
|
| 326 |
+
return engine
|
| 327 |
+
|
| 328 |
+
engine = _get_engine()
|
| 329 |
+
logger.info(f"Retrieved TTS engine from cache with type: {engine.engine_type}")
|
| 330 |
+
return engine
|
| 331 |
except ImportError:
|
| 332 |
+
logger.info("Streamlit not available, creating direct TTS engine instance")
|
| 333 |
+
engine = TTSEngine(lang_code)
|
| 334 |
+
logger.info(f"Direct TTS engine created with type: {engine.engine_type}")
|
| 335 |
+
return engine
|
| 336 |
|
| 337 |
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
| 338 |
"""Public interface for TTS generation
|
|
|
|
| 345 |
Returns:
|
| 346 |
str: Path to generated audio file
|
| 347 |
"""
|
| 348 |
+
logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}")
|
| 349 |
+
try:
|
| 350 |
+
# Get the TTS engine
|
| 351 |
+
logger.info("Getting TTS engine instance")
|
| 352 |
+
engine = get_tts_engine()
|
| 353 |
+
logger.info(f"Using TTS engine type: {engine.engine_type}")
|
| 354 |
+
|
| 355 |
+
# Generate speech
|
| 356 |
+
logger.info("Calling engine.generate_speech")
|
| 357 |
+
output_path = engine.generate_speech(text, voice, speed)
|
| 358 |
+
logger.info(f"Speech generation complete, output path: {output_path}")
|
| 359 |
+
return output_path
|
| 360 |
+
except Exception as e:
|
| 361 |
+
logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True)
|
| 362 |
+
logger.error(f"Error type: {type(e).__name__}")
|
| 363 |
+
if hasattr(e, '__traceback__'):
|
| 364 |
+
tb = e.__traceback__
|
| 365 |
+
while tb.tb_next:
|
| 366 |
+
tb = tb.tb_next
|
| 367 |
+
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
|
| 368 |
+
raise
|
utils/tts_dia.py
CHANGED
|
@@ -27,10 +27,36 @@ def _get_model() -> Dia:
|
|
| 27 |
if _model is None:
|
| 28 |
logger.info("Loading Dia model...")
|
| 29 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
_model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
except Exception as e:
|
| 33 |
logger.error(f"Error loading Dia model: {e}", exc_info=True)
|
|
|
|
|
|
|
| 34 |
raise
|
| 35 |
return _model
|
| 36 |
|
|
@@ -46,58 +72,127 @@ def generate_speech(text: str, language: str = "zh") -> str:
|
|
| 46 |
str: Path to the generated audio file
|
| 47 |
"""
|
| 48 |
logger.info(f"Generating speech for text length: {len(text)}")
|
|
|
|
| 49 |
|
| 50 |
try:
|
| 51 |
# Create output directory if it doesn't exist
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# Generate unique output path
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Get the model
|
| 58 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# Generate audio
|
|
|
|
| 61 |
start_time = time.time()
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
end_time = time.time()
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
# Process the output
|
| 79 |
if output_audio_np is not None:
|
|
|
|
|
|
|
|
|
|
| 80 |
# Apply a slight slowdown for better quality (0.94x speed)
|
| 81 |
speed_factor = 0.94
|
| 82 |
original_len = len(output_audio_np)
|
| 83 |
target_len = int(original_len / speed_factor)
|
| 84 |
|
|
|
|
| 85 |
if target_len != original_len and target_len > 0:
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Save the audio file
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
return output_path
|
| 96 |
else:
|
| 97 |
-
logger.warning("Generation produced no output
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
except Exception as e:
|
| 101 |
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
# Return dummy path in case of error
|
| 103 |
return "temp/outputs/dummy.wav"
|
|
|
|
| 27 |
if _model is None:
|
| 28 |
logger.info("Loading Dia model...")
|
| 29 |
try:
|
| 30 |
+
# Check if torch is available with correct version
|
| 31 |
+
logger.info(f"PyTorch version: {torch.__version__}")
|
| 32 |
+
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
logger.info(f"CUDA version: {torch.version.cuda}")
|
| 35 |
+
logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
|
| 36 |
+
|
| 37 |
+
# Check if model path exists
|
| 38 |
+
logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
|
| 39 |
+
|
| 40 |
+
# Load the model with detailed logging
|
| 41 |
+
logger.info("Initializing Dia model...")
|
| 42 |
_model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
|
| 43 |
+
|
| 44 |
+
# Log model details
|
| 45 |
+
logger.info(f"Dia model loaded successfully")
|
| 46 |
+
logger.info(f"Model type: {type(_model).__name__}")
|
| 47 |
+
logger.info(f"Model device: {next(_model.parameters()).device}")
|
| 48 |
+
except ImportError as import_err:
|
| 49 |
+
logger.error(f"Import error loading Dia model: {import_err}")
|
| 50 |
+
logger.error(f"This may indicate missing dependencies")
|
| 51 |
+
raise
|
| 52 |
+
except FileNotFoundError as file_err:
|
| 53 |
+
logger.error(f"File not found error loading Dia model: {file_err}")
|
| 54 |
+
logger.error(f"Model path may be incorrect or inaccessible")
|
| 55 |
+
raise
|
| 56 |
except Exception as e:
|
| 57 |
logger.error(f"Error loading Dia model: {e}", exc_info=True)
|
| 58 |
+
logger.error(f"Error type: {type(e).__name__}")
|
| 59 |
+
logger.error(f"This may indicate incompatible versions or missing CUDA support")
|
| 60 |
raise
|
| 61 |
return _model
|
| 62 |
|
|
|
|
| 72 |
str: Path to the generated audio file
|
| 73 |
"""
|
| 74 |
logger.info(f"Generating speech for text length: {len(text)}")
|
| 75 |
+
logger.info(f"Text content (first 50 chars): {text[:50]}...")
|
| 76 |
|
| 77 |
try:
|
| 78 |
# Create output directory if it doesn't exist
|
| 79 |
+
output_dir = "temp/outputs"
|
| 80 |
+
logger.info(f"Ensuring output directory exists: {output_dir}")
|
| 81 |
+
try:
|
| 82 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 83 |
+
logger.info(f"Output directory ready: {output_dir}")
|
| 84 |
+
except PermissionError as perm_err:
|
| 85 |
+
logger.error(f"Permission error creating output directory: {perm_err}")
|
| 86 |
+
raise
|
| 87 |
+
except Exception as dir_err:
|
| 88 |
+
logger.error(f"Error creating output directory: {dir_err}")
|
| 89 |
+
raise
|
| 90 |
|
| 91 |
# Generate unique output path
|
| 92 |
+
timestamp = int(time.time())
|
| 93 |
+
output_path = f"{output_dir}/output_{timestamp}.wav"
|
| 94 |
+
logger.info(f"Output will be saved to: {output_path}")
|
| 95 |
|
| 96 |
# Get the model
|
| 97 |
+
logger.info("Retrieving Dia model instance")
|
| 98 |
+
try:
|
| 99 |
+
model = _get_model()
|
| 100 |
+
logger.info("Successfully retrieved Dia model instance")
|
| 101 |
+
except Exception as model_err:
|
| 102 |
+
logger.error(f"Failed to get Dia model: {model_err}")
|
| 103 |
+
logger.error(f"Error type: {type(model_err).__name__}")
|
| 104 |
+
raise
|
| 105 |
|
| 106 |
# Generate audio
|
| 107 |
+
logger.info("Starting audio generation with Dia model")
|
| 108 |
start_time = time.time()
|
| 109 |
|
| 110 |
+
try:
|
| 111 |
+
with torch.inference_mode():
|
| 112 |
+
logger.info("Calling model.generate() with inference_mode")
|
| 113 |
+
output_audio_np = model.generate(
|
| 114 |
+
text,
|
| 115 |
+
max_tokens=None, # Use default from model config
|
| 116 |
+
cfg_scale=3.0,
|
| 117 |
+
temperature=1.3,
|
| 118 |
+
top_p=0.95,
|
| 119 |
+
cfg_filter_top_k=35,
|
| 120 |
+
use_torch_compile=False, # Keep False for stability
|
| 121 |
+
verbose=False
|
| 122 |
+
)
|
| 123 |
+
logger.info("Model.generate() completed")
|
| 124 |
+
except RuntimeError as rt_err:
|
| 125 |
+
logger.error(f"Runtime error during generation: {rt_err}")
|
| 126 |
+
if "CUDA out of memory" in str(rt_err):
|
| 127 |
+
logger.error("CUDA out of memory error - consider reducing batch size or model size")
|
| 128 |
+
raise
|
| 129 |
+
except Exception as gen_err:
|
| 130 |
+
logger.error(f"Error during audio generation: {gen_err}")
|
| 131 |
+
logger.error(f"Error type: {type(gen_err).__name__}")
|
| 132 |
+
raise
|
| 133 |
|
| 134 |
end_time = time.time()
|
| 135 |
+
generation_time = end_time - start_time
|
| 136 |
+
logger.info(f"Generation finished in {generation_time:.2f} seconds")
|
| 137 |
|
| 138 |
# Process the output
|
| 139 |
if output_audio_np is not None:
|
| 140 |
+
logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
|
| 141 |
+
logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
|
| 142 |
+
|
| 143 |
# Apply a slight slowdown for better quality (0.94x speed)
|
| 144 |
speed_factor = 0.94
|
| 145 |
original_len = len(output_audio_np)
|
| 146 |
target_len = int(original_len / speed_factor)
|
| 147 |
|
| 148 |
+
logger.info(f"Applying speed adjustment factor: {speed_factor}")
|
| 149 |
if target_len != original_len and target_len > 0:
|
| 150 |
+
try:
|
| 151 |
+
x_original = np.arange(original_len)
|
| 152 |
+
x_resampled = np.linspace(0, original_len - 1, target_len)
|
| 153 |
+
output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
|
| 154 |
+
logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
|
| 155 |
+
except Exception as resample_err:
|
| 156 |
+
logger.error(f"Error during audio resampling: {resample_err}")
|
| 157 |
+
logger.warning("Using original audio without resampling")
|
| 158 |
|
| 159 |
# Save the audio file
|
| 160 |
+
logger.info(f"Saving audio to file: {output_path}")
|
| 161 |
+
try:
|
| 162 |
+
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
|
| 163 |
+
logger.info(f"Audio successfully saved to {output_path}")
|
| 164 |
+
except Exception as save_err:
|
| 165 |
+
logger.error(f"Error saving audio file: {save_err}")
|
| 166 |
+
logger.error(f"Error type: {type(save_err).__name__}")
|
| 167 |
+
raise
|
| 168 |
|
| 169 |
return output_path
|
| 170 |
else:
|
| 171 |
+
logger.warning("Generation produced no output (None returned from model)")
|
| 172 |
+
logger.warning("This may indicate a model configuration issue or empty input text")
|
| 173 |
+
dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
|
| 174 |
+
logger.warning(f"Returning dummy audio path: {dummy_path}")
|
| 175 |
+
return dummy_path
|
| 176 |
|
| 177 |
except Exception as e:
|
| 178 |
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
|
| 179 |
+
logger.error(f"Error type: {type(e).__name__}")
|
| 180 |
+
|
| 181 |
+
# Log additional diagnostic information based on error type
|
| 182 |
+
if isinstance(e, ImportError):
|
| 183 |
+
logger.error(f"Import error - missing dependency: {e.__class__.__module__}.{e.__class__.__name__}")
|
| 184 |
+
logger.error("Check if all required packages are installed correctly")
|
| 185 |
+
elif isinstance(e, RuntimeError) and "CUDA" in str(e):
|
| 186 |
+
logger.error("CUDA-related runtime error - check GPU compatibility and memory")
|
| 187 |
+
elif isinstance(e, AttributeError):
|
| 188 |
+
logger.error(f"Attribute error - likely API incompatibility or incorrect module version")
|
| 189 |
+
if hasattr(e, '__traceback__'):
|
| 190 |
+
tb = e.__traceback__
|
| 191 |
+
while tb.tb_next:
|
| 192 |
+
tb = tb.tb_next
|
| 193 |
+
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
|
| 194 |
+
elif isinstance(e, FileNotFoundError):
|
| 195 |
+
logger.error(f"File not found - check if model files exist and are accessible")
|
| 196 |
+
|
| 197 |
# Return dummy path in case of error
|
| 198 |
return "temp/outputs/dummy.wav"
|