Spaces:
Paused
Paused
| """ | |
| STT Provider Factory for Flare | |
| """ | |
| from typing import Optional | |
| from stt_interface import STTInterface, STTEngineType | |
| from logger import log_info, log_error, log_warning, log_debug | |
| from stt_google import GoogleCloudSTT | |
| from config_provider import ConfigProvider | |
| from stt_interface import STTInterface | |
| # Import providers conditionally | |
| stt_providers = {} | |
| try: | |
| from stt_google import GoogleCloudSTT | |
| stt_providers['google'] = GoogleCloudSTT | |
| except ImportError: | |
| log_info("⚠️ Google Cloud STT not available") | |
| try: | |
| from stt_azure import AzureSTT | |
| stt_providers['azure'] = AzureSTT | |
| except ImportError: | |
| log_error("⚠️ Azure STT not available") | |
| try: | |
| from stt_flicker import FlickerSTT | |
| stt_providers['flicker'] = FlickerSTT | |
| except ImportError: | |
| log_error("⚠️ Flicker STT not available") | |
| class NoSTT(STTInterface): | |
| """Dummy STT provider when STT is disabled""" | |
| async def start_streaming(self, config) -> None: | |
| pass | |
| async def stream_audio(self, audio_chunk: bytes): | |
| return | |
| yield # Make it a generator | |
| async def stop_streaming(self): | |
| return None | |
| def supports_realtime(self) -> bool: | |
| return False | |
| def get_supported_languages(self): | |
| return [] | |
| def get_provider_name(self) -> str: | |
| return "no_stt" | |
| class STTFactory: | |
| """Factory for creating STT providers""" | |
| def create_provider() -> Optional[STTInterface]: | |
| """Create STT provider based on configuration""" | |
| try: | |
| cfg = ConfigProvider.get() | |
| stt_engine = cfg.global_config.stt_engine | |
| log_info(f"🎤 Creating STT provider: {stt_engine}") | |
| if stt_engine == "no_stt": | |
| return NoSTT() | |
| # Get provider class | |
| provider_class = stt_providers.get(stt_engine) | |
| if not provider_class: | |
| log_info(f"⚠️ STT provider '{stt_engine}' not available") | |
| return NoSTT() | |
| # Get API key or credentials | |
| api_key = cfg.global_config.get_stt_api_key() | |
| if not api_key: | |
| log_info(f"⚠️ No API key configured for {stt_engine}") | |
| return NoSTT() | |
| # Create provider instance | |
| if stt_engine == "google": | |
| # For Google, api_key is the path to credentials JSON | |
| return provider_class(credentials_path=api_key) | |
| elif stt_engine == "azure": | |
| # For Azure, parse the key format | |
| parts = api_key.split('|') | |
| if len(parts) != 2: | |
| log_info("⚠️ Invalid Azure STT key format. Expected: subscription_key|region") | |
| return NoSTT() | |
| return provider_class(subscription_key=parts[0], region=parts[1]) | |
| elif stt_engine == "flicker": | |
| return provider_class(api_key=api_key) | |
| else: | |
| return provider_class(api_key=api_key) | |
| except Exception as e: | |
| log_error("❌ Failed to create STT provider", e) | |
| return NoSTT() | |
| def get_available_providers(): | |
| """Get list of available STT providers""" | |
| return list(stt_providers.keys()) + ["no_stt"] |