Michael Hu commited on
Commit
48ba7e8
Β·
1 Parent(s): 9626844

Create provider factories with dependency injection

Browse files
src/infrastructure/config/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Configuration and dependency injection module."""
2
+
3
+ from .dependency_container import DependencyContainer
4
+ from .app_config import AppConfig
5
+
6
+ __all__ = ['DependencyContainer', 'AppConfig']
src/infrastructure/config/app_config.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Application configuration management."""
2
+
3
+ import os
4
+ import logging
5
+ from typing import Dict, List, Optional, Any
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ @dataclass
13
+ class TTSConfig:
14
+ """Configuration for TTS providers."""
15
+ preferred_providers: List[str] = field(default_factory=lambda: ['kokoro', 'dia', 'cosyvoice2', 'dummy'])
16
+ default_voice: str = 'default'
17
+ default_speed: float = 1.0
18
+ default_language: str = 'en'
19
+ enable_streaming: bool = True
20
+ max_text_length: int = 10000
21
+
22
+
23
+ @dataclass
24
+ class STTConfig:
25
+ """Configuration for STT providers."""
26
+ preferred_providers: List[str] = field(default_factory=lambda: ['whisper', 'parakeet'])
27
+ default_model: str = 'whisper'
28
+ chunk_length_s: int = 30
29
+ batch_size: int = 16
30
+ enable_vad: bool = True
31
+
32
+
33
+ @dataclass
34
+ class TranslationConfig:
35
+ """Configuration for translation providers."""
36
+ default_provider: str = 'nllb'
37
+ model_name: str = 'facebook/nllb-200-3.3B'
38
+ max_chunk_length: int = 1000
39
+ batch_size: int = 8
40
+ cache_translations: bool = True
41
+
42
+
43
+ @dataclass
44
+ class ProcessingConfig:
45
+ """Configuration for audio processing pipeline."""
46
+ temp_dir: str = '/tmp/audio_processing'
47
+ cleanup_temp_files: bool = True
48
+ max_file_size_mb: int = 100
49
+ supported_audio_formats: List[str] = field(default_factory=lambda: ['wav', 'mp3', 'flac', 'ogg'])
50
+ processing_timeout_seconds: int = 300
51
+
52
+
53
+ @dataclass
54
+ class LoggingConfig:
55
+ """Configuration for logging."""
56
+ level: str = 'INFO'
57
+ format: str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
58
+ enable_file_logging: bool = False
59
+ log_file_path: str = 'app.log'
60
+ max_log_file_size_mb: int = 10
61
+ backup_count: int = 5
62
+
63
+
64
+ class AppConfig:
65
+ """Centralized application configuration management."""
66
+
67
+ def __init__(self, config_file: Optional[str] = None):
68
+ """
69
+ Initialize application configuration.
70
+
71
+ Args:
72
+ config_file: Optional path to configuration file
73
+ """
74
+ self.config_file = config_file
75
+ self._config_data: Dict[str, Any] = {}
76
+
77
+ # Initialize configuration sections
78
+ self.tts = TTSConfig()
79
+ self.stt = STTConfig()
80
+ self.translation = TranslationConfig()
81
+ self.processing = ProcessingConfig()
82
+ self.logging = LoggingConfig()
83
+
84
+ # Load configuration
85
+ self._load_configuration()
86
+
87
+ def _load_configuration(self) -> None:
88
+ """Load configuration from environment variables and config file."""
89
+ try:
90
+ # Load from environment variables first
91
+ self._load_from_environment()
92
+
93
+ # Load from config file if provided
94
+ if self.config_file and os.path.exists(self.config_file):
95
+ self._load_from_file()
96
+
97
+ # Validate configuration
98
+ self._validate_configuration()
99
+
100
+ logger.info("Configuration loaded successfully")
101
+
102
+ except Exception as e:
103
+ logger.error(f"Failed to load configuration: {e}")
104
+ # Use default configuration
105
+ logger.info("Using default configuration")
106
+
107
+ def _load_from_environment(self) -> None:
108
+ """Load configuration from environment variables."""
109
+ # TTS Configuration
110
+ if os.getenv('TTS_PREFERRED_PROVIDERS'):
111
+ self.tts.preferred_providers = os.getenv('TTS_PREFERRED_PROVIDERS').split(',')
112
+ if os.getenv('TTS_DEFAULT_VOICE'):
113
+ self.tts.default_voice = os.getenv('TTS_DEFAULT_VOICE')
114
+ if os.getenv('TTS_DEFAULT_SPEED'):
115
+ self.tts.default_speed = float(os.getenv('TTS_DEFAULT_SPEED'))
116
+ if os.getenv('TTS_DEFAULT_LANGUAGE'):
117
+ self.tts.default_language = os.getenv('TTS_DEFAULT_LANGUAGE')
118
+ if os.getenv('TTS_ENABLE_STREAMING'):
119
+ self.tts.enable_streaming = os.getenv('TTS_ENABLE_STREAMING').lower() == 'true'
120
+
121
+ # STT Configuration
122
+ if os.getenv('STT_PREFERRED_PROVIDERS'):
123
+ self.stt.preferred_providers = os.getenv('STT_PREFERRED_PROVIDERS').split(',')
124
+ if os.getenv('STT_DEFAULT_MODEL'):
125
+ self.stt.default_model = os.getenv('STT_DEFAULT_MODEL')
126
+ if os.getenv('STT_CHUNK_LENGTH'):
127
+ self.stt.chunk_length_s = int(os.getenv('STT_CHUNK_LENGTH'))
128
+ if os.getenv('STT_BATCH_SIZE'):
129
+ self.stt.batch_size = int(os.getenv('STT_BATCH_SIZE'))
130
+
131
+ # Translation Configuration
132
+ if os.getenv('TRANSLATION_DEFAULT_PROVIDER'):
133
+ self.translation.default_provider = os.getenv('TRANSLATION_DEFAULT_PROVIDER')
134
+ if os.getenv('TRANSLATION_MODEL_NAME'):
135
+ self.translation.model_name = os.getenv('TRANSLATION_MODEL_NAME')
136
+ if os.getenv('TRANSLATION_MAX_CHUNK_LENGTH'):
137
+ self.translation.max_chunk_length = int(os.getenv('TRANSLATION_MAX_CHUNK_LENGTH'))
138
+
139
+ # Processing Configuration
140
+ if os.getenv('PROCESSING_TEMP_DIR'):
141
+ self.processing.temp_dir = os.getenv('PROCESSING_TEMP_DIR')
142
+ if os.getenv('PROCESSING_CLEANUP_TEMP_FILES'):
143
+ self.processing.cleanup_temp_files = os.getenv('PROCESSING_CLEANUP_TEMP_FILES').lower() == 'true'
144
+ if os.getenv('PROCESSING_MAX_FILE_SIZE_MB'):
145
+ self.processing.max_file_size_mb = int(os.getenv('PROCESSING_MAX_FILE_SIZE_MB'))
146
+ if os.getenv('PROCESSING_TIMEOUT_SECONDS'):
147
+ self.processing.processing_timeout_seconds = int(os.getenv('PROCESSING_TIMEOUT_SECONDS'))
148
+
149
+ # Logging Configuration
150
+ if os.getenv('LOG_LEVEL'):
151
+ self.logging.level = os.getenv('LOG_LEVEL')
152
+ if os.getenv('LOG_FORMAT'):
153
+ self.logging.format = os.getenv('LOG_FORMAT')
154
+ if os.getenv('LOG_FILE_PATH'):
155
+ self.logging.log_file_path = os.getenv('LOG_FILE_PATH')
156
+
157
+ def _load_from_file(self) -> None:
158
+ """Load configuration from file (JSON or YAML)."""
159
+ try:
160
+ import json
161
+
162
+ with open(self.config_file, 'r') as f:
163
+ if self.config_file.endswith('.json'):
164
+ self._config_data = json.load(f)
165
+ elif self.config_file.endswith(('.yml', '.yaml')):
166
+ try:
167
+ import yaml
168
+ self._config_data = yaml.safe_load(f)
169
+ except ImportError:
170
+ logger.warning("PyYAML not installed, cannot load YAML config file")
171
+ return
172
+ else:
173
+ logger.warning(f"Unsupported config file format: {self.config_file}")
174
+ return
175
+
176
+ # Apply configuration from file
177
+ self._apply_config_data()
178
+
179
+ except Exception as e:
180
+ logger.error(f"Failed to load config file {self.config_file}: {e}")
181
+
182
+ def _apply_config_data(self) -> None:
183
+ """Apply configuration data from loaded file."""
184
+ if not self._config_data:
185
+ return
186
+
187
+ # Apply TTS configuration
188
+ tts_config = self._config_data.get('tts', {})
189
+ if 'preferred_providers' in tts_config:
190
+ self.tts.preferred_providers = tts_config['preferred_providers']
191
+ if 'default_voice' in tts_config:
192
+ self.tts.default_voice = tts_config['default_voice']
193
+ if 'default_speed' in tts_config:
194
+ self.tts.default_speed = tts_config['default_speed']
195
+ if 'default_language' in tts_config:
196
+ self.tts.default_language = tts_config['default_language']
197
+ if 'enable_streaming' in tts_config:
198
+ self.tts.enable_streaming = tts_config['enable_streaming']
199
+
200
+ # Apply STT configuration
201
+ stt_config = self._config_data.get('stt', {})
202
+ if 'preferred_providers' in stt_config:
203
+ self.stt.preferred_providers = stt_config['preferred_providers']
204
+ if 'default_model' in stt_config:
205
+ self.stt.default_model = stt_config['default_model']
206
+ if 'chunk_length_s' in stt_config:
207
+ self.stt.chunk_length_s = stt_config['chunk_length_s']
208
+ if 'batch_size' in stt_config:
209
+ self.stt.batch_size = stt_config['batch_size']
210
+
211
+ # Apply Translation configuration
212
+ translation_config = self._config_data.get('translation', {})
213
+ if 'default_provider' in translation_config:
214
+ self.translation.default_provider = translation_config['default_provider']
215
+ if 'model_name' in translation_config:
216
+ self.translation.model_name = translation_config['model_name']
217
+ if 'max_chunk_length' in translation_config:
218
+ self.translation.max_chunk_length = translation_config['max_chunk_length']
219
+
220
+ # Apply Processing configuration
221
+ processing_config = self._config_data.get('processing', {})
222
+ if 'temp_dir' in processing_config:
223
+ self.processing.temp_dir = processing_config['temp_dir']
224
+ if 'cleanup_temp_files' in processing_config:
225
+ self.processing.cleanup_temp_files = processing_config['cleanup_temp_files']
226
+ if 'max_file_size_mb' in processing_config:
227
+ self.processing.max_file_size_mb = processing_config['max_file_size_mb']
228
+ if 'processing_timeout_seconds' in processing_config:
229
+ self.processing.processing_timeout_seconds = processing_config['processing_timeout_seconds']
230
+
231
+ # Apply Logging configuration
232
+ logging_config = self._config_data.get('logging', {})
233
+ if 'level' in logging_config:
234
+ self.logging.level = logging_config['level']
235
+ if 'format' in logging_config:
236
+ self.logging.format = logging_config['format']
237
+ if 'log_file_path' in logging_config:
238
+ self.logging.log_file_path = logging_config['log_file_path']
239
+
240
+ def _validate_configuration(self) -> None:
241
+ """Validate configuration values."""
242
+ # Validate TTS configuration
243
+ if not (0.1 <= self.tts.default_speed <= 3.0):
244
+ logger.warning(f"Invalid TTS speed {self.tts.default_speed}, using default 1.0")
245
+ self.tts.default_speed = 1.0
246
+
247
+ if self.tts.max_text_length <= 0:
248
+ logger.warning(f"Invalid max text length {self.tts.max_text_length}, using default 10000")
249
+ self.tts.max_text_length = 10000
250
+
251
+ # Validate STT configuration
252
+ if self.stt.chunk_length_s <= 0:
253
+ logger.warning(f"Invalid chunk length {self.stt.chunk_length_s}, using default 30")
254
+ self.stt.chunk_length_s = 30
255
+
256
+ if self.stt.batch_size <= 0:
257
+ logger.warning(f"Invalid batch size {self.stt.batch_size}, using default 16")
258
+ self.stt.batch_size = 16
259
+
260
+ # Validate Translation configuration
261
+ if self.translation.max_chunk_length <= 0:
262
+ logger.warning(f"Invalid max chunk length {self.translation.max_chunk_length}, using default 1000")
263
+ self.translation.max_chunk_length = 1000
264
+
265
+ # Validate Processing configuration
266
+ if self.processing.max_file_size_mb <= 0:
267
+ logger.warning(f"Invalid max file size {self.processing.max_file_size_mb}, using default 100")
268
+ self.processing.max_file_size_mb = 100
269
+
270
+ if self.processing.processing_timeout_seconds <= 0:
271
+ logger.warning(f"Invalid timeout {self.processing.processing_timeout_seconds}, using default 300")
272
+ self.processing.processing_timeout_seconds = 300
273
+
274
+ # Ensure temp directory exists
275
+ try:
276
+ Path(self.processing.temp_dir).mkdir(parents=True, exist_ok=True)
277
+ except Exception as e:
278
+ logger.warning(f"Failed to create temp directory {self.processing.temp_dir}: {e}")
279
+ self.processing.temp_dir = '/tmp/audio_processing'
280
+ Path(self.processing.temp_dir).mkdir(parents=True, exist_ok=True)
281
+
282
+ def get_tts_config(self) -> Dict[str, Any]:
283
+ """Get TTS configuration as dictionary."""
284
+ return {
285
+ 'preferred_providers': self.tts.preferred_providers,
286
+ 'default_voice': self.tts.default_voice,
287
+ 'default_speed': self.tts.default_speed,
288
+ 'default_language': self.tts.default_language,
289
+ 'enable_streaming': self.tts.enable_streaming,
290
+ 'max_text_length': self.tts.max_text_length
291
+ }
292
+
293
+ def get_stt_config(self) -> Dict[str, Any]:
294
+ """Get STT configuration as dictionary."""
295
+ return {
296
+ 'preferred_providers': self.stt.preferred_providers,
297
+ 'default_model': self.stt.default_model,
298
+ 'chunk_length_s': self.stt.chunk_length_s,
299
+ 'batch_size': self.stt.batch_size,
300
+ 'enable_vad': self.stt.enable_vad
301
+ }
302
+
303
+ def get_translation_config(self) -> Dict[str, Any]:
304
+ """Get translation configuration as dictionary."""
305
+ return {
306
+ 'default_provider': self.translation.default_provider,
307
+ 'model_name': self.translation.model_name,
308
+ 'max_chunk_length': self.translation.max_chunk_length,
309
+ 'batch_size': self.translation.batch_size,
310
+ 'cache_translations': self.translation.cache_translations
311
+ }
312
+
313
+ def get_processing_config(self) -> Dict[str, Any]:
314
+ """Get processing configuration as dictionary."""
315
+ return {
316
+ 'temp_dir': self.processing.temp_dir,
317
+ 'cleanup_temp_files': self.processing.cleanup_temp_files,
318
+ 'max_file_size_mb': self.processing.max_file_size_mb,
319
+ 'supported_audio_formats': self.processing.supported_audio_formats,
320
+ 'processing_timeout_seconds': self.processing.processing_timeout_seconds
321
+ }
322
+
323
+ def get_logging_config(self) -> Dict[str, Any]:
324
+ """Get logging configuration as dictionary."""
325
+ return {
326
+ 'level': self.logging.level,
327
+ 'format': self.logging.format,
328
+ 'enable_file_logging': self.logging.enable_file_logging,
329
+ 'log_file_path': self.logging.log_file_path,
330
+ 'max_log_file_size_mb': self.logging.max_log_file_size_mb,
331
+ 'backup_count': self.logging.backup_count
332
+ }
333
+
334
+ def reload_configuration(self) -> None:
335
+ """Reload configuration from sources."""
336
+ logger.info("Reloading configuration")
337
+ self._load_configuration()
338
+
339
+ def save_configuration(self, output_file: str) -> None:
340
+ """
341
+ Save current configuration to file.
342
+
343
+ Args:
344
+ output_file: Path to output configuration file
345
+ """
346
+ try:
347
+ config_dict = {
348
+ 'tts': self.get_tts_config(),
349
+ 'stt': self.get_stt_config(),
350
+ 'translation': self.get_translation_config(),
351
+ 'processing': self.get_processing_config(),
352
+ 'logging': self.get_logging_config()
353
+ }
354
+
355
+ import json
356
+ with open(output_file, 'w') as f:
357
+ json.dump(config_dict, f, indent=2)
358
+
359
+ logger.info(f"Configuration saved to {output_file}")
360
+
361
+ except Exception as e:
362
+ logger.error(f"Failed to save configuration to {output_file}: {e}")
363
+ raise
364
+
365
+ def __str__(self) -> str:
366
+ """String representation of configuration."""
367
+ return (
368
+ f"AppConfig(\n"
369
+ f" TTS: {self.tts}\n"
370
+ f" STT: {self.stt}\n"
371
+ f" Translation: {self.translation}\n"
372
+ f" Processing: {self.processing}\n"
373
+ f" Logging: {self.logging}\n"
374
+ f")"
375
+ )
src/infrastructure/config/dependency_container.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dependency injection container for managing component lifecycles."""
2
+
3
+ import logging
4
+ from typing import Dict, Any, Optional, TypeVar, Type, Callable, Union
5
+ from enum import Enum
6
+ from threading import Lock
7
+ import weakref
8
+
9
+ from .app_config import AppConfig
10
+ from ..tts.provider_factory import TTSProviderFactory
11
+ from ..stt.provider_factory import STTProviderFactory
12
+ from ..translation.provider_factory import TranslationProviderFactory, TranslationProviderType
13
+ from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService
14
+ from ...domain.interfaces.speech_recognition import ISpeechRecognitionService
15
+ from ...domain.interfaces.translation import ITranslationService
16
+ from ...domain.interfaces.audio_processing import IAudioProcessingService
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ T = TypeVar('T')
21
+
22
+
23
+ class ServiceLifetime(Enum):
24
+ """Service lifetime management options."""
25
+ SINGLETON = "singleton"
26
+ TRANSIENT = "transient"
27
+ SCOPED = "scoped"
28
+
29
+
30
+ class ServiceDescriptor:
31
+ """Describes how a service should be created and managed."""
32
+
33
+ def __init__(
34
+ self,
35
+ service_type: Type[T],
36
+ implementation: Union[Type[T], Callable[..., T]],
37
+ lifetime: ServiceLifetime = ServiceLifetime.TRANSIENT,
38
+ factory_args: Optional[Dict[str, Any]] = None
39
+ ):
40
+ """
41
+ Initialize service descriptor.
42
+
43
+ Args:
44
+ service_type: The service interface type
45
+ implementation: The implementation class or factory function
46
+ lifetime: Service lifetime management
47
+ factory_args: Arguments to pass to the factory/constructor
48
+ """
49
+ self.service_type = service_type
50
+ self.implementation = implementation
51
+ self.lifetime = lifetime
52
+ self.factory_args = factory_args or {}
53
+
54
+
55
+ class DependencyContainer:
56
+ """Dependency injection container for managing component lifecycles."""
57
+
58
+ def __init__(self, config: Optional[AppConfig] = None):
59
+ """
60
+ Initialize the dependency container.
61
+
62
+ Args:
63
+ config: Application configuration instance
64
+ """
65
+ self._config = config or AppConfig()
66
+ self._services: Dict[Type, ServiceDescriptor] = {}
67
+ self._singletons: Dict[Type, Any] = {}
68
+ self._scoped_instances: Dict[Type, Any] = {}
69
+ self._lock = Lock()
70
+
71
+ # Provider factories
72
+ self._tts_factory: Optional[TTSProviderFactory] = None
73
+ self._stt_factory: Optional[STTProviderFactory] = None
74
+ self._translation_factory: Optional[TranslationProviderFactory] = None
75
+
76
+ # Register default services
77
+ self._register_default_services()
78
+
79
+ def _register_default_services(self) -> None:
80
+ """Register default service implementations."""
81
+ # Register configuration as singleton
82
+ self.register_singleton(AppConfig, self._config)
83
+
84
+ # Register provider factories as singletons
85
+ self.register_singleton(TTSProviderFactory, self._get_tts_factory)
86
+ self.register_singleton(STTProviderFactory, self._get_stt_factory)
87
+ self.register_singleton(TranslationProviderFactory, self._get_translation_factory)
88
+
89
+ def register_singleton(
90
+ self,
91
+ service_type: Type[T],
92
+ implementation: Union[Type[T], Callable[..., T], T],
93
+ factory_args: Optional[Dict[str, Any]] = None
94
+ ) -> None:
95
+ """
96
+ Register a service as singleton.
97
+
98
+ Args:
99
+ service_type: The service interface type
100
+ implementation: The implementation class, factory function, or instance
101
+ factory_args: Arguments to pass to the factory/constructor
102
+ """
103
+ with self._lock:
104
+ # If implementation is already an instance, store it directly
105
+ if not (isinstance(implementation, type) or callable(implementation)):
106
+ self._singletons[service_type] = implementation
107
+ logger.debug(f"Registered singleton instance for {service_type.__name__}")
108
+ return
109
+
110
+ descriptor = ServiceDescriptor(
111
+ service_type=service_type,
112
+ implementation=implementation,
113
+ lifetime=ServiceLifetime.SINGLETON,
114
+ factory_args=factory_args
115
+ )
116
+ self._services[service_type] = descriptor
117
+ logger.debug(f"Registered singleton service: {service_type.__name__}")
118
+
119
+ def register_transient(
120
+ self,
121
+ service_type: Type[T],
122
+ implementation: Union[Type[T], Callable[..., T]],
123
+ factory_args: Optional[Dict[str, Any]] = None
124
+ ) -> None:
125
+ """
126
+ Register a service as transient (new instance each time).
127
+
128
+ Args:
129
+ service_type: The service interface type
130
+ implementation: The implementation class or factory function
131
+ factory_args: Arguments to pass to the factory/constructor
132
+ """
133
+ with self._lock:
134
+ descriptor = ServiceDescriptor(
135
+ service_type=service_type,
136
+ implementation=implementation,
137
+ lifetime=ServiceLifetime.TRANSIENT,
138
+ factory_args=factory_args
139
+ )
140
+ self._services[service_type] = descriptor
141
+ logger.debug(f"Registered transient service: {service_type.__name__}")
142
+
143
+ def register_scoped(
144
+ self,
145
+ service_type: Type[T],
146
+ implementation: Union[Type[T], Callable[..., T]],
147
+ factory_args: Optional[Dict[str, Any]] = None
148
+ ) -> None:
149
+ """
150
+ Register a service as scoped (one instance per scope).
151
+
152
+ Args:
153
+ service_type: The service interface type
154
+ implementation: The implementation class or factory function
155
+ factory_args: Arguments to pass to the factory/constructor
156
+ """
157
+ with self._lock:
158
+ descriptor = ServiceDescriptor(
159
+ service_type=service_type,
160
+ implementation=implementation,
161
+ lifetime=ServiceLifetime.SCOPED,
162
+ factory_args=factory_args
163
+ )
164
+ self._services[service_type] = descriptor
165
+ logger.debug(f"Registered scoped service: {service_type.__name__}")
166
+
167
+ def resolve(self, service_type: Type[T]) -> T:
168
+ """
169
+ Resolve a service instance.
170
+
171
+ Args:
172
+ service_type: The service type to resolve
173
+
174
+ Returns:
175
+ T: The service instance
176
+
177
+ Raises:
178
+ ValueError: If service is not registered
179
+ Exception: If service creation fails
180
+ """
181
+ with self._lock:
182
+ # Check if already a singleton instance
183
+ if service_type in self._singletons:
184
+ return self._singletons[service_type]
185
+
186
+ # Check if service is registered
187
+ if service_type not in self._services:
188
+ raise ValueError(f"Service {service_type.__name__} is not registered")
189
+
190
+ descriptor = self._services[service_type]
191
+
192
+ try:
193
+ if descriptor.lifetime == ServiceLifetime.SINGLETON:
194
+ return self._create_singleton(service_type, descriptor)
195
+ elif descriptor.lifetime == ServiceLifetime.SCOPED:
196
+ return self._create_scoped(service_type, descriptor)
197
+ else: # TRANSIENT
198
+ return self._create_transient(descriptor)
199
+
200
+ except Exception as e:
201
+ logger.error(f"Failed to resolve service {service_type.__name__}: {e}")
202
+ raise
203
+
204
+ def _create_singleton(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T:
205
+ """Create or return existing singleton instance."""
206
+ if service_type in self._singletons:
207
+ return self._singletons[service_type]
208
+
209
+ instance = self._create_instance(descriptor)
210
+ self._singletons[service_type] = instance
211
+ logger.debug(f"Created singleton instance for {service_type.__name__}")
212
+ return instance
213
+
214
+ def _create_scoped(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T:
215
+ """Create or return existing scoped instance."""
216
+ if service_type in self._scoped_instances:
217
+ return self._scoped_instances[service_type]
218
+
219
+ instance = self._create_instance(descriptor)
220
+ self._scoped_instances[service_type] = instance
221
+ logger.debug(f"Created scoped instance for {service_type.__name__}")
222
+ return instance
223
+
224
+ def _create_transient(self, descriptor: ServiceDescriptor) -> T:
225
+ """Create new transient instance."""
226
+ instance = self._create_instance(descriptor)
227
+ logger.debug(f"Created transient instance for {descriptor.service_type.__name__}")
228
+ return instance
229
+
230
+ def _create_instance(self, descriptor: ServiceDescriptor) -> T:
231
+ """Create service instance using descriptor."""
232
+ implementation = descriptor.implementation
233
+ factory_args = descriptor.factory_args
234
+
235
+ # If implementation is a callable (factory function)
236
+ if callable(implementation) and not isinstance(implementation, type):
237
+ return implementation(**factory_args)
238
+
239
+ # If implementation is a class
240
+ if isinstance(implementation, type):
241
+ return implementation(**factory_args)
242
+
243
+ raise ValueError(f"Invalid implementation type for {descriptor.service_type.__name__}")
244
+
245
+ def _get_tts_factory(self) -> TTSProviderFactory:
246
+ """Get or create TTS provider factory."""
247
+ if self._tts_factory is None:
248
+ self._tts_factory = TTSProviderFactory()
249
+ logger.debug("Created TTS provider factory")
250
+ return self._tts_factory
251
+
252
+ def _get_stt_factory(self) -> STTProviderFactory:
253
+ """Get or create STT provider factory."""
254
+ if self._stt_factory is None:
255
+ self._stt_factory = STTProviderFactory()
256
+ logger.debug("Created STT provider factory")
257
+ return self._stt_factory
258
+
259
+ def _get_translation_factory(self) -> TranslationProviderFactory:
260
+ """Get or create translation provider factory."""
261
+ if self._translation_factory is None:
262
+ self._translation_factory = TranslationProviderFactory()
263
+ logger.debug("Created translation provider factory")
264
+ return self._translation_factory
265
+
266
+ def get_tts_provider(self, provider_name: Optional[str] = None, **kwargs) -> ISpeechSynthesisService:
267
+ """
268
+ Get TTS provider with fallback logic.
269
+
270
+ Args:
271
+ provider_name: Specific provider name or None for default
272
+ **kwargs: Additional provider arguments
273
+
274
+ Returns:
275
+ ISpeechSynthesisService: TTS provider instance
276
+ """
277
+ factory = self.resolve(TTSProviderFactory)
278
+
279
+ if provider_name:
280
+ return factory.create_provider(provider_name, **kwargs)
281
+ else:
282
+ preferred_providers = self._config.tts.preferred_providers
283
+ return factory.get_provider_with_fallback(preferred_providers, **kwargs)
284
+
285
+ def get_stt_provider(self, provider_name: Optional[str] = None) -> ISpeechRecognitionService:
286
+ """
287
+ Get STT provider with fallback logic.
288
+
289
+ Args:
290
+ provider_name: Specific provider name or None for default
291
+
292
+ Returns:
293
+ ISpeechRecognitionService: STT provider instance
294
+ """
295
+ factory = self.resolve(STTProviderFactory)
296
+
297
+ if provider_name:
298
+ return factory.create_provider(provider_name)
299
+ else:
300
+ preferred_provider = self._config.stt.default_model
301
+ return factory.create_provider_with_fallback(preferred_provider)
302
+
303
+ def get_translation_provider(
304
+ self,
305
+ provider_type: Optional[TranslationProviderType] = None,
306
+ config: Optional[Dict[str, Any]] = None
307
+ ) -> ITranslationService:
308
+ """
309
+ Get translation provider with fallback logic.
310
+
311
+ Args:
312
+ provider_type: Specific provider type or None for default
313
+ config: Optional provider configuration
314
+
315
+ Returns:
316
+ ITranslationService: Translation provider instance
317
+ """
318
+ factory = self.resolve(TranslationProviderFactory)
319
+
320
+ if provider_type:
321
+ return factory.create_provider(provider_type, config)
322
+ else:
323
+ return factory.get_default_provider(config)
324
+
325
+ def clear_scoped_instances(self) -> None:
326
+ """Clear all scoped instances."""
327
+ with self._lock:
328
+ # Cleanup scoped instances if they have cleanup methods
329
+ for instance in self._scoped_instances.values():
330
+ self._cleanup_instance(instance)
331
+
332
+ self._scoped_instances.clear()
333
+ logger.debug("Cleared scoped instances")
334
+
335
+ def _cleanup_instance(self, instance: Any) -> None:
336
+ """Cleanup instance if it has cleanup methods."""
337
+ try:
338
+ # Try common cleanup method names
339
+ cleanup_methods = ['cleanup', 'dispose', 'close', '__del__']
340
+ for method_name in cleanup_methods:
341
+ if hasattr(instance, method_name):
342
+ method = getattr(instance, method_name)
343
+ if callable(method):
344
+ method()
345
+ logger.debug(f"Called {method_name} on {type(instance).__name__}")
346
+ break
347
+ except Exception as e:
348
+ logger.warning(f"Failed to cleanup instance {type(instance).__name__}: {e}")
349
+
350
+ def cleanup(self) -> None:
351
+ """Cleanup all managed resources."""
352
+ with self._lock:
353
+ logger.info("Starting dependency container cleanup")
354
+
355
+ # Cleanup scoped instances
356
+ self.clear_scoped_instances()
357
+
358
+ # Cleanup singleton instances
359
+ for instance in self._singletons.values():
360
+ self._cleanup_instance(instance)
361
+
362
+ # Cleanup provider factories
363
+ if self._tts_factory:
364
+ try:
365
+ self._tts_factory.cleanup_providers()
366
+ except Exception as e:
367
+ logger.warning(f"Failed to cleanup TTS factory: {e}")
368
+
369
+ if self._translation_factory:
370
+ try:
371
+ self._translation_factory.clear_cache()
372
+ except Exception as e:
373
+ logger.warning(f"Failed to cleanup translation factory: {e}")
374
+
375
+ # Clear all references
376
+ self._singletons.clear()
377
+ self._tts_factory = None
378
+ self._stt_factory = None
379
+ self._translation_factory = None
380
+
381
+ logger.info("Dependency container cleanup completed")
382
+
383
+ def is_registered(self, service_type: Type) -> bool:
384
+ """
385
+ Check if a service type is registered.
386
+
387
+ Args:
388
+ service_type: The service type to check
389
+
390
+ Returns:
391
+ bool: True if registered, False otherwise
392
+ """
393
+ with self._lock:
394
+ return service_type in self._services or service_type in self._singletons
395
+
396
+ def get_registered_services(self) -> Dict[str, str]:
397
+ """
398
+ Get information about all registered services.
399
+
400
+ Returns:
401
+ Dict[str, str]: Mapping of service names to their lifetime
402
+ """
403
+ with self._lock:
404
+ services_info = {}
405
+
406
+ # Add singleton instances
407
+ for service_type in self._singletons.keys():
408
+ services_info[service_type.__name__] = "singleton (instance)"
409
+
410
+ # Add registered services
411
+ for service_type, descriptor in self._services.items():
412
+ if service_type not in self._singletons:
413
+ services_info[service_type.__name__] = descriptor.lifetime.value
414
+
415
+ return services_info
416
+
417
+ def create_scope(self) -> 'DependencyScope':
418
+ """
419
+ Create a new dependency scope.
420
+
421
+ Returns:
422
+ DependencyScope: New scope instance
423
+ """
424
+ return DependencyScope(self)
425
+
426
+ def __enter__(self):
427
+ """Context manager entry."""
428
+ return self
429
+
430
+ def __exit__(self, exc_type, exc_val, exc_tb):
431
+ """Context manager exit with cleanup."""
432
+ self.cleanup()
433
+
434
+
435
+ class DependencyScope:
436
+ """Scoped dependency container for managing scoped service lifetimes."""
437
+
438
+ def __init__(self, parent_container: DependencyContainer):
439
+ """
440
+ Initialize dependency scope.
441
+
442
+ Args:
443
+ parent_container: Parent dependency container
444
+ """
445
+ self._parent = parent_container
446
+ self._scoped_instances: Dict[Type, Any] = {}
447
+ self._lock = Lock()
448
+
449
+ def resolve(self, service_type: Type[T]) -> T:
450
+ """
451
+ Resolve service within this scope.
452
+
453
+ Args:
454
+ service_type: The service type to resolve
455
+
456
+ Returns:
457
+ T: The service instance
458
+ """
459
+ with self._lock:
460
+ # Check if we have a scoped instance
461
+ if service_type in self._scoped_instances:
462
+ return self._scoped_instances[service_type]
463
+
464
+ # Resolve from parent container
465
+ instance = self._parent.resolve(service_type)
466
+
467
+ # If it's a scoped service, store it in this scope
468
+ if (service_type in self._parent._services and
469
+ self._parent._services[service_type].lifetime == ServiceLifetime.SCOPED):
470
+ self._scoped_instances[service_type] = instance
471
+
472
+ return instance
473
+
474
+ def cleanup(self) -> None:
475
+ """Cleanup scoped instances."""
476
+ with self._lock:
477
+ for instance in self._scoped_instances.values():
478
+ self._parent._cleanup_instance(instance)
479
+ self._scoped_instances.clear()
480
+
481
+ def __enter__(self):
482
+ """Context manager entry."""
483
+ return self
484
+
485
+ def __exit__(self, exc_type, exc_val, exc_tb):
486
+ """Context manager exit with cleanup."""
487
+ self.cleanup()
488
+
489
+
490
+ # Global container instance
491
+ _global_container: Optional[DependencyContainer] = None
492
+ _container_lock = Lock()
493
+
494
+
495
+ def get_container() -> DependencyContainer:
496
+ """
497
+ Get the global dependency container instance.
498
+
499
+ Returns:
500
+ DependencyContainer: Global container instance
501
+ """
502
+ global _global_container
503
+
504
+ with _container_lock:
505
+ if _global_container is None:
506
+ _global_container = DependencyContainer()
507
+ logger.info("Created global dependency container")
508
+
509
+ return _global_container
510
+
511
+
512
+ def set_container(container: DependencyContainer) -> None:
513
+ """
514
+ Set the global dependency container instance.
515
+
516
+ Args:
517
+ container: Container instance to set as global
518
+ """
519
+ global _global_container
520
+
521
+ with _container_lock:
522
+ if _global_container is not None:
523
+ _global_container.cleanup()
524
+
525
+ _global_container = container
526
+ logger.info("Set global dependency container")
527
+
528
+
529
+ def cleanup_container() -> None:
530
+ """Cleanup the global dependency container."""
531
+ global _global_container
532
+
533
+ with _container_lock:
534
+ if _global_container is not None:
535
+ _global_container.cleanup()
536
+ _global_container = None
537
+ logger.info("Cleaned up global dependency container")
test_dependency_injection.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test script for dependency injection implementation."""
3
+
4
+ import sys
5
+ import os
6
+
7
+ # Add src to path
8
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
9
+
10
+ try:
11
+ from infrastructure.config.dependency_container import DependencyContainer, get_container
12
+ from infrastructure.config.app_config import AppConfig
13
+ from infrastructure.tts.provider_factory import TTSProviderFactory
14
+ from infrastructure.stt.provider_factory import STTProviderFactory
15
+ from infrastructure.translation.provider_factory import TranslationProviderFactory
16
+
17
+ print("πŸ§ͺ Testing Dependency Injection Implementation...")
18
+ print()
19
+
20
+ # Test basic container functionality
21
+ container = DependencyContainer()
22
+ print('βœ… DependencyContainer created successfully')
23
+
24
+ # Test configuration resolution
25
+ config = container.resolve(AppConfig)
26
+ print(f'βœ… AppConfig resolved: {type(config).__name__}')
27
+
28
+ # Test factory resolution
29
+ tts_factory = container.resolve(TTSProviderFactory)
30
+ stt_factory = container.resolve(STTProviderFactory)
31
+ translation_factory = container.resolve(TranslationProviderFactory)
32
+
33
+ print(f'βœ… TTSProviderFactory resolved: {type(tts_factory).__name__}')
34
+ print(f'βœ… STTProviderFactory resolved: {type(stt_factory).__name__}')
35
+ print(f'βœ… TranslationProviderFactory resolved: {type(translation_factory).__name__}')
36
+
37
+ # Test global container
38
+ global_container = get_container()
39
+ print(f'βœ… Global container retrieved: {type(global_container).__name__}')
40
+
41
+ # Test service registration info
42
+ services = container.get_registered_services()
43
+ print(f'βœ… Registered services: {len(services)} services')
44
+ for service_name, lifetime in services.items():
45
+ print(f' - {service_name}: {lifetime}')
46
+
47
+ # Test cleanup
48
+ container.cleanup()
49
+ print('βœ… Container cleanup completed')
50
+
51
+ print()
52
+ print('πŸŽ‰ All dependency injection tests passed!')
53
+
54
+ except Exception as e:
55
+ print(f'❌ Test failed: {e}')
56
+ import traceback
57
+ traceback.print_exc()
58
+ sys.exit(1)