Michael Hu commited on
Commit
7b25fdd
·
1 Parent(s): 030c851

use dia tts as fallback model if kokoro is not available

Browse files
Files changed (1) hide show
  1. utils/tts.py +140 -45
utils/tts.py CHANGED
@@ -5,42 +5,72 @@ import soundfile as sf
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
- # Wrap the problematic import in a try-except block
 
 
 
 
9
  try:
10
  from kokoro import KPipeline
11
  KOKORO_AVAILABLE = True
 
12
  except AttributeError as e:
13
  # Specifically catch the EspeakWrapper.set_data_path error
14
  if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
15
  logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue")
16
- KOKORO_AVAILABLE = False
17
  else:
18
  # Re-raise if it's a different error
 
19
  raise
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class TTSEngine:
22
  def __init__(self, lang_code='z'):
23
- """Initialize TTS Engine with Kokoro
24
 
25
  Args:
26
  lang_code (str): Language code ('a' for US English, 'b' for British English,
27
  'j' for Japanese, 'z' for Mandarin Chinese)
 
28
  """
29
  logger.info("Initializing TTS Engine")
30
- if not KOKORO_AVAILABLE:
31
- logger.warning("Using dummy TTS implementation as Kokoro is not available")
32
- self.pipeline = None
33
- else:
34
  self.pipeline = KPipeline(lang_code=lang_code)
 
35
  logger.info("TTS engine initialized with Kokoro")
 
 
 
 
 
 
 
 
 
 
36
 
37
  def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
38
- """Generate speech from text using Kokoro
39
 
40
  Args:
41
  text (str): Input text to synthesize
42
  voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
 
43
  speed (float): Speech speed multiplier (0.5 to 2.0)
 
44
 
45
  Returns:
46
  str: Path to the generated audio file
@@ -54,26 +84,29 @@ class TTSEngine:
54
  # Generate unique output path
55
  output_path = f"temp/outputs/output_{int(time.time())}.wav"
56
 
57
- if not KOKORO_AVAILABLE:
58
- # Generate a simple sine wave as dummy audio
59
- import numpy as np
60
- sample_rate = 24000
61
- duration = 3.0 # seconds
62
- t = np.linspace(0, duration, int(sample_rate * duration), False)
63
- tone = np.sin(2 * np.pi * 440 * t) * 0.3
64
-
65
- logger.info(f"Saving dummy audio to {output_path}")
66
- sf.write(output_path, tone, sample_rate)
67
- logger.info(f"Dummy audio generation complete: {output_path}")
68
- return output_path
69
-
70
- # Get the first generated segment
71
- # We only take the first segment since the original code handled single segments
72
- generator = self.pipeline(text, voice=voice, speed=speed)
73
- for _, _, audio in generator:
74
- logger.info(f"Saving audio to {output_path}")
75
- sf.write(output_path, audio, 24000)
76
- break
 
 
 
77
 
78
  logger.info(f"Audio generation complete: {output_path}")
79
  return output_path
@@ -81,6 +114,26 @@ class TTSEngine:
81
  except Exception as e:
82
  logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
83
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
86
  """Generate speech from text and yield each segment
@@ -94,27 +147,69 @@ class TTSEngine:
94
  tuple: (sample_rate, audio_data) pairs for each segment
95
  """
96
  try:
97
- if not KOKORO_AVAILABLE:
98
- # Generate dummy audio chunks
99
- import numpy as np
100
- sample_rate = 24000
101
- duration = 1.0 # seconds per chunk
102
-
103
- # Create 3 chunks of dummy audio
104
- for i in range(3):
105
- t = np.linspace(0, duration, int(sample_rate * duration), False)
106
- freq = 440 + (i * 220) # Different frequency for each chunk
107
- tone = np.sin(2 * np.pi * freq * t) * 0.3
108
- yield sample_rate, tone
109
- return
110
-
111
- generator = self.pipeline(text, voice=voice, speed=speed)
112
- for _, _, audio in generator:
113
- yield 24000, audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  except Exception as e:
116
  logger.error(f"TTS streaming failed: {str(e)}", exc_info=True)
117
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # Initialize TTS engine with cache decorator if using Streamlit
120
  def get_tts_engine(lang_code='a'):
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
+ # Flag to track TTS engine availability
9
+ KOKORO_AVAILABLE = False
10
+ DIA_AVAILABLE = False
11
+
12
+ # Try to import Kokoro first
13
  try:
14
  from kokoro import KPipeline
15
  KOKORO_AVAILABLE = True
16
+ logger.info("Kokoro TTS engine is available")
17
  except AttributeError as e:
18
  # Specifically catch the EspeakWrapper.set_data_path error
19
  if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
20
  logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue")
 
21
  else:
22
  # Re-raise if it's a different error
23
+ logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
24
  raise
25
+ except ImportError:
26
+ logger.warning("Kokoro TTS engine is not available")
27
+
28
+ # Try to import Dia as fallback
29
+ if not KOKORO_AVAILABLE:
30
+ try:
31
+ from utils.tts_dia import _get_model as get_dia_model
32
+ DIA_AVAILABLE = True
33
+ logger.info("Dia TTS engine is available as fallback")
34
+ except ImportError as e:
35
+ logger.warning(f"Dia TTS engine is not available: {str(e)}")
36
+ logger.warning("Will use dummy TTS implementation as fallback")
37
 
38
  class TTSEngine:
39
  def __init__(self, lang_code='z'):
40
+ """Initialize TTS Engine with Kokoro or Dia as fallback
41
 
42
  Args:
43
  lang_code (str): Language code ('a' for US English, 'b' for British English,
44
  'j' for Japanese, 'z' for Mandarin Chinese)
45
+ Note: lang_code is only used for Kokoro, not for Dia
46
  """
47
  logger.info("Initializing TTS Engine")
48
+ self.engine_type = None
49
+
50
+ if KOKORO_AVAILABLE:
 
51
  self.pipeline = KPipeline(lang_code=lang_code)
52
+ self.engine_type = "kokoro"
53
  logger.info("TTS engine initialized with Kokoro")
54
+ elif DIA_AVAILABLE:
55
+ # For Dia, we don't need to initialize anything here
56
+ # The model will be lazy-loaded when needed
57
+ self.pipeline = None
58
+ self.engine_type = "dia"
59
+ logger.info("TTS engine initialized with Dia (lazy loading)")
60
+ else:
61
+ logger.warning("Using dummy TTS implementation as no TTS engines are available")
62
+ self.pipeline = None
63
+ self.engine_type = "dummy"
64
 
65
  def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
66
+ """Generate speech from text using available TTS engine
67
 
68
  Args:
69
  text (str): Input text to synthesize
70
  voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
71
+ Note: voice parameter is only used for Kokoro, not for Dia
72
  speed (float): Speech speed multiplier (0.5 to 2.0)
73
+ Note: speed parameter is only used for Kokoro, not for Dia
74
 
75
  Returns:
76
  str: Path to the generated audio file
 
84
  # Generate unique output path
85
  output_path = f"temp/outputs/output_{int(time.time())}.wav"
86
 
87
+ # Use the appropriate TTS engine based on availability
88
+ if self.engine_type == "kokoro":
89
+ # Use Kokoro for TTS generation
90
+ generator = self.pipeline(text, voice=voice, speed=speed)
91
+ for _, _, audio in generator:
92
+ logger.info(f"Saving Kokoro audio to {output_path}")
93
+ sf.write(output_path, audio, 24000)
94
+ break
95
+ elif self.engine_type == "dia":
96
+ # Use Dia for TTS generation
97
+ try:
98
+ # Import here to avoid circular imports
99
+ from utils.tts_dia import generate_speech as dia_generate_speech
100
+ # Call Dia's generate_speech function
101
+ output_path = dia_generate_speech(text)
102
+ logger.info(f"Generated audio with Dia: {output_path}")
103
+ except Exception as dia_error:
104
+ logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
105
+ # Fall back to dummy audio if Dia fails
106
+ return self._generate_dummy_audio(output_path)
107
+ else:
108
+ # Generate dummy audio as fallback
109
+ return self._generate_dummy_audio(output_path)
110
 
111
  logger.info(f"Audio generation complete: {output_path}")
112
  return output_path
 
114
  except Exception as e:
115
  logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
116
  raise
117
+
118
+ def _generate_dummy_audio(self, output_path):
119
+ """Generate a dummy audio file with a simple sine wave
120
+
121
+ Args:
122
+ output_path (str): Path to save the dummy audio file
123
+
124
+ Returns:
125
+ str: Path to the generated dummy audio file
126
+ """
127
+ import numpy as np
128
+ sample_rate = 24000
129
+ duration = 3.0 # seconds
130
+ t = np.linspace(0, duration, int(sample_rate * duration), False)
131
+ tone = np.sin(2 * np.pi * 440 * t) * 0.3
132
+
133
+ logger.info(f"Saving dummy audio to {output_path}")
134
+ sf.write(output_path, tone, sample_rate)
135
+ logger.info(f"Dummy audio generation complete: {output_path}")
136
+ return output_path
137
 
138
  def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
139
  """Generate speech from text and yield each segment
 
147
  tuple: (sample_rate, audio_data) pairs for each segment
148
  """
149
  try:
150
+ # Use the appropriate TTS engine based on availability
151
+ if self.engine_type == "kokoro":
152
+ # Use Kokoro for streaming TTS
153
+ generator = self.pipeline(text, voice=voice, speed=speed)
154
+ for _, _, audio in generator:
155
+ yield 24000, audio
156
+ elif self.engine_type == "dia":
157
+ # Dia doesn't support streaming natively, so we generate the full audio
158
+ # and then yield it as a single chunk
159
+ try:
160
+ # Import here to avoid circular imports
161
+ import torch
162
+ from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
163
+
164
+ # Get the Dia model
165
+ model = _get_model()
166
+
167
+ # Generate audio
168
+ with torch.inference_mode():
169
+ output_audio_np = model.generate(
170
+ text,
171
+ max_tokens=None,
172
+ cfg_scale=3.0,
173
+ temperature=1.3,
174
+ top_p=0.95,
175
+ cfg_filter_top_k=35,
176
+ use_torch_compile=False,
177
+ verbose=False
178
+ )
179
+
180
+ if output_audio_np is not None:
181
+ yield DEFAULT_SAMPLE_RATE, output_audio_np
182
+ else:
183
+ # Fall back to dummy audio if Dia fails
184
+ yield from self._generate_dummy_audio_stream()
185
+ except Exception as dia_error:
186
+ logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
187
+ # Fall back to dummy audio if Dia fails
188
+ yield from self._generate_dummy_audio_stream()
189
+ else:
190
+ # Generate dummy audio chunks as fallback
191
+ yield from self._generate_dummy_audio_stream()
192
 
193
  except Exception as e:
194
  logger.error(f"TTS streaming failed: {str(e)}", exc_info=True)
195
  raise
196
+
197
+ def _generate_dummy_audio_stream(self):
198
+ """Generate dummy audio chunks with simple sine waves
199
+
200
+ Yields:
201
+ tuple: (sample_rate, audio_data) pairs for each dummy segment
202
+ """
203
+ import numpy as np
204
+ sample_rate = 24000
205
+ duration = 1.0 # seconds per chunk
206
+
207
+ # Create 3 chunks of dummy audio
208
+ for i in range(3):
209
+ t = np.linspace(0, duration, int(sample_rate * duration), False)
210
+ freq = 440 + (i * 220) # Different frequency for each chunk
211
+ tone = np.sin(2 * np.pi * freq * t) * 0.3
212
+ yield sample_rate, tone
213
 
214
  # Initialize TTS engine with cache decorator if using Streamlit
215
  def get_tts_engine(lang_code='a'):