Spaces:
Build error
Build error
Michael Hu
commited on
Commit
Β·
48ba7e8
1
Parent(s):
9626844
Create provider factories with dependency injection
Browse files
src/infrastructure/config/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Configuration and dependency injection module."""
|
2 |
+
|
3 |
+
from .dependency_container import DependencyContainer
|
4 |
+
from .app_config import AppConfig
|
5 |
+
|
6 |
+
__all__ = ['DependencyContainer', 'AppConfig']
|
src/infrastructure/config/app_config.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Application configuration management."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
from typing import Dict, List, Optional, Any
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class TTSConfig:
|
14 |
+
"""Configuration for TTS providers."""
|
15 |
+
preferred_providers: List[str] = field(default_factory=lambda: ['kokoro', 'dia', 'cosyvoice2', 'dummy'])
|
16 |
+
default_voice: str = 'default'
|
17 |
+
default_speed: float = 1.0
|
18 |
+
default_language: str = 'en'
|
19 |
+
enable_streaming: bool = True
|
20 |
+
max_text_length: int = 10000
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class STTConfig:
|
25 |
+
"""Configuration for STT providers."""
|
26 |
+
preferred_providers: List[str] = field(default_factory=lambda: ['whisper', 'parakeet'])
|
27 |
+
default_model: str = 'whisper'
|
28 |
+
chunk_length_s: int = 30
|
29 |
+
batch_size: int = 16
|
30 |
+
enable_vad: bool = True
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class TranslationConfig:
|
35 |
+
"""Configuration for translation providers."""
|
36 |
+
default_provider: str = 'nllb'
|
37 |
+
model_name: str = 'facebook/nllb-200-3.3B'
|
38 |
+
max_chunk_length: int = 1000
|
39 |
+
batch_size: int = 8
|
40 |
+
cache_translations: bool = True
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class ProcessingConfig:
|
45 |
+
"""Configuration for audio processing pipeline."""
|
46 |
+
temp_dir: str = '/tmp/audio_processing'
|
47 |
+
cleanup_temp_files: bool = True
|
48 |
+
max_file_size_mb: int = 100
|
49 |
+
supported_audio_formats: List[str] = field(default_factory=lambda: ['wav', 'mp3', 'flac', 'ogg'])
|
50 |
+
processing_timeout_seconds: int = 300
|
51 |
+
|
52 |
+
|
53 |
+
@dataclass
|
54 |
+
class LoggingConfig:
|
55 |
+
"""Configuration for logging."""
|
56 |
+
level: str = 'INFO'
|
57 |
+
format: str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
58 |
+
enable_file_logging: bool = False
|
59 |
+
log_file_path: str = 'app.log'
|
60 |
+
max_log_file_size_mb: int = 10
|
61 |
+
backup_count: int = 5
|
62 |
+
|
63 |
+
|
64 |
+
class AppConfig:
|
65 |
+
"""Centralized application configuration management."""
|
66 |
+
|
67 |
+
def __init__(self, config_file: Optional[str] = None):
|
68 |
+
"""
|
69 |
+
Initialize application configuration.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
config_file: Optional path to configuration file
|
73 |
+
"""
|
74 |
+
self.config_file = config_file
|
75 |
+
self._config_data: Dict[str, Any] = {}
|
76 |
+
|
77 |
+
# Initialize configuration sections
|
78 |
+
self.tts = TTSConfig()
|
79 |
+
self.stt = STTConfig()
|
80 |
+
self.translation = TranslationConfig()
|
81 |
+
self.processing = ProcessingConfig()
|
82 |
+
self.logging = LoggingConfig()
|
83 |
+
|
84 |
+
# Load configuration
|
85 |
+
self._load_configuration()
|
86 |
+
|
87 |
+
def _load_configuration(self) -> None:
|
88 |
+
"""Load configuration from environment variables and config file."""
|
89 |
+
try:
|
90 |
+
# Load from environment variables first
|
91 |
+
self._load_from_environment()
|
92 |
+
|
93 |
+
# Load from config file if provided
|
94 |
+
if self.config_file and os.path.exists(self.config_file):
|
95 |
+
self._load_from_file()
|
96 |
+
|
97 |
+
# Validate configuration
|
98 |
+
self._validate_configuration()
|
99 |
+
|
100 |
+
logger.info("Configuration loaded successfully")
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
logger.error(f"Failed to load configuration: {e}")
|
104 |
+
# Use default configuration
|
105 |
+
logger.info("Using default configuration")
|
106 |
+
|
107 |
+
def _load_from_environment(self) -> None:
|
108 |
+
"""Load configuration from environment variables."""
|
109 |
+
# TTS Configuration
|
110 |
+
if os.getenv('TTS_PREFERRED_PROVIDERS'):
|
111 |
+
self.tts.preferred_providers = os.getenv('TTS_PREFERRED_PROVIDERS').split(',')
|
112 |
+
if os.getenv('TTS_DEFAULT_VOICE'):
|
113 |
+
self.tts.default_voice = os.getenv('TTS_DEFAULT_VOICE')
|
114 |
+
if os.getenv('TTS_DEFAULT_SPEED'):
|
115 |
+
self.tts.default_speed = float(os.getenv('TTS_DEFAULT_SPEED'))
|
116 |
+
if os.getenv('TTS_DEFAULT_LANGUAGE'):
|
117 |
+
self.tts.default_language = os.getenv('TTS_DEFAULT_LANGUAGE')
|
118 |
+
if os.getenv('TTS_ENABLE_STREAMING'):
|
119 |
+
self.tts.enable_streaming = os.getenv('TTS_ENABLE_STREAMING').lower() == 'true'
|
120 |
+
|
121 |
+
# STT Configuration
|
122 |
+
if os.getenv('STT_PREFERRED_PROVIDERS'):
|
123 |
+
self.stt.preferred_providers = os.getenv('STT_PREFERRED_PROVIDERS').split(',')
|
124 |
+
if os.getenv('STT_DEFAULT_MODEL'):
|
125 |
+
self.stt.default_model = os.getenv('STT_DEFAULT_MODEL')
|
126 |
+
if os.getenv('STT_CHUNK_LENGTH'):
|
127 |
+
self.stt.chunk_length_s = int(os.getenv('STT_CHUNK_LENGTH'))
|
128 |
+
if os.getenv('STT_BATCH_SIZE'):
|
129 |
+
self.stt.batch_size = int(os.getenv('STT_BATCH_SIZE'))
|
130 |
+
|
131 |
+
# Translation Configuration
|
132 |
+
if os.getenv('TRANSLATION_DEFAULT_PROVIDER'):
|
133 |
+
self.translation.default_provider = os.getenv('TRANSLATION_DEFAULT_PROVIDER')
|
134 |
+
if os.getenv('TRANSLATION_MODEL_NAME'):
|
135 |
+
self.translation.model_name = os.getenv('TRANSLATION_MODEL_NAME')
|
136 |
+
if os.getenv('TRANSLATION_MAX_CHUNK_LENGTH'):
|
137 |
+
self.translation.max_chunk_length = int(os.getenv('TRANSLATION_MAX_CHUNK_LENGTH'))
|
138 |
+
|
139 |
+
# Processing Configuration
|
140 |
+
if os.getenv('PROCESSING_TEMP_DIR'):
|
141 |
+
self.processing.temp_dir = os.getenv('PROCESSING_TEMP_DIR')
|
142 |
+
if os.getenv('PROCESSING_CLEANUP_TEMP_FILES'):
|
143 |
+
self.processing.cleanup_temp_files = os.getenv('PROCESSING_CLEANUP_TEMP_FILES').lower() == 'true'
|
144 |
+
if os.getenv('PROCESSING_MAX_FILE_SIZE_MB'):
|
145 |
+
self.processing.max_file_size_mb = int(os.getenv('PROCESSING_MAX_FILE_SIZE_MB'))
|
146 |
+
if os.getenv('PROCESSING_TIMEOUT_SECONDS'):
|
147 |
+
self.processing.processing_timeout_seconds = int(os.getenv('PROCESSING_TIMEOUT_SECONDS'))
|
148 |
+
|
149 |
+
# Logging Configuration
|
150 |
+
if os.getenv('LOG_LEVEL'):
|
151 |
+
self.logging.level = os.getenv('LOG_LEVEL')
|
152 |
+
if os.getenv('LOG_FORMAT'):
|
153 |
+
self.logging.format = os.getenv('LOG_FORMAT')
|
154 |
+
if os.getenv('LOG_FILE_PATH'):
|
155 |
+
self.logging.log_file_path = os.getenv('LOG_FILE_PATH')
|
156 |
+
|
157 |
+
def _load_from_file(self) -> None:
|
158 |
+
"""Load configuration from file (JSON or YAML)."""
|
159 |
+
try:
|
160 |
+
import json
|
161 |
+
|
162 |
+
with open(self.config_file, 'r') as f:
|
163 |
+
if self.config_file.endswith('.json'):
|
164 |
+
self._config_data = json.load(f)
|
165 |
+
elif self.config_file.endswith(('.yml', '.yaml')):
|
166 |
+
try:
|
167 |
+
import yaml
|
168 |
+
self._config_data = yaml.safe_load(f)
|
169 |
+
except ImportError:
|
170 |
+
logger.warning("PyYAML not installed, cannot load YAML config file")
|
171 |
+
return
|
172 |
+
else:
|
173 |
+
logger.warning(f"Unsupported config file format: {self.config_file}")
|
174 |
+
return
|
175 |
+
|
176 |
+
# Apply configuration from file
|
177 |
+
self._apply_config_data()
|
178 |
+
|
179 |
+
except Exception as e:
|
180 |
+
logger.error(f"Failed to load config file {self.config_file}: {e}")
|
181 |
+
|
182 |
+
def _apply_config_data(self) -> None:
|
183 |
+
"""Apply configuration data from loaded file."""
|
184 |
+
if not self._config_data:
|
185 |
+
return
|
186 |
+
|
187 |
+
# Apply TTS configuration
|
188 |
+
tts_config = self._config_data.get('tts', {})
|
189 |
+
if 'preferred_providers' in tts_config:
|
190 |
+
self.tts.preferred_providers = tts_config['preferred_providers']
|
191 |
+
if 'default_voice' in tts_config:
|
192 |
+
self.tts.default_voice = tts_config['default_voice']
|
193 |
+
if 'default_speed' in tts_config:
|
194 |
+
self.tts.default_speed = tts_config['default_speed']
|
195 |
+
if 'default_language' in tts_config:
|
196 |
+
self.tts.default_language = tts_config['default_language']
|
197 |
+
if 'enable_streaming' in tts_config:
|
198 |
+
self.tts.enable_streaming = tts_config['enable_streaming']
|
199 |
+
|
200 |
+
# Apply STT configuration
|
201 |
+
stt_config = self._config_data.get('stt', {})
|
202 |
+
if 'preferred_providers' in stt_config:
|
203 |
+
self.stt.preferred_providers = stt_config['preferred_providers']
|
204 |
+
if 'default_model' in stt_config:
|
205 |
+
self.stt.default_model = stt_config['default_model']
|
206 |
+
if 'chunk_length_s' in stt_config:
|
207 |
+
self.stt.chunk_length_s = stt_config['chunk_length_s']
|
208 |
+
if 'batch_size' in stt_config:
|
209 |
+
self.stt.batch_size = stt_config['batch_size']
|
210 |
+
|
211 |
+
# Apply Translation configuration
|
212 |
+
translation_config = self._config_data.get('translation', {})
|
213 |
+
if 'default_provider' in translation_config:
|
214 |
+
self.translation.default_provider = translation_config['default_provider']
|
215 |
+
if 'model_name' in translation_config:
|
216 |
+
self.translation.model_name = translation_config['model_name']
|
217 |
+
if 'max_chunk_length' in translation_config:
|
218 |
+
self.translation.max_chunk_length = translation_config['max_chunk_length']
|
219 |
+
|
220 |
+
# Apply Processing configuration
|
221 |
+
processing_config = self._config_data.get('processing', {})
|
222 |
+
if 'temp_dir' in processing_config:
|
223 |
+
self.processing.temp_dir = processing_config['temp_dir']
|
224 |
+
if 'cleanup_temp_files' in processing_config:
|
225 |
+
self.processing.cleanup_temp_files = processing_config['cleanup_temp_files']
|
226 |
+
if 'max_file_size_mb' in processing_config:
|
227 |
+
self.processing.max_file_size_mb = processing_config['max_file_size_mb']
|
228 |
+
if 'processing_timeout_seconds' in processing_config:
|
229 |
+
self.processing.processing_timeout_seconds = processing_config['processing_timeout_seconds']
|
230 |
+
|
231 |
+
# Apply Logging configuration
|
232 |
+
logging_config = self._config_data.get('logging', {})
|
233 |
+
if 'level' in logging_config:
|
234 |
+
self.logging.level = logging_config['level']
|
235 |
+
if 'format' in logging_config:
|
236 |
+
self.logging.format = logging_config['format']
|
237 |
+
if 'log_file_path' in logging_config:
|
238 |
+
self.logging.log_file_path = logging_config['log_file_path']
|
239 |
+
|
240 |
+
def _validate_configuration(self) -> None:
|
241 |
+
"""Validate configuration values."""
|
242 |
+
# Validate TTS configuration
|
243 |
+
if not (0.1 <= self.tts.default_speed <= 3.0):
|
244 |
+
logger.warning(f"Invalid TTS speed {self.tts.default_speed}, using default 1.0")
|
245 |
+
self.tts.default_speed = 1.0
|
246 |
+
|
247 |
+
if self.tts.max_text_length <= 0:
|
248 |
+
logger.warning(f"Invalid max text length {self.tts.max_text_length}, using default 10000")
|
249 |
+
self.tts.max_text_length = 10000
|
250 |
+
|
251 |
+
# Validate STT configuration
|
252 |
+
if self.stt.chunk_length_s <= 0:
|
253 |
+
logger.warning(f"Invalid chunk length {self.stt.chunk_length_s}, using default 30")
|
254 |
+
self.stt.chunk_length_s = 30
|
255 |
+
|
256 |
+
if self.stt.batch_size <= 0:
|
257 |
+
logger.warning(f"Invalid batch size {self.stt.batch_size}, using default 16")
|
258 |
+
self.stt.batch_size = 16
|
259 |
+
|
260 |
+
# Validate Translation configuration
|
261 |
+
if self.translation.max_chunk_length <= 0:
|
262 |
+
logger.warning(f"Invalid max chunk length {self.translation.max_chunk_length}, using default 1000")
|
263 |
+
self.translation.max_chunk_length = 1000
|
264 |
+
|
265 |
+
# Validate Processing configuration
|
266 |
+
if self.processing.max_file_size_mb <= 0:
|
267 |
+
logger.warning(f"Invalid max file size {self.processing.max_file_size_mb}, using default 100")
|
268 |
+
self.processing.max_file_size_mb = 100
|
269 |
+
|
270 |
+
if self.processing.processing_timeout_seconds <= 0:
|
271 |
+
logger.warning(f"Invalid timeout {self.processing.processing_timeout_seconds}, using default 300")
|
272 |
+
self.processing.processing_timeout_seconds = 300
|
273 |
+
|
274 |
+
# Ensure temp directory exists
|
275 |
+
try:
|
276 |
+
Path(self.processing.temp_dir).mkdir(parents=True, exist_ok=True)
|
277 |
+
except Exception as e:
|
278 |
+
logger.warning(f"Failed to create temp directory {self.processing.temp_dir}: {e}")
|
279 |
+
self.processing.temp_dir = '/tmp/audio_processing'
|
280 |
+
Path(self.processing.temp_dir).mkdir(parents=True, exist_ok=True)
|
281 |
+
|
282 |
+
def get_tts_config(self) -> Dict[str, Any]:
|
283 |
+
"""Get TTS configuration as dictionary."""
|
284 |
+
return {
|
285 |
+
'preferred_providers': self.tts.preferred_providers,
|
286 |
+
'default_voice': self.tts.default_voice,
|
287 |
+
'default_speed': self.tts.default_speed,
|
288 |
+
'default_language': self.tts.default_language,
|
289 |
+
'enable_streaming': self.tts.enable_streaming,
|
290 |
+
'max_text_length': self.tts.max_text_length
|
291 |
+
}
|
292 |
+
|
293 |
+
def get_stt_config(self) -> Dict[str, Any]:
|
294 |
+
"""Get STT configuration as dictionary."""
|
295 |
+
return {
|
296 |
+
'preferred_providers': self.stt.preferred_providers,
|
297 |
+
'default_model': self.stt.default_model,
|
298 |
+
'chunk_length_s': self.stt.chunk_length_s,
|
299 |
+
'batch_size': self.stt.batch_size,
|
300 |
+
'enable_vad': self.stt.enable_vad
|
301 |
+
}
|
302 |
+
|
303 |
+
def get_translation_config(self) -> Dict[str, Any]:
|
304 |
+
"""Get translation configuration as dictionary."""
|
305 |
+
return {
|
306 |
+
'default_provider': self.translation.default_provider,
|
307 |
+
'model_name': self.translation.model_name,
|
308 |
+
'max_chunk_length': self.translation.max_chunk_length,
|
309 |
+
'batch_size': self.translation.batch_size,
|
310 |
+
'cache_translations': self.translation.cache_translations
|
311 |
+
}
|
312 |
+
|
313 |
+
def get_processing_config(self) -> Dict[str, Any]:
|
314 |
+
"""Get processing configuration as dictionary."""
|
315 |
+
return {
|
316 |
+
'temp_dir': self.processing.temp_dir,
|
317 |
+
'cleanup_temp_files': self.processing.cleanup_temp_files,
|
318 |
+
'max_file_size_mb': self.processing.max_file_size_mb,
|
319 |
+
'supported_audio_formats': self.processing.supported_audio_formats,
|
320 |
+
'processing_timeout_seconds': self.processing.processing_timeout_seconds
|
321 |
+
}
|
322 |
+
|
323 |
+
def get_logging_config(self) -> Dict[str, Any]:
|
324 |
+
"""Get logging configuration as dictionary."""
|
325 |
+
return {
|
326 |
+
'level': self.logging.level,
|
327 |
+
'format': self.logging.format,
|
328 |
+
'enable_file_logging': self.logging.enable_file_logging,
|
329 |
+
'log_file_path': self.logging.log_file_path,
|
330 |
+
'max_log_file_size_mb': self.logging.max_log_file_size_mb,
|
331 |
+
'backup_count': self.logging.backup_count
|
332 |
+
}
|
333 |
+
|
334 |
+
def reload_configuration(self) -> None:
|
335 |
+
"""Reload configuration from sources."""
|
336 |
+
logger.info("Reloading configuration")
|
337 |
+
self._load_configuration()
|
338 |
+
|
339 |
+
def save_configuration(self, output_file: str) -> None:
|
340 |
+
"""
|
341 |
+
Save current configuration to file.
|
342 |
+
|
343 |
+
Args:
|
344 |
+
output_file: Path to output configuration file
|
345 |
+
"""
|
346 |
+
try:
|
347 |
+
config_dict = {
|
348 |
+
'tts': self.get_tts_config(),
|
349 |
+
'stt': self.get_stt_config(),
|
350 |
+
'translation': self.get_translation_config(),
|
351 |
+
'processing': self.get_processing_config(),
|
352 |
+
'logging': self.get_logging_config()
|
353 |
+
}
|
354 |
+
|
355 |
+
import json
|
356 |
+
with open(output_file, 'w') as f:
|
357 |
+
json.dump(config_dict, f, indent=2)
|
358 |
+
|
359 |
+
logger.info(f"Configuration saved to {output_file}")
|
360 |
+
|
361 |
+
except Exception as e:
|
362 |
+
logger.error(f"Failed to save configuration to {output_file}: {e}")
|
363 |
+
raise
|
364 |
+
|
365 |
+
def __str__(self) -> str:
|
366 |
+
"""String representation of configuration."""
|
367 |
+
return (
|
368 |
+
f"AppConfig(\n"
|
369 |
+
f" TTS: {self.tts}\n"
|
370 |
+
f" STT: {self.stt}\n"
|
371 |
+
f" Translation: {self.translation}\n"
|
372 |
+
f" Processing: {self.processing}\n"
|
373 |
+
f" Logging: {self.logging}\n"
|
374 |
+
f")"
|
375 |
+
)
|
src/infrastructure/config/dependency_container.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Dependency injection container for managing component lifecycles."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Dict, Any, Optional, TypeVar, Type, Callable, Union
|
5 |
+
from enum import Enum
|
6 |
+
from threading import Lock
|
7 |
+
import weakref
|
8 |
+
|
9 |
+
from .app_config import AppConfig
|
10 |
+
from ..tts.provider_factory import TTSProviderFactory
|
11 |
+
from ..stt.provider_factory import STTProviderFactory
|
12 |
+
from ..translation.provider_factory import TranslationProviderFactory, TranslationProviderType
|
13 |
+
from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService
|
14 |
+
from ...domain.interfaces.speech_recognition import ISpeechRecognitionService
|
15 |
+
from ...domain.interfaces.translation import ITranslationService
|
16 |
+
from ...domain.interfaces.audio_processing import IAudioProcessingService
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
T = TypeVar('T')
|
21 |
+
|
22 |
+
|
23 |
+
class ServiceLifetime(Enum):
|
24 |
+
"""Service lifetime management options."""
|
25 |
+
SINGLETON = "singleton"
|
26 |
+
TRANSIENT = "transient"
|
27 |
+
SCOPED = "scoped"
|
28 |
+
|
29 |
+
|
30 |
+
class ServiceDescriptor:
|
31 |
+
"""Describes how a service should be created and managed."""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
service_type: Type[T],
|
36 |
+
implementation: Union[Type[T], Callable[..., T]],
|
37 |
+
lifetime: ServiceLifetime = ServiceLifetime.TRANSIENT,
|
38 |
+
factory_args: Optional[Dict[str, Any]] = None
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Initialize service descriptor.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
service_type: The service interface type
|
45 |
+
implementation: The implementation class or factory function
|
46 |
+
lifetime: Service lifetime management
|
47 |
+
factory_args: Arguments to pass to the factory/constructor
|
48 |
+
"""
|
49 |
+
self.service_type = service_type
|
50 |
+
self.implementation = implementation
|
51 |
+
self.lifetime = lifetime
|
52 |
+
self.factory_args = factory_args or {}
|
53 |
+
|
54 |
+
|
55 |
+
class DependencyContainer:
|
56 |
+
"""Dependency injection container for managing component lifecycles."""
|
57 |
+
|
58 |
+
def __init__(self, config: Optional[AppConfig] = None):
|
59 |
+
"""
|
60 |
+
Initialize the dependency container.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
config: Application configuration instance
|
64 |
+
"""
|
65 |
+
self._config = config or AppConfig()
|
66 |
+
self._services: Dict[Type, ServiceDescriptor] = {}
|
67 |
+
self._singletons: Dict[Type, Any] = {}
|
68 |
+
self._scoped_instances: Dict[Type, Any] = {}
|
69 |
+
self._lock = Lock()
|
70 |
+
|
71 |
+
# Provider factories
|
72 |
+
self._tts_factory: Optional[TTSProviderFactory] = None
|
73 |
+
self._stt_factory: Optional[STTProviderFactory] = None
|
74 |
+
self._translation_factory: Optional[TranslationProviderFactory] = None
|
75 |
+
|
76 |
+
# Register default services
|
77 |
+
self._register_default_services()
|
78 |
+
|
79 |
+
def _register_default_services(self) -> None:
|
80 |
+
"""Register default service implementations."""
|
81 |
+
# Register configuration as singleton
|
82 |
+
self.register_singleton(AppConfig, self._config)
|
83 |
+
|
84 |
+
# Register provider factories as singletons
|
85 |
+
self.register_singleton(TTSProviderFactory, self._get_tts_factory)
|
86 |
+
self.register_singleton(STTProviderFactory, self._get_stt_factory)
|
87 |
+
self.register_singleton(TranslationProviderFactory, self._get_translation_factory)
|
88 |
+
|
89 |
+
def register_singleton(
|
90 |
+
self,
|
91 |
+
service_type: Type[T],
|
92 |
+
implementation: Union[Type[T], Callable[..., T], T],
|
93 |
+
factory_args: Optional[Dict[str, Any]] = None
|
94 |
+
) -> None:
|
95 |
+
"""
|
96 |
+
Register a service as singleton.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
service_type: The service interface type
|
100 |
+
implementation: The implementation class, factory function, or instance
|
101 |
+
factory_args: Arguments to pass to the factory/constructor
|
102 |
+
"""
|
103 |
+
with self._lock:
|
104 |
+
# If implementation is already an instance, store it directly
|
105 |
+
if not (isinstance(implementation, type) or callable(implementation)):
|
106 |
+
self._singletons[service_type] = implementation
|
107 |
+
logger.debug(f"Registered singleton instance for {service_type.__name__}")
|
108 |
+
return
|
109 |
+
|
110 |
+
descriptor = ServiceDescriptor(
|
111 |
+
service_type=service_type,
|
112 |
+
implementation=implementation,
|
113 |
+
lifetime=ServiceLifetime.SINGLETON,
|
114 |
+
factory_args=factory_args
|
115 |
+
)
|
116 |
+
self._services[service_type] = descriptor
|
117 |
+
logger.debug(f"Registered singleton service: {service_type.__name__}")
|
118 |
+
|
119 |
+
def register_transient(
|
120 |
+
self,
|
121 |
+
service_type: Type[T],
|
122 |
+
implementation: Union[Type[T], Callable[..., T]],
|
123 |
+
factory_args: Optional[Dict[str, Any]] = None
|
124 |
+
) -> None:
|
125 |
+
"""
|
126 |
+
Register a service as transient (new instance each time).
|
127 |
+
|
128 |
+
Args:
|
129 |
+
service_type: The service interface type
|
130 |
+
implementation: The implementation class or factory function
|
131 |
+
factory_args: Arguments to pass to the factory/constructor
|
132 |
+
"""
|
133 |
+
with self._lock:
|
134 |
+
descriptor = ServiceDescriptor(
|
135 |
+
service_type=service_type,
|
136 |
+
implementation=implementation,
|
137 |
+
lifetime=ServiceLifetime.TRANSIENT,
|
138 |
+
factory_args=factory_args
|
139 |
+
)
|
140 |
+
self._services[service_type] = descriptor
|
141 |
+
logger.debug(f"Registered transient service: {service_type.__name__}")
|
142 |
+
|
143 |
+
def register_scoped(
|
144 |
+
self,
|
145 |
+
service_type: Type[T],
|
146 |
+
implementation: Union[Type[T], Callable[..., T]],
|
147 |
+
factory_args: Optional[Dict[str, Any]] = None
|
148 |
+
) -> None:
|
149 |
+
"""
|
150 |
+
Register a service as scoped (one instance per scope).
|
151 |
+
|
152 |
+
Args:
|
153 |
+
service_type: The service interface type
|
154 |
+
implementation: The implementation class or factory function
|
155 |
+
factory_args: Arguments to pass to the factory/constructor
|
156 |
+
"""
|
157 |
+
with self._lock:
|
158 |
+
descriptor = ServiceDescriptor(
|
159 |
+
service_type=service_type,
|
160 |
+
implementation=implementation,
|
161 |
+
lifetime=ServiceLifetime.SCOPED,
|
162 |
+
factory_args=factory_args
|
163 |
+
)
|
164 |
+
self._services[service_type] = descriptor
|
165 |
+
logger.debug(f"Registered scoped service: {service_type.__name__}")
|
166 |
+
|
167 |
+
def resolve(self, service_type: Type[T]) -> T:
|
168 |
+
"""
|
169 |
+
Resolve a service instance.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
service_type: The service type to resolve
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
T: The service instance
|
176 |
+
|
177 |
+
Raises:
|
178 |
+
ValueError: If service is not registered
|
179 |
+
Exception: If service creation fails
|
180 |
+
"""
|
181 |
+
with self._lock:
|
182 |
+
# Check if already a singleton instance
|
183 |
+
if service_type in self._singletons:
|
184 |
+
return self._singletons[service_type]
|
185 |
+
|
186 |
+
# Check if service is registered
|
187 |
+
if service_type not in self._services:
|
188 |
+
raise ValueError(f"Service {service_type.__name__} is not registered")
|
189 |
+
|
190 |
+
descriptor = self._services[service_type]
|
191 |
+
|
192 |
+
try:
|
193 |
+
if descriptor.lifetime == ServiceLifetime.SINGLETON:
|
194 |
+
return self._create_singleton(service_type, descriptor)
|
195 |
+
elif descriptor.lifetime == ServiceLifetime.SCOPED:
|
196 |
+
return self._create_scoped(service_type, descriptor)
|
197 |
+
else: # TRANSIENT
|
198 |
+
return self._create_transient(descriptor)
|
199 |
+
|
200 |
+
except Exception as e:
|
201 |
+
logger.error(f"Failed to resolve service {service_type.__name__}: {e}")
|
202 |
+
raise
|
203 |
+
|
204 |
+
def _create_singleton(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T:
|
205 |
+
"""Create or return existing singleton instance."""
|
206 |
+
if service_type in self._singletons:
|
207 |
+
return self._singletons[service_type]
|
208 |
+
|
209 |
+
instance = self._create_instance(descriptor)
|
210 |
+
self._singletons[service_type] = instance
|
211 |
+
logger.debug(f"Created singleton instance for {service_type.__name__}")
|
212 |
+
return instance
|
213 |
+
|
214 |
+
def _create_scoped(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T:
|
215 |
+
"""Create or return existing scoped instance."""
|
216 |
+
if service_type in self._scoped_instances:
|
217 |
+
return self._scoped_instances[service_type]
|
218 |
+
|
219 |
+
instance = self._create_instance(descriptor)
|
220 |
+
self._scoped_instances[service_type] = instance
|
221 |
+
logger.debug(f"Created scoped instance for {service_type.__name__}")
|
222 |
+
return instance
|
223 |
+
|
224 |
+
def _create_transient(self, descriptor: ServiceDescriptor) -> T:
|
225 |
+
"""Create new transient instance."""
|
226 |
+
instance = self._create_instance(descriptor)
|
227 |
+
logger.debug(f"Created transient instance for {descriptor.service_type.__name__}")
|
228 |
+
return instance
|
229 |
+
|
230 |
+
def _create_instance(self, descriptor: ServiceDescriptor) -> T:
|
231 |
+
"""Create service instance using descriptor."""
|
232 |
+
implementation = descriptor.implementation
|
233 |
+
factory_args = descriptor.factory_args
|
234 |
+
|
235 |
+
# If implementation is a callable (factory function)
|
236 |
+
if callable(implementation) and not isinstance(implementation, type):
|
237 |
+
return implementation(**factory_args)
|
238 |
+
|
239 |
+
# If implementation is a class
|
240 |
+
if isinstance(implementation, type):
|
241 |
+
return implementation(**factory_args)
|
242 |
+
|
243 |
+
raise ValueError(f"Invalid implementation type for {descriptor.service_type.__name__}")
|
244 |
+
|
245 |
+
def _get_tts_factory(self) -> TTSProviderFactory:
|
246 |
+
"""Get or create TTS provider factory."""
|
247 |
+
if self._tts_factory is None:
|
248 |
+
self._tts_factory = TTSProviderFactory()
|
249 |
+
logger.debug("Created TTS provider factory")
|
250 |
+
return self._tts_factory
|
251 |
+
|
252 |
+
def _get_stt_factory(self) -> STTProviderFactory:
|
253 |
+
"""Get or create STT provider factory."""
|
254 |
+
if self._stt_factory is None:
|
255 |
+
self._stt_factory = STTProviderFactory()
|
256 |
+
logger.debug("Created STT provider factory")
|
257 |
+
return self._stt_factory
|
258 |
+
|
259 |
+
def _get_translation_factory(self) -> TranslationProviderFactory:
|
260 |
+
"""Get or create translation provider factory."""
|
261 |
+
if self._translation_factory is None:
|
262 |
+
self._translation_factory = TranslationProviderFactory()
|
263 |
+
logger.debug("Created translation provider factory")
|
264 |
+
return self._translation_factory
|
265 |
+
|
266 |
+
def get_tts_provider(self, provider_name: Optional[str] = None, **kwargs) -> ISpeechSynthesisService:
|
267 |
+
"""
|
268 |
+
Get TTS provider with fallback logic.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
provider_name: Specific provider name or None for default
|
272 |
+
**kwargs: Additional provider arguments
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
ISpeechSynthesisService: TTS provider instance
|
276 |
+
"""
|
277 |
+
factory = self.resolve(TTSProviderFactory)
|
278 |
+
|
279 |
+
if provider_name:
|
280 |
+
return factory.create_provider(provider_name, **kwargs)
|
281 |
+
else:
|
282 |
+
preferred_providers = self._config.tts.preferred_providers
|
283 |
+
return factory.get_provider_with_fallback(preferred_providers, **kwargs)
|
284 |
+
|
285 |
+
def get_stt_provider(self, provider_name: Optional[str] = None) -> ISpeechRecognitionService:
|
286 |
+
"""
|
287 |
+
Get STT provider with fallback logic.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
provider_name: Specific provider name or None for default
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
ISpeechRecognitionService: STT provider instance
|
294 |
+
"""
|
295 |
+
factory = self.resolve(STTProviderFactory)
|
296 |
+
|
297 |
+
if provider_name:
|
298 |
+
return factory.create_provider(provider_name)
|
299 |
+
else:
|
300 |
+
preferred_provider = self._config.stt.default_model
|
301 |
+
return factory.create_provider_with_fallback(preferred_provider)
|
302 |
+
|
303 |
+
def get_translation_provider(
|
304 |
+
self,
|
305 |
+
provider_type: Optional[TranslationProviderType] = None,
|
306 |
+
config: Optional[Dict[str, Any]] = None
|
307 |
+
) -> ITranslationService:
|
308 |
+
"""
|
309 |
+
Get translation provider with fallback logic.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
provider_type: Specific provider type or None for default
|
313 |
+
config: Optional provider configuration
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
ITranslationService: Translation provider instance
|
317 |
+
"""
|
318 |
+
factory = self.resolve(TranslationProviderFactory)
|
319 |
+
|
320 |
+
if provider_type:
|
321 |
+
return factory.create_provider(provider_type, config)
|
322 |
+
else:
|
323 |
+
return factory.get_default_provider(config)
|
324 |
+
|
325 |
+
def clear_scoped_instances(self) -> None:
|
326 |
+
"""Clear all scoped instances."""
|
327 |
+
with self._lock:
|
328 |
+
# Cleanup scoped instances if they have cleanup methods
|
329 |
+
for instance in self._scoped_instances.values():
|
330 |
+
self._cleanup_instance(instance)
|
331 |
+
|
332 |
+
self._scoped_instances.clear()
|
333 |
+
logger.debug("Cleared scoped instances")
|
334 |
+
|
335 |
+
def _cleanup_instance(self, instance: Any) -> None:
|
336 |
+
"""Cleanup instance if it has cleanup methods."""
|
337 |
+
try:
|
338 |
+
# Try common cleanup method names
|
339 |
+
cleanup_methods = ['cleanup', 'dispose', 'close', '__del__']
|
340 |
+
for method_name in cleanup_methods:
|
341 |
+
if hasattr(instance, method_name):
|
342 |
+
method = getattr(instance, method_name)
|
343 |
+
if callable(method):
|
344 |
+
method()
|
345 |
+
logger.debug(f"Called {method_name} on {type(instance).__name__}")
|
346 |
+
break
|
347 |
+
except Exception as e:
|
348 |
+
logger.warning(f"Failed to cleanup instance {type(instance).__name__}: {e}")
|
349 |
+
|
350 |
+
def cleanup(self) -> None:
|
351 |
+
"""Cleanup all managed resources."""
|
352 |
+
with self._lock:
|
353 |
+
logger.info("Starting dependency container cleanup")
|
354 |
+
|
355 |
+
# Cleanup scoped instances
|
356 |
+
self.clear_scoped_instances()
|
357 |
+
|
358 |
+
# Cleanup singleton instances
|
359 |
+
for instance in self._singletons.values():
|
360 |
+
self._cleanup_instance(instance)
|
361 |
+
|
362 |
+
# Cleanup provider factories
|
363 |
+
if self._tts_factory:
|
364 |
+
try:
|
365 |
+
self._tts_factory.cleanup_providers()
|
366 |
+
except Exception as e:
|
367 |
+
logger.warning(f"Failed to cleanup TTS factory: {e}")
|
368 |
+
|
369 |
+
if self._translation_factory:
|
370 |
+
try:
|
371 |
+
self._translation_factory.clear_cache()
|
372 |
+
except Exception as e:
|
373 |
+
logger.warning(f"Failed to cleanup translation factory: {e}")
|
374 |
+
|
375 |
+
# Clear all references
|
376 |
+
self._singletons.clear()
|
377 |
+
self._tts_factory = None
|
378 |
+
self._stt_factory = None
|
379 |
+
self._translation_factory = None
|
380 |
+
|
381 |
+
logger.info("Dependency container cleanup completed")
|
382 |
+
|
383 |
+
def is_registered(self, service_type: Type) -> bool:
|
384 |
+
"""
|
385 |
+
Check if a service type is registered.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
service_type: The service type to check
|
389 |
+
|
390 |
+
Returns:
|
391 |
+
bool: True if registered, False otherwise
|
392 |
+
"""
|
393 |
+
with self._lock:
|
394 |
+
return service_type in self._services or service_type in self._singletons
|
395 |
+
|
396 |
+
def get_registered_services(self) -> Dict[str, str]:
|
397 |
+
"""
|
398 |
+
Get information about all registered services.
|
399 |
+
|
400 |
+
Returns:
|
401 |
+
Dict[str, str]: Mapping of service names to their lifetime
|
402 |
+
"""
|
403 |
+
with self._lock:
|
404 |
+
services_info = {}
|
405 |
+
|
406 |
+
# Add singleton instances
|
407 |
+
for service_type in self._singletons.keys():
|
408 |
+
services_info[service_type.__name__] = "singleton (instance)"
|
409 |
+
|
410 |
+
# Add registered services
|
411 |
+
for service_type, descriptor in self._services.items():
|
412 |
+
if service_type not in self._singletons:
|
413 |
+
services_info[service_type.__name__] = descriptor.lifetime.value
|
414 |
+
|
415 |
+
return services_info
|
416 |
+
|
417 |
+
def create_scope(self) -> 'DependencyScope':
|
418 |
+
"""
|
419 |
+
Create a new dependency scope.
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
DependencyScope: New scope instance
|
423 |
+
"""
|
424 |
+
return DependencyScope(self)
|
425 |
+
|
426 |
+
def __enter__(self):
|
427 |
+
"""Context manager entry."""
|
428 |
+
return self
|
429 |
+
|
430 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
431 |
+
"""Context manager exit with cleanup."""
|
432 |
+
self.cleanup()
|
433 |
+
|
434 |
+
|
435 |
+
class DependencyScope:
|
436 |
+
"""Scoped dependency container for managing scoped service lifetimes."""
|
437 |
+
|
438 |
+
def __init__(self, parent_container: DependencyContainer):
|
439 |
+
"""
|
440 |
+
Initialize dependency scope.
|
441 |
+
|
442 |
+
Args:
|
443 |
+
parent_container: Parent dependency container
|
444 |
+
"""
|
445 |
+
self._parent = parent_container
|
446 |
+
self._scoped_instances: Dict[Type, Any] = {}
|
447 |
+
self._lock = Lock()
|
448 |
+
|
449 |
+
def resolve(self, service_type: Type[T]) -> T:
|
450 |
+
"""
|
451 |
+
Resolve service within this scope.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
service_type: The service type to resolve
|
455 |
+
|
456 |
+
Returns:
|
457 |
+
T: The service instance
|
458 |
+
"""
|
459 |
+
with self._lock:
|
460 |
+
# Check if we have a scoped instance
|
461 |
+
if service_type in self._scoped_instances:
|
462 |
+
return self._scoped_instances[service_type]
|
463 |
+
|
464 |
+
# Resolve from parent container
|
465 |
+
instance = self._parent.resolve(service_type)
|
466 |
+
|
467 |
+
# If it's a scoped service, store it in this scope
|
468 |
+
if (service_type in self._parent._services and
|
469 |
+
self._parent._services[service_type].lifetime == ServiceLifetime.SCOPED):
|
470 |
+
self._scoped_instances[service_type] = instance
|
471 |
+
|
472 |
+
return instance
|
473 |
+
|
474 |
+
def cleanup(self) -> None:
|
475 |
+
"""Cleanup scoped instances."""
|
476 |
+
with self._lock:
|
477 |
+
for instance in self._scoped_instances.values():
|
478 |
+
self._parent._cleanup_instance(instance)
|
479 |
+
self._scoped_instances.clear()
|
480 |
+
|
481 |
+
def __enter__(self):
|
482 |
+
"""Context manager entry."""
|
483 |
+
return self
|
484 |
+
|
485 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
486 |
+
"""Context manager exit with cleanup."""
|
487 |
+
self.cleanup()
|
488 |
+
|
489 |
+
|
490 |
+
# Global container instance
|
491 |
+
_global_container: Optional[DependencyContainer] = None
|
492 |
+
_container_lock = Lock()
|
493 |
+
|
494 |
+
|
495 |
+
def get_container() -> DependencyContainer:
|
496 |
+
"""
|
497 |
+
Get the global dependency container instance.
|
498 |
+
|
499 |
+
Returns:
|
500 |
+
DependencyContainer: Global container instance
|
501 |
+
"""
|
502 |
+
global _global_container
|
503 |
+
|
504 |
+
with _container_lock:
|
505 |
+
if _global_container is None:
|
506 |
+
_global_container = DependencyContainer()
|
507 |
+
logger.info("Created global dependency container")
|
508 |
+
|
509 |
+
return _global_container
|
510 |
+
|
511 |
+
|
512 |
+
def set_container(container: DependencyContainer) -> None:
|
513 |
+
"""
|
514 |
+
Set the global dependency container instance.
|
515 |
+
|
516 |
+
Args:
|
517 |
+
container: Container instance to set as global
|
518 |
+
"""
|
519 |
+
global _global_container
|
520 |
+
|
521 |
+
with _container_lock:
|
522 |
+
if _global_container is not None:
|
523 |
+
_global_container.cleanup()
|
524 |
+
|
525 |
+
_global_container = container
|
526 |
+
logger.info("Set global dependency container")
|
527 |
+
|
528 |
+
|
529 |
+
def cleanup_container() -> None:
|
530 |
+
"""Cleanup the global dependency container."""
|
531 |
+
global _global_container
|
532 |
+
|
533 |
+
with _container_lock:
|
534 |
+
if _global_container is not None:
|
535 |
+
_global_container.cleanup()
|
536 |
+
_global_container = None
|
537 |
+
logger.info("Cleaned up global dependency container")
|
test_dependency_injection.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""Test script for dependency injection implementation."""
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Add src to path
|
8 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
9 |
+
|
10 |
+
try:
|
11 |
+
from infrastructure.config.dependency_container import DependencyContainer, get_container
|
12 |
+
from infrastructure.config.app_config import AppConfig
|
13 |
+
from infrastructure.tts.provider_factory import TTSProviderFactory
|
14 |
+
from infrastructure.stt.provider_factory import STTProviderFactory
|
15 |
+
from infrastructure.translation.provider_factory import TranslationProviderFactory
|
16 |
+
|
17 |
+
print("π§ͺ Testing Dependency Injection Implementation...")
|
18 |
+
print()
|
19 |
+
|
20 |
+
# Test basic container functionality
|
21 |
+
container = DependencyContainer()
|
22 |
+
print('β
DependencyContainer created successfully')
|
23 |
+
|
24 |
+
# Test configuration resolution
|
25 |
+
config = container.resolve(AppConfig)
|
26 |
+
print(f'β
AppConfig resolved: {type(config).__name__}')
|
27 |
+
|
28 |
+
# Test factory resolution
|
29 |
+
tts_factory = container.resolve(TTSProviderFactory)
|
30 |
+
stt_factory = container.resolve(STTProviderFactory)
|
31 |
+
translation_factory = container.resolve(TranslationProviderFactory)
|
32 |
+
|
33 |
+
print(f'β
TTSProviderFactory resolved: {type(tts_factory).__name__}')
|
34 |
+
print(f'β
STTProviderFactory resolved: {type(stt_factory).__name__}')
|
35 |
+
print(f'β
TranslationProviderFactory resolved: {type(translation_factory).__name__}')
|
36 |
+
|
37 |
+
# Test global container
|
38 |
+
global_container = get_container()
|
39 |
+
print(f'β
Global container retrieved: {type(global_container).__name__}')
|
40 |
+
|
41 |
+
# Test service registration info
|
42 |
+
services = container.get_registered_services()
|
43 |
+
print(f'β
Registered services: {len(services)} services')
|
44 |
+
for service_name, lifetime in services.items():
|
45 |
+
print(f' - {service_name}: {lifetime}')
|
46 |
+
|
47 |
+
# Test cleanup
|
48 |
+
container.cleanup()
|
49 |
+
print('β
Container cleanup completed')
|
50 |
+
|
51 |
+
print()
|
52 |
+
print('π All dependency injection tests passed!')
|
53 |
+
|
54 |
+
except Exception as e:
|
55 |
+
print(f'β Test failed: {e}')
|
56 |
+
import traceback
|
57 |
+
traceback.print_exc()
|
58 |
+
sys.exit(1)
|