Spaces:
Running
Running
Michael Hu
commited on
Commit
·
91223c9
1
Parent(s):
419e343
enhance logging
Browse files- utils/tts.py +141 -18
- utils/tts_dia.py +119 -24
utils/tts.py
CHANGED
@@ -28,12 +28,43 @@ except ImportError:
|
|
28 |
# Try to import Dia as fallback
|
29 |
if not KOKORO_AVAILABLE:
|
30 |
try:
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
except ImportError as e:
|
35 |
-
logger.
|
|
|
36 |
logger.warning("Will use dummy TTS implementation as fallback")
|
|
|
|
|
|
|
|
|
37 |
|
38 |
class TTSEngine:
|
39 |
def __init__(self, lang_code='z'):
|
@@ -45,20 +76,34 @@ class TTSEngine:
|
|
45 |
Note: lang_code is only used for Kokoro, not for Dia
|
46 |
"""
|
47 |
logger.info("Initializing TTS Engine")
|
|
|
48 |
self.engine_type = None
|
49 |
|
50 |
if KOKORO_AVAILABLE:
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
# For Dia, we don't need to initialize anything here
|
56 |
# The model will be lazy-loaded when needed
|
57 |
self.pipeline = None
|
58 |
self.engine_type = "dia"
|
59 |
logger.info("TTS engine initialized with Dia (lazy loading)")
|
60 |
-
|
|
|
|
|
61 |
logger.warning("Using dummy TTS implementation as no TTS engines are available")
|
|
|
62 |
self.pipeline = None
|
63 |
self.engine_type = "dummy"
|
64 |
|
@@ -95,13 +140,29 @@ class TTSEngine:
|
|
95 |
elif self.engine_type == "dia":
|
96 |
# Use Dia for TTS generation
|
97 |
try:
|
|
|
98 |
# Import here to avoid circular imports
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
# Call Dia's generate_speech function
|
|
|
101 |
output_path = dia_generate_speech(text)
|
102 |
logger.info(f"Generated audio with Dia: {output_path}")
|
|
|
|
|
|
|
|
|
103 |
except Exception as dia_error:
|
104 |
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
|
|
|
|
|
105 |
# Fall back to dummy audio if Dia fails
|
106 |
return self._generate_dummy_audio(output_path)
|
107 |
else:
|
@@ -157,14 +218,36 @@ class TTSEngine:
|
|
157 |
# Dia doesn't support streaming natively, so we generate the full audio
|
158 |
# and then yield it as a single chunk
|
159 |
try:
|
|
|
160 |
# Import here to avoid circular imports
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
# Get the Dia model
|
165 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
# Generate audio
|
|
|
168 |
with torch.inference_mode():
|
169 |
output_audio_np = model.generate(
|
170 |
text,
|
@@ -178,12 +261,22 @@ class TTSEngine:
|
|
178 |
)
|
179 |
|
180 |
if output_audio_np is not None:
|
|
|
181 |
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
182 |
else:
|
|
|
|
|
183 |
# Fall back to dummy audio if Dia fails
|
184 |
yield from self._generate_dummy_audio_stream()
|
|
|
|
|
|
|
|
|
|
|
185 |
except Exception as dia_error:
|
186 |
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
|
|
|
|
187 |
# Fall back to dummy audio if Dia fails
|
188 |
yield from self._generate_dummy_audio_stream()
|
189 |
else:
|
@@ -221,14 +314,25 @@ def get_tts_engine(lang_code='a'):
|
|
221 |
Returns:
|
222 |
TTSEngine: Initialized TTS engine instance
|
223 |
"""
|
|
|
224 |
try:
|
225 |
import streamlit as st
|
|
|
226 |
@st.cache_resource
|
227 |
def _get_engine():
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
except ImportError:
|
231 |
-
|
|
|
|
|
|
|
232 |
|
233 |
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
234 |
"""Public interface for TTS generation
|
@@ -241,5 +345,24 @@ def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> s
|
|
241 |
Returns:
|
242 |
str: Path to generated audio file
|
243 |
"""
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# Try to import Dia as fallback
|
29 |
if not KOKORO_AVAILABLE:
|
30 |
try:
|
31 |
+
logger.info("Attempting to import Dia TTS engine as fallback")
|
32 |
+
try:
|
33 |
+
# Check if required dependencies for Dia are available
|
34 |
+
import torch
|
35 |
+
logger.info("PyTorch is available for Dia TTS")
|
36 |
+
except ImportError as torch_err:
|
37 |
+
logger.error(f"PyTorch dependency for Dia TTS is missing: {str(torch_err)}")
|
38 |
+
raise ImportError(f"PyTorch dependency required for Dia TTS: {str(torch_err)}") from torch_err
|
39 |
+
|
40 |
+
# Try to import the Dia module
|
41 |
+
try:
|
42 |
+
from utils.tts_dia import _get_model as get_dia_model
|
43 |
+
logger.info("Successfully imported Dia TTS module")
|
44 |
+
|
45 |
+
# Verify the model can be accessed
|
46 |
+
logger.info("Verifying Dia model can be accessed")
|
47 |
+
model_info = get_dia_model.__module__
|
48 |
+
logger.info(f"Dia model module: {model_info}")
|
49 |
+
|
50 |
+
DIA_AVAILABLE = True
|
51 |
+
logger.info("Dia TTS engine is available as fallback")
|
52 |
+
except ImportError as module_err:
|
53 |
+
logger.error(f"Failed to import Dia TTS module: {str(module_err)}")
|
54 |
+
logger.error(f"Module path: {module_err.__traceback__.tb_frame.f_globals.get('__file__', 'unknown')}")
|
55 |
+
raise
|
56 |
+
except AttributeError as attr_err:
|
57 |
+
logger.error(f"Dia TTS module attribute error: {str(attr_err)}")
|
58 |
+
logger.error(f"This may indicate the module exists but has incorrect structure")
|
59 |
+
raise
|
60 |
except ImportError as e:
|
61 |
+
logger.error(f"Dia TTS engine is not available due to import error: {str(e)}")
|
62 |
+
logger.error(f"Import path attempted: {e.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
63 |
logger.warning("Will use dummy TTS implementation as fallback")
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"Unexpected error initializing Dia TTS: {str(e)}")
|
66 |
+
logger.error(f"Error type: {type(e).__name__}")
|
67 |
+
logger.error("Will use dummy TTS implementation as fallback")
|
68 |
|
69 |
class TTSEngine:
|
70 |
def __init__(self, lang_code='z'):
|
|
|
76 |
Note: lang_code is only used for Kokoro, not for Dia
|
77 |
"""
|
78 |
logger.info("Initializing TTS Engine")
|
79 |
+
logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
|
80 |
self.engine_type = None
|
81 |
|
82 |
if KOKORO_AVAILABLE:
|
83 |
+
logger.info(f"Using Kokoro as primary TTS engine with language code: {lang_code}")
|
84 |
+
try:
|
85 |
+
self.pipeline = KPipeline(lang_code=lang_code)
|
86 |
+
self.engine_type = "kokoro"
|
87 |
+
logger.info("TTS engine successfully initialized with Kokoro")
|
88 |
+
except Exception as kokoro_err:
|
89 |
+
logger.error(f"Failed to initialize Kokoro pipeline: {str(kokoro_err)}")
|
90 |
+
logger.error(f"Error type: {type(kokoro_err).__name__}")
|
91 |
+
logger.info("Will try to fall back to Dia TTS engine")
|
92 |
+
# Fall through to try Dia
|
93 |
+
|
94 |
+
# Try Dia if Kokoro is not available or failed to initialize
|
95 |
+
if self.engine_type is None and DIA_AVAILABLE:
|
96 |
+
logger.info("Using Dia as fallback TTS engine")
|
97 |
# For Dia, we don't need to initialize anything here
|
98 |
# The model will be lazy-loaded when needed
|
99 |
self.pipeline = None
|
100 |
self.engine_type = "dia"
|
101 |
logger.info("TTS engine initialized with Dia (lazy loading)")
|
102 |
+
|
103 |
+
# Use dummy if no TTS engines are available
|
104 |
+
if self.engine_type is None:
|
105 |
logger.warning("Using dummy TTS implementation as no TTS engines are available")
|
106 |
+
logger.warning("Check logs above for specific errors that prevented Kokoro or Dia initialization")
|
107 |
self.pipeline = None
|
108 |
self.engine_type = "dummy"
|
109 |
|
|
|
140 |
elif self.engine_type == "dia":
|
141 |
# Use Dia for TTS generation
|
142 |
try:
|
143 |
+
logger.info("Attempting to use Dia TTS for speech generation")
|
144 |
# Import here to avoid circular imports
|
145 |
+
try:
|
146 |
+
logger.info("Importing Dia speech generation module")
|
147 |
+
from utils.tts_dia import generate_speech as dia_generate_speech
|
148 |
+
logger.info("Successfully imported Dia speech generation function")
|
149 |
+
except ImportError as import_err:
|
150 |
+
logger.error(f"Failed to import Dia speech generation function: {str(import_err)}")
|
151 |
+
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
152 |
+
raise
|
153 |
+
|
154 |
# Call Dia's generate_speech function
|
155 |
+
logger.info("Calling Dia's generate_speech function")
|
156 |
output_path = dia_generate_speech(text)
|
157 |
logger.info(f"Generated audio with Dia: {output_path}")
|
158 |
+
except ImportError as import_err:
|
159 |
+
logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
|
160 |
+
logger.error("Falling back to dummy audio generation")
|
161 |
+
return self._generate_dummy_audio(output_path)
|
162 |
except Exception as dia_error:
|
163 |
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
|
164 |
+
logger.error(f"Error type: {type(dia_error).__name__}")
|
165 |
+
logger.error("Falling back to dummy audio generation")
|
166 |
# Fall back to dummy audio if Dia fails
|
167 |
return self._generate_dummy_audio(output_path)
|
168 |
else:
|
|
|
218 |
# Dia doesn't support streaming natively, so we generate the full audio
|
219 |
# and then yield it as a single chunk
|
220 |
try:
|
221 |
+
logger.info("Attempting to use Dia TTS for speech streaming")
|
222 |
# Import here to avoid circular imports
|
223 |
+
try:
|
224 |
+
logger.info("Importing required modules for Dia streaming")
|
225 |
+
import torch
|
226 |
+
logger.info("PyTorch successfully imported for Dia streaming")
|
227 |
+
|
228 |
+
try:
|
229 |
+
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
|
230 |
+
logger.info("Successfully imported Dia model and sample rate")
|
231 |
+
except ImportError as import_err:
|
232 |
+
logger.error(f"Failed to import Dia model for streaming: {str(import_err)}")
|
233 |
+
logger.error(f"Import path: {import_err.__traceback__.tb_frame.f_globals.get('__name__', 'unknown')}")
|
234 |
+
raise
|
235 |
+
except ImportError as torch_err:
|
236 |
+
logger.error(f"PyTorch import failed for Dia streaming: {str(torch_err)}")
|
237 |
+
raise
|
238 |
|
239 |
# Get the Dia model
|
240 |
+
logger.info("Getting Dia model instance")
|
241 |
+
try:
|
242 |
+
model = _get_model()
|
243 |
+
logger.info("Successfully obtained Dia model instance")
|
244 |
+
except Exception as model_err:
|
245 |
+
logger.error(f"Failed to get Dia model instance: {str(model_err)}")
|
246 |
+
logger.error(f"Error type: {type(model_err).__name__}")
|
247 |
+
raise
|
248 |
|
249 |
# Generate audio
|
250 |
+
logger.info("Generating audio with Dia model")
|
251 |
with torch.inference_mode():
|
252 |
output_audio_np = model.generate(
|
253 |
text,
|
|
|
261 |
)
|
262 |
|
263 |
if output_audio_np is not None:
|
264 |
+
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
265 |
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
266 |
else:
|
267 |
+
logger.warning("Dia model returned None for audio output")
|
268 |
+
logger.warning("Falling back to dummy audio stream")
|
269 |
# Fall back to dummy audio if Dia fails
|
270 |
yield from self._generate_dummy_audio_stream()
|
271 |
+
except ImportError as import_err:
|
272 |
+
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
|
273 |
+
logger.error("Falling back to dummy audio stream")
|
274 |
+
# Fall back to dummy audio if Dia fails
|
275 |
+
yield from self._generate_dummy_audio_stream()
|
276 |
except Exception as dia_error:
|
277 |
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
278 |
+
logger.error(f"Error type: {type(dia_error).__name__}")
|
279 |
+
logger.error("Falling back to dummy audio stream")
|
280 |
# Fall back to dummy audio if Dia fails
|
281 |
yield from self._generate_dummy_audio_stream()
|
282 |
else:
|
|
|
314 |
Returns:
|
315 |
TTSEngine: Initialized TTS engine instance
|
316 |
"""
|
317 |
+
logger.info(f"Requesting TTS engine with language code: {lang_code}")
|
318 |
try:
|
319 |
import streamlit as st
|
320 |
+
logger.info("Streamlit detected, using cached TTS engine")
|
321 |
@st.cache_resource
|
322 |
def _get_engine():
|
323 |
+
logger.info("Creating cached TTS engine instance")
|
324 |
+
engine = TTSEngine(lang_code)
|
325 |
+
logger.info(f"Cached TTS engine created with type: {engine.engine_type}")
|
326 |
+
return engine
|
327 |
+
|
328 |
+
engine = _get_engine()
|
329 |
+
logger.info(f"Retrieved TTS engine from cache with type: {engine.engine_type}")
|
330 |
+
return engine
|
331 |
except ImportError:
|
332 |
+
logger.info("Streamlit not available, creating direct TTS engine instance")
|
333 |
+
engine = TTSEngine(lang_code)
|
334 |
+
logger.info(f"Direct TTS engine created with type: {engine.engine_type}")
|
335 |
+
return engine
|
336 |
|
337 |
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
|
338 |
"""Public interface for TTS generation
|
|
|
345 |
Returns:
|
346 |
str: Path to generated audio file
|
347 |
"""
|
348 |
+
logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}")
|
349 |
+
try:
|
350 |
+
# Get the TTS engine
|
351 |
+
logger.info("Getting TTS engine instance")
|
352 |
+
engine = get_tts_engine()
|
353 |
+
logger.info(f"Using TTS engine type: {engine.engine_type}")
|
354 |
+
|
355 |
+
# Generate speech
|
356 |
+
logger.info("Calling engine.generate_speech")
|
357 |
+
output_path = engine.generate_speech(text, voice, speed)
|
358 |
+
logger.info(f"Speech generation complete, output path: {output_path}")
|
359 |
+
return output_path
|
360 |
+
except Exception as e:
|
361 |
+
logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True)
|
362 |
+
logger.error(f"Error type: {type(e).__name__}")
|
363 |
+
if hasattr(e, '__traceback__'):
|
364 |
+
tb = e.__traceback__
|
365 |
+
while tb.tb_next:
|
366 |
+
tb = tb.tb_next
|
367 |
+
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
|
368 |
+
raise
|
utils/tts_dia.py
CHANGED
@@ -27,10 +27,36 @@ def _get_model() -> Dia:
|
|
27 |
if _model is None:
|
28 |
logger.info("Loading Dia model...")
|
29 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
_model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
except Exception as e:
|
33 |
logger.error(f"Error loading Dia model: {e}", exc_info=True)
|
|
|
|
|
34 |
raise
|
35 |
return _model
|
36 |
|
@@ -46,58 +72,127 @@ def generate_speech(text: str, language: str = "zh") -> str:
|
|
46 |
str: Path to the generated audio file
|
47 |
"""
|
48 |
logger.info(f"Generating speech for text length: {len(text)}")
|
|
|
49 |
|
50 |
try:
|
51 |
# Create output directory if it doesn't exist
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# Generate unique output path
|
55 |
-
|
|
|
|
|
56 |
|
57 |
# Get the model
|
58 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# Generate audio
|
|
|
61 |
start_time = time.time()
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
end_time = time.time()
|
76 |
-
|
|
|
77 |
|
78 |
# Process the output
|
79 |
if output_audio_np is not None:
|
|
|
|
|
|
|
80 |
# Apply a slight slowdown for better quality (0.94x speed)
|
81 |
speed_factor = 0.94
|
82 |
original_len = len(output_audio_np)
|
83 |
target_len = int(original_len / speed_factor)
|
84 |
|
|
|
85 |
if target_len != original_len and target_len > 0:
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# Save the audio file
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
return output_path
|
96 |
else:
|
97 |
-
logger.warning("Generation produced no output
|
98 |
-
|
|
|
|
|
|
|
99 |
|
100 |
except Exception as e:
|
101 |
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
# Return dummy path in case of error
|
103 |
return "temp/outputs/dummy.wav"
|
|
|
27 |
if _model is None:
|
28 |
logger.info("Loading Dia model...")
|
29 |
try:
|
30 |
+
# Check if torch is available with correct version
|
31 |
+
logger.info(f"PyTorch version: {torch.__version__}")
|
32 |
+
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
33 |
+
if torch.cuda.is_available():
|
34 |
+
logger.info(f"CUDA version: {torch.version.cuda}")
|
35 |
+
logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
|
36 |
+
|
37 |
+
# Check if model path exists
|
38 |
+
logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
|
39 |
+
|
40 |
+
# Load the model with detailed logging
|
41 |
+
logger.info("Initializing Dia model...")
|
42 |
_model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
|
43 |
+
|
44 |
+
# Log model details
|
45 |
+
logger.info(f"Dia model loaded successfully")
|
46 |
+
logger.info(f"Model type: {type(_model).__name__}")
|
47 |
+
logger.info(f"Model device: {next(_model.parameters()).device}")
|
48 |
+
except ImportError as import_err:
|
49 |
+
logger.error(f"Import error loading Dia model: {import_err}")
|
50 |
+
logger.error(f"This may indicate missing dependencies")
|
51 |
+
raise
|
52 |
+
except FileNotFoundError as file_err:
|
53 |
+
logger.error(f"File not found error loading Dia model: {file_err}")
|
54 |
+
logger.error(f"Model path may be incorrect or inaccessible")
|
55 |
+
raise
|
56 |
except Exception as e:
|
57 |
logger.error(f"Error loading Dia model: {e}", exc_info=True)
|
58 |
+
logger.error(f"Error type: {type(e).__name__}")
|
59 |
+
logger.error(f"This may indicate incompatible versions or missing CUDA support")
|
60 |
raise
|
61 |
return _model
|
62 |
|
|
|
72 |
str: Path to the generated audio file
|
73 |
"""
|
74 |
logger.info(f"Generating speech for text length: {len(text)}")
|
75 |
+
logger.info(f"Text content (first 50 chars): {text[:50]}...")
|
76 |
|
77 |
try:
|
78 |
# Create output directory if it doesn't exist
|
79 |
+
output_dir = "temp/outputs"
|
80 |
+
logger.info(f"Ensuring output directory exists: {output_dir}")
|
81 |
+
try:
|
82 |
+
os.makedirs(output_dir, exist_ok=True)
|
83 |
+
logger.info(f"Output directory ready: {output_dir}")
|
84 |
+
except PermissionError as perm_err:
|
85 |
+
logger.error(f"Permission error creating output directory: {perm_err}")
|
86 |
+
raise
|
87 |
+
except Exception as dir_err:
|
88 |
+
logger.error(f"Error creating output directory: {dir_err}")
|
89 |
+
raise
|
90 |
|
91 |
# Generate unique output path
|
92 |
+
timestamp = int(time.time())
|
93 |
+
output_path = f"{output_dir}/output_{timestamp}.wav"
|
94 |
+
logger.info(f"Output will be saved to: {output_path}")
|
95 |
|
96 |
# Get the model
|
97 |
+
logger.info("Retrieving Dia model instance")
|
98 |
+
try:
|
99 |
+
model = _get_model()
|
100 |
+
logger.info("Successfully retrieved Dia model instance")
|
101 |
+
except Exception as model_err:
|
102 |
+
logger.error(f"Failed to get Dia model: {model_err}")
|
103 |
+
logger.error(f"Error type: {type(model_err).__name__}")
|
104 |
+
raise
|
105 |
|
106 |
# Generate audio
|
107 |
+
logger.info("Starting audio generation with Dia model")
|
108 |
start_time = time.time()
|
109 |
|
110 |
+
try:
|
111 |
+
with torch.inference_mode():
|
112 |
+
logger.info("Calling model.generate() with inference_mode")
|
113 |
+
output_audio_np = model.generate(
|
114 |
+
text,
|
115 |
+
max_tokens=None, # Use default from model config
|
116 |
+
cfg_scale=3.0,
|
117 |
+
temperature=1.3,
|
118 |
+
top_p=0.95,
|
119 |
+
cfg_filter_top_k=35,
|
120 |
+
use_torch_compile=False, # Keep False for stability
|
121 |
+
verbose=False
|
122 |
+
)
|
123 |
+
logger.info("Model.generate() completed")
|
124 |
+
except RuntimeError as rt_err:
|
125 |
+
logger.error(f"Runtime error during generation: {rt_err}")
|
126 |
+
if "CUDA out of memory" in str(rt_err):
|
127 |
+
logger.error("CUDA out of memory error - consider reducing batch size or model size")
|
128 |
+
raise
|
129 |
+
except Exception as gen_err:
|
130 |
+
logger.error(f"Error during audio generation: {gen_err}")
|
131 |
+
logger.error(f"Error type: {type(gen_err).__name__}")
|
132 |
+
raise
|
133 |
|
134 |
end_time = time.time()
|
135 |
+
generation_time = end_time - start_time
|
136 |
+
logger.info(f"Generation finished in {generation_time:.2f} seconds")
|
137 |
|
138 |
# Process the output
|
139 |
if output_audio_np is not None:
|
140 |
+
logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
|
141 |
+
logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
|
142 |
+
|
143 |
# Apply a slight slowdown for better quality (0.94x speed)
|
144 |
speed_factor = 0.94
|
145 |
original_len = len(output_audio_np)
|
146 |
target_len = int(original_len / speed_factor)
|
147 |
|
148 |
+
logger.info(f"Applying speed adjustment factor: {speed_factor}")
|
149 |
if target_len != original_len and target_len > 0:
|
150 |
+
try:
|
151 |
+
x_original = np.arange(original_len)
|
152 |
+
x_resampled = np.linspace(0, original_len - 1, target_len)
|
153 |
+
output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
|
154 |
+
logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
|
155 |
+
except Exception as resample_err:
|
156 |
+
logger.error(f"Error during audio resampling: {resample_err}")
|
157 |
+
logger.warning("Using original audio without resampling")
|
158 |
|
159 |
# Save the audio file
|
160 |
+
logger.info(f"Saving audio to file: {output_path}")
|
161 |
+
try:
|
162 |
+
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
|
163 |
+
logger.info(f"Audio successfully saved to {output_path}")
|
164 |
+
except Exception as save_err:
|
165 |
+
logger.error(f"Error saving audio file: {save_err}")
|
166 |
+
logger.error(f"Error type: {type(save_err).__name__}")
|
167 |
+
raise
|
168 |
|
169 |
return output_path
|
170 |
else:
|
171 |
+
logger.warning("Generation produced no output (None returned from model)")
|
172 |
+
logger.warning("This may indicate a model configuration issue or empty input text")
|
173 |
+
dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
|
174 |
+
logger.warning(f"Returning dummy audio path: {dummy_path}")
|
175 |
+
return dummy_path
|
176 |
|
177 |
except Exception as e:
|
178 |
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
|
179 |
+
logger.error(f"Error type: {type(e).__name__}")
|
180 |
+
|
181 |
+
# Log additional diagnostic information based on error type
|
182 |
+
if isinstance(e, ImportError):
|
183 |
+
logger.error(f"Import error - missing dependency: {e.__class__.__module__}.{e.__class__.__name__}")
|
184 |
+
logger.error("Check if all required packages are installed correctly")
|
185 |
+
elif isinstance(e, RuntimeError) and "CUDA" in str(e):
|
186 |
+
logger.error("CUDA-related runtime error - check GPU compatibility and memory")
|
187 |
+
elif isinstance(e, AttributeError):
|
188 |
+
logger.error(f"Attribute error - likely API incompatibility or incorrect module version")
|
189 |
+
if hasattr(e, '__traceback__'):
|
190 |
+
tb = e.__traceback__
|
191 |
+
while tb.tb_next:
|
192 |
+
tb = tb.tb_next
|
193 |
+
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
|
194 |
+
elif isinstance(e, FileNotFoundError):
|
195 |
+
logger.error(f"File not found - check if model files exist and are accessible")
|
196 |
+
|
197 |
# Return dummy path in case of error
|
198 |
return "temp/outputs/dummy.wav"
|