Michael Hu commited on
Commit
58d9769
·
1 Parent(s): e734196

create fallback flow for tts engines

Browse files
utils/tts_base.py CHANGED
@@ -28,7 +28,7 @@ class TTSEngineBase(ABC):
28
  logger.info(f"Initializing {self.__class__.__name__} with language code: {lang_code}")
29
 
30
  @abstractmethod
31
- def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
32
  """Generate speech from text
33
 
34
  Args:
@@ -39,7 +39,7 @@ class TTSEngineBase(ABC):
39
  Note: Not all engines support speed adjustment
40
 
41
  Returns:
42
- str: Path to the generated audio file
43
  """
44
  pass
45
 
 
28
  logger.info(f"Initializing {self.__class__.__name__} with language code: {lang_code}")
29
 
30
  @abstractmethod
31
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]:
32
  """Generate speech from text
33
 
34
  Args:
 
39
  Note: Not all engines support speed adjustment
40
 
41
  Returns:
42
+ Optional[str]: Path to the generated audio file, or None if generation fails
43
  """
44
  pass
45
 
utils/tts_cascading.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Tuple, Generator, Optional
3
+ import numpy as np
4
+
5
+ from utils.tts_base import TTSEngineBase, DummyTTSEngine
6
+ from utils.tts_engines import create_engine
7
+
8
+ # Configure logging
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class CascadingTTSEngine(TTSEngineBase):
12
+ """Cascading TTS engine implementation
13
+
14
+ This engine tries multiple TTS engines in order until one succeeds.
15
+ It provides a fallback mechanism to maximize the chances of getting
16
+ quality speech output.
17
+ """
18
+
19
+ def __init__(self, engine_types: List[str], lang_code: str = 'z'):
20
+ """Initialize the cascading TTS engine
21
+
22
+ Args:
23
+ engine_types (List[str]): List of engine types to try in order
24
+ lang_code (str): Language code for the engines
25
+ """
26
+ super().__init__(lang_code)
27
+ self.engine_types = engine_types
28
+ self.lang_code = lang_code
29
+ logger.info(f"Initialized cascading TTS engine with engines: {engine_types}")
30
+
31
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
32
+ """Generate speech by trying multiple engines in order
33
+
34
+ Args:
35
+ text (str): Input text to synthesize
36
+ voice (str): Voice ID to use
37
+ speed (float): Speech speed multiplier
38
+
39
+ Returns:
40
+ str: Path to the generated audio file
41
+ """
42
+ logger.info(f"Generating speech with cascading engine for text length: {len(text)}")
43
+
44
+ # Try each engine in order
45
+ for engine_type in self.engine_types:
46
+ try:
47
+ logger.info(f"Trying TTS engine: {engine_type}")
48
+ engine = create_engine(engine_type, self.lang_code)
49
+
50
+ # Generate speech with the current engine
51
+ result = engine.generate_speech(text, voice, speed)
52
+
53
+ # If the engine returned a valid result, return it
54
+ if result is not None:
55
+ logger.info(f"Successfully generated speech with {engine_type}")
56
+ return result
57
+
58
+ logger.warning(f"TTS engine {engine_type} failed to generate speech, trying next engine")
59
+ except Exception as e:
60
+ logger.error(f"Error with TTS engine {engine_type}: {str(e)}")
61
+ logger.error(f"Error type: {type(e).__name__}")
62
+ logger.warning(f"Trying next TTS engine")
63
+
64
+ # If all engines failed, fall back to dummy engine
65
+ logger.warning("All TTS engines failed, falling back to dummy engine")
66
+ return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
67
+
68
+ def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
69
+ """Generate speech stream by trying multiple engines in order
70
+
71
+ Args:
72
+ text (str): Input text to synthesize
73
+ voice (str): Voice ID to use
74
+ speed (float): Speech speed multiplier
75
+
76
+ Yields:
77
+ tuple: (sample_rate, audio_data) pairs for each segment
78
+ """
79
+ logger.info(f"Generating speech stream with cascading engine for text length: {len(text)}")
80
+
81
+ # Try each engine in order
82
+ for engine_type in self.engine_types:
83
+ try:
84
+ logger.info(f"Trying TTS engine for streaming: {engine_type}")
85
+ engine = create_engine(engine_type, self.lang_code)
86
+
87
+ # Create a generator for the current engine
88
+ generator = engine.generate_speech_stream(text, voice, speed)
89
+
90
+ # Try to get the first chunk to verify the engine works
91
+ first_chunk = next(generator, None)
92
+ if first_chunk is not None:
93
+ # Engine produced a valid first chunk, yield it and continue with this engine
94
+ logger.info(f"Successfully started speech stream with {engine_type}")
95
+ yield first_chunk
96
+
97
+ # Yield the rest of the chunks from this engine
98
+ for chunk in generator:
99
+ yield chunk
100
+
101
+ # Successfully streamed all chunks, return
102
+ return
103
+
104
+ logger.warning(f"TTS engine {engine_type} failed to generate speech stream, trying next engine")
105
+ except Exception as e:
106
+ logger.error(f"Error with TTS engine {engine_type} streaming: {str(e)}")
107
+ logger.error(f"Error type: {type(e).__name__}")
108
+ logger.warning(f"Trying next TTS engine for streaming")
109
+
110
+ # If all engines failed, fall back to dummy engine
111
+ logger.warning("All TTS engines failed for streaming, falling back to dummy engine")
112
+ yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed)
utils/tts_engines.py CHANGED
@@ -3,7 +3,7 @@ import time
3
  import os
4
  import numpy as np
5
  import soundfile as sf
6
- from typing import Dict, List, Optional, Tuple, Generator, Any
7
 
8
  from utils.tts_base import TTSEngineBase, DummyTTSEngine
9
 
@@ -64,7 +64,7 @@ class KokoroTTSEngine(TTSEngineBase):
64
  logger.error(f"Error type: {type(e).__name__}")
65
  raise
66
 
67
- def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
68
  """Generate speech using Kokoro TTS engine
69
 
70
  Args:
@@ -73,7 +73,7 @@ class KokoroTTSEngine(TTSEngineBase):
73
  speed (float): Speech speed multiplier (0.5 to 2.0)
74
 
75
  Returns:
76
- str: Path to the generated audio file
77
  """
78
  logger.info(f"Generating speech with Kokoro for text length: {len(text)}")
79
 
@@ -126,7 +126,7 @@ class KokoroSpaceTTSEngine(TTSEngineBase):
126
  logger.error(f"Error type: {type(e).__name__}")
127
  raise
128
 
129
- def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
130
  """Generate speech using Kokoro Space TTS engine
131
 
132
  Args:
@@ -135,7 +135,7 @@ class KokoroSpaceTTSEngine(TTSEngineBase):
135
  speed (float): Speech speed multiplier (0.5 to 2.0)
136
 
137
  Returns:
138
- str: Path to the generated audio file
139
  """
140
  logger.info(f"Generating speech with Kokoro Space for text length: {len(text)}")
141
  logger.info(f"Text to generate speech on is: {text[:50]}..." if len(text) > 50 else f"Text to generate speech on is: {text}")
@@ -156,19 +156,19 @@ class KokoroSpaceTTSEngine(TTSEngineBase):
156
  )
157
  logger.info(f"Received audio from Kokoro FastAPI server: {result}")
158
 
159
- # TODO: Process the result and save to output_path
160
- # For now, we'll return the result path directly if it's a string
161
  if isinstance(result, str) and os.path.exists(result):
162
  return result
163
  else:
164
- logger.warning("Unexpected result from Kokoro Space, falling back to dummy audio")
165
- return DummyTTSEngine().generate_speech(text, voice, speed)
166
 
167
  except Exception as e:
168
  logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}")
169
  logger.error(f"Error type: {type(e).__name__}")
170
- logger.info("Falling back to dummy audio generation")
171
- return DummyTTSEngine().generate_speech(text, voice, speed)
172
 
173
 
174
  class DiaTTSEngine(TTSEngineBase):
@@ -182,7 +182,7 @@ class DiaTTSEngine(TTSEngineBase):
182
  # Dia doesn't need initialization here, it will be lazy-loaded when needed
183
  logger.info("Dia TTS engine initialized (lazy loading)")
184
 
185
- def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
186
  """Generate speech using Dia TTS engine
187
 
188
  Args:
@@ -191,7 +191,7 @@ class DiaTTSEngine(TTSEngineBase):
191
  speed (float): Speech speed multiplier (not used in Dia)
192
 
193
  Returns:
194
- str: Path to the generated audio file
195
  """
196
  logger.info(f"Generating speech with Dia for text length: {len(text)}")
197
 
@@ -201,13 +201,8 @@ class DiaTTSEngine(TTSEngineBase):
201
 
202
  # Check if Dia is available
203
  if not DIA_AVAILABLE:
204
- logger.warning("Dia TTS engine is not available, falling back to Dia Space")
205
- # Try using Dia Space instead
206
- if DIA_SPACE_AVAILABLE:
207
- return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
208
- else:
209
- logger.warning("Dia Space is also not available, falling back to dummy TTS")
210
- return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
211
 
212
  logger.info("Successfully imported Dia speech generation function")
213
 
@@ -218,18 +213,13 @@ class DiaTTSEngine(TTSEngineBase):
218
  return output_path
219
  except ModuleNotFoundError as e:
220
  if "dac" in str(e):
221
- logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to Dia Space")
222
- # Try using Dia Space instead
223
- if DIA_SPACE_AVAILABLE:
224
- return DiaSpaceTTSEngine(self.lang_code).generate_speech(text, voice, speed)
225
- else:
226
- logger.warning("Dia Space is also not available, falling back to dummy TTS")
227
- return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
228
  raise
229
  except Exception as e:
230
  logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
231
- logger.warning("Falling back to dummy TTS engine")
232
- return DummyTTSEngine(self.lang_code).generate_speech(text, voice, speed)
233
 
234
 
235
  class DiaSpaceTTSEngine(TTSEngineBase):
@@ -250,7 +240,7 @@ class DiaSpaceTTSEngine(TTSEngineBase):
250
  logger.error(f"Error type: {type(e).__name__}")
251
  raise
252
 
253
- def generate_speech(self, text: str, voice: str = 'S1', speed: float = 1.0, response_format: str = 'wav') -> str:
254
  """Generate speech using Dia Space TTS engine
255
 
256
  Args:
@@ -260,7 +250,7 @@ class DiaSpaceTTSEngine(TTSEngineBase):
260
  response_format (str): Audio format ('wav', 'mp3', 'opus')
261
 
262
  Returns:
263
- str: Path to the generated audio file
264
  """
265
  logger.info(f"Generating speech with Dia Space for text length: {len(text)}")
266
 
@@ -281,19 +271,19 @@ class DiaSpaceTTSEngine(TTSEngineBase):
281
  except Exception as e:
282
  logger.error(f"Failed to generate speech from Dia Space API: {str(e)}")
283
  logger.error(f"Error type: {type(e).__name__}")
284
- logger.info("Falling back to dummy audio generation")
285
- return DummyTTSEngine().generate_speech(text, voice, speed)
286
 
287
  except ImportError as import_err:
288
  logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
289
- logger.error("Falling back to dummy audio generation")
290
- return DummyTTSEngine().generate_speech(text, voice, speed)
291
 
292
  except Exception as dia_error:
293
  logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
294
  logger.error(f"Error type: {type(dia_error).__name__}")
295
- logger.error("Falling back to dummy audio generation")
296
- return DummyTTSEngine().generate_speech(text, voice, speed)
297
 
298
  def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
299
  """Generate speech stream using Dia TTS engine
 
3
  import os
4
  import numpy as np
5
  import soundfile as sf
6
+ from typing import Dict, List, Optional, Tuple, Generator, Any, Union
7
 
8
  from utils.tts_base import TTSEngineBase, DummyTTSEngine
9
 
 
64
  logger.error(f"Error type: {type(e).__name__}")
65
  raise
66
 
67
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]:
68
  """Generate speech using Kokoro TTS engine
69
 
70
  Args:
 
73
  speed (float): Speech speed multiplier (0.5 to 2.0)
74
 
75
  Returns:
76
+ Optional[str]: Path to the generated audio file or None if generation fails
77
  """
78
  logger.info(f"Generating speech with Kokoro for text length: {len(text)}")
79
 
 
126
  logger.error(f"Error type: {type(e).__name__}")
127
  raise
128
 
129
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]:
130
  """Generate speech using Kokoro Space TTS engine
131
 
132
  Args:
 
135
  speed (float): Speech speed multiplier (0.5 to 2.0)
136
 
137
  Returns:
138
+ Optional[str]: Path to the generated audio file or None if generation fails
139
  """
140
  logger.info(f"Generating speech with Kokoro Space for text length: {len(text)}")
141
  logger.info(f"Text to generate speech on is: {text[:50]}..." if len(text) > 50 else f"Text to generate speech on is: {text}")
 
156
  )
157
  logger.info(f"Received audio from Kokoro FastAPI server: {result}")
158
 
159
+ # Process the result and save to output_path
160
+ # Return the result path directly if it's a string
161
  if isinstance(result, str) and os.path.exists(result):
162
  return result
163
  else:
164
+ logger.warning("Unexpected result from Kokoro Space")
165
+ return None
166
 
167
  except Exception as e:
168
  logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}")
169
  logger.error(f"Error type: {type(e).__name__}")
170
+ logger.info("Kokoro Space TTS engine failed")
171
+ return None
172
 
173
 
174
  class DiaTTSEngine(TTSEngineBase):
 
182
  # Dia doesn't need initialization here, it will be lazy-loaded when needed
183
  logger.info("Dia TTS engine initialized (lazy loading)")
184
 
185
+ def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]:
186
  """Generate speech using Dia TTS engine
187
 
188
  Args:
 
191
  speed (float): Speech speed multiplier (not used in Dia)
192
 
193
  Returns:
194
+ Optional[str]: Path to the generated audio file or None if generation fails
195
  """
196
  logger.info(f"Generating speech with Dia for text length: {len(text)}")
197
 
 
201
 
202
  # Check if Dia is available
203
  if not DIA_AVAILABLE:
204
+ logger.warning("Dia TTS engine is not available")
205
+ return None
 
 
 
 
 
206
 
207
  logger.info("Successfully imported Dia speech generation function")
208
 
 
213
  return output_path
214
  except ModuleNotFoundError as e:
215
  if "dac" in str(e):
216
+ logger.warning("Dia TTS engine failed due to missing 'dac' module")
217
+ return None
 
 
 
 
 
218
  raise
219
  except Exception as e:
220
  logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
221
+ logger.warning("Dia TTS engine failed")
222
+ return None
223
 
224
 
225
  class DiaSpaceTTSEngine(TTSEngineBase):
 
240
  logger.error(f"Error type: {type(e).__name__}")
241
  raise
242
 
243
+ def generate_speech(self, text: str, voice: str = 'S1', speed: float = 1.0, response_format: str = 'wav') -> Optional[str]:
244
  """Generate speech using Dia Space TTS engine
245
 
246
  Args:
 
250
  response_format (str): Audio format ('wav', 'mp3', 'opus')
251
 
252
  Returns:
253
+ Optional[str]: Path to the generated audio file or None if generation fails
254
  """
255
  logger.info(f"Generating speech with Dia Space for text length: {len(text)}")
256
 
 
271
  except Exception as e:
272
  logger.error(f"Failed to generate speech from Dia Space API: {str(e)}")
273
  logger.error(f"Error type: {type(e).__name__}")
274
+ logger.info("Dia Space TTS engine failed")
275
+ return None
276
 
277
  except ImportError as import_err:
278
  logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}")
279
+ logger.error("Dia Space TTS engine failed")
280
+ return None
281
 
282
  except Exception as dia_error:
283
  logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
284
  logger.error(f"Error type: {type(dia_error).__name__}")
285
+ logger.error("Dia Space TTS engine failed")
286
+ return None
287
 
288
  def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
289
  """Generate speech stream using Dia TTS engine
utils/tts_factory.py CHANGED
@@ -6,6 +6,7 @@ logger = logging.getLogger(__name__)
6
 
7
  # Import the base class
8
  from utils.tts_base import TTSEngineBase, DummyTTSEngine
 
9
 
10
  class TTSFactory:
11
  """Factory class for creating TTS engines
@@ -36,17 +37,41 @@ class TTSFactory:
36
  if engine_type is not None:
37
  if engine_type in available_engines:
38
  logger.info(f"Creating requested engine: {engine_type}")
39
- return create_engine(engine_type, lang_code)
 
40
  else:
41
  logger.warning(f"Requested engine '{engine_type}' is not available")
42
 
43
- # Try to create the best available engine
44
- # Priority: kokoro > kokoro_space > dia > dummy
45
- for engine in ['kokoro', 'kokoro_space', 'dia']:
46
- if engine in available_engines:
47
- logger.info(f"Creating best available engine: {engine}")
48
- return create_engine(engine, lang_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Fall back to dummy engine
51
- logger.warning("No TTS engines available, falling back to dummy engine")
52
- return DummyTTSEngine(lang_code)
 
6
 
7
  # Import the base class
8
  from utils.tts_base import TTSEngineBase, DummyTTSEngine
9
+ from utils.tts_cascading import CascadingTTSEngine
10
 
11
  class TTSFactory:
12
  """Factory class for creating TTS engines
 
37
  if engine_type is not None:
38
  if engine_type in available_engines:
39
  logger.info(f"Creating requested engine: {engine_type}")
40
+ engine = create_engine(engine_type, lang_code)
41
+ return engine
42
  else:
43
  logger.warning(f"Requested engine '{engine_type}' is not available")
44
 
45
+ # Fall back to dummy engine if no engines are available
46
+ if not available_engines or (len(available_engines) == 1 and available_engines[0] == 'dummy'):
47
+ logger.warning("No TTS engines available, falling back to dummy engine")
48
+ return DummyTTSEngine(lang_code)
49
+
50
+ return TTSFactory.create_cascading_engine(available_engines, lang_code)
51
+
52
+ @staticmethod
53
+ def create_cascading_engine(available_engines: List[str], lang_code: str = 'z') -> TTSEngineBase:
54
+ """Create a cascading TTS engine that tries multiple engines in order
55
+
56
+ Args:
57
+ available_engines (List[str]): List of available engine names
58
+ lang_code (str): Language code for the engines
59
+
60
+ Returns:
61
+ TTSEngineBase: A cascading TTS engine instance
62
+ """
63
+ from utils.tts_engines import create_engine
64
+
65
+ # Define the priority order for engines
66
+ priority_order = ['kokoro', 'kokoro_space', 'dia', 'dia_space', 'dummy']
67
+
68
+ # Filter and sort available engines by priority
69
+ engines_by_priority = [engine for engine in priority_order if engine in available_engines]
70
+
71
+ # Always ensure dummy is the last fallback
72
+ if 'dummy' not in engines_by_priority:
73
+ engines_by_priority.append('dummy')
74
+
75
+ logger.info(f"Creating cascading engine with priority: {engines_by_priority}")
76
 
77
+ return CascadingTTSEngine(engines_by_priority, lang_code)