File size: 3,324 Bytes
4846929
 
 
 
8a08517
4846929
8a08517
4846929
 
 
 
 
 
74c6c9c
4846929
 
 
 
 
74c6c9c
4846929
 
 
 
 
74c6c9c
4846929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1b9e28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
STT Provider Factory for Flare
"""
from typing import Optional
from .stt_interface import STTInterface, STTEngineType
from utils.logger import log_info, log_error, log_warning, log_debug
from .stt_google import GoogleCloudSTT
from config.config_provider import ConfigProvider

# Import providers conditionally
stt_providers = {}

try:
    from .stt_google import GoogleCloudSTT
    stt_providers['google'] = GoogleCloudSTT
except ImportError:
    log_info("⚠️ Google Cloud STT not available")

try:
    from .stt_azure import AzureSTT
    stt_providers['azure'] = AzureSTT
except ImportError:
    log_error("⚠️ Azure STT not available")

try:
    from .stt_flicker import FlickerSTT
    stt_providers['flicker'] = FlickerSTT
except ImportError:
    log_error("⚠️ Flicker STT not available")

class NoSTT(STTInterface):
    """Dummy STT provider when STT is disabled"""

    async def start_streaming(self, config) -> None:
        pass

    async def stream_audio(self, audio_chunk: bytes):
        return
        yield  # Make it a generator

    async def stop_streaming(self):
        return None

    def supports_realtime(self) -> bool:
        return False

    def get_supported_languages(self):
        return []

    def get_provider_name(self) -> str:
        return "no_stt"

class STTFactory:
    """Factory for creating STT providers"""

    @staticmethod
    def create_provider() -> Optional[STTInterface]:
        """Create STT provider based on configuration"""
        try:
            cfg = ConfigProvider.get()
            stt_provider_config = cfg.global_config.stt_provider
            stt_engine = stt_provider_config.name

            log_info(f"🎤 Creating STT provider: {stt_engine}")

            if stt_engine == "no_stt":
                return NoSTT()

            # Get provider class
            provider_class = stt_providers.get(stt_engine)
            if not provider_class:
                log_warning(f"⚠️ STT provider '{stt_engine}' not available")
                return NoSTT()

            # Get API key or credentials
            api_key = stt_provider_config.api_key

            if not api_key:
                log_warning(f"⚠️ No API key configured for {stt_engine}")
                return NoSTT()

            # Create provider instance
            if stt_engine == "google":
                # For Google, api_key is the path to credentials JSON
                return provider_class(credentials_path=api_key)
            elif stt_engine == "azure":
                # For Azure, parse the key format
                parts = api_key.split('|')
                if len(parts) != 2:
                    log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
                    return NoSTT()
                return provider_class(subscription_key=parts[0], region=parts[1])
            elif stt_engine == "flicker":
                return provider_class(api_key=api_key)
            else:
                return provider_class(api_key=api_key)

        except Exception as e:
            log_error("❌ Failed to create STT provider", e)
            return NoSTT()

    @staticmethod
    def get_available_providers():
        """Get list of available STT providers"""
        return list(stt_providers.keys()) + ["no_stt"]