File size: 3,463 Bytes
c1b9e28
 
 
 
 
9f79da5
c1b9e28
9f79da5
c1b9e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f79da5
c1b9e28
 
9f79da5
c1b9e28
 
 
9f79da5
c1b9e28
 
9f79da5
c1b9e28
 
9f79da5
c1b9e28
 
9f79da5
c1b9e28
 
9f79da5
c1b9e28
 
9f79da5
c1b9e28
 
 
 
 
 
 
9f79da5
c1b9e28
9f79da5
c1b9e28
 
9f79da5
c1b9e28
 
 
 
 
9f79da5
c1b9e28
 
9f79da5
c1b9e28
 
 
9f79da5
c1b9e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f79da5
c1b9e28
 
 
9f79da5
c1b9e28
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""

STT Provider Factory for Flare

"""
from typing import Optional
from stt_interface import STTInterface, STTEngineType
from utils.logger import log_info, log_error, log_warning, log_debug
from stt_google import GoogleCloudSTT
from config.config_provider import ConfigProvider
from stt_interface import STTInterface

# Import providers conditionally
stt_providers = {}

try:
    from stt_google import GoogleCloudSTT
    stt_providers['google'] = GoogleCloudSTT
except ImportError:
    log_info("⚠️ Google Cloud STT not available")

try:
    from stt_azure import AzureSTT
    stt_providers['azure'] = AzureSTT
except ImportError:
    log_error("⚠️ Azure STT not available")

try:
    from stt_flicker import FlickerSTT
    stt_providers['flicker'] = FlickerSTT
except ImportError:
    log_error("⚠️ Flicker STT not available")

class NoSTT(STTInterface):
    """Dummy STT provider when STT is disabled"""

    async def start_streaming(self, config) -> None:
        pass

    async def stream_audio(self, audio_chunk: bytes):
        return
        yield  # Make it a generator

    async def stop_streaming(self):
        return None

    def supports_realtime(self) -> bool:
        return False

    def get_supported_languages(self):
        return []

    def get_provider_name(self) -> str:
        return "no_stt"

class STTFactory:
    """Factory for creating STT providers"""

    @staticmethod
    def create_provider() -> Optional[STTInterface]:
        """Create STT provider based on configuration"""
        try:
            cfg = ConfigProvider.get()
            stt_provider_config = cfg.global_config.stt_provider
            stt_engine = stt_provider_config.name

            log_info(f"🎤 Creating STT provider: {stt_engine}")

            if stt_engine == "no_stt":
                return NoSTT()

            # Get provider class
            provider_class = stt_providers.get(stt_engine)
            if not provider_class:
                log_warning(f"⚠️ STT provider '{stt_engine}' not available")
                return NoSTT()

            # Get API key or credentials
            api_key = stt_provider_config.api_key

            if not api_key:
                log_warning(f"⚠️ No API key configured for {stt_engine}")
                return NoSTT()

            # Create provider instance
            if stt_engine == "google":
                # For Google, api_key is the path to credentials JSON
                return provider_class(credentials_path=api_key)
            elif stt_engine == "azure":
                # For Azure, parse the key format
                parts = api_key.split('|')
                if len(parts) != 2:
                    log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
                    return NoSTT()
                return provider_class(subscription_key=parts[0], region=parts[1])
            elif stt_engine == "flicker":
                return provider_class(api_key=api_key)
            else:
                return provider_class(api_key=api_key)

        except Exception as e:
            log_error("❌ Failed to create STT provider", e)
            return NoSTT()

    @staticmethod
    def get_available_providers():
        """Get list of available STT providers"""
        return list(stt_providers.keys()) + ["no_stt"]