Michael Hu commited on
Commit
6514731
·
1 Parent(s): ca57c53

fix dia tts

Browse files
pyproject.toml CHANGED
@@ -26,6 +26,7 @@ dependencies = [
26
  "phonemizer-fork>=3.3.2",
27
  "nemo_toolkit[asr]",
28
  "faster-whisper>=1.1.1",
 
29
  ]
30
 
31
  [project.optional-dependencies]
 
26
  "phonemizer-fork>=3.3.2",
27
  "nemo_toolkit[asr]",
28
  "faster-whisper>=1.1.1",
29
+ "descript-audio-codec"
30
  ]
31
 
32
  [project.optional-dependencies]
src/application/services/audio_processing_service.py CHANGED
@@ -571,7 +571,7 @@ class AudioProcessingApplicationService:
571
  return output_path
572
 
573
  except Exception as e:
574
- logger.error(f"TTS failed: {e} [correlation_id={correlation_id}]", exc_info=True)
575
  raise SpeechSynthesisException(f"Speech synthesis failed: {str(e)}")
576
 
577
  def _get_error_code_from_exception(self, exception: Exception) -> str:
 
571
  return output_path
572
 
573
  except Exception as e:
574
+ logger.error(f"TTS failed: {e} [correlation_id={correlation_id}]", exception=e)
575
  raise SpeechSynthesisException(f"Speech synthesis failed: {str(e)}")
576
 
577
  def _get_error_code_from_exception(self, exception: Exception) -> str:
src/infrastructure/base/file_utils.py CHANGED
@@ -356,7 +356,7 @@ class ErrorHandler:
356
  error_msg += f" during {context}"
357
  error_msg += f": {str(error)}"
358
 
359
- self.logger.error(error_msg, exc_info=True)
360
 
361
  if reraise_as:
362
  raise reraise_as(error_msg) from error
 
356
  error_msg += f" during {context}"
357
  error_msg += f": {str(error)}"
358
 
359
+ self.logger.error(error_msg, exception=error)
360
 
361
  if reraise_as:
362
  raise reraise_as(error_msg) from error
src/infrastructure/base/stt_provider_base.py CHANGED
@@ -312,5 +312,5 @@ class STTProviderBase(ISpeechRecognitionService, ABC):
312
  error_msg += f" during {context}"
313
  error_msg += f": {str(error)}"
314
 
315
- logger.error(error_msg, exc_info=True)
316
  raise SpeechRecognitionException(error_msg) from error
 
312
  error_msg += f" during {context}"
313
  error_msg += f": {str(error)}"
314
 
315
+ logger.error(error_msg, exception=error)
316
  raise SpeechRecognitionException(error_msg) from error
src/infrastructure/base/translation_provider_base.py CHANGED
@@ -315,7 +315,7 @@ class TranslationProviderBase(ITranslationService, ABC):
315
  error_msg += f" during {context}"
316
  error_msg += f": {str(error)}"
317
 
318
- logger.error(error_msg, exc_info=True)
319
  raise TranslationFailedException(error_msg) from error
320
 
321
  def set_chunk_size(self, chunk_size: int) -> None:
 
315
  error_msg += f" during {context}"
316
  error_msg += f": {str(error)}"
317
 
318
+ logger.error(error_msg, exception=error)
319
  raise TranslationFailedException(error_msg) from error
320
 
321
  def set_chunk_size(self, chunk_size: int) -> None:
src/infrastructure/base/tts_provider_base.py CHANGED
@@ -340,5 +340,5 @@ class TTSProviderBase(ISpeechSynthesisService, ABC):
340
  error_msg += f" during {context}"
341
  error_msg += f": {str(error)}"
342
 
343
- logger.error(error_msg, exc_info=True)
344
  raise SpeechSynthesisException(error_msg) from error
 
340
  error_msg += f" during {context}"
341
  error_msg += f": {str(error)}"
342
 
343
+ logger.error(error_msg, exception=error)
344
  raise SpeechSynthesisException(error_msg) from error
src/infrastructure/config/container_setup.py CHANGED
@@ -280,7 +280,7 @@ def create_configured_container(config_file: Optional[str] = None) -> Dependency
280
  _validate_container_setup(container)
281
  logger.info("Container validation completed")
282
  except Exception as validation_error:
283
- logger.error(f"Container validation failed: {validation_error}", exc_info=True)
284
  # For now, let's continue even if validation fails to see if the app works
285
  logger.warning("Continuing despite validation failure...")
286
 
@@ -288,7 +288,7 @@ def create_configured_container(config_file: Optional[str] = None) -> Dependency
288
  return container
289
 
290
  except Exception as e:
291
- logger.error(f"Failed to create configured container: {e}", exc_info=True)
292
  raise
293
 
294
 
@@ -352,7 +352,7 @@ def _validate_container_setup(container: DependencyContainer) -> None:
352
 
353
  except Exception as e:
354
  error_msg = f"Container validation failed during service resolution: {e}"
355
- logger.error(error_msg, exc_info=True)
356
  raise RuntimeError(error_msg)
357
 
358
 
 
280
  _validate_container_setup(container)
281
  logger.info("Container validation completed")
282
  except Exception as validation_error:
283
+ logger.error(f"Container validation failed: {validation_error}", exception=validation_error)
284
  # For now, let's continue even if validation fails to see if the app works
285
  logger.warning("Continuing despite validation failure...")
286
 
 
288
  return container
289
 
290
  except Exception as e:
291
+ logger.error(f"Failed to create configured container: {e}", exception=e)
292
  raise
293
 
294
 
 
352
 
353
  except Exception as e:
354
  error_msg = f"Container validation failed during service resolution: {e}"
355
+ logger.error(error_msg, exception=e)
356
  raise RuntimeError(error_msg)
357
 
358
 
src/infrastructure/config/dependency_container.py CHANGED
@@ -214,7 +214,7 @@ class DependencyContainer:
214
  return result
215
 
216
  except Exception as e:
217
- logger.error(f"Failed to resolve service {service_type.__name__}: {e}", exc_info=True)
218
  raise
219
 
220
  def _create_singleton(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T:
@@ -260,7 +260,7 @@ class DependencyContainer:
260
  logger.info(f"Factory function completed for {descriptor.service_type.__name__}")
261
  return result
262
  except Exception as e:
263
- logger.error(f"Factory function failed for {descriptor.service_type.__name__}: {e}", exc_info=True)
264
  raise
265
 
266
  # If implementation is a class
@@ -271,7 +271,7 @@ class DependencyContainer:
271
  logger.info(f"Class instantiation completed for {descriptor.service_type.__name__}")
272
  return result
273
  except Exception as e:
274
- logger.error(f"Class instantiation failed for {descriptor.service_type.__name__}: {e}", exc_info=True)
275
  raise
276
 
277
  logger.error(f"Invalid implementation type for {descriptor.service_type.__name__}: {type(implementation)}")
@@ -312,7 +312,14 @@ class DependencyContainer:
312
  factory = self.resolve(TTSProviderFactory)
313
 
314
  if provider_name:
315
- return factory.create_provider(provider_name, **kwargs)
 
 
 
 
 
 
 
316
  else:
317
  preferred_providers = self._config.tts.preferred_providers
318
  return factory.get_provider_with_fallback(preferred_providers, **kwargs)
 
214
  return result
215
 
216
  except Exception as e:
217
+ logger.error(f"Failed to resolve service {service_type.__name__}: {e}", exception=e)
218
  raise
219
 
220
  def _create_singleton(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T:
 
260
  logger.info(f"Factory function completed for {descriptor.service_type.__name__}")
261
  return result
262
  except Exception as e:
263
+ logger.error(f"Factory function failed for {descriptor.service_type.__name__}: {e}", exception=e)
264
  raise
265
 
266
  # If implementation is a class
 
271
  logger.info(f"Class instantiation completed for {descriptor.service_type.__name__}")
272
  return result
273
  except Exception as e:
274
+ logger.error(f"Class instantiation failed for {descriptor.service_type.__name__}: {e}", exception=e)
275
  raise
276
 
277
  logger.error(f"Invalid implementation type for {descriptor.service_type.__name__}: {type(implementation)}")
 
312
  factory = self.resolve(TTSProviderFactory)
313
 
314
  if provider_name:
315
+ try:
316
+ return factory.create_provider(provider_name, **kwargs)
317
+ except Exception as e:
318
+ logger.warning(f"Failed to create specific TTS provider {provider_name}: {e}")
319
+ logger.info("Falling back to default provider selection")
320
+ # Fall back to default provider selection
321
+ preferred_providers = self._config.tts.preferred_providers
322
+ return factory.get_provider_with_fallback(preferred_providers, **kwargs)
323
  else:
324
  preferred_providers = self._config.tts.preferred_providers
325
  return factory.get_provider_with_fallback(preferred_providers, **kwargs)
src/infrastructure/tts/cosyvoice2_provider.py CHANGED
@@ -61,13 +61,13 @@ class CosyVoice2TTSProvider(TTSProviderBase):
61
  self.model = CosyVoice('pretrained_models/CosyVoice-300M')
62
  logger.info("CosyVoice2 model successfully loaded")
63
  except ImportError as e:
64
- logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}", exc_info=True)
65
  self.model = None
66
  except FileNotFoundError as e:
67
- logger.error(f"Failed to load CosyVoice2 model files: {str(e)}", exc_info=True)
68
  self.model = None
69
  except Exception as e:
70
- logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}", exc_info=True)
71
  self.model = None
72
 
73
  model_available = self.model is not None
@@ -144,7 +144,7 @@ class CosyVoice2TTSProvider(TTSProviderBase):
144
  return audio_bytes, DEFAULT_SAMPLE_RATE
145
 
146
  except Exception as e:
147
- logger.error(f"CosyVoice2 audio generation failed: {str(e)}", exc_info=True)
148
  self._handle_provider_error(e, "audio generation")
149
 
150
  def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
 
61
  self.model = CosyVoice('pretrained_models/CosyVoice-300M')
62
  logger.info("CosyVoice2 model successfully loaded")
63
  except ImportError as e:
64
+ logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}", exception=e)
65
  self.model = None
66
  except FileNotFoundError as e:
67
+ logger.error(f"Failed to load CosyVoice2 model files: {str(e)}", exception=e)
68
  self.model = None
69
  except Exception as e:
70
+ logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}", exception=e)
71
  self.model = None
72
 
73
  model_available = self.model is not None
 
144
  return audio_bytes, DEFAULT_SAMPLE_RATE
145
 
146
  except Exception as e:
147
+ logger.error(f"CosyVoice2 audio generation failed: {str(e)}", exception=e)
148
  self._handle_provider_error(e, "audio generation")
149
 
150
  def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
src/infrastructure/tts/provider_factory.py CHANGED
@@ -20,7 +20,7 @@ class TTSProviderFactory:
20
  def _register_default_providers(self):
21
  """Register all available TTS providers."""
22
  # Import providers dynamically to avoid import errors if dependencies are missing
23
-
24
  # Always register dummy provider as fallback
25
  from .dummy_provider import DummyTTSProvider
26
  self._providers['dummy'] = DummyTTSProvider
@@ -39,7 +39,16 @@ class TTSProviderFactory:
39
  self._providers['dia'] = DiaTTSProvider
40
  logger.info("Registered Dia TTS provider")
41
  except ImportError as e:
42
- logger.debug(f"Dia TTS provider not available: {e}")
 
 
 
 
 
 
 
 
 
43
 
44
  # Try to register CosyVoice2 provider
45
  try:
@@ -68,10 +77,10 @@ class TTSProviderFactory:
68
  # Check if provider is available
69
  if self._provider_instances[name].is_available():
70
  available.append(name)
71
-
72
  except Exception as e:
73
  logger.warning(f"Failed to check availability of {name} provider: {e}")
74
-
75
  return available
76
 
77
  def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase:
@@ -94,9 +103,15 @@ class TTSProviderFactory:
94
  f"Unknown TTS provider: {provider_name}. Available providers: {available}"
95
  )
96
 
 
 
 
 
 
 
97
  try:
98
  provider_class = self._providers[provider_name]
99
-
100
  # Create instance with appropriate parameters
101
  if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
102
  lang_code = kwargs.get('lang_code', 'z')
@@ -133,7 +148,7 @@ class TTSProviderFactory:
133
  preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
134
 
135
  available_providers = self.get_available_providers()
136
-
137
  # Try preferred providers in order
138
  for provider_name in preferred_providers:
139
  if provider_name in available_providers:
@@ -177,7 +192,7 @@ class TTSProviderFactory:
177
  self._provider_instances[provider_name] = provider_class()
178
 
179
  provider = self._provider_instances[provider_name]
180
-
181
  return {
182
  "available": provider.is_available(),
183
  "name": provider.provider_name,
@@ -199,6 +214,6 @@ class TTSProviderFactory:
199
  provider._cleanup_temp_files()
200
  except Exception as e:
201
  logger.warning(f"Failed to cleanup provider {provider.provider_name}: {e}")
202
-
203
  self._provider_instances.clear()
204
  logger.info("Cleaned up TTS provider instances")
 
20
  def _register_default_providers(self):
21
  """Register all available TTS providers."""
22
  # Import providers dynamically to avoid import errors if dependencies are missing
23
+
24
  # Always register dummy provider as fallback
25
  from .dummy_provider import DummyTTSProvider
26
  self._providers['dummy'] = DummyTTSProvider
 
39
  self._providers['dia'] = DiaTTSProvider
40
  logger.info("Registered Dia TTS provider")
41
  except ImportError as e:
42
+ logger.warning(f"Dia TTS provider not available: {e}")
43
+ # Still register it so it can attempt installation later
44
+ try:
45
+ from .dia_provider import DiaTTSProvider
46
+ self._providers['dia'] = DiaTTSProvider
47
+ logger.info("Registered Dia TTS provider (dependencies may be installed on demand)")
48
+ except Exception:
49
+ logger.warning("Failed to register Dia TTS provider")
50
+ except Exception as e:
51
+ logger.warning(f"Failed to register Dia TTS provider: {e}")
52
 
53
  # Try to register CosyVoice2 provider
54
  try:
 
77
  # Check if provider is available
78
  if self._provider_instances[name].is_available():
79
  available.append(name)
80
+
81
  except Exception as e:
82
  logger.warning(f"Failed to check availability of {name} provider: {e}")
83
+
84
  return available
85
 
86
  def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase:
 
103
  f"Unknown TTS provider: {provider_name}. Available providers: {available}"
104
  )
105
 
106
+ # Check if provider is actually available before creating
107
+ available_providers = self.get_available_providers()
108
+ if provider_name not in available_providers:
109
+ logger.warning(f"TTS provider {provider_name} is registered but not available")
110
+ raise SpeechSynthesisException(f"TTS provider {provider_name} is not available")
111
+
112
  try:
113
  provider_class = self._providers[provider_name]
114
+
115
  # Create instance with appropriate parameters
116
  if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
117
  lang_code = kwargs.get('lang_code', 'z')
 
148
  preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
149
 
150
  available_providers = self.get_available_providers()
151
+
152
  # Try preferred providers in order
153
  for provider_name in preferred_providers:
154
  if provider_name in available_providers:
 
192
  self._provider_instances[provider_name] = provider_class()
193
 
194
  provider = self._provider_instances[provider_name]
195
+
196
  return {
197
  "available": provider.is_available(),
198
  "name": provider.provider_name,
 
214
  provider._cleanup_temp_files()
215
  except Exception as e:
216
  logger.warning(f"Failed to cleanup provider {provider.provider_name}: {e}")
217
+
218
  self._provider_instances.clear()
219
  logger.info("Cleaned up TTS provider instances")