Michael Hu
add more logs
fdc056d
raw
history blame
6.44 kB
"""Factory for creating STT provider instances."""
import logging
from typing import Dict, Type, Optional
from ..base.stt_provider_base import STTProviderBase
from .whisper_provider import WhisperSTTProvider
from .parakeet_provider import ParakeetSTTProvider
from ...domain.exceptions import SpeechRecognitionException
logger = logging.getLogger(__name__)
class STTProviderFactory:
"""Factory for creating STT provider instances with availability checking and fallback logic."""
_providers: Dict[str, Type[STTProviderBase]] = {
"whisper": WhisperSTTProvider,
"parakeet": ParakeetSTTProvider
}
_fallback_order = ["whisper", "parakeet"]
@classmethod
def create_provider(cls, provider_name: str) -> STTProviderBase:
"""
Create an STT provider instance by name.
Args:
provider_name: Name of the provider to create
Returns:
STTProviderBase: The created provider instance
Raises:
SpeechRecognitionException: If provider is not available or creation fails
"""
provider_name = provider_name.lower()
if provider_name not in cls._providers:
raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}")
provider_class = cls._providers[provider_name]
try:
provider = provider_class()
if not provider.is_available():
raise SpeechRecognitionException(f"STT provider {provider_name} is not available")
logger.info(f"Created STT provider: {provider_name}")
return provider
except Exception as e:
logger.error(f"Failed to create STT provider {provider_name}: {str(e)}")
raise SpeechRecognitionException(f"Failed to create STT provider {provider_name}: {str(e)}") from e
@classmethod
def create_provider_with_fallback(cls, preferred_provider: str) -> STTProviderBase:
"""
Create an STT provider with fallback to other available providers.
Args:
preferred_provider: The preferred provider name
Returns:
STTProviderBase: The created provider instance
Raises:
SpeechRecognitionException: If no providers are available
"""
# Try preferred provider first
try:
return cls.create_provider(preferred_provider)
except SpeechRecognitionException as e:
logger.warning(f"Preferred STT provider {preferred_provider} failed: {str(e)}")
# Try fallback providers
for provider_name in cls._fallback_order:
if provider_name.lower() == preferred_provider.lower():
continue # Skip the preferred provider we already tried
try:
logger.info(f"Trying fallback STT provider: {provider_name}")
return cls.create_provider(provider_name)
except SpeechRecognitionException as e:
logger.warning(f"Fallback STT provider {provider_name} failed: {str(e)}")
continue
raise SpeechRecognitionException("No STT providers are available")
@classmethod
def get_available_providers(cls) -> list[str]:
"""
Get list of available STT providers.
Returns:
list[str]: List of available provider names
"""
available = []
for provider_name, provider_class in cls._providers.items():
try:
provider = provider_class()
if provider.is_available():
available.append(provider_name)
except Exception as e:
logger.info(f"Provider {provider_name} not available: {str(e)}")
return available
@classmethod
def get_provider_info(cls, provider_name: str) -> Optional[dict]:
"""
Get information about a specific provider.
Args:
provider_name: Name of the provider
Returns:
Optional[dict]: Provider information or None if not found
"""
provider_name = provider_name.lower()
if provider_name not in cls._providers:
return None
provider_class = cls._providers[provider_name]
try:
provider = provider_class()
return {
"name": provider.provider_name,
"available": provider.is_available(),
"supported_languages": provider.supported_languages,
"available_models": provider.get_available_models() if provider.is_available() else [],
"default_model": provider.get_default_model() if provider.is_available() else None
}
except Exception as e:
logger.info(f"Failed to get info for provider {provider_name}: {str(e)}")
return {
"name": provider_name,
"available": False,
"error": str(e)
}
@classmethod
def register_provider(cls, name: str, provider_class: Type[STTProviderBase]) -> None:
"""
Register a new STT provider.
Args:
name: Name of the provider
provider_class: The provider class
"""
cls._providers[name.lower()] = provider_class
logger.info(f"Registered STT provider: {name}")
# Legacy compatibility - create an ASRFactory alias
class ASRFactory:
"""Legacy ASRFactory for backward compatibility."""
@staticmethod
def get_model(model_name: str = "parakeet") -> STTProviderBase:
"""
Get STT provider by model name (legacy interface).
Args:
model_name: Name of the model/provider to use
Returns:
STTProviderBase: The provider instance
"""
# Map legacy model names to provider names
provider_mapping = {
"whisper": "whisper",
"parakeet": "parakeet",
"faster-whisper": "whisper"
}
provider_name = provider_mapping.get(model_name.lower(), model_name.lower())
try:
return STTProviderFactory.create_provider(provider_name)
except SpeechRecognitionException:
# Fallback to any available provider
logger.warning(f"Requested provider {provider_name} not available, using fallback")
return STTProviderFactory.create_provider_with_fallback(provider_name)