Spaces:
Running
Running
Michael Hu
commited on
Commit
·
cb90410
1
Parent(s):
a316f58
fix import issue
Browse files- utils/tts_dia.py +47 -7
- utils/tts_engines.py +37 -5
utils/tts_dia.py
CHANGED
@@ -1,18 +1,36 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
import logging
|
4 |
-
import torch
|
5 |
import numpy as np
|
6 |
import soundfile as sf
|
7 |
from pathlib import Path
|
8 |
from typing import Optional
|
9 |
|
10 |
-
from dia.model import Dia
|
11 |
-
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# Constants
|
17 |
DEFAULT_SAMPLE_RATE = 44100
|
18 |
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
|
@@ -21,9 +39,15 @@ DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
|
|
21 |
_model = None
|
22 |
|
23 |
|
24 |
-
def _get_model()
|
25 |
"""Lazy-load the Dia model to avoid loading it until needed"""
|
26 |
global _model
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
if _model is None:
|
28 |
logger.info("Loading Dia model...")
|
29 |
try:
|
@@ -80,16 +104,32 @@ def generate_speech(text: str, language: str = "zh") -> str:
|
|
80 |
"""
|
81 |
logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
|
82 |
|
83 |
-
#
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
|
|
86 |
try:
|
|
|
|
|
|
|
87 |
# Create a Dia engine and generate speech
|
88 |
dia_engine = DiaTTSEngine(language)
|
89 |
return dia_engine.generate_speech(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
except Exception as e:
|
91 |
logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
|
92 |
# Fall back to dummy TTS
|
93 |
from utils.tts_base import DummyTTSEngine
|
94 |
-
dummy_engine = DummyTTSEngine()
|
95 |
return dummy_engine.generate_speech(text)
|
|
|
1 |
import os
|
2 |
import time
|
3 |
import logging
|
|
|
4 |
import numpy as np
|
5 |
import soundfile as sf
|
6 |
from pathlib import Path
|
7 |
from typing import Optional
|
8 |
|
|
|
|
|
9 |
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
+
# Flag to track Dia availability
|
14 |
+
DIA_AVAILABLE = False
|
15 |
+
|
16 |
+
# Try to import required dependencies
|
17 |
+
try:
|
18 |
+
import torch
|
19 |
+
# Try to import Dia, which will try to import dac
|
20 |
+
try:
|
21 |
+
from dia.model import Dia
|
22 |
+
DIA_AVAILABLE = True
|
23 |
+
logger.info("Dia TTS engine is available")
|
24 |
+
except ModuleNotFoundError as e:
|
25 |
+
if "dac" in str(e):
|
26 |
+
logger.warning("Dia TTS engine is not available due to missing 'dac' module")
|
27 |
+
else:
|
28 |
+
logger.warning(f"Dia TTS engine is not available: {str(e)}")
|
29 |
+
DIA_AVAILABLE = False
|
30 |
+
except ImportError:
|
31 |
+
logger.warning("Torch not available, Dia TTS engine cannot be used")
|
32 |
+
DIA_AVAILABLE = False
|
33 |
+
|
34 |
# Constants
|
35 |
DEFAULT_SAMPLE_RATE = 44100
|
36 |
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
|
|
|
39 |
_model = None
|
40 |
|
41 |
|
42 |
+
def _get_model():
|
43 |
"""Lazy-load the Dia model to avoid loading it until needed"""
|
44 |
global _model
|
45 |
+
|
46 |
+
# Check if Dia is available before attempting to load
|
47 |
+
if not DIA_AVAILABLE:
|
48 |
+
logger.warning("Dia is not available, cannot load model")
|
49 |
+
raise ImportError("Dia module is not available")
|
50 |
+
|
51 |
if _model is None:
|
52 |
logger.info("Loading Dia model...")
|
53 |
try:
|
|
|
104 |
"""
|
105 |
logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
|
106 |
|
107 |
+
# Check if Dia is available
|
108 |
+
if not DIA_AVAILABLE:
|
109 |
+
logger.warning("Dia is not available, falling back to dummy TTS engine")
|
110 |
+
from utils.tts_base import DummyTTSEngine
|
111 |
+
dummy_engine = DummyTTSEngine(language)
|
112 |
+
return dummy_engine.generate_speech(text)
|
113 |
|
114 |
+
# Use the new implementation via factory pattern
|
115 |
try:
|
116 |
+
# Import here to avoid circular imports
|
117 |
+
from utils.tts_engines import DiaTTSEngine
|
118 |
+
|
119 |
# Create a Dia engine and generate speech
|
120 |
dia_engine = DiaTTSEngine(language)
|
121 |
return dia_engine.generate_speech(text)
|
122 |
+
except ModuleNotFoundError as e:
|
123 |
+
logger.error(f"Module not found error in Dia generate_speech: {str(e)}")
|
124 |
+
if "dac" in str(e):
|
125 |
+
logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS")
|
126 |
+
# Fall back to dummy TTS
|
127 |
+
from utils.tts_base import DummyTTSEngine
|
128 |
+
dummy_engine = DummyTTSEngine(language)
|
129 |
+
return dummy_engine.generate_speech(text)
|
130 |
except Exception as e:
|
131 |
logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
|
132 |
# Fall back to dummy TTS
|
133 |
from utils.tts_base import DummyTTSEngine
|
134 |
+
dummy_engine = DummyTTSEngine(language)
|
135 |
return dummy_engine.generate_speech(text)
|
utils/tts_engines.py
CHANGED
@@ -197,7 +197,18 @@ class DiaTTSEngine(TTSEngineBase):
|
|
197 |
|
198 |
try:
|
199 |
# Import here to avoid circular imports
|
200 |
-
from utils.tts_dia import generate_speech as dia_generate_speech
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
logger.info("Successfully imported Dia speech generation function")
|
202 |
|
203 |
# Call Dia's generate_speech function
|
@@ -211,7 +222,14 @@ class DiaTTSEngine(TTSEngineBase):
|
|
211 |
# Try using Dia Space instead
|
212 |
if DIA_SPACE_AVAILABLE:
|
213 |
return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
|
|
|
|
|
|
|
214 |
raise
|
|
|
|
|
|
|
|
|
215 |
|
216 |
|
217 |
class DiaSpaceTTSEngine(TTSEngineBase):
|
@@ -292,8 +310,15 @@ class DiaSpaceTTSEngine(TTSEngineBase):
|
|
292 |
|
293 |
try:
|
294 |
# Import required modules
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
import torch
|
296 |
-
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
|
297 |
|
298 |
# Get the Dia model
|
299 |
model = _get_model()
|
@@ -317,18 +342,25 @@ class DiaSpaceTTSEngine(TTSEngineBase):
|
|
317 |
else:
|
318 |
logger.warning("Dia model returned None for audio output")
|
319 |
logger.warning("Falling back to dummy audio stream")
|
320 |
-
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
except ImportError as import_err:
|
323 |
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
|
324 |
logger.error("Falling back to dummy audio stream")
|
325 |
-
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
326 |
|
327 |
except Exception as dia_error:
|
328 |
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
329 |
logger.error(f"Error type: {type(dia_error).__name__}")
|
330 |
logger.error("Falling back to dummy audio stream")
|
331 |
-
yield from DummyTTSEngine().generate_speech_stream(text, voice, speed)
|
332 |
|
333 |
|
334 |
def get_available_engines() -> List[str]:
|
|
|
197 |
|
198 |
try:
|
199 |
# Import here to avoid circular imports
|
200 |
+
from utils.tts_dia import generate_speech as dia_generate_speech, DIA_AVAILABLE
|
201 |
+
|
202 |
+
# Check if Dia is available
|
203 |
+
if not DIA_AVAILABLE:
|
204 |
+
logger.warning("Dia TTS engine is not available, falling back to Dia Space")
|
205 |
+
# Try using Dia Space instead
|
206 |
+
if DIA_SPACE_AVAILABLE:
|
207 |
+
return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
|
208 |
+
else:
|
209 |
+
logger.warning("Dia Space is also not available, falling back to dummy TTS")
|
210 |
+
return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
|
211 |
+
|
212 |
logger.info("Successfully imported Dia speech generation function")
|
213 |
|
214 |
# Call Dia's generate_speech function
|
|
|
222 |
# Try using Dia Space instead
|
223 |
if DIA_SPACE_AVAILABLE:
|
224 |
return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
|
225 |
+
else:
|
226 |
+
logger.warning("Dia Space is also not available, falling back to dummy TTS")
|
227 |
+
return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
|
228 |
raise
|
229 |
+
except Exception as e:
|
230 |
+
logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
|
231 |
+
logger.warning("Falling back to dummy TTS engine")
|
232 |
+
return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
|
233 |
|
234 |
|
235 |
class DiaSpaceTTSEngine(TTSEngineBase):
|
|
|
310 |
|
311 |
try:
|
312 |
# Import required modules
|
313 |
+
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE, DIA_AVAILABLE
|
314 |
+
|
315 |
+
# Check if Dia is available
|
316 |
+
if not DIA_AVAILABLE:
|
317 |
+
logger.warning("Dia TTS engine is not available, falling back to dummy audio stream")
|
318 |
+
yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
|
319 |
+
return
|
320 |
+
|
321 |
import torch
|
|
|
322 |
|
323 |
# Get the Dia model
|
324 |
model = _get_model()
|
|
|
342 |
else:
|
343 |
logger.warning("Dia model returned None for audio output")
|
344 |
logger.warning("Falling back to dummy audio stream")
|
345 |
+
yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
|
346 |
+
|
347 |
+
except ModuleNotFoundError as e:
|
348 |
+
if "dac" in str(e):
|
349 |
+
logger.warning("Dia TTS streaming failed due to missing 'dac' module, falling back to dummy audio stream")
|
350 |
+
else:
|
351 |
+
logger.error(f"Module not found error in Dia TTS streaming: {str(e)}")
|
352 |
+
yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
|
353 |
|
354 |
except ImportError as import_err:
|
355 |
logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}")
|
356 |
logger.error("Falling back to dummy audio stream")
|
357 |
+
yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
|
358 |
|
359 |
except Exception as dia_error:
|
360 |
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
|
361 |
logger.error(f"Error type: {type(dia_error).__name__}")
|
362 |
logger.error("Falling back to dummy audio stream")
|
363 |
+
yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
|
364 |
|
365 |
|
366 |
def get_available_engines() -> List[str]:
|