Spaces:
Build error
Build error
"""Integration tests for provider integration and switching.""" | |
import pytest | |
from unittest.mock import Mock, patch, MagicMock | |
from typing import Dict, Any, List | |
from src.infrastructure.config.dependency_container import DependencyContainer | |
from src.infrastructure.config.app_config import AppConfig | |
from src.infrastructure.tts.provider_factory import TTSProviderFactory | |
from src.infrastructure.stt.provider_factory import STTProviderFactory | |
from src.infrastructure.translation.provider_factory import TranslationProviderFactory | |
from src.domain.models.audio_content import AudioContent | |
from src.domain.models.text_content import TextContent | |
from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
from src.domain.models.translation_request import TranslationRequest | |
from src.domain.models.voice_settings import VoiceSettings | |
from src.domain.exceptions import ( | |
SpeechRecognitionException, | |
TranslationFailedException, | |
SpeechSynthesisException, | |
ProviderNotAvailableException | |
) | |
class TestProviderIntegration: | |
"""Integration tests for provider integration and switching.""" | |
def mock_config(self): | |
"""Create mock configuration for testing.""" | |
config = Mock(spec=AppConfig) | |
# TTS configuration | |
config.tts.preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy'] | |
config.tts.fallback_enabled = True | |
config.tts.provider_timeout = 30.0 | |
# STT configuration | |
config.stt.default_model = 'whisper-small' | |
config.stt.fallback_models = ['whisper-medium', 'parakeet'] | |
config.stt.provider_timeout = 60.0 | |
# Translation configuration | |
config.translation.default_provider = 'nllb' | |
config.translation.fallback_enabled = True | |
config.translation.chunk_size = 512 | |
return config | |
def dependency_container(self, mock_config): | |
"""Create dependency container with mock configuration.""" | |
container = DependencyContainer(mock_config) | |
return container | |
def sample_audio_content(self): | |
"""Create sample audio content for testing.""" | |
return AudioContent( | |
data=b"fake_audio_data", | |
format="wav", | |
sample_rate=16000, | |
duration=2.5 | |
) | |
def sample_text_content(self): | |
"""Create sample text content for testing.""" | |
return TextContent( | |
text="Hello, this is a test message.", | |
language="en" | |
) | |
def test_tts_provider_switching(self, dependency_container, sample_text_content): | |
"""Test switching between different TTS providers.""" | |
voice_settings = VoiceSettings( | |
voice_id="test_voice", | |
speed=1.0, | |
language="en" | |
) | |
synthesis_request = SpeechSynthesisRequest( | |
text=sample_text_content.text, | |
voice_settings=voice_settings | |
) | |
# Test each TTS provider | |
providers_to_test = ['kokoro', 'dia', 'cosyvoice2', 'dummy'] | |
for provider_name in providers_to_test: | |
with patch(f'src.infrastructure.tts.{provider_name}_provider') as mock_provider_module: | |
# Mock the provider class | |
mock_provider_class = Mock() | |
mock_provider_instance = Mock() | |
mock_provider_instance.synthesize.return_value = AudioContent( | |
data=f"{provider_name}_audio_data".encode(), | |
format="wav", | |
sample_rate=22050, | |
duration=2.0 | |
) | |
mock_provider_class.return_value = mock_provider_instance | |
setattr(mock_provider_module, f'{provider_name.title()}Provider', mock_provider_class) | |
# Get provider from container | |
provider = dependency_container.get_tts_provider(provider_name) | |
# Test synthesis | |
result = provider.synthesize(synthesis_request) | |
assert isinstance(result, AudioContent) | |
assert provider_name.encode() in result.data | |
mock_provider_instance.synthesize.assert_called_once() | |
def test_tts_provider_fallback(self, dependency_container, sample_text_content): | |
"""Test TTS provider fallback mechanism.""" | |
voice_settings = VoiceSettings( | |
voice_id="test_voice", | |
speed=1.0, | |
language="en" | |
) | |
synthesis_request = SpeechSynthesisRequest( | |
text=sample_text_content.text, | |
voice_settings=voice_settings | |
) | |
with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
# Mock first provider to fail, second to succeed | |
mock_provider1 = Mock() | |
mock_provider1.synthesize.side_effect = SpeechSynthesisException("Provider 1 failed") | |
mock_provider2 = Mock() | |
mock_provider2.synthesize.return_value = AudioContent( | |
data=b"fallback_audio_data", | |
format="wav", | |
sample_rate=22050, | |
duration=2.0 | |
) | |
mock_factory.get_provider_with_fallback.return_value = mock_provider2 | |
# Get provider with fallback | |
provider = dependency_container.get_tts_provider() | |
result = provider.synthesize(synthesis_request) | |
assert isinstance(result, AudioContent) | |
assert b"fallback_audio_data" in result.data | |
def test_stt_provider_switching(self, dependency_container, sample_audio_content): | |
"""Test switching between different STT providers.""" | |
providers_to_test = ['whisper-small', 'whisper-medium', 'parakeet'] | |
for provider_name in providers_to_test: | |
with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
mock_provider = Mock() | |
mock_provider.transcribe.return_value = TextContent( | |
text=f"Transcription from {provider_name}", | |
language="en" | |
) | |
mock_factory.create_provider.return_value = mock_provider | |
# Get provider from container | |
provider = dependency_container.get_stt_provider(provider_name) | |
# Test transcription | |
result = provider.transcribe(sample_audio_content, provider_name) | |
assert isinstance(result, TextContent) | |
assert provider_name in result.text | |
mock_provider.transcribe.assert_called_once() | |
def test_stt_provider_fallback(self, dependency_container, sample_audio_content): | |
"""Test STT provider fallback mechanism.""" | |
with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
# Mock first provider to fail, fallback to succeed | |
mock_provider1 = Mock() | |
mock_provider1.transcribe.side_effect = SpeechRecognitionException("Provider 1 failed") | |
mock_provider2 = Mock() | |
mock_provider2.transcribe.return_value = TextContent( | |
text="Fallback transcription successful", | |
language="en" | |
) | |
mock_factory.create_provider_with_fallback.return_value = mock_provider2 | |
# Get provider with fallback | |
provider = dependency_container.get_stt_provider() | |
result = provider.transcribe(sample_audio_content, "whisper-small") | |
assert isinstance(result, TextContent) | |
assert "Fallback transcription successful" in result.text | |
def test_translation_provider_integration(self, dependency_container): | |
"""Test translation provider integration.""" | |
translation_request = TranslationRequest( | |
text="Hello, how are you?", | |
source_language="en", | |
target_language="es" | |
) | |
with patch('src.infrastructure.translation.provider_factory.TranslationProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
mock_provider = Mock() | |
mock_provider.translate.return_value = TextContent( | |
text="Hola, ¿cómo estás?", | |
language="es" | |
) | |
mock_factory.get_default_provider.return_value = mock_provider | |
# Get translation provider | |
provider = dependency_container.get_translation_provider() | |
result = provider.translate(translation_request) | |
assert isinstance(result, TextContent) | |
assert result.text == "Hola, ¿cómo estás?" | |
assert result.language == "es" | |
def test_provider_availability_checking(self, dependency_container): | |
"""Test provider availability checking.""" | |
with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
# Mock availability checking | |
mock_factory.is_provider_available.side_effect = lambda name: name in ['kokoro', 'dummy'] | |
mock_factory.get_available_providers.return_value = ['kokoro', 'dummy'] | |
# Test availability | |
available_providers = mock_factory.get_available_providers() | |
assert 'kokoro' in available_providers | |
assert 'dummy' in available_providers | |
assert 'dia' not in available_providers # Not available in mock | |
def test_provider_configuration_loading(self, dependency_container, mock_config): | |
"""Test provider configuration loading and validation.""" | |
# Test TTS configuration | |
tts_provider = dependency_container.get_tts_provider('dummy') | |
assert tts_provider is not None | |
# Test STT configuration | |
stt_provider = dependency_container.get_stt_provider('whisper-small') | |
assert stt_provider is not None | |
# Test translation configuration | |
translation_provider = dependency_container.get_translation_provider() | |
assert translation_provider is not None | |
def test_provider_error_handling(self, dependency_container, sample_audio_content): | |
"""Test provider error handling and recovery.""" | |
with patch('src.infrastructure.stt.provider_factory.STTProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
# Mock provider that always fails | |
mock_provider = Mock() | |
mock_provider.transcribe.side_effect = SpeechRecognitionException("Provider unavailable") | |
mock_factory.create_provider.return_value = mock_provider | |
# Test error handling | |
provider = dependency_container.get_stt_provider('whisper-small') | |
with pytest.raises(SpeechRecognitionException): | |
provider.transcribe(sample_audio_content, 'whisper-small') | |
def test_provider_performance_monitoring(self, dependency_container, sample_text_content): | |
"""Test provider performance monitoring.""" | |
import time | |
voice_settings = VoiceSettings( | |
voice_id="test_voice", | |
speed=1.0, | |
language="en" | |
) | |
synthesis_request = SpeechSynthesisRequest( | |
text=sample_text_content.text, | |
voice_settings=voice_settings | |
) | |
with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
mock_provider = Mock() | |
def slow_synthesize(request): | |
time.sleep(0.1) # Simulate processing time | |
return AudioContent( | |
data=b"slow_audio_data", | |
format="wav", | |
sample_rate=22050, | |
duration=2.0 | |
) | |
mock_provider.synthesize.side_effect = slow_synthesize | |
mock_factory.create_provider.return_value = mock_provider | |
# Measure performance | |
start_time = time.time() | |
provider = dependency_container.get_tts_provider('dummy') | |
result = provider.synthesize(synthesis_request) | |
end_time = time.time() | |
processing_time = end_time - start_time | |
assert isinstance(result, AudioContent) | |
assert processing_time >= 0.1 # Should take at least the sleep time | |
def test_provider_resource_cleanup(self, dependency_container): | |
"""Test provider resource cleanup.""" | |
# Get multiple providers | |
tts_provider = dependency_container.get_tts_provider('dummy') | |
stt_provider = dependency_container.get_stt_provider('whisper-small') | |
translation_provider = dependency_container.get_translation_provider() | |
assert tts_provider is not None | |
assert stt_provider is not None | |
assert translation_provider is not None | |
# Test cleanup | |
dependency_container.cleanup() | |
# Verify cleanup was called (would need to mock the actual providers) | |
# This is more of a smoke test to ensure cleanup doesn't crash | |
def test_provider_concurrent_access(self, dependency_container, sample_text_content): | |
"""Test concurrent access to providers.""" | |
import threading | |
import queue | |
voice_settings = VoiceSettings( | |
voice_id="test_voice", | |
speed=1.0, | |
language="en" | |
) | |
synthesis_request = SpeechSynthesisRequest( | |
text=sample_text_content.text, | |
voice_settings=voice_settings | |
) | |
results_queue = queue.Queue() | |
def synthesize_audio(): | |
try: | |
provider = dependency_container.get_tts_provider('dummy') | |
with patch.object(provider, 'synthesize') as mock_synthesize: | |
mock_synthesize.return_value = AudioContent( | |
data=b"concurrent_audio_data", | |
format="wav", | |
sample_rate=22050, | |
duration=2.0 | |
) | |
result = provider.synthesize(synthesis_request) | |
results_queue.put(result) | |
except Exception as e: | |
results_queue.put(e) | |
# Start multiple threads | |
threads = [] | |
for _ in range(3): | |
thread = threading.Thread(target=synthesize_audio) | |
threads.append(thread) | |
thread.start() | |
# Wait for completion | |
for thread in threads: | |
thread.join() | |
# Verify results | |
results = [] | |
while not results_queue.empty(): | |
result = results_queue.get() | |
if isinstance(result, Exception): | |
pytest.fail(f"Concurrent access failed: {result}") | |
results.append(result) | |
assert len(results) == 3 | |
for result in results: | |
assert isinstance(result, AudioContent) | |
def test_provider_configuration_updates(self, dependency_container, mock_config): | |
"""Test dynamic provider configuration updates.""" | |
# Initial configuration | |
initial_providers = mock_config.tts.preferred_providers | |
assert 'kokoro' in initial_providers | |
# Update configuration | |
mock_config.tts.preferred_providers = ['dia', 'dummy'] | |
# Verify configuration update affects provider selection | |
# (This would require actual implementation of dynamic config updates) | |
updated_providers = mock_config.tts.preferred_providers | |
assert 'dia' in updated_providers | |
assert 'dummy' in updated_providers | |
def test_provider_health_checking(self, dependency_container): | |
"""Test provider health checking mechanisms.""" | |
with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
# Mock health check methods | |
mock_factory.check_provider_health.return_value = { | |
'kokoro': {'status': 'healthy', 'response_time': 0.1}, | |
'dia': {'status': 'unhealthy', 'error': 'Connection timeout'}, | |
'dummy': {'status': 'healthy', 'response_time': 0.05} | |
} | |
health_status = mock_factory.check_provider_health() | |
assert health_status['kokoro']['status'] == 'healthy' | |
assert health_status['dia']['status'] == 'unhealthy' | |
assert health_status['dummy']['status'] == 'healthy' | |
def test_provider_load_balancing(self, dependency_container): | |
"""Test provider load balancing mechanisms.""" | |
with patch('src.infrastructure.tts.provider_factory.TTSProviderFactory') as mock_factory_class: | |
mock_factory = Mock() | |
mock_factory_class.return_value = mock_factory | |
# Mock load balancing | |
provider_calls = {'kokoro': 0, 'dia': 0, 'dummy': 0} | |
def mock_get_provider(name=None): | |
if name is None: | |
# Round-robin selection | |
providers = ['kokoro', 'dia', 'dummy'] | |
selected = min(providers, key=lambda p: provider_calls[p]) | |
provider_calls[selected] += 1 | |
name = selected | |
mock_provider = Mock() | |
mock_provider.name = name | |
return mock_provider | |
mock_factory.create_provider.side_effect = mock_get_provider | |
# Get multiple providers to test load balancing | |
providers = [] | |
for _ in range(6): | |
provider = mock_factory.create_provider() | |
providers.append(provider) | |
# Verify load distribution | |
provider_names = [p.name for p in providers] | |
assert provider_names.count('kokoro') == 2 | |
assert provider_names.count('dia') == 2 | |
assert provider_names.count('dummy') == 2 |