ciyidogan commited on
Commit
4846929
·
verified ·
1 Parent(s): bdb6ec7

Update stt/stt_factory.py

Browse files
Files changed (1) hide show
  1. stt/stt_factory.py +104 -105
stt/stt_factory.py CHANGED
@@ -1,106 +1,105 @@
1
- """
2
- STT Provider Factory for Flare
3
- """
4
- from typing import Optional
5
- from stt_interface import STTInterface, STTEngineType
6
- from utils.logger import log_info, log_error, log_warning, log_debug
7
- from stt_google import GoogleCloudSTT
8
- from config.config_provider import ConfigProvider
9
- from stt_interface import STTInterface
10
-
11
- # Import providers conditionally
12
- stt_providers = {}
13
-
14
- try:
15
- from stt_google import GoogleCloudSTT
16
- stt_providers['google'] = GoogleCloudSTT
17
- except ImportError:
18
- log_info("⚠️ Google Cloud STT not available")
19
-
20
- try:
21
- from stt_azure import AzureSTT
22
- stt_providers['azure'] = AzureSTT
23
- except ImportError:
24
- log_error("⚠️ Azure STT not available")
25
-
26
- try:
27
- from stt_flicker import FlickerSTT
28
- stt_providers['flicker'] = FlickerSTT
29
- except ImportError:
30
- log_error("⚠️ Flicker STT not available")
31
-
32
- class NoSTT(STTInterface):
33
- """Dummy STT provider when STT is disabled"""
34
-
35
- async def start_streaming(self, config) -> None:
36
- pass
37
-
38
- async def stream_audio(self, audio_chunk: bytes):
39
- return
40
- yield # Make it a generator
41
-
42
- async def stop_streaming(self):
43
- return None
44
-
45
- def supports_realtime(self) -> bool:
46
- return False
47
-
48
- def get_supported_languages(self):
49
- return []
50
-
51
- def get_provider_name(self) -> str:
52
- return "no_stt"
53
-
54
- class STTFactory:
55
- """Factory for creating STT providers"""
56
-
57
- @staticmethod
58
- def create_provider() -> Optional[STTInterface]:
59
- """Create STT provider based on configuration"""
60
- try:
61
- cfg = ConfigProvider.get()
62
- stt_provider_config = cfg.global_config.stt_provider
63
- stt_engine = stt_provider_config.name
64
-
65
- log_info(f"🎤 Creating STT provider: {stt_engine}")
66
-
67
- if stt_engine == "no_stt":
68
- return NoSTT()
69
-
70
- # Get provider class
71
- provider_class = stt_providers.get(stt_engine)
72
- if not provider_class:
73
- log_warning(f"⚠️ STT provider '{stt_engine}' not available")
74
- return NoSTT()
75
-
76
- # Get API key or credentials
77
- api_key = stt_provider_config.api_key
78
-
79
- if not api_key:
80
- log_warning(f"⚠️ No API key configured for {stt_engine}")
81
- return NoSTT()
82
-
83
- # Create provider instance
84
- if stt_engine == "google":
85
- # For Google, api_key is the path to credentials JSON
86
- return provider_class(credentials_path=api_key)
87
- elif stt_engine == "azure":
88
- # For Azure, parse the key format
89
- parts = api_key.split('|')
90
- if len(parts) != 2:
91
- log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
92
- return NoSTT()
93
- return provider_class(subscription_key=parts[0], region=parts[1])
94
- elif stt_engine == "flicker":
95
- return provider_class(api_key=api_key)
96
- else:
97
- return provider_class(api_key=api_key)
98
-
99
- except Exception as e:
100
- log_error("❌ Failed to create STT provider", e)
101
- return NoSTT()
102
-
103
- @staticmethod
104
- def get_available_providers():
105
- """Get list of available STT providers"""
106
  return list(stt_providers.keys()) + ["no_stt"]
 
1
+ """
2
+ STT Provider Factory for Flare
3
+ """
4
+ from typing import Optional
5
+ from stt_interface import STTInterface, STTEngineType
6
+ from utils.logger import log_info, log_error, log_warning, log_debug
7
+ from stt_google import GoogleCloudSTT
8
+ from config.config_provider import ConfigProvider
9
+
10
+ # Import providers conditionally
11
+ stt_providers = {}
12
+
13
+ try:
14
+ from stt_google import GoogleCloudSTT
15
+ stt_providers['google'] = GoogleCloudSTT
16
+ except ImportError:
17
+ log_info("⚠️ Google Cloud STT not available")
18
+
19
+ try:
20
+ from stt_azure import AzureSTT
21
+ stt_providers['azure'] = AzureSTT
22
+ except ImportError:
23
+ log_error("⚠️ Azure STT not available")
24
+
25
+ try:
26
+ from stt_flicker import FlickerSTT
27
+ stt_providers['flicker'] = FlickerSTT
28
+ except ImportError:
29
+ log_error("⚠️ Flicker STT not available")
30
+
31
+ class NoSTT(STTInterface):
32
+ """Dummy STT provider when STT is disabled"""
33
+
34
+ async def start_streaming(self, config) -> None:
35
+ pass
36
+
37
+ async def stream_audio(self, audio_chunk: bytes):
38
+ return
39
+ yield # Make it a generator
40
+
41
+ async def stop_streaming(self):
42
+ return None
43
+
44
+ def supports_realtime(self) -> bool:
45
+ return False
46
+
47
+ def get_supported_languages(self):
48
+ return []
49
+
50
+ def get_provider_name(self) -> str:
51
+ return "no_stt"
52
+
53
+ class STTFactory:
54
+ """Factory for creating STT providers"""
55
+
56
+ @staticmethod
57
+ def create_provider() -> Optional[STTInterface]:
58
+ """Create STT provider based on configuration"""
59
+ try:
60
+ cfg = ConfigProvider.get()
61
+ stt_provider_config = cfg.global_config.stt_provider
62
+ stt_engine = stt_provider_config.name
63
+
64
+ log_info(f"🎤 Creating STT provider: {stt_engine}")
65
+
66
+ if stt_engine == "no_stt":
67
+ return NoSTT()
68
+
69
+ # Get provider class
70
+ provider_class = stt_providers.get(stt_engine)
71
+ if not provider_class:
72
+ log_warning(f"⚠️ STT provider '{stt_engine}' not available")
73
+ return NoSTT()
74
+
75
+ # Get API key or credentials
76
+ api_key = stt_provider_config.api_key
77
+
78
+ if not api_key:
79
+ log_warning(f"⚠️ No API key configured for {stt_engine}")
80
+ return NoSTT()
81
+
82
+ # Create provider instance
83
+ if stt_engine == "google":
84
+ # For Google, api_key is the path to credentials JSON
85
+ return provider_class(credentials_path=api_key)
86
+ elif stt_engine == "azure":
87
+ # For Azure, parse the key format
88
+ parts = api_key.split('|')
89
+ if len(parts) != 2:
90
+ log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
91
+ return NoSTT()
92
+ return provider_class(subscription_key=parts[0], region=parts[1])
93
+ elif stt_engine == "flicker":
94
+ return provider_class(api_key=api_key)
95
+ else:
96
+ return provider_class(api_key=api_key)
97
+
98
+ except Exception as e:
99
+ log_error("❌ Failed to create STT provider", e)
100
+ return NoSTT()
101
+
102
+ @staticmethod
103
+ def get_available_providers():
104
+ """Get list of available STT providers"""
 
105
  return list(stt_providers.keys()) + ["no_stt"]