Spaces:
Build error
Build error
"""Factory for creating STT provider instances.""" | |
import logging | |
from typing import Dict, Type, Optional | |
from ..base.stt_provider_base import STTProviderBase | |
from .whisper_provider import WhisperSTTProvider | |
from .parakeet_provider import ParakeetSTTProvider | |
from ...domain.exceptions import SpeechRecognitionException | |
logger = logging.getLogger(__name__) | |
class STTProviderFactory: | |
"""Factory for creating STT provider instances with availability checking and fallback logic.""" | |
_providers: Dict[str, Type[STTProviderBase]] = { | |
"whisper": WhisperSTTProvider, | |
"parakeet": ParakeetSTTProvider | |
} | |
_fallback_order = ["whisper", "parakeet"] | |
def create_provider(cls, provider_name: str) -> STTProviderBase: | |
""" | |
Create an STT provider instance by name. | |
Args: | |
provider_name: Name of the provider to create | |
Returns: | |
STTProviderBase: The created provider instance | |
Raises: | |
SpeechRecognitionException: If provider is not available or creation fails | |
""" | |
provider_name = provider_name.lower() | |
if provider_name not in cls._providers: | |
raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}") | |
provider_class = cls._providers[provider_name] | |
try: | |
provider = provider_class() | |
if not provider.is_available(): | |
raise SpeechRecognitionException(f"STT provider {provider_name} is not available") | |
logger.info(f"Created STT provider: {provider_name}") | |
return provider | |
except Exception as e: | |
logger.error(f"Failed to create STT provider {provider_name}: {str(e)}") | |
raise SpeechRecognitionException(f"Failed to create STT provider {provider_name}: {str(e)}") from e | |
def create_provider_with_fallback(cls, preferred_provider: str) -> STTProviderBase: | |
""" | |
Create an STT provider with fallback to other available providers. | |
Args: | |
preferred_provider: The preferred provider name | |
Returns: | |
STTProviderBase: The created provider instance | |
Raises: | |
SpeechRecognitionException: If no providers are available | |
""" | |
# Try preferred provider first | |
try: | |
return cls.create_provider(preferred_provider) | |
except SpeechRecognitionException as e: | |
logger.warning(f"Preferred STT provider {preferred_provider} failed: {str(e)}") | |
# Try fallback providers | |
for provider_name in cls._fallback_order: | |
if provider_name.lower() == preferred_provider.lower(): | |
continue # Skip the preferred provider we already tried | |
try: | |
logger.info(f"Trying fallback STT provider: {provider_name}") | |
return cls.create_provider(provider_name) | |
except SpeechRecognitionException as e: | |
logger.warning(f"Fallback STT provider {provider_name} failed: {str(e)}") | |
continue | |
raise SpeechRecognitionException("No STT providers are available") | |
def get_available_providers(cls) -> list[str]: | |
""" | |
Get list of available STT providers. | |
Returns: | |
list[str]: List of available provider names | |
""" | |
available = [] | |
for provider_name, provider_class in cls._providers.items(): | |
try: | |
provider = provider_class() | |
if provider.is_available(): | |
available.append(provider_name) | |
except Exception as e: | |
logger.info(f"Provider {provider_name} not available: {str(e)}") | |
return available | |
def get_provider_info(cls, provider_name: str) -> Optional[dict]: | |
""" | |
Get information about a specific provider. | |
Args: | |
provider_name: Name of the provider | |
Returns: | |
Optional[dict]: Provider information or None if not found | |
""" | |
provider_name = provider_name.lower() | |
if provider_name not in cls._providers: | |
return None | |
provider_class = cls._providers[provider_name] | |
try: | |
provider = provider_class() | |
return { | |
"name": provider.provider_name, | |
"available": provider.is_available(), | |
"supported_languages": provider.supported_languages, | |
"available_models": provider.get_available_models() if provider.is_available() else [], | |
"default_model": provider.get_default_model() if provider.is_available() else None | |
} | |
except Exception as e: | |
logger.info(f"Failed to get info for provider {provider_name}: {str(e)}") | |
return { | |
"name": provider_name, | |
"available": False, | |
"error": str(e) | |
} | |
def register_provider(cls, name: str, provider_class: Type[STTProviderBase]) -> None: | |
""" | |
Register a new STT provider. | |
Args: | |
name: Name of the provider | |
provider_class: The provider class | |
""" | |
cls._providers[name.lower()] = provider_class | |
logger.info(f"Registered STT provider: {name}") | |
# Legacy compatibility - create an ASRFactory alias | |
class ASRFactory: | |
"""Legacy ASRFactory for backward compatibility.""" | |
def get_model(model_name: str = "parakeet") -> STTProviderBase: | |
""" | |
Get STT provider by model name (legacy interface). | |
Args: | |
model_name: Name of the model/provider to use | |
Returns: | |
STTProviderBase: The provider instance | |
""" | |
# Map legacy model names to provider names | |
provider_mapping = { | |
"whisper": "whisper", | |
"parakeet": "parakeet", | |
"faster-whisper": "whisper" | |
} | |
provider_name = provider_mapping.get(model_name.lower(), model_name.lower()) | |
try: | |
return STTProviderFactory.create_provider(provider_name) | |
except SpeechRecognitionException: | |
# Fallback to any available provider | |
logger.warning(f"Requested provider {provider_name} not available, using fallback") | |
return STTProviderFactory.create_provider_with_fallback(provider_name) |