teachingAssistant / tests /unit /infrastructure /config /test_dependency_container.py
Michael Hu
Create unit tests for infrastructure layer
93dc283
"""Unit tests for DependencyContainer."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from threading import Thread
import time
from src.infrastructure.config.dependency_container import (
DependencyContainer,
DependencyScope,
ServiceDescriptor,
ServiceLifetime,
get_container,
set_container,
cleanup_container
)
from src.infrastructure.config.app_config import AppConfig
from src.infrastructure.tts.provider_factory import TTSProviderFactory
from src.infrastructure.stt.provider_factory import STTProviderFactory
from src.infrastructure.translation.provider_factory import TranslationProviderFactory, TranslationProviderType
from src.domain.interfaces.speech_synthesis import ISpeechSynthesisService
from src.domain.interfaces.speech_recognition import ISpeechRecognitionService
from src.domain.interfaces.translation import ITranslationService
class MockService:
"""Mock service for testing."""
def __init__(self, name="mock", **kwargs):
self.name = name
self.kwargs = kwargs
self.cleanup_called = False
def cleanup(self):
self.cleanup_called = True
class MockServiceWithDispose:
"""Mock service with dispose method."""
def __init__(self, name="mock"):
self.name = name
self.dispose_called = False
def dispose(self):
self.dispose_called = True
def mock_factory(**kwargs):
"""Mock factory function."""
return MockService("factory_created", **kwargs)
class TestServiceDescriptor:
"""Test cases for ServiceDescriptor."""
def test_service_descriptor_creation(self):
"""Test service descriptor creation."""
descriptor = ServiceDescriptor(
service_type=MockService,
implementation=MockService,
lifetime=ServiceLifetime.SINGLETON,
factory_args={'name': 'test'}
)
assert descriptor.service_type == MockService
assert descriptor.implementation == MockService
assert descriptor.lifetime == ServiceLifetime.SINGLETON
assert descriptor.factory_args == {'name': 'test'}
def test_service_descriptor_defaults(self):
"""Test service descriptor with default values."""
descriptor = ServiceDescriptor(
service_type=MockService,
implementation=MockService
)
assert descriptor.lifetime == ServiceLifetime.TRANSIENT
assert descriptor.factory_args == {}
class TestDependencyContainer:
"""Test cases for DependencyContainer."""
def setup_method(self):
"""Set up test fixtures."""
self.container = DependencyContainer()
def teardown_method(self):
"""Clean up after tests."""
self.container.cleanup()
def test_container_initialization(self):
"""Test container initialization."""
assert isinstance(self.container._config, AppConfig)
assert isinstance(self.container._services, dict)
assert isinstance(self.container._singletons, dict)
assert isinstance(self.container._scoped_instances, dict)
# Should have default services registered
assert AppConfig in self.container._singletons
def test_container_initialization_with_config(self):
"""Test container initialization with custom config."""
config = AppConfig()
container = DependencyContainer(config)
assert container._config is config
assert AppConfig in container._singletons
assert container._singletons[AppConfig] is config
def test_register_singleton_class(self):
"""Test registering singleton service with class."""
self.container.register_singleton(MockService, MockService, {'name': 'test'})
assert MockService in self.container._services
descriptor = self.container._services[MockService]
assert descriptor.lifetime == ServiceLifetime.SINGLETON
assert descriptor.factory_args == {'name': 'test'}
def test_register_singleton_instance(self):
"""Test registering singleton service with instance."""
instance = MockService("test_instance")
self.container.register_singleton(MockService, instance)
assert MockService in self.container._singletons
assert self.container._singletons[MockService] is instance
def test_register_singleton_factory(self):
"""Test registering singleton service with factory function."""
self.container.register_singleton(MockService, mock_factory, {'name': 'factory_test'})
service = self.container.resolve(MockService)
assert isinstance(service, MockService)
assert service.name == "factory_created"
assert service.kwargs == {'name': 'factory_test'}
def test_register_transient(self):
"""Test registering transient service."""
self.container.register_transient(MockService, MockService, {'name': 'transient'})
assert MockService in self.container._services
descriptor = self.container._services[MockService]
assert descriptor.lifetime == ServiceLifetime.TRANSIENT
def test_register_scoped(self):
"""Test registering scoped service."""
self.container.register_scoped(MockService, MockService, {'name': 'scoped'})
assert MockService in self.container._services
descriptor = self.container._services[MockService]
assert descriptor.lifetime == ServiceLifetime.SCOPED
def test_resolve_singleton(self):
"""Test resolving singleton service."""
self.container.register_singleton(MockService, MockService, {'name': 'singleton'})
service1 = self.container.resolve(MockService)
service2 = self.container.resolve(MockService)
assert service1 is service2
assert service1.name == 'singleton'
def test_resolve_transient(self):
"""Test resolving transient service."""
self.container.register_transient(MockService, MockService, {'name': 'transient'})
service1 = self.container.resolve(MockService)
service2 = self.container.resolve(MockService)
assert service1 is not service2
assert service1.name == 'transient'
assert service2.name == 'transient'
def test_resolve_scoped(self):
"""Test resolving scoped service."""
self.container.register_scoped(MockService, MockService, {'name': 'scoped'})
service1 = self.container.resolve(MockService)
service2 = self.container.resolve(MockService)
assert service1 is service2 # Same instance within scope
assert service1.name == 'scoped'
def test_resolve_unregistered_service(self):
"""Test resolving unregistered service raises error."""
class UnregisteredService:
pass
with pytest.raises(ValueError, match="Service UnregisteredService is not registered"):
self.container.resolve(UnregisteredService)
def test_resolve_service_creation_error(self):
"""Test handling service creation errors."""
def failing_factory():
raise Exception("Creation failed")
self.container.register_singleton(MockService, failing_factory)
with pytest.raises(Exception, match="Creation failed"):
self.container.resolve(MockService)
def test_thread_safety(self):
"""Test container thread safety."""
self.container.register_singleton(MockService, MockService, {'name': 'thread_test'})
results = []
def resolve_service():
service = self.container.resolve(MockService)
results.append(service)
threads = [Thread(target=resolve_service) for _ in range(10)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# All threads should get the same singleton instance
assert len(results) == 10
assert all(service is results[0] for service in results)
def test_get_tts_provider_default(self):
"""Test getting TTS provider with default settings."""
with patch.object(TTSProviderFactory, 'get_provider_with_fallback') as mock_get:
mock_provider = Mock()
mock_get.return_value = mock_provider
provider = self.container.get_tts_provider()
assert provider is mock_provider
mock_get.assert_called_once()
def test_get_tts_provider_specific(self):
"""Test getting specific TTS provider."""
with patch.object(TTSProviderFactory, 'create_provider') as mock_create:
mock_provider = Mock()
mock_create.return_value = mock_provider
provider = self.container.get_tts_provider('kokoro', lang_code='en')
assert provider is mock_provider
mock_create.assert_called_once_with('kokoro', lang_code='en')
def test_get_stt_provider_default(self):
"""Test getting STT provider with default settings."""
with patch.object(STTProviderFactory, 'create_provider_with_fallback') as mock_get:
mock_provider = Mock()
mock_get.return_value = mock_provider
provider = self.container.get_stt_provider()
assert provider is mock_provider
mock_get.assert_called_once()
def test_get_stt_provider_specific(self):
"""Test getting specific STT provider."""
with patch.object(STTProviderFactory, 'create_provider') as mock_create:
mock_provider = Mock()
mock_create.return_value = mock_provider
provider = self.container.get_stt_provider('whisper')
assert provider is mock_provider
mock_create.assert_called_once_with('whisper')
def test_get_translation_provider_default(self):
"""Test getting translation provider with default settings."""
with patch.object(TranslationProviderFactory, 'get_default_provider') as mock_get:
mock_provider = Mock()
mock_get.return_value = mock_provider
provider = self.container.get_translation_provider()
assert provider is mock_provider
mock_get.assert_called_once_with(None)
def test_get_translation_provider_specific(self):
"""Test getting specific translation provider."""
with patch.object(TranslationProviderFactory, 'create_provider') as mock_create:
mock_provider = Mock()
mock_create.return_value = mock_provider
config = {'model': 'test'}
provider = self.container.get_translation_provider(TranslationProviderType.NLLB, config)
assert provider is mock_provider
mock_create.assert_called_once_with(TranslationProviderType.NLLB, config)
def test_clear_scoped_instances(self):
"""Test clearing scoped instances."""
self.container.register_scoped(MockService, MockService)
# Create scoped instance
service = self.container.resolve(MockService)
assert MockService in self.container._scoped_instances
self.container.clear_scoped_instances()
assert len(self.container._scoped_instances) == 0
assert service.cleanup_called is True
def test_cleanup_instance_with_cleanup_method(self):
"""Test cleanup of instance with cleanup method."""
instance = MockService()
self.container._cleanup_instance(instance)
assert instance.cleanup_called is True
def test_cleanup_instance_with_dispose_method(self):
"""Test cleanup of instance with dispose method."""
instance = MockServiceWithDispose()
self.container._cleanup_instance(instance)
assert instance.dispose_called is True
def test_cleanup_instance_no_cleanup_method(self):
"""Test cleanup of instance without cleanup method."""
instance = object()
# Should not raise exception
self.container._cleanup_instance(instance)
def test_cleanup_instance_error_handling(self):
"""Test cleanup error handling."""
instance = Mock()
instance.cleanup.side_effect = Exception("Cleanup error")
# Should not raise exception
self.container._cleanup_instance(instance)
def test_cleanup_container(self):
"""Test full container cleanup."""
# Register services
self.container.register_singleton(MockService, MockService)
self.container.register_scoped(MockServiceWithDispose, MockServiceWithDispose)
# Create instances
singleton = self.container.resolve(MockService)
scoped = self.container.resolve(MockServiceWithDispose)
# Mock factories
mock_tts_factory = Mock()
mock_translation_factory = Mock()
self.container._tts_factory = mock_tts_factory
self.container._translation_factory = mock_translation_factory
self.container.cleanup()
# Check cleanup was called
assert singleton.cleanup_called is True
assert scoped.dispose_called is True
mock_tts_factory.cleanup_providers.assert_called_once()
mock_translation_factory.clear_cache.assert_called_once()
# Check instances were cleared
assert len(self.container._singletons) == 0
assert len(self.container._scoped_instances) == 0
assert self.container._tts_factory is None
assert self.container._translation_factory is None
def test_cleanup_factory_error_handling(self):
"""Test cleanup error handling for factories."""
mock_tts_factory = Mock()
mock_tts_factory.cleanup_providers.side_effect = Exception("TTS cleanup error")
self.container._tts_factory = mock_tts_factory
# Should not raise exception
self.container.cleanup()
def test_is_registered(self):
"""Test checking if service is registered."""
assert self.container.is_registered(AppConfig) is True # Default registration
assert self.container.is_registered(MockService) is False
self.container.register_singleton(MockService, MockService)
assert self.container.is_registered(MockService) is True
def test_get_registered_services(self):
"""Test getting registered services info."""
self.container.register_singleton(MockService, MockService)
self.container.register_transient(MockServiceWithDispose, MockServiceWithDispose)
services = self.container.get_registered_services()
assert 'AppConfig' in services
assert 'MockService' in services
assert 'MockServiceWithDispose' in services
assert services['MockService'] == 'singleton'
assert services['MockServiceWithDispose'] == 'transient'
def test_create_scope(self):
"""Test creating dependency scope."""
scope = self.container.create_scope()
assert isinstance(scope, DependencyScope)
assert scope._parent is self.container
def test_context_manager(self):
"""Test container as context manager."""
with DependencyContainer() as container:
container.register_singleton(MockService, MockService)
service = container.resolve(MockService)
assert isinstance(service, MockService)
# Cleanup should have been called
assert service.cleanup_called is True
class TestDependencyScope:
"""Test cases for DependencyScope."""
def setup_method(self):
"""Set up test fixtures."""
self.container = DependencyContainer()
self.scope = DependencyScope(self.container)
def teardown_method(self):
"""Clean up after tests."""
self.scope.cleanup()
self.container.cleanup()
def test_scope_initialization(self):
"""Test scope initialization."""
assert self.scope._parent is self.container
assert isinstance(self.scope._scoped_instances, dict)
def test_resolve_singleton_from_parent(self):
"""Test resolving singleton from parent container."""
self.container.register_singleton(MockService, MockService)
service1 = self.scope.resolve(MockService)
service2 = self.scope.resolve(MockService)
assert service1 is service2
assert isinstance(service1, MockService)
def test_resolve_scoped_service(self):
"""Test resolving scoped service within scope."""
self.container.register_scoped(MockService, MockService)
service1 = self.scope.resolve(MockService)
service2 = self.scope.resolve(MockService)
assert service1 is service2 # Same within scope
assert MockService in self.scope._scoped_instances
def test_resolve_transient_service(self):
"""Test resolving transient service."""
self.container.register_transient(MockService, MockService)
service1 = self.scope.resolve(MockService)
service2 = self.scope.resolve(MockService)
assert service1 is not service2 # Different instances
def test_scope_cleanup(self):
"""Test scope cleanup."""
self.container.register_scoped(MockService, MockService)
service = self.scope.resolve(MockService)
assert MockService in self.scope._scoped_instances
self.scope.cleanup()
assert len(self.scope._scoped_instances) == 0
assert service.cleanup_called is True
def test_scope_context_manager(self):
"""Test scope as context manager."""
self.container.register_scoped(MockService, MockService)
with self.container.create_scope() as scope:
service = scope.resolve(MockService)
assert isinstance(service, MockService)
# Cleanup should have been called
assert service.cleanup_called is True
class TestGlobalContainer:
"""Test cases for global container functions."""
def teardown_method(self):
"""Clean up after tests."""
cleanup_container()
def test_get_container_creates_global(self):
"""Test getting global container creates it if not exists."""
container = get_container()
assert isinstance(container, DependencyContainer)
# Second call should return same instance
container2 = get_container()
assert container is container2
def test_set_container(self):
"""Test setting global container."""
custom_container = DependencyContainer()
set_container(custom_container)
container = get_container()
assert container is custom_container
def test_set_container_cleans_up_previous(self):
"""Test setting container cleans up previous one."""
# Get initial container and register service
container1 = get_container()
container1.register_singleton(MockService, MockService)
service = container1.resolve(MockService)
# Set new container
container2 = DependencyContainer()
set_container(container2)
# Previous container should be cleaned up
assert service.cleanup_called is True
# New container should be active
assert get_container() is container2
def test_cleanup_container(self):
"""Test cleaning up global container."""
container = get_container()
container.register_singleton(MockService, MockService)
service = container.resolve(MockService)
cleanup_container()
# Service should be cleaned up
assert service.cleanup_called is True
# New container should be created on next get
new_container = get_container()
assert new_container is not container