"""Integration tests for provider integration and switching.""" import pytest from unittest.mock import Mock, patch, MagicMock from typing import Dict, Any, List from src.infrastructure.config.dependency_container import DependencyContainer from src.infrastructure.config.app_config import AppConfig from src.infrastructure.tts.provider_factory import TTSProviderFactory from src.infrastructure.stt.provider_factory import STTProviderFactory from src.infrastructure.translation.provider_factory import TranslationProviderFactory from src.domain.models.audio_content import AudioContent from src.domain.models.text_content import TextContent from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest from src.domain.models.translation_request import TranslationRequest from src.domain.models.voice_settings import VoiceSettings from src.domain.exceptions import ( SpeechRecognitionException, TranslationFailedException, SpeechSynthesisException, ProviderNotAvailableException ) class TestProviderIntegration: """Integration tests for provider integration and switching.""" @pytest.fixture def mock_config(self): """Create mock configuration for testing.""" config = Mock(spec=AppConfig) # TTS configuration config.tts.preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy'] config.tts.fallback_enabled = True config.tts.provider_timeout = 30.0 # STT configuration config.stt.default_model = 'whisper-small' config.stt.fallback_models = ['whisper-medium', 'parakeet'] config.stt.provider_timeout = 60.0 # Translation configuration config.translation.default_provider = 'nllb' config.translation.fallback_enabled = True config.translation.chunk_size = 512 return config @pytest.fixture def dependency_container(self, mock_config): """Create dependency container with mock configuration.""" container = DependencyContainer(mock_config) return container @pytest.fixture def sample_audio_content(self): """Create sample audio content for testing.""" return AudioContent( data=b"fake_audio_data", format="wav", sample_rate=16000, duration=2.5 ) @pytest.fixture def sample_text_content(self): """Create sample text content for testing.""" return TextContent( text="Hello, this is a test message.", language="en" ) def test_tts_provider_switching(self, dependency_container, sample_text_content): """Test switching between different TTS providers.""" voice_settings = VoiceSettings( voice_id="test_voice", speed=1.0, language="en" ) synthesis_request = SpeechSynthesisRequest( text=sample_text_content.text, voice_settings=voice_settings ) # Test each TTS provider providers_to_test = ['kokoro', 'dia', 'cosyvoice2', 'dummy'] for provider_name in providers_to_test: with patch(f'src.infrastructure.tts.{provider_name}_provider') as mock_provider_module: # Mock the provider class mock_provider_class = Mock() mock_provider_instance = Mock() mock_provider_instance.synthesize.return_value = AudioContent( data=f"{provider_name}_audio_data".encode(), format="wav", sample_rate=22050, duration=2.0 ) mock_provider_class.return_value = mock_provider_instance setattr(mock_provider_module, f'{provider_name.title()}Provider', mock_provider_class) # Get provider from container provider = dependency_container.get_tts_provider(provider_name) # Test synthesis result = provider.synthesize(synthesis_request) assert isinstance(result, AudioContent) assert provider_name.encode() in result.data mock_provider_instance.synthesize.assert_called_once() def test_tts_provider_fallback(self, dependency_container, sample_text_content): """Test TTS provider fallback mechanism.""" voice_settings = VoiceSettings( voice_id="test_voice", speed=1.0, language="en" ) synthesis_request = SpeechSynthesisRequest( text=sample_text_content.text, voice_settings=voice_settings ) with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory # Mock first provider to fail, second to succeed mock_provider1 = Mock() mock_provider1.synthesize.side_effect = SpeechSynthesisException("Provider 1 failed") mock_provider2 = Mock() mock_provider2.synthesize.return_value = AudioContent( data=b"fallback_audio_data", format="wav", sample_rate=22050, duration=2.0 ) mock_factory.get_provider_with_fallback.return_value = mock_provider2 # Get provider with fallback provider = dependency_container.get_tts_provider() result = provider.synthesize(synthesis_request) assert isinstance(result, AudioContent) assert b"fallback_audio_data" in result.data def test_stt_provider_switching(self, dependency_container, sample_audio_content): """Test switching between different STT providers.""" providers_to_test = ['whisper-small', 'whisper-medium', 'parakeet'] for provider_name in providers_to_test: with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory mock_provider = Mock() mock_provider.transcribe.return_value = TextContent( text=f"Transcription from {provider_name}", language="en" ) mock_factory.create_provider.return_value = mock_provider # Get provider from container provider = dependency_container.get_stt_provider(provider_name) # Test transcription result = provider.transcribe(sample_audio_content, provider_name) assert isinstance(result, TextContent) assert provider_name in result.text mock_provider.transcribe.assert_called_once() def test_stt_provider_fallback(self, dependency_container, sample_audio_content): """Test STT provider fallback mechanism.""" with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory # Mock first provider to fail, fallback to succeed mock_provider1 = Mock() mock_provider1.transcribe.side_effect = SpeechRecognitionException("Provider 1 failed") mock_provider2 = Mock() mock_provider2.transcribe.return_value = TextContent( text="Fallback transcription successful", language="en" ) mock_factory.create_provider_with_fallback.return_value = mock_provider2 # Get provider with fallback provider = dependency_container.get_stt_provider() result = provider.transcribe(sample_audio_content, "whisper-small") assert isinstance(result, TextContent) assert "Fallback transcription successful" in result.text def test_translation_provider_integration(self, dependency_container): """Test translation provider integration.""" translation_request = TranslationRequest( text="Hello, how are you?", source_language="en", target_language="es" ) with patch('src.infrastructure.translation.provider_factory.TranslationProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory mock_provider = Mock() mock_provider.translate.return_value = TextContent( text="Hola, ¿cómo estás?", language="es" ) mock_factory.get_default_provider.return_value = mock_provider # Get translation provider provider = dependency_container.get_translation_provider() result = provider.translate(translation_request) assert isinstance(result, TextContent) assert result.text == "Hola, ¿cómo estás?" assert result.language == "es" def test_provider_availability_checking(self, dependency_container): """Test provider availability checking.""" with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory # Mock availability checking mock_factory.is_provider_available.side_effect = lambda name: name in ['kokoro', 'dummy'] mock_factory.get_available_providers.return_value = ['kokoro', 'dummy'] # Test availability available_providers = mock_factory.get_available_providers() assert 'kokoro' in available_providers assert 'dummy' in available_providers assert 'dia' not in available_providers # Not available in mock def test_provider_configuration_loading(self, dependency_container, mock_config): """Test provider configuration loading and validation.""" # Test TTS configuration tts_provider = dependency_container.get_tts_provider('dummy') assert tts_provider is not None # Test STT configuration stt_provider = dependency_container.get_stt_provider('whisper-small') assert stt_provider is not None # Test translation configuration translation_provider = dependency_container.get_translation_provider() assert translation_provider is not None def test_provider_error_handling(self, dependency_container, sample_audio_content): """Test provider error handling and recovery.""" with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory # Mock provider that always fails mock_provider = Mock() mock_provider.transcribe.side_effect = SpeechRecognitionException("Provider unavailable") mock_factory.create_provider.return_value = mock_provider # Test error handling provider = dependency_container.get_stt_provider('whisper-small') with pytest.raises(SpeechRecognitionException): provider.transcribe(sample_audio_content, 'whisper-small') def test_provider_performance_monitoring(self, dependency_container, sample_text_content): """Test provider performance monitoring.""" import time voice_settings = VoiceSettings( voice_id="test_voice", speed=1.0, language="en" ) synthesis_request = SpeechSynthesisRequest( text=sample_text_content.text, voice_settings=voice_settings ) with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory mock_provider = Mock() def slow_synthesize(request): time.sleep(0.1) # Simulate processing time return AudioContent( data=b"slow_audio_data", format="wav", sample_rate=22050, duration=2.0 ) mock_provider.synthesize.side_effect = slow_synthesize mock_factory.create_provider.return_value = mock_provider # Measure performance start_time = time.time() provider = dependency_container.get_tts_provider('dummy') result = provider.synthesize(synthesis_request) end_time = time.time() processing_time = end_time - start_time assert isinstance(result, AudioContent) assert processing_time >= 0.1 # Should take at least the sleep time def test_provider_resource_cleanup(self, dependency_container): """Test provider resource cleanup.""" # Get multiple providers tts_provider = dependency_container.get_tts_provider('dummy') stt_provider = dependency_container.get_stt_provider('whisper-small') translation_provider = dependency_container.get_translation_provider() assert tts_provider is not None assert stt_provider is not None assert translation_provider is not None # Test cleanup dependency_container.cleanup() # Verify cleanup was called (would need to mock the actual providers) # This is more of a smoke test to ensure cleanup doesn't crash def test_provider_concurrent_access(self, dependency_container, sample_text_content): """Test concurrent access to providers.""" import threading import queue voice_settings = VoiceSettings( voice_id="test_voice", speed=1.0, language="en" ) synthesis_request = SpeechSynthesisRequest( text=sample_text_content.text, voice_settings=voice_settings ) results_queue = queue.Queue() def synthesize_audio(): try: provider = dependency_container.get_tts_provider('dummy') with patch.object(provider, 'synthesize') as mock_synthesize: mock_synthesize.return_value = AudioContent( data=b"concurrent_audio_data", format="wav", sample_rate=22050, duration=2.0 ) result = provider.synthesize(synthesis_request) results_queue.put(result) except Exception as e: results_queue.put(e) # Start multiple threads threads = [] for _ in range(3): thread = threading.Thread(target=synthesize_audio) threads.append(thread) thread.start() # Wait for completion for thread in threads: thread.join() # Verify results results = [] while not results_queue.empty(): result = results_queue.get() if isinstance(result, Exception): pytest.fail(f"Concurrent access failed: {result}") results.append(result) assert len(results) == 3 for result in results: assert isinstance(result, AudioContent) def test_provider_configuration_updates(self, dependency_container, mock_config): """Test dynamic provider configuration updates.""" # Initial configuration initial_providers = mock_config.tts.preferred_providers assert 'kokoro' in initial_providers # Update configuration mock_config.tts.preferred_providers = ['dia', 'dummy'] # Verify configuration update affects provider selection # (This would require actual implementation of dynamic config updates) updated_providers = mock_config.tts.preferred_providers assert 'dia' in updated_providers assert 'dummy' in updated_providers def test_provider_health_checking(self, dependency_container): """Test provider health checking mechanisms.""" with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory # Mock health check methods mock_factory.check_provider_health.return_value = { 'kokoro': {'status': 'healthy', 'response_time': 0.1}, 'dia': {'status': 'unhealthy', 'error': 'Connection timeout'}, 'dummy': {'status': 'healthy', 'response_time': 0.05} } health_status = mock_factory.check_provider_health() assert health_status['kokoro']['status'] == 'healthy' assert health_status['dia']['status'] == 'unhealthy' assert health_status['dummy']['status'] == 'healthy' def test_provider_load_balancing(self, dependency_container): """Test provider load balancing mechanisms.""" with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: mock_factory = Mock() mock_factory_class.return_value = mock_factory # Mock load balancing provider_calls = {'kokoro': 0, 'dia': 0, 'dummy': 0} def mock_get_provider(name=None): if name is None: # Round-robin selection providers = ['kokoro', 'dia', 'dummy'] selected = min(providers, key=lambda p: provider_calls[p]) provider_calls[selected] += 1 name = selected mock_provider = Mock() mock_provider.name = name return mock_provider mock_factory.create_provider.side_effect = mock_get_provider # Get multiple providers to test load balancing providers = [] for _ in range(6): provider = mock_factory.create_provider() providers.append(provider) # Verify load distribution provider_names = [p.name for p in providers] assert provider_names.count('kokoro') == 2 assert provider_names.count('dia') == 2 assert provider_names.count('dummy') == 2