Michael Hu commited on
Commit
cb90410
·
1 Parent(s): a316f58

fix import issue

Browse files
Files changed (2) hide show
  1. utils/tts_dia.py +47 -7
  2. utils/tts_engines.py +37 -5
utils/tts_dia.py CHANGED
@@ -1,18 +1,36 @@
1
  import os
2
  import time
3
  import logging
4
- import torch
5
  import numpy as np
6
  import soundfile as sf
7
  from pathlib import Path
8
  from typing import Optional
9
 
10
- from dia.model import Dia
11
-
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Constants
17
  DEFAULT_SAMPLE_RATE = 44100
18
  DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
@@ -21,9 +39,15 @@ DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
21
  _model = None
22
 
23
 
24
- def _get_model() -> Dia:
25
  """Lazy-load the Dia model to avoid loading it until needed"""
26
  global _model
 
 
 
 
 
 
27
  if _model is None:
28
  logger.info("Loading Dia model...")
29
  try:
@@ -80,16 +104,32 @@ def generate_speech(text: str, language: str = "zh") -> str:
80
  """
81
  logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
82
 
83
- # Use the new implementation via factory pattern
84
- from utils.tts_engines import DiaTTSEngine
 
 
 
 
85
 
 
86
  try:
 
 
 
87
  # Create a Dia engine and generate speech
88
  dia_engine = DiaTTSEngine(language)
89
  return dia_engine.generate_speech(text)
 
 
 
 
 
 
 
 
90
  except Exception as e:
91
  logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
92
  # Fall back to dummy TTS
93
  from utils.tts_base import DummyTTSEngine
94
- dummy_engine = DummyTTSEngine()
95
  return dummy_engine.generate_speech(text)
 
1
  import os
2
  import time
3
  import logging
 
4
  import numpy as np
5
  import soundfile as sf
6
  from pathlib import Path
7
  from typing import Optional
8
 
 
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Flag to track Dia availability
14
+ DIA_AVAILABLE = False
15
+
16
+ # Try to import required dependencies
17
+ try:
18
+ import torch
19
+ # Try to import Dia, which will try to import dac
20
+ try:
21
+ from dia.model import Dia
22
+ DIA_AVAILABLE = True
23
+ logger.info("Dia TTS engine is available")
24
+ except ModuleNotFoundError as e:
25
+ if "dac" in str(e):
26
+ logger.warning("Dia TTS engine is not available due to missing 'dac' module")
27
+ else:
28
+ logger.warning(f"Dia TTS engine is not available: {str(e)}")
29
+ DIA_AVAILABLE = False
30
+ except ImportError:
31
+ logger.warning("Torch not available, Dia TTS engine cannot be used")
32
+ DIA_AVAILABLE = False
33
+
34
  # Constants
35
  DEFAULT_SAMPLE_RATE = 44100
36
  DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
 
39
  _model = None
40
 
41
 
42
+ def _get_model():
43
  """Lazy-load the Dia model to avoid loading it until needed"""
44
  global _model
45
+
46
+ # Check if Dia is available before attempting to load
47
+ if not DIA_AVAILABLE:
48
+ logger.warning("Dia is not available, cannot load model")
49
+ raise ImportError("Dia module is not available")
50
+
51
  if _model is None:
52
  logger.info("Loading Dia model...")
53
  try:
 
104
  """
105
  logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
106
 
107
+ # Check if Dia is available
108
+ if not DIA_AVAILABLE:
109
+ logger.warning("Dia is not available, falling back to dummy TTS engine")
110
+ from utils.tts_base import DummyTTSEngine
111
+ dummy_engine = DummyTTSEngine(language)
112
+ return dummy_engine.generate_speech(text)
113
 
114
+ # Use the new implementation via factory pattern
115
  try:
116
+ # Import here to avoid circular imports
117
+ from utils.tts_engines import DiaTTSEngine
118
+
119
  # Create a Dia engine and generate speech
120
  dia_engine = DiaTTSEngine(language)
121
  return dia_engine.generate_speech(text)
122
+ except ModuleNotFoundError as e:
123
+ logger.error(f"Module not found error in Dia generate_speech: {str(e)}")
124
+ if "dac" in str(e):
125
+ logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS")
126
+ # Fall back to dummy TTS
127
+ from utils.tts_base import DummyTTSEngine
128
+ dummy_engine = DummyTTSEngine(language)
129
+ return dummy_engine.generate_speech(text)
130
  except Exception as e:
131
  logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
132
  # Fall back to dummy TTS
133
  from utils.tts_base import DummyTTSEngine
134
+ dummy_engine = DummyTTSEngine(language)
135
  return dummy_engine.generate_speech(text)
utils/tts_engines.py CHANGED
@@ -197,7 +197,18 @@ class DiaTTSEngine(TTSEngineBase):
197
 
198
  try:
199
  # Import here to avoid circular imports
200
- from utils.tts_dia import generate_speech as dia_generate_speech
 
 
 
 
 
 
 
 
 
 
 
201
  logger.info("Successfully imported Dia speech generation function")
202
 
203
  # Call Dia's generate_speech function
@@ -211,7 +222,14 @@ class DiaTTSEngine(TTSEngineBase):
211
  # Try using Dia Space instead
212
  if DIA_SPACE_AVAILABLE:
213
  return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
 
 
 
214
  raise
 
 
 
 
215
 
216
 
217
  class DiaSpaceTTSEngine(TTSEngineBase):
@@ -292,8 +310,15 @@ class DiaSpaceTTSEngine(TTSEngineBase):
292
 
293
  try:
294
  # Import required modules
 
 
 
 
 
 
 
 
295
  import torch
296
- from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
297
 
298
  # Get the Dia model
299
  model = _get_model()
@@ -317,18 +342,25 @@ class DiaSpaceTTSEngine(TTSEngineBase):
317
  else:
318
  logger.warning("Dia model returned None for audio output")
319
  logger.warning("Falling back to dummy audio stream")
320
- yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
 
 
 
 
 
 
 
321
 
322
  except ImportError as import_err:
323
  logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
324
  logger.error("Falling back to dummy audio stream")
325
- yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
326
 
327
  except Exception as dia_error:
328
  logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
329
  logger.error(f"Error type: {type(dia_error).__name__}")
330
  logger.error("Falling back to dummy audio stream")
331
- yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
332
 
333
 
334
  def get_available_engines() -> List[str]:
 
197
 
198
  try:
199
  # Import here to avoid circular imports
200
+ from utils.tts_dia import generate_speech as dia_generate_speech, DIA_AVAILABLE
201
+
202
+ # Check if Dia is available
203
+ if not DIA_AVAILABLE:
204
+ logger.warning("Dia TTS engine is not available, falling back to Dia Space")
205
+ # Try using Dia Space instead
206
+ if DIA_SPACE_AVAILABLE:
207
+ return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
208
+ else:
209
+ logger.warning("Dia Space is also not available, falling back to dummy TTS")
210
+ return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
211
+
212
  logger.info("Successfully imported Dia speech generation function")
213
 
214
  # Call Dia's generate_speech function
 
222
  # Try using Dia Space instead
223
  if DIA_SPACE_AVAILABLE:
224
  return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
225
+ else:
226
+ logger.warning("Dia Space is also not available, falling back to dummy TTS")
227
+ return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
228
  raise
229
+ except Exception as e:
230
+ logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
231
+ logger.warning("Falling back to dummy TTS engine")
232
+ return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
233
 
234
 
235
  class DiaSpaceTTSEngine(TTSEngineBase):
 
310
 
311
  try:
312
  # Import required modules
313
+ from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE, DIA_AVAILABLE
314
+
315
+ # Check if Dia is available
316
+ if not DIA_AVAILABLE:
317
+ logger.warning("Dia TTS engine is not available, falling back to dummy audio stream")
318
+ yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
319
+ return
320
+
321
  import torch
 
322
 
323
  # Get the Dia model
324
  model = _get_model()
 
342
  else:
343
  logger.warning("Dia model returned None for audio output")
344
  logger.warning("Falling back to dummy audio stream")
345
+ yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
346
+
347
+ except ModuleNotFoundError as e:
348
+ if "dac" in str(e):
349
+ logger.warning("Dia TTS streaming failed due to missing 'dac' module, falling back to dummy audio stream")
350
+ else:
351
+ logger.error(f"Module not found error in Dia TTS streaming: {str(e)}")
352
+ yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
353
 
354
  except ImportError as import_err:
355
  logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
356
  logger.error("Falling back to dummy audio stream")
357
+ yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
358
 
359
  except Exception as dia_error:
360
  logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
361
  logger.error(f"Error type: {type(dia_error).__name__}")
362
  logger.error("Falling back to dummy audio stream")
363
+ yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
364
 
365
 
366
  def get_available_engines() -> List[str]: