Michael Hu commited on
Commit
a316f58
·
1 Parent(s): 60bd17d

handle dia model not available

Browse files
Files changed (2) hide show
  1. utils/tts_dia_space.py +154 -0
  2. utils/tts_engines.py +67 -0
utils/tts_dia_space.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import requests
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from typing import Optional, Tuple, Generator
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Constants
14
+ DEFAULT_SAMPLE_RATE = 44100
15
+ DEFAULT_API_URL = "https://droolingpanda-dia-tts-server.hf.space"
16
+ DEFAULT_MODEL = "dia-1.6b"
17
+
18
+ # Global client instance (lazy loaded)
19
+ _client = None
20
+
21
+
22
+ def _get_client():
23
+ """Lazy-load the Dia Space client to avoid loading it until needed"""
24
+ global _client
25
+ if _client is None:
26
+ logger.info("Loading Dia Space client...")
27
+ try:
28
+ # Import requests if not already imported
29
+ import requests
30
+
31
+ # Initialize the client (just a session for now)
32
+ logger.info("Initializing Dia Space client")
33
+ _client = requests.Session()
34
+
35
+ # Test connection to the API
36
+ response = _client.get(f"{DEFAULT_API_URL}/docs")
37
+ if response.status_code == 200:
38
+ logger.info("Dia Space client loaded successfully")
39
+ logger.info(f"Client type: {type(_client).__name__}")
40
+ else:
41
+ logger.warning(f"Dia Space API returned status code {response.status_code}")
42
+ except ImportError as import_err:
43
+ logger.error(f"Import error loading Dia Space client: {import_err}")
44
+ logger.error("This may indicate missing dependencies")
45
+ raise
46
+ except Exception as e:
47
+ logger.error(f"Error loading Dia Space client: {e}", exc_info=True)
48
+ logger.error(f"Error type: {type(e).__name__}")
49
+ raise
50
+ return _client
51
+
52
+
53
+ def generate_speech(text: str, language: str = "zh", voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> str:
54
+ """Public interface for TTS generation using Dia Space API
55
+
56
+ This is a legacy function maintained for backward compatibility.
57
+ New code should use the factory pattern implementation directly.
58
+
59
+ Args:
60
+ text (str): Input text to synthesize
61
+ language (str): Language code (not used in Dia Space, kept for API compatibility)
62
+ voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
63
+ response_format (str): Audio format ('wav', 'mp3', 'opus')
64
+ speed (float): Speech speed multiplier
65
+
66
+ Returns:
67
+ str: Path to the generated audio file
68
+ """
69
+ logger.info(f"Legacy Dia Space generate_speech called with text length: {len(text)}")
70
+
71
+ # Use the new implementation via factory pattern
72
+ from utils.tts_engines import DiaSpaceTTSEngine
73
+
74
+ try:
75
+ # Create a Dia Space engine and generate speech
76
+ dia_space_engine = DiaSpaceTTSEngine(language)
77
+ return dia_space_engine.generate_speech(text, voice, speed, response_format)
78
+ except Exception as e:
79
+ logger.error(f"Error in legacy Dia Space generate_speech: {str(e)}", exc_info=True)
80
+ # Fall back to dummy TTS
81
+ from utils.tts_base import DummyTTSEngine
82
+ dummy_engine = DummyTTSEngine()
83
+ return dummy_engine.generate_speech(text)
84
+
85
+
86
+ def _create_output_dir() -> str:
87
+ """Create output directory for audio files
88
+
89
+ Returns:
90
+ str: Path to the output directory
91
+ """
92
+ output_dir = "temp/outputs"
93
+ os.makedirs(output_dir, exist_ok=True)
94
+ return output_dir
95
+
96
+
97
+ def _generate_output_path(prefix: str = "output", extension: str = "wav") -> str:
98
+ """Generate a unique output path for audio files
99
+
100
+ Args:
101
+ prefix (str): Prefix for the output filename
102
+ extension (str): File extension for the output file
103
+
104
+ Returns:
105
+ str: Path to the output file
106
+ """
107
+ output_dir = _create_output_dir()
108
+ timestamp = int(time.time())
109
+ return f"{output_dir}/{prefix}_{timestamp}.{extension}"
110
+
111
+
112
+ def _call_dia_api(text: str, voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> bytes:
113
+ """Call the Dia Space API to generate speech
114
+
115
+ Args:
116
+ text (str): Input text to synthesize
117
+ voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
118
+ response_format (str): Audio format ('wav', 'mp3', 'opus')
119
+ speed (float): Speech speed multiplier
120
+
121
+ Returns:
122
+ bytes: Audio data
123
+ """
124
+ client = _get_client()
125
+
126
+ # Prepare the request payload
127
+ payload = {
128
+ "model": DEFAULT_MODEL,
129
+ "input": text,
130
+ "voice": voice,
131
+ "response_format": response_format,
132
+ "speed": speed
133
+ }
134
+
135
+ # Make the API request
136
+ logger.info(f"Calling Dia Space API with voice: {voice}, format: {response_format}, speed: {speed}")
137
+ try:
138
+ response = client.post(
139
+ f"{DEFAULT_API_URL}/v1/audio/speech",
140
+ json=payload,
141
+ headers={"Content-Type": "application/json"}
142
+ )
143
+
144
+ # Check for successful response
145
+ if response.status_code == 200:
146
+ logger.info("Dia Space API call successful")
147
+ return response.content
148
+ else:
149
+ logger.error(f"Dia Space API returned error: {response.status_code}")
150
+ logger.error(f"Response: {response.text}")
151
+ raise Exception(f"Dia Space API error: {response.status_code}")
152
+ except Exception as e:
153
+ logger.error(f"Error calling Dia Space API: {str(e)}", exc_info=True)
154
+ raise
utils/tts_engines.py CHANGED
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
14
  KOKORO_AVAILABLE = False
15
  KOKORO_SPACE_AVAILABLE = True
16
  DIA_AVAILABLE = False
 
17
 
18
  # Try to import Kokoro
19
  try:
@@ -39,6 +40,12 @@ try:
39
  logger.info("Dia TTS engine is available")
40
  except ImportError:
41
  logger.warning("Dia TTS engine is not available")
 
 
 
 
 
 
42
 
43
 
44
  class KokoroTTSEngine(TTSEngineBase):
@@ -198,6 +205,66 @@ class DiaTTSEngine(TTSEngineBase):
198
  output_path = dia_generate_speech(text, language=self.lang_code)
199
  logger.info(f"Generated audio with Dia: {output_path}")
200
  return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  except ImportError as import_err:
203
  logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
 
14
  KOKORO_AVAILABLE = False
15
  KOKORO_SPACE_AVAILABLE = True
16
  DIA_AVAILABLE = False
17
+ DIA_SPACE_AVAILABLE = True
18
 
19
  # Try to import Kokoro
20
  try:
 
40
  logger.info("Dia TTS engine is available")
41
  except ImportError:
42
  logger.warning("Dia TTS engine is not available")
43
+ except ModuleNotFoundError as e:
44
+ if "dac" in str(e):
45
+ logger.warning("Dia TTS engine is not available due to missing 'dac' module")
46
+ else:
47
+ logger.warning(f"Dia TTS engine is not available: {str(e)}")
48
+ DIA_AVAILABLE = False
49
 
50
 
51
  class KokoroTTSEngine(TTSEngineBase):
 
205
  output_path = dia_generate_speech(text, language=self.lang_code)
206
  logger.info(f"Generated audio with Dia: {output_path}")
207
  return output_path
208
+ except ModuleNotFoundError as e:
209
+ if "dac" in str(e):
210
+ logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to Dia Space")
211
+ # Try using Dia Space instead
212
+ if DIA_SPACE_AVAILABLE:
213
+ return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
214
+ raise
215
+
216
+
217
+ class DiaSpaceTTSEngine(TTSEngineBase):
218
+ """Dia Space TTS engine implementation
219
+
220
+ This engine uses the Dia TTS Server API for speech generation.
221
+ """
222
+
223
+ def __init__(self, lang_code: str = 'z'):
224
+ super().__init__(lang_code)
225
+ try:
226
+ # Import here to avoid circular imports
227
+ from utils.tts_dia_space import _get_client
228
+ self.client = _get_client()
229
+ logger.info("Dia Space TTS engine successfully initialized")
230
+ except Exception as e:
231
+ logger.error(f"Failed to initialize Dia Space client: {str(e)}")
232
+ logger.error(f"Error type: {type(e).__name__}")
233
+ raise
234
+
235
+ def generate_speech(self, text: str, voice: str = 'S1', speed: float = 1.0, response_format: str = 'wav') -> str:
236
+ """Generate speech using Dia Space TTS engine
237
+
238
+ Args:
239
+ text (str): Input text to synthesize
240
+ voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
241
+ speed (float): Speech speed multiplier
242
+ response_format (str): Audio format ('wav', 'mp3', 'opus')
243
+
244
+ Returns:
245
+ str: Path to the generated audio file
246
+ """
247
+ logger.info(f"Generating speech with Dia Space for text length: {len(text)}")
248
+
249
+ try:
250
+ # Import here to avoid circular imports
251
+ from utils.tts_dia_space import _call_dia_api, _generate_output_path
252
+
253
+ # Call the Dia Space API
254
+ audio_data = _call_dia_api(text, voice, response_format, speed)
255
+
256
+ # Save the audio data to a file
257
+ output_path = _generate_output_path(prefix="dia_space", extension=response_format)
258
+ with open(output_path, 'wb') as f:
259
+ f.write(audio_data)
260
+
261
+ logger.info(f"Generated audio with Dia Space: {output_path}")
262
+ return output_path
263
+ except Exception as e:
264
+ logger.error(f"Failed to generate speech from Dia Space API: {str(e)}")
265
+ logger.error(f"Error type: {type(e).__name__}")
266
+ logger.info("Falling back to dummy audio generation")
267
+ return DummyTTSEngine().generate_speech(text, voice, speed)
268
 
269
  except ImportError as import_err:
270
  logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")