Spaces:
Build error
Build error
"""Dependency injection container for managing component lifecycles.""" | |
import logging | |
from typing import Dict, Any, Optional, TypeVar, Type, Callable, Union | |
from enum import Enum | |
from threading import RLock | |
import weakref | |
from .app_config import AppConfig | |
from ..tts.provider_factory import TTSProviderFactory | |
from ..stt.provider_factory import STTProviderFactory | |
from ..translation.provider_factory import TranslationProviderFactory, TranslationProviderType | |
from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService | |
from ...domain.interfaces.speech_recognition import ISpeechRecognitionService | |
from ...domain.interfaces.translation import ITranslationService | |
from ...domain.interfaces.audio_processing import IAudioProcessingService | |
logger = logging.getLogger(__name__) | |
T = TypeVar('T') | |
class ServiceLifetime(Enum): | |
"""Service lifetime management options.""" | |
SINGLETON = "singleton" | |
TRANSIENT = "transient" | |
SCOPED = "scoped" | |
class ServiceDescriptor: | |
"""Describes how a service should be created and managed.""" | |
def __init__( | |
self, | |
service_type: Type[T], | |
implementation: Union[Type[T], Callable[..., T]], | |
lifetime: ServiceLifetime = ServiceLifetime.TRANSIENT, | |
factory_args: Optional[Dict[str, Any]] = None | |
): | |
""" | |
Initialize service descriptor. | |
Args: | |
service_type: The service interface type | |
implementation: The implementation class or factory function | |
lifetime: Service lifetime management | |
factory_args: Arguments to pass to the factory/constructor | |
""" | |
self.service_type = service_type | |
self.implementation = implementation | |
self.lifetime = lifetime | |
self.factory_args = factory_args or {} | |
class DependencyContainer: | |
"""Dependency injection container for managing component lifecycles.""" | |
def __init__(self, config: Optional[AppConfig] = None): | |
""" | |
Initialize the dependency container. | |
Args: | |
config: Application configuration instance | |
""" | |
self._config = config or AppConfig() | |
self._services: Dict[Type, ServiceDescriptor] = {} | |
self._singletons: Dict[Type, Any] = {} | |
self._scoped_instances: Dict[Type, Any] = {} | |
self._lock = RLock() # Use RLock for re-entrant locking | |
# Provider factories | |
self._tts_factory: Optional[TTSProviderFactory] = None | |
self._stt_factory: Optional[STTProviderFactory] = None | |
self._translation_factory: Optional[TranslationProviderFactory] = None | |
# Register default services | |
self._register_default_services() | |
def _register_default_services(self) -> None: | |
"""Register default service implementations.""" | |
# Register configuration as singleton | |
self.register_singleton(AppConfig, self._config) | |
# Register provider factories as singletons | |
self.register_singleton(TTSProviderFactory, self._get_tts_factory) | |
self.register_singleton(STTProviderFactory, self._get_stt_factory) | |
self.register_singleton(TranslationProviderFactory, self._get_translation_factory) | |
def register_singleton( | |
self, | |
service_type: Type[T], | |
implementation: Union[Type[T], Callable[..., T], T], | |
factory_args: Optional[Dict[str, Any]] = None | |
) -> None: | |
""" | |
Register a service as singleton. | |
Args: | |
service_type: The service interface type | |
implementation: The implementation class, factory function, or instance | |
factory_args: Arguments to pass to the factory/constructor | |
""" | |
with self._lock: | |
# If implementation is already an instance, store it directly | |
if not (isinstance(implementation, type) or callable(implementation)): | |
self._singletons[service_type] = implementation | |
logger.info(f"Registered singleton instance for {service_type.__name__}") | |
return | |
descriptor = ServiceDescriptor( | |
service_type=service_type, | |
implementation=implementation, | |
lifetime=ServiceLifetime.SINGLETON, | |
factory_args=factory_args | |
) | |
self._services[service_type] = descriptor | |
logger.info(f"Registered singleton service: {service_type.__name__}") | |
def register_transient( | |
self, | |
service_type: Type[T], | |
implementation: Union[Type[T], Callable[..., T]], | |
factory_args: Optional[Dict[str, Any]] = None | |
) -> None: | |
""" | |
Register a service as transient (new instance each time). | |
Args: | |
service_type: The service interface type | |
implementation: The implementation class or factory function | |
factory_args: Arguments to pass to the factory/constructor | |
""" | |
with self._lock: | |
descriptor = ServiceDescriptor( | |
service_type=service_type, | |
implementation=implementation, | |
lifetime=ServiceLifetime.TRANSIENT, | |
factory_args=factory_args | |
) | |
self._services[service_type] = descriptor | |
logger.info(f"Registered transient service: {service_type.__name__}") | |
def register_scoped( | |
self, | |
service_type: Type[T], | |
implementation: Union[Type[T], Callable[..., T]], | |
factory_args: Optional[Dict[str, Any]] = None | |
) -> None: | |
""" | |
Register a service as scoped (one instance per scope). | |
Args: | |
service_type: The service interface type | |
implementation: The implementation class or factory function | |
factory_args: Arguments to pass to the factory/constructor | |
""" | |
with self._lock: | |
descriptor = ServiceDescriptor( | |
service_type=service_type, | |
implementation=implementation, | |
lifetime=ServiceLifetime.SCOPED, | |
factory_args=factory_args | |
) | |
self._services[service_type] = descriptor | |
logger.info(f"Registered scoped service: {service_type.__name__}") | |
def resolve(self, service_type: Type[T]) -> T: | |
""" | |
Resolve a service instance. | |
Args: | |
service_type: The service type to resolve | |
Returns: | |
T: The service instance | |
Raises: | |
ValueError: If service is not registered | |
Exception: If service creation fails | |
""" | |
logger.info(f"Starting resolve for service: {service_type.__name__}") | |
with self._lock: | |
logger.info(f"Acquired lock for resolving: {service_type.__name__}") | |
# Check if already a singleton instance | |
if service_type in self._singletons: | |
logger.info(f"Found existing singleton instance for: {service_type.__name__}") | |
return self._singletons[service_type] | |
# Check if service is registered | |
if service_type not in self._services: | |
logger.error(f"Service {service_type.__name__} is not registered") | |
raise ValueError(f"Service {service_type.__name__} is not registered") | |
descriptor = self._services[service_type] | |
logger.info(f"Found service descriptor for {service_type.__name__} with lifetime: {descriptor.lifetime.value}") | |
try: | |
if descriptor.lifetime == ServiceLifetime.SINGLETON: | |
logger.info(f"Creating singleton for: {service_type.__name__}") | |
result = self._create_singleton(service_type, descriptor) | |
logger.info(f"Successfully created singleton for: {service_type.__name__}") | |
return result | |
elif descriptor.lifetime == ServiceLifetime.SCOPED: | |
logger.info(f"Creating scoped instance for: {service_type.__name__}") | |
result = self._create_scoped(service_type, descriptor) | |
logger.info(f"Successfully created scoped instance for: {service_type.__name__}") | |
return result | |
else: # TRANSIENT | |
logger.info(f"Creating transient instance for: {service_type.__name__}") | |
result = self._create_transient(descriptor) | |
logger.info(f"Successfully created transient instance for: {service_type.__name__}") | |
return result | |
except Exception as e: | |
logger.error(f"Failed to resolve service {service_type.__name__}: {e}", exception=e) | |
raise | |
def _create_singleton(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T: | |
"""Create or return existing singleton instance.""" | |
if service_type in self._singletons: | |
return self._singletons[service_type] | |
instance = self._create_instance(descriptor) | |
self._singletons[service_type] = instance | |
logger.info(f"Created singleton instance for {service_type.__name__}") | |
return instance | |
def _create_scoped(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T: | |
"""Create or return existing scoped instance.""" | |
if service_type in self._scoped_instances: | |
return self._scoped_instances[service_type] | |
instance = self._create_instance(descriptor) | |
self._scoped_instances[service_type] = instance | |
logger.info(f"Created scoped instance for {service_type.__name__}") | |
return instance | |
def _create_transient(self, descriptor: ServiceDescriptor) -> T: | |
"""Create new transient instance.""" | |
instance = self._create_instance(descriptor) | |
logger.info(f"Created transient instance for {descriptor.service_type.__name__}") | |
return instance | |
def _create_instance(self, descriptor: ServiceDescriptor) -> T: | |
"""Create service instance using descriptor.""" | |
logger.info(f"Creating instance for {descriptor.service_type.__name__}") | |
implementation = descriptor.implementation | |
factory_args = descriptor.factory_args | |
logger.info(f"Implementation type: {type(implementation)}, Factory args: {factory_args}") | |
# If implementation is a callable (factory function) | |
if callable(implementation) and not isinstance(implementation, type): | |
logger.info(f"Calling factory function for {descriptor.service_type.__name__}") | |
try: | |
result = implementation(**factory_args) | |
logger.info(f"Factory function completed for {descriptor.service_type.__name__}") | |
return result | |
except Exception as e: | |
logger.error(f"Factory function failed for {descriptor.service_type.__name__}: {e}", exception=e) | |
raise | |
# If implementation is a class | |
if isinstance(implementation, type): | |
logger.info(f"Instantiating class {implementation.__name__} for {descriptor.service_type.__name__}") | |
try: | |
result = implementation(**factory_args) | |
logger.info(f"Class instantiation completed for {descriptor.service_type.__name__}") | |
return result | |
except Exception as e: | |
logger.error(f"Class instantiation failed for {descriptor.service_type.__name__}: {e}", exception=e) | |
raise | |
logger.error(f"Invalid implementation type for {descriptor.service_type.__name__}: {type(implementation)}") | |
raise ValueError(f"Invalid implementation type for {descriptor.service_type.__name__}") | |
def _get_tts_factory(self) -> TTSProviderFactory: | |
"""Get or create TTS provider factory.""" | |
if self._tts_factory is None: | |
self._tts_factory = TTSProviderFactory() | |
logger.info("Created TTS provider factory") | |
return self._tts_factory | |
def _get_stt_factory(self) -> STTProviderFactory: | |
"""Get or create STT provider factory.""" | |
if self._stt_factory is None: | |
self._stt_factory = STTProviderFactory() | |
logger.info("Created STT provider factory") | |
return self._stt_factory | |
def _get_translation_factory(self) -> TranslationProviderFactory: | |
"""Get or create translation provider factory.""" | |
if self._translation_factory is None: | |
self._translation_factory = TranslationProviderFactory() | |
logger.info("Created translation provider factory") | |
return self._translation_factory | |
def get_tts_provider(self, provider_name: Optional[str] = None, **kwargs) -> ISpeechSynthesisService: | |
""" | |
Get TTS provider with fallback logic. | |
Args: | |
provider_name: Specific provider name or None for default | |
**kwargs: Additional provider arguments | |
Returns: | |
ISpeechSynthesisService: TTS provider instance | |
""" | |
logger.info(f"🎯 Requesting TTS provider: {provider_name or 'default'}") | |
factory = self.resolve(TTSProviderFactory) | |
if provider_name: | |
logger.info(f"🔧 Attempting to create specific TTS provider: {provider_name}") | |
try: | |
provider = factory.create_provider(provider_name, **kwargs) | |
logger.info(f"✅ Successfully created TTS provider: {provider_name}") | |
return provider | |
except Exception as e: | |
logger.warning(f"❌ Failed to create specific TTS provider {provider_name}: {e}") | |
logger.info("🔄 Falling back to default provider selection") | |
# Fall back to default provider selection | |
preferred_providers = self._config.tts.preferred_providers | |
logger.info(f"📋 Preferred providers for fallback: {preferred_providers}") | |
return factory.get_provider_with_fallback(preferred_providers, **kwargs) | |
else: | |
preferred_providers = self._config.tts.preferred_providers | |
logger.info(f"📋 Using preferred providers: {preferred_providers}") | |
return factory.get_provider_with_fallback(preferred_providers, **kwargs) | |
def get_stt_provider(self, provider_name: Optional[str] = None) -> ISpeechRecognitionService: | |
""" | |
Get STT provider with fallback logic. | |
Args: | |
provider_name: Specific provider name or None for default | |
Returns: | |
ISpeechRecognitionService: STT provider instance | |
""" | |
factory = self.resolve(STTProviderFactory) | |
if provider_name: | |
return factory.create_provider(provider_name) | |
else: | |
preferred_provider = self._config.stt.default_model | |
return factory.create_provider_with_fallback(preferred_provider) | |
def get_translation_provider( | |
self, | |
provider_type: Optional[TranslationProviderType] = None, | |
config: Optional[Dict[str, Any]] = None | |
) -> ITranslationService: | |
""" | |
Get translation provider with fallback logic. | |
Args: | |
provider_type: Specific provider type or None for default | |
config: Optional provider configuration | |
Returns: | |
ITranslationService: Translation provider instance | |
""" | |
factory = self.resolve(TranslationProviderFactory) | |
if provider_type: | |
return factory.create_provider(provider_type, config) | |
else: | |
return factory.get_default_provider(config) | |
def clear_scoped_instances(self) -> None: | |
"""Clear all scoped instances.""" | |
with self._lock: | |
# Cleanup scoped instances if they have cleanup methods | |
for instance in self._scoped_instances.values(): | |
self._cleanup_instance(instance) | |
self._scoped_instances.clear() | |
logger.info("Cleared scoped instances") | |
def _cleanup_instance(self, instance: Any) -> None: | |
"""Cleanup instance if it has cleanup methods.""" | |
try: | |
# Try common cleanup method names | |
cleanup_methods = ['cleanup', 'dispose', 'close', '__del__'] | |
for method_name in cleanup_methods: | |
if hasattr(instance, method_name): | |
method = getattr(instance, method_name) | |
if callable(method): | |
method() | |
logger.info(f"Called {method_name} on {type(instance).__name__}") | |
break | |
except Exception as e: | |
logger.warning(f"Failed to cleanup instance {type(instance).__name__}: {e}") | |
def cleanup(self) -> None: | |
"""Cleanup all managed resources.""" | |
with self._lock: | |
logger.info("Starting dependency container cleanup") | |
# Cleanup scoped instances | |
self.clear_scoped_instances() | |
# Cleanup singleton instances | |
for instance in self._singletons.values(): | |
self._cleanup_instance(instance) | |
# Cleanup provider factories | |
if self._tts_factory: | |
try: | |
self._tts_factory.cleanup_providers() | |
except Exception as e: | |
logger.warning(f"Failed to cleanup TTS factory: {e}") | |
if self._translation_factory: | |
try: | |
self._translation_factory.clear_cache() | |
except Exception as e: | |
logger.warning(f"Failed to cleanup translation factory: {e}") | |
# Clear all references | |
self._singletons.clear() | |
self._tts_factory = None | |
self._stt_factory = None | |
self._translation_factory = None | |
logger.info("Dependency container cleanup completed") | |
def is_registered(self, service_type: Type) -> bool: | |
""" | |
Check if a service type is registered. | |
Args: | |
service_type: The service type to check | |
Returns: | |
bool: True if registered, False otherwise | |
""" | |
with self._lock: | |
return service_type in self._services or service_type in self._singletons | |
def get_registered_services(self) -> Dict[str, str]: | |
""" | |
Get information about all registered services. | |
Returns: | |
Dict[str, str]: Mapping of service names to their lifetime | |
""" | |
with self._lock: | |
services_info = {} | |
# Add singleton instances | |
for service_type in self._singletons.keys(): | |
services_info[service_type.__name__] = "singleton (instance)" | |
# Add registered services | |
for service_type, descriptor in self._services.items(): | |
if service_type not in self._singletons: | |
services_info[service_type.__name__] = descriptor.lifetime.value | |
return services_info | |
def create_scope(self) -> 'DependencyScope': | |
""" | |
Create a new dependency scope. | |
Returns: | |
DependencyScope: New scope instance | |
""" | |
return DependencyScope(self) | |
def __enter__(self): | |
"""Context manager entry.""" | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
"""Context manager exit with cleanup.""" | |
self.cleanup() | |
class DependencyScope: | |
"""Scoped dependency container for managing scoped service lifetimes.""" | |
def __init__(self, parent_container: DependencyContainer): | |
""" | |
Initialize dependency scope. | |
Args: | |
parent_container: Parent dependency container | |
""" | |
self._parent = parent_container | |
self._scoped_instances: Dict[Type, Any] = {} | |
self._lock = RLock() # Use RLock for re-entrant locking | |
def resolve(self, service_type: Type[T]) -> T: | |
""" | |
Resolve service within this scope. | |
Args: | |
service_type: The service type to resolve | |
Returns: | |
T: The service instance | |
""" | |
with self._lock: | |
# Check if we have a scoped instance | |
if service_type in self._scoped_instances: | |
return self._scoped_instances[service_type] | |
# Resolve from parent container | |
instance = self._parent.resolve(service_type) | |
# If it's a scoped service, store it in this scope | |
if (service_type in self._parent._services and | |
self._parent._services[service_type].lifetime == ServiceLifetime.SCOPED): | |
self._scoped_instances[service_type] = instance | |
return instance | |
def cleanup(self) -> None: | |
"""Cleanup scoped instances.""" | |
with self._lock: | |
for instance in self._scoped_instances.values(): | |
self._parent._cleanup_instance(instance) | |
self._scoped_instances.clear() | |
def __enter__(self): | |
"""Context manager entry.""" | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
"""Context manager exit with cleanup.""" | |
self.cleanup() | |
# Global container instance | |
_global_container: Optional[DependencyContainer] = None | |
_container_lock = RLock() # Use RLock for re-entrant locking | |
def get_container() -> DependencyContainer: | |
""" | |
Get the global dependency container instance. | |
Returns: | |
DependencyContainer: Global container instance | |
""" | |
global _global_container | |
with _container_lock: | |
if _global_container is None: | |
_global_container = DependencyContainer() | |
logger.info("Created global dependency container") | |
return _global_container | |
def set_container(container: DependencyContainer) -> None: | |
""" | |
Set the global dependency container instance. | |
Args: | |
container: Container instance to set as global | |
""" | |
global _global_container | |
with _container_lock: | |
if _global_container is not None: | |
_global_container.cleanup() | |
_global_container = container | |
logger.info("Set global dependency container") | |
def cleanup_container() -> None: | |
"""Cleanup the global dependency container.""" | |
global _global_container | |
with _container_lock: | |
if _global_container is not None: | |
_global_container.cleanup() | |
_global_container = None | |
logger.info("Cleaned up global dependency container") |