Spaces:
Build error
Build error
| """TTS provider factory for creating and managing TTS providers.""" | |
| import logging | |
| from typing import Dict, List, Optional, Type | |
| from ..base.tts_provider_base import TTSProviderBase | |
| from ...domain.exceptions import SpeechSynthesisException | |
| logger = logging.getLogger(__name__) | |
| class TTSProviderFactory: | |
| """Factory for creating and managing TTS providers.""" | |
| def __init__(self): | |
| """Initialize the TTS provider factory.""" | |
| self._providers: Dict[str, Type[TTSProviderBase]] = {} | |
| self._provider_instances: Dict[str, TTSProviderBase] = {} | |
| self._register_default_providers() | |
| def _register_default_providers(self): | |
| """Register all available TTS providers.""" | |
| # Import providers dynamically to avoid import errors if dependencies are missing | |
| # Always register dummy provider as fallback | |
| from .dummy_provider import DummyTTSProvider | |
| self._providers['dummy'] = DummyTTSProvider | |
| # Try to register Kokoro provider | |
| try: | |
| from .kokoro_provider import KokoroTTSProvider | |
| self._providers['kokoro'] = KokoroTTSProvider | |
| logger.info("Registered Kokoro TTS provider") | |
| except ImportError as e: | |
| logger.debug(f"Kokoro TTS provider not available: {e}") | |
| # Try to register Dia provider | |
| try: | |
| from .dia_provider import DiaTTSProvider | |
| self._providers['dia'] = DiaTTSProvider | |
| logger.info("Registered Dia TTS provider") | |
| except ImportError as e: | |
| logger.debug(f"Dia TTS provider not available: {e}") | |
| # Try to register CosyVoice2 provider | |
| try: | |
| from .cosyvoice2_provider import CosyVoice2TTSProvider | |
| self._providers['cosyvoice2'] = CosyVoice2TTSProvider | |
| logger.info("Registered CosyVoice2 TTS provider") | |
| except ImportError as e: | |
| logger.debug(f"CosyVoice2 TTS provider not available: {e}") | |
| def get_available_providers(self) -> List[str]: | |
| """Get list of available TTS providers.""" | |
| available = [] | |
| for name, provider_class in self._providers.items(): | |
| try: | |
| # Create instance if not cached | |
| if name not in self._provider_instances: | |
| if name == 'kokoro': | |
| self._provider_instances[name] = provider_class() | |
| elif name == 'dia': | |
| self._provider_instances[name] = provider_class() | |
| elif name == 'cosyvoice2': | |
| self._provider_instances[name] = provider_class() | |
| else: | |
| self._provider_instances[name] = provider_class() | |
| # Check if provider is available | |
| if self._provider_instances[name].is_available(): | |
| available.append(name) | |
| except Exception as e: | |
| logger.warning(f"Failed to check availability of {name} provider: {e}") | |
| return available | |
| def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase: | |
| """ | |
| Create a TTS provider instance. | |
| Args: | |
| provider_name: Name of the provider to create | |
| **kwargs: Additional arguments for provider initialization | |
| Returns: | |
| TTSProviderBase: The created provider instance | |
| Raises: | |
| SpeechSynthesisException: If provider is not available or creation fails | |
| """ | |
| if provider_name not in self._providers: | |
| available = list(self._providers.keys()) | |
| raise SpeechSynthesisException( | |
| f"Unknown TTS provider: {provider_name}. Available providers: {available}" | |
| ) | |
| try: | |
| provider_class = self._providers[provider_name] | |
| # Create instance with appropriate parameters | |
| if provider_name in ['kokoro', 'dia', 'cosyvoice2']: | |
| lang_code = kwargs.get('lang_code', 'z') | |
| provider = provider_class(lang_code=lang_code) | |
| else: | |
| provider = provider_class(**kwargs) | |
| # Verify the provider is available | |
| if not provider.is_available(): | |
| raise SpeechSynthesisException(f"TTS provider {provider_name} is not available") | |
| logger.info(f"Created TTS provider: {provider_name}") | |
| return provider | |
| except Exception as e: | |
| logger.error(f"Failed to create TTS provider {provider_name}: {e}") | |
| raise SpeechSynthesisException(f"Failed to create TTS provider {provider_name}: {e}") from e | |
| def get_provider_with_fallback(self, preferred_providers: List[str] = None, **kwargs) -> TTSProviderBase: | |
| """ | |
| Get a TTS provider with fallback logic. | |
| Args: | |
| preferred_providers: List of preferred providers in order of preference | |
| **kwargs: Additional arguments for provider initialization | |
| Returns: | |
| TTSProviderBase: The first available provider | |
| Raises: | |
| SpeechSynthesisException: If no providers are available | |
| """ | |
| if preferred_providers is None: | |
| preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy'] | |
| available_providers = self.get_available_providers() | |
| # Try preferred providers in order | |
| for provider_name in preferred_providers: | |
| if provider_name in available_providers: | |
| try: | |
| return self.create_provider(provider_name, **kwargs) | |
| except Exception as e: | |
| logger.warning(f"Failed to create preferred provider {provider_name}: {e}") | |
| continue | |
| # If no preferred providers work, try any available provider | |
| for provider_name in available_providers: | |
| if provider_name not in preferred_providers: | |
| try: | |
| return self.create_provider(provider_name, **kwargs) | |
| except Exception as e: | |
| logger.warning(f"Failed to create fallback provider {provider_name}: {e}") | |
| continue | |
| raise SpeechSynthesisException("No TTS providers are available") | |
| def get_provider_info(self, provider_name: str) -> Dict: | |
| """ | |
| Get information about a specific provider. | |
| Args: | |
| provider_name: Name of the provider | |
| Returns: | |
| Dict: Provider information including availability and supported features | |
| """ | |
| if provider_name not in self._providers: | |
| return {"available": False, "error": "Provider not registered"} | |
| try: | |
| # Create instance if not cached | |
| if provider_name not in self._provider_instances: | |
| provider_class = self._providers[provider_name] | |
| if provider_name in ['kokoro', 'dia', 'cosyvoice2']: | |
| self._provider_instances[provider_name] = provider_class() | |
| else: | |
| self._provider_instances[provider_name] = provider_class() | |
| provider = self._provider_instances[provider_name] | |
| return { | |
| "available": provider.is_available(), | |
| "name": provider.provider_name, | |
| "supported_languages": provider.supported_languages, | |
| "available_voices": provider.get_available_voices() if provider.is_available() else [] | |
| } | |
| except Exception as e: | |
| return { | |
| "available": False, | |
| "error": str(e) | |
| } | |
| def cleanup_providers(self): | |
| """Clean up provider instances and resources.""" | |
| for provider in self._provider_instances.values(): | |
| try: | |
| if hasattr(provider, '_cleanup_temp_files'): | |
| provider._cleanup_temp_files() | |
| except Exception as e: | |
| logger.warning(f"Failed to cleanup provider {provider.provider_name}: {e}") | |
| self._provider_instances.clear() | |
| logger.info("Cleaned up TTS provider instances") |