File size: 8,201 Bytes
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""TTS provider factory for creating and managing TTS providers."""

import logging
from typing import Dict, List, Optional, Type
from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException

logger = logging.getLogger(__name__)


class TTSProviderFactory:
    """Factory for creating and managing TTS providers."""

    def __init__(self):
        """Initialize the TTS provider factory."""
        self._providers: Dict[str, Type[TTSProviderBase]] = {}
        self._provider_instances: Dict[str, TTSProviderBase] = {}
        self._register_default_providers()

    def _register_default_providers(self):
        """Register all available TTS providers."""
        # Import providers dynamically to avoid import errors if dependencies are missing
        
        # Always register dummy provider as fallback
        from .dummy_provider import DummyTTSProvider
        self._providers['dummy'] = DummyTTSProvider

        # Try to register Kokoro provider
        try:
            from .kokoro_provider import KokoroTTSProvider
            self._providers['kokoro'] = KokoroTTSProvider
            logger.info("Registered Kokoro TTS provider")
        except ImportError as e:
            logger.debug(f"Kokoro TTS provider not available: {e}")

        # Try to register Dia provider
        try:
            from .dia_provider import DiaTTSProvider
            self._providers['dia'] = DiaTTSProvider
            logger.info("Registered Dia TTS provider")
        except ImportError as e:
            logger.debug(f"Dia TTS provider not available: {e}")

        # Try to register CosyVoice2 provider
        try:
            from .cosyvoice2_provider import CosyVoice2TTSProvider
            self._providers['cosyvoice2'] = CosyVoice2TTSProvider
            logger.info("Registered CosyVoice2 TTS provider")
        except ImportError as e:
            logger.debug(f"CosyVoice2 TTS provider not available: {e}")

    def get_available_providers(self) -> List[str]:
        """Get list of available TTS providers."""
        available = []
        for name, provider_class in self._providers.items():
            try:
                # Create instance if not cached
                if name not in self._provider_instances:
                    if name == 'kokoro':
                        self._provider_instances[name] = provider_class()
                    elif name == 'dia':
                        self._provider_instances[name] = provider_class()
                    elif name == 'cosyvoice2':
                        self._provider_instances[name] = provider_class()
                    else:
                        self._provider_instances[name] = provider_class()

                # Check if provider is available
                if self._provider_instances[name].is_available():
                    available.append(name)
                    
            except Exception as e:
                logger.warning(f"Failed to check availability of {name} provider: {e}")
                
        return available

    def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase:
        """
        Create a TTS provider instance.

        Args:
            provider_name: Name of the provider to create
            **kwargs: Additional arguments for provider initialization

        Returns:
            TTSProviderBase: The created provider instance

        Raises:
            SpeechSynthesisException: If provider is not available or creation fails
        """
        if provider_name not in self._providers:
            available = list(self._providers.keys())
            raise SpeechSynthesisException(
                f"Unknown TTS provider: {provider_name}. Available providers: {available}"
            )

        try:
            provider_class = self._providers[provider_name]
            
            # Create instance with appropriate parameters
            if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
                lang_code = kwargs.get('lang_code', 'z')
                provider = provider_class(lang_code=lang_code)
            else:
                provider = provider_class(**kwargs)

            # Verify the provider is available
            if not provider.is_available():
                raise SpeechSynthesisException(f"TTS provider {provider_name} is not available")

            logger.info(f"Created TTS provider: {provider_name}")
            return provider

        except Exception as e:
            logger.error(f"Failed to create TTS provider {provider_name}: {e}")
            raise SpeechSynthesisException(f"Failed to create TTS provider {provider_name}: {e}") from e

    def get_provider_with_fallback(self, preferred_providers: List[str] = None, **kwargs) -> TTSProviderBase:
        """
        Get a TTS provider with fallback logic.

        Args:
            preferred_providers: List of preferred providers in order of preference
            **kwargs: Additional arguments for provider initialization

        Returns:
            TTSProviderBase: The first available provider

        Raises:
            SpeechSynthesisException: If no providers are available
        """
        if preferred_providers is None:
            preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']

        available_providers = self.get_available_providers()
        
        # Try preferred providers in order
        for provider_name in preferred_providers:
            if provider_name in available_providers:
                try:
                    return self.create_provider(provider_name, **kwargs)
                except Exception as e:
                    logger.warning(f"Failed to create preferred provider {provider_name}: {e}")
                    continue

        # If no preferred providers work, try any available provider
        for provider_name in available_providers:
            if provider_name not in preferred_providers:
                try:
                    return self.create_provider(provider_name, **kwargs)
                except Exception as e:
                    logger.warning(f"Failed to create fallback provider {provider_name}: {e}")
                    continue

        raise SpeechSynthesisException("No TTS providers are available")

    def get_provider_info(self, provider_name: str) -> Dict:
        """
        Get information about a specific provider.

        Args:
            provider_name: Name of the provider

        Returns:
            Dict: Provider information including availability and supported features
        """
        if provider_name not in self._providers:
            return {"available": False, "error": "Provider not registered"}

        try:
            # Create instance if not cached
            if provider_name not in self._provider_instances:
                provider_class = self._providers[provider_name]
                if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
                    self._provider_instances[provider_name] = provider_class()
                else:
                    self._provider_instances[provider_name] = provider_class()

            provider = self._provider_instances[provider_name]
            
            return {
                "available": provider.is_available(),
                "name": provider.provider_name,
                "supported_languages": provider.supported_languages,
                "available_voices": provider.get_available_voices() if provider.is_available() else []
            }

        except Exception as e:
            return {
                "available": False,
                "error": str(e)
            }

    def cleanup_providers(self):
        """Clean up provider instances and resources."""
        for provider in self._provider_instances.values():
            try:
                if hasattr(provider, '_cleanup_temp_files'):
                    provider._cleanup_temp_files()
            except Exception as e:
                logger.warning(f"Failed to cleanup provider {provider.provider_name}: {e}")
        
        self._provider_instances.clear()
        logger.info("Cleaned up TTS provider instances")