Spaces:
Build error
Build error
| """Integration tests for provider integration and switching.""" | |
| import pytest | |
| from unittest.mock import Mock, patch, MagicMock | |
| from typing import Dict, Any, List | |
| from src.infrastructure.config.dependency_container import DependencyContainer | |
| from src.infrastructure.config.app_config import AppConfig | |
| from src.infrastructure.tts.provider_factory import TTSProviderFactory | |
| from src.infrastructure.stt.provider_factory import STTProviderFactory | |
| from src.infrastructure.translation.provider_factory import TranslationProviderFactory | |
| from src.domain.models.audio_content import AudioContent | |
| from src.domain.models.text_content import TextContent | |
| from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
| from src.domain.models.translation_request import TranslationRequest | |
| from src.domain.models.voice_settings import VoiceSettings | |
| from src.domain.exceptions import ( | |
| SpeechRecognitionException, | |
| TranslationFailedException, | |
| SpeechSynthesisException, | |
| ProviderNotAvailableException | |
| ) | |
| class TestProviderIntegration: | |
| """Integration tests for provider integration and switching.""" | |
| def mock_config(self): | |
| """Create mock configuration for testing.""" | |
| config = Mock(spec=AppConfig) | |
| # TTS configuration | |
| config.tts.preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy'] | |
| config.tts.fallback_enabled = True | |
| config.tts.provider_timeout = 30.0 | |
| # STT configuration | |
| config.stt.default_model = 'whisper-small' | |
| config.stt.fallback_models = ['whisper-medium', 'parakeet'] | |
| config.stt.provider_timeout = 60.0 | |
| # Translation configuration | |
| config.translation.default_provider = 'nllb' | |
| config.translation.fallback_enabled = True | |
| config.translation.chunk_size = 512 | |
| return config | |
| def dependency_container(self, mock_config): | |
| """Create dependency container with mock configuration.""" | |
| container = DependencyContainer(mock_config) | |
| return container | |
| def sample_audio_content(self): | |
| """Create sample audio content for testing.""" | |
| return AudioContent( | |
| data=b"fake_audio_data", | |
| format="wav", | |
| sample_rate=16000, | |
| duration=2.5 | |
| ) | |
| def sample_text_content(self): | |
| """Create sample text content for testing.""" | |
| return TextContent( | |
| text="Hello, this is a test message.", | |
| language="en" | |
| ) | |
| def test_tts_provider_switching(self, dependency_container, sample_text_content): | |
| """Test switching between different TTS providers.""" | |
| voice_settings = VoiceSettings( | |
| voice_id="test_voice", | |
| speed=1.0, | |
| language="en" | |
| ) | |
| synthesis_request = SpeechSynthesisRequest( | |
| text=sample_text_content.text, | |
| voice_settings=voice_settings | |
| ) | |
| # Test each TTS provider | |
| providers_to_test = ['kokoro', 'dia', 'cosyvoice2', 'dummy'] | |
| for provider_name in providers_to_test: | |
| with patch(f'src.infrastructure.tts.{provider_name}_provider') as mock_provider_module: | |
| # Mock the provider class | |
| mock_provider_class = Mock() | |
| mock_provider_instance = Mock() | |
| mock_provider_instance.synthesize.return_value = AudioContent( | |
| data=f"{provider_name}_audio_data".encode(), | |
| format="wav", | |
| sample_rate=22050, | |
| duration=2.0 | |
| ) | |
| mock_provider_class.return_value = mock_provider_instance | |
| setattr(mock_provider_module, f'{provider_name.title()}Provider', mock_provider_class) | |
| # Get provider from container | |
| provider = dependency_container.get_tts_provider(provider_name) | |
| # Test synthesis | |
| result = provider.synthesize(synthesis_request) | |
| assert isinstance(result, AudioContent) | |
| assert provider_name.encode() in result.data | |
| mock_provider_instance.synthesize.assert_called_once() | |
| def test_tts_provider_fallback(self, dependency_container, sample_text_content): | |
| """Test TTS provider fallback mechanism.""" | |
| voice_settings = VoiceSettings( | |
| voice_id="test_voice", | |
| speed=1.0, | |
| language="en" | |
| ) | |
| synthesis_request = SpeechSynthesisRequest( | |
| text=sample_text_content.text, | |
| voice_settings=voice_settings | |
| ) | |
| with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| # Mock first provider to fail, second to succeed | |
| mock_provider1 = Mock() | |
| mock_provider1.synthesize.side_effect = SpeechSynthesisException("Provider 1 failed") | |
| mock_provider2 = Mock() | |
| mock_provider2.synthesize.return_value = AudioContent( | |
| data=b"fallback_audio_data", | |
| format="wav", | |
| sample_rate=22050, | |
| duration=2.0 | |
| ) | |
| mock_factory.get_provider_with_fallback.return_value = mock_provider2 | |
| # Get provider with fallback | |
| provider = dependency_container.get_tts_provider() | |
| result = provider.synthesize(synthesis_request) | |
| assert isinstance(result, AudioContent) | |
| assert b"fallback_audio_data" in result.data | |
| def test_stt_provider_switching(self, dependency_container, sample_audio_content): | |
| """Test switching between different STT providers.""" | |
| providers_to_test = ['whisper-small', 'whisper-medium', 'parakeet'] | |
| for provider_name in providers_to_test: | |
| with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| mock_provider = Mock() | |
| mock_provider.transcribe.return_value = TextContent( | |
| text=f"Transcription from {provider_name}", | |
| language="en" | |
| ) | |
| mock_factory.create_provider.return_value = mock_provider | |
| # Get provider from container | |
| provider = dependency_container.get_stt_provider(provider_name) | |
| # Test transcription | |
| result = provider.transcribe(sample_audio_content, provider_name) | |
| assert isinstance(result, TextContent) | |
| assert provider_name in result.text | |
| mock_provider.transcribe.assert_called_once() | |
| def test_stt_provider_fallback(self, dependency_container, sample_audio_content): | |
| """Test STT provider fallback mechanism.""" | |
| with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| # Mock first provider to fail, fallback to succeed | |
| mock_provider1 = Mock() | |
| mock_provider1.transcribe.side_effect = SpeechRecognitionException("Provider 1 failed") | |
| mock_provider2 = Mock() | |
| mock_provider2.transcribe.return_value = TextContent( | |
| text="Fallback transcription successful", | |
| language="en" | |
| ) | |
| mock_factory.create_provider_with_fallback.return_value = mock_provider2 | |
| # Get provider with fallback | |
| provider = dependency_container.get_stt_provider() | |
| result = provider.transcribe(sample_audio_content, "whisper-small") | |
| assert isinstance(result, TextContent) | |
| assert "Fallback transcription successful" in result.text | |
| def test_translation_provider_integration(self, dependency_container): | |
| """Test translation provider integration.""" | |
| translation_request = TranslationRequest( | |
| text="Hello, how are you?", | |
| source_language="en", | |
| target_language="es" | |
| ) | |
| with patch('src.infrastructure.translation.provider_factory.TranslationProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| mock_provider = Mock() | |
| mock_provider.translate.return_value = TextContent( | |
| text="Hola, ¿cómo estás?", | |
| language="es" | |
| ) | |
| mock_factory.get_default_provider.return_value = mock_provider | |
| # Get translation provider | |
| provider = dependency_container.get_translation_provider() | |
| result = provider.translate(translation_request) | |
| assert isinstance(result, TextContent) | |
| assert result.text == "Hola, ¿cómo estás?" | |
| assert result.language == "es" | |
| def test_provider_availability_checking(self, dependency_container): | |
| """Test provider availability checking.""" | |
| with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| # Mock availability checking | |
| mock_factory.is_provider_available.side_effect = lambda name: name in ['kokoro', 'dummy'] | |
| mock_factory.get_available_providers.return_value = ['kokoro', 'dummy'] | |
| # Test availability | |
| available_providers = mock_factory.get_available_providers() | |
| assert 'kokoro' in available_providers | |
| assert 'dummy' in available_providers | |
| assert 'dia' not in available_providers # Not available in mock | |
| def test_provider_configuration_loading(self, dependency_container, mock_config): | |
| """Test provider configuration loading and validation.""" | |
| # Test TTS configuration | |
| tts_provider = dependency_container.get_tts_provider('dummy') | |
| assert tts_provider is not None | |
| # Test STT configuration | |
| stt_provider = dependency_container.get_stt_provider('whisper-small') | |
| assert stt_provider is not None | |
| # Test translation configuration | |
| translation_provider = dependency_container.get_translation_provider() | |
| assert translation_provider is not None | |
| def test_provider_error_handling(self, dependency_container, sample_audio_content): | |
| """Test provider error handling and recovery.""" | |
| with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| # Mock provider that always fails | |
| mock_provider = Mock() | |
| mock_provider.transcribe.side_effect = SpeechRecognitionException("Provider unavailable") | |
| mock_factory.create_provider.return_value = mock_provider | |
| # Test error handling | |
| provider = dependency_container.get_stt_provider('whisper-small') | |
| with pytest.raises(SpeechRecognitionException): | |
| provider.transcribe(sample_audio_content, 'whisper-small') | |
| def test_provider_performance_monitoring(self, dependency_container, sample_text_content): | |
| """Test provider performance monitoring.""" | |
| import time | |
| voice_settings = VoiceSettings( | |
| voice_id="test_voice", | |
| speed=1.0, | |
| language="en" | |
| ) | |
| synthesis_request = SpeechSynthesisRequest( | |
| text=sample_text_content.text, | |
| voice_settings=voice_settings | |
| ) | |
| with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| mock_provider = Mock() | |
| def slow_synthesize(request): | |
| time.sleep(0.1) # Simulate processing time | |
| return AudioContent( | |
| data=b"slow_audio_data", | |
| format="wav", | |
| sample_rate=22050, | |
| duration=2.0 | |
| ) | |
| mock_provider.synthesize.side_effect = slow_synthesize | |
| mock_factory.create_provider.return_value = mock_provider | |
| # Measure performance | |
| start_time = time.time() | |
| provider = dependency_container.get_tts_provider('dummy') | |
| result = provider.synthesize(synthesis_request) | |
| end_time = time.time() | |
| processing_time = end_time - start_time | |
| assert isinstance(result, AudioContent) | |
| assert processing_time >= 0.1 # Should take at least the sleep time | |
| def test_provider_resource_cleanup(self, dependency_container): | |
| """Test provider resource cleanup.""" | |
| # Get multiple providers | |
| tts_provider = dependency_container.get_tts_provider('dummy') | |
| stt_provider = dependency_container.get_stt_provider('whisper-small') | |
| translation_provider = dependency_container.get_translation_provider() | |
| assert tts_provider is not None | |
| assert stt_provider is not None | |
| assert translation_provider is not None | |
| # Test cleanup | |
| dependency_container.cleanup() | |
| # Verify cleanup was called (would need to mock the actual providers) | |
| # This is more of a smoke test to ensure cleanup doesn't crash | |
| def test_provider_concurrent_access(self, dependency_container, sample_text_content): | |
| """Test concurrent access to providers.""" | |
| import threading | |
| import queue | |
| voice_settings = VoiceSettings( | |
| voice_id="test_voice", | |
| speed=1.0, | |
| language="en" | |
| ) | |
| synthesis_request = SpeechSynthesisRequest( | |
| text=sample_text_content.text, | |
| voice_settings=voice_settings | |
| ) | |
| results_queue = queue.Queue() | |
| def synthesize_audio(): | |
| try: | |
| provider = dependency_container.get_tts_provider('dummy') | |
| with patch.object(provider, 'synthesize') as mock_synthesize: | |
| mock_synthesize.return_value = AudioContent( | |
| data=b"concurrent_audio_data", | |
| format="wav", | |
| sample_rate=22050, | |
| duration=2.0 | |
| ) | |
| result = provider.synthesize(synthesis_request) | |
| results_queue.put(result) | |
| except Exception as e: | |
| results_queue.put(e) | |
| # Start multiple threads | |
| threads = [] | |
| for _ in range(3): | |
| thread = threading.Thread(target=synthesize_audio) | |
| threads.append(thread) | |
| thread.start() | |
| # Wait for completion | |
| for thread in threads: | |
| thread.join() | |
| # Verify results | |
| results = [] | |
| while not results_queue.empty(): | |
| result = results_queue.get() | |
| if isinstance(result, Exception): | |
| pytest.fail(f"Concurrent access failed: {result}") | |
| results.append(result) | |
| assert len(results) == 3 | |
| for result in results: | |
| assert isinstance(result, AudioContent) | |
| def test_provider_configuration_updates(self, dependency_container, mock_config): | |
| """Test dynamic provider configuration updates.""" | |
| # Initial configuration | |
| initial_providers = mock_config.tts.preferred_providers | |
| assert 'kokoro' in initial_providers | |
| # Update configuration | |
| mock_config.tts.preferred_providers = ['dia', 'dummy'] | |
| # Verify configuration update affects provider selection | |
| # (This would require actual implementation of dynamic config updates) | |
| updated_providers = mock_config.tts.preferred_providers | |
| assert 'dia' in updated_providers | |
| assert 'dummy' in updated_providers | |
| def test_provider_health_checking(self, dependency_container): | |
| """Test provider health checking mechanisms.""" | |
| with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| # Mock health check methods | |
| mock_factory.check_provider_health.return_value = { | |
| 'kokoro': {'status': 'healthy', 'response_time': 0.1}, | |
| 'dia': {'status': 'unhealthy', 'error': 'Connection timeout'}, | |
| 'dummy': {'status': 'healthy', 'response_time': 0.05} | |
| } | |
| health_status = mock_factory.check_provider_health() | |
| assert health_status['kokoro']['status'] == 'healthy' | |
| assert health_status['dia']['status'] == 'unhealthy' | |
| assert health_status['dummy']['status'] == 'healthy' | |
| def test_provider_load_balancing(self, dependency_container): | |
| """Test provider load balancing mechanisms.""" | |
| with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
| mock_factory = Mock() | |
| mock_factory_class.return_value = mock_factory | |
| # Mock load balancing | |
| provider_calls = {'kokoro': 0, 'dia': 0, 'dummy': 0} | |
| def mock_get_provider(name=None): | |
| if name is None: | |
| # Round-robin selection | |
| providers = ['kokoro', 'dia', 'dummy'] | |
| selected = min(providers, key=lambda p: provider_calls[p]) | |
| provider_calls[selected] += 1 | |
| name = selected | |
| mock_provider = Mock() | |
| mock_provider.name = name | |
| return mock_provider | |
| mock_factory.create_provider.side_effect = mock_get_provider | |
| # Get multiple providers to test load balancing | |
| providers = [] | |
| for _ in range(6): | |
| provider = mock_factory.create_provider() | |
| providers.append(provider) | |
| # Verify load distribution | |
| provider_names = [p.name for p in providers] | |
| assert provider_names.count('kokoro') == 2 | |
| assert provider_names.count('dia') == 2 | |
| assert provider_names.count('dummy') == 2 |