Spaces:
Sleeping
Sleeping
| """Unit tests for TTSProviderBase abstract class.""" | |
| import pytest | |
| from unittest.mock import Mock, patch, MagicMock | |
| import tempfile | |
| from pathlib import Path | |
| import time | |
| from src.infrastructure.base.tts_provider_base import TTSProviderBase | |
| from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
| from src.domain.models.text_content import TextContent | |
| from src.domain.models.voice_settings import VoiceSettings | |
| from src.domain.models.audio_content import AudioContent | |
| from src.domain.models.audio_chunk import AudioChunk | |
| from src.domain.exceptions import SpeechSynthesisException | |
| class ConcreteTTSProvider(TTSProviderBase): | |
| """Concrete implementation for testing.""" | |
| def __init__(self, provider_name="test", supported_languages=None, available=True, voices=None): | |
| super().__init__(provider_name, supported_languages) | |
| self._available = available | |
| self._voices = voices or ["voice1", "voice2"] | |
| self._should_fail = False | |
| def _generate_audio(self, request): | |
| if self._should_fail: | |
| raise Exception("Test error") | |
| return b"fake_audio_data", 44100 | |
| def _generate_audio_stream(self, request): | |
| if self._should_fail: | |
| raise Exception("Test stream error") | |
| chunks = [ | |
| (b"chunk1", 44100, False), | |
| (b"chunk2", 44100, False), | |
| (b"chunk3", 44100, True) | |
| ] | |
| for chunk in chunks: | |
| yield chunk | |
| def is_available(self): | |
| return self._available | |
| def get_available_voices(self): | |
| return self._voices | |
| def set_should_fail(self, should_fail): | |
| self._should_fail = should_fail | |
| class TestTTSProviderBase: | |
| """Test cases for TTSProviderBase abstract class.""" | |
| def setup_method(self): | |
| """Set up test fixtures.""" | |
| self.provider = ConcreteTTSProvider() | |
| self.text_content = TextContent(text="Hello world", language="en") | |
| self.voice_settings = VoiceSettings(voice_id="voice1", speed=1.0, pitch=1.0) | |
| self.request = SpeechSynthesisRequest( | |
| text_content=self.text_content, | |
| voice_settings=self.voice_settings | |
| ) | |
| def test_provider_initialization(self): | |
| """Test provider initialization with default values.""" | |
| provider = ConcreteTTSProvider("test_provider", ["en", "es"]) | |
| assert provider.provider_name == "test_provider" | |
| assert provider.supported_languages == ["en", "es"] | |
| assert isinstance(provider._output_dir, Path) | |
| assert provider._output_dir.exists() | |
| def test_provider_initialization_no_languages(self): | |
| """Test provider initialization without supported languages.""" | |
| provider = ConcreteTTSProvider("test_provider") | |
| assert provider.provider_name == "test_provider" | |
| assert provider.supported_languages == [] | |
| def test_synthesize_success(self): | |
| """Test successful speech synthesis.""" | |
| result = self.provider.synthesize(self.request) | |
| assert isinstance(result, AudioContent) | |
| assert result.data == b"fake_audio_data" | |
| assert result.format == "wav" | |
| assert result.sample_rate == 44100 | |
| assert result.duration > 0 | |
| assert "test_" in result.filename | |
| def test_synthesize_with_language_validation(self): | |
| """Test synthesis with language validation.""" | |
| provider = ConcreteTTSProvider("test", ["en", "es"]) | |
| # Valid language should work | |
| result = provider.synthesize(self.request) | |
| assert isinstance(result, AudioContent) | |
| # Invalid language should fail | |
| invalid_request = SpeechSynthesisRequest( | |
| text_content=TextContent(text="Hola", language="fr"), | |
| voice_settings=self.voice_settings | |
| ) | |
| with pytest.raises(SpeechSynthesisException, match="Language fr not supported"): | |
| provider.synthesize(invalid_request) | |
| def test_synthesize_with_voice_validation(self): | |
| """Test synthesis with voice validation.""" | |
| provider = ConcreteTTSProvider("test", voices=["voice1", "voice2"]) | |
| # Valid voice should work | |
| result = provider.synthesize(self.request) | |
| assert isinstance(result, AudioContent) | |
| # Invalid voice should fail | |
| invalid_request = SpeechSynthesisRequest( | |
| text_content=self.text_content, | |
| voice_settings=VoiceSettings(voice_id="invalid_voice", speed=1.0, pitch=1.0) | |
| ) | |
| with pytest.raises(SpeechSynthesisException, match="Voice invalid_voice not available"): | |
| provider.synthesize(invalid_request) | |
| def test_synthesize_empty_text_fails(self): | |
| """Test that empty text raises exception.""" | |
| empty_request = SpeechSynthesisRequest( | |
| text_content=TextContent(text="", language="en"), | |
| voice_settings=self.voice_settings | |
| ) | |
| with pytest.raises(SpeechSynthesisException, match="Text content cannot be empty"): | |
| self.provider.synthesize(empty_request) | |
| def test_synthesize_whitespace_text_fails(self): | |
| """Test that whitespace-only text raises exception.""" | |
| whitespace_request = SpeechSynthesisRequest( | |
| text_content=TextContent(text=" ", language="en"), | |
| voice_settings=self.voice_settings | |
| ) | |
| with pytest.raises(SpeechSynthesisException, match="Text content cannot be empty"): | |
| self.provider.synthesize(whitespace_request) | |
| def test_synthesize_provider_error(self): | |
| """Test handling of provider-specific errors.""" | |
| self.provider.set_should_fail(True) | |
| with pytest.raises(SpeechSynthesisException, match="TTS synthesis failed"): | |
| self.provider.synthesize(self.request) | |
| def test_synthesize_stream_success(self): | |
| """Test successful streaming synthesis.""" | |
| chunks = list(self.provider.synthesize_stream(self.request)) | |
| assert len(chunks) == 3 | |
| for i, chunk in enumerate(chunks): | |
| assert isinstance(chunk, AudioChunk) | |
| assert chunk.data == f"chunk{i+1}".encode() | |
| assert chunk.format == "wav" | |
| assert chunk.sample_rate == 44100 | |
| assert chunk.chunk_index == i | |
| assert chunk.timestamp > 0 | |
| # Last chunk should be final | |
| assert chunks[-1].is_final is True | |
| assert chunks[0].is_final is False | |
| assert chunks[1].is_final is False | |
| def test_synthesize_stream_provider_error(self): | |
| """Test handling of provider errors in streaming.""" | |
| self.provider.set_should_fail(True) | |
| with pytest.raises(SpeechSynthesisException, match="TTS streaming synthesis failed"): | |
| list(self.provider.synthesize_stream(self.request)) | |
| def test_calculate_duration(self): | |
| """Test audio duration calculation.""" | |
| # Test with standard parameters | |
| audio_data = b"x" * 88200 # 1 second at 44100 Hz, 16-bit, mono | |
| duration = self.provider._calculate_duration(audio_data, 44100) | |
| assert duration == 1.0 | |
| # Test with different sample rate | |
| duration = self.provider._calculate_duration(audio_data, 22050) | |
| assert duration == 2.0 | |
| # Test with stereo | |
| duration = self.provider._calculate_duration(audio_data, 44100, channels=2) | |
| assert duration == 0.5 | |
| # Test with empty data | |
| duration = self.provider._calculate_duration(b"", 44100) | |
| assert duration == 0.0 | |
| # Test with zero sample rate | |
| duration = self.provider._calculate_duration(audio_data, 0) | |
| assert duration == 0.0 | |
| def test_ensure_output_directory(self): | |
| """Test output directory creation.""" | |
| output_dir = self.provider._ensure_output_directory() | |
| assert isinstance(output_dir, Path) | |
| assert output_dir.exists() | |
| assert output_dir.is_dir() | |
| assert "tts_output" in str(output_dir) | |
| def test_generate_output_path(self): | |
| """Test output path generation.""" | |
| path1 = self.provider._generate_output_path() | |
| path2 = self.provider._generate_output_path() | |
| # Paths should be different (due to timestamp) | |
| assert path1 != path2 | |
| assert path1.suffix == ".wav" | |
| assert path2.suffix == ".wav" | |
| assert "test_" in path1.name | |
| assert "test_" in path2.name | |
| # Test with custom prefix and extension | |
| path3 = self.provider._generate_output_path("custom", "mp3") | |
| assert path3.suffix == ".mp3" | |
| assert "custom_" in path3.name | |
| def test_cleanup_temp_files(self, mock_unlink, mock_stat, mock_glob, mock_time): | |
| """Test temporary file cleanup.""" | |
| # Mock current time | |
| mock_time.return_value = 1000000 | |
| # Mock old file | |
| old_file = Mock() | |
| old_file.is_file.return_value = True | |
| old_file.stat.return_value.st_mtime = 900000 # 100000 seconds old | |
| # Mock recent file | |
| recent_file = Mock() | |
| recent_file.is_file.return_value = True | |
| recent_file.stat.return_value.st_mtime = 999000 # 1000 seconds old | |
| mock_glob.return_value = [old_file, recent_file] | |
| # Cleanup with 24 hour limit (86400 seconds) | |
| self.provider._cleanup_temp_files(24) | |
| # Old file should be deleted, recent file should not | |
| old_file.unlink.assert_called_once() | |
| recent_file.unlink.assert_not_called() | |
| def test_cleanup_temp_files_error_handling(self): | |
| """Test cleanup error handling.""" | |
| # Should not raise exception even if cleanup fails | |
| with patch.object(self.provider._output_dir, 'glob', side_effect=Exception("Test error")): | |
| self.provider._cleanup_temp_files() # Should not raise | |
| def test_handle_provider_error(self): | |
| """Test provider error handling.""" | |
| original_error = ValueError("Original error") | |
| with pytest.raises(SpeechSynthesisException) as exc_info: | |
| self.provider._handle_provider_error(original_error, "testing") | |
| assert "test error during testing: Original error" in str(exc_info.value) | |
| assert exc_info.value.__cause__ is original_error | |
| def test_handle_provider_error_no_context(self): | |
| """Test provider error handling without context.""" | |
| original_error = ValueError("Original error") | |
| with pytest.raises(SpeechSynthesisException) as exc_info: | |
| self.provider._handle_provider_error(original_error) | |
| assert "test error: Original error" in str(exc_info.value) | |
| assert exc_info.value.__cause__ is original_error | |
| def test_abstract_methods_not_implemented(self): | |
| """Test that abstract methods raise NotImplementedError.""" | |
| # Create instance of base class directly (should fail) | |
| with pytest.raises(TypeError): | |
| TTSProviderBase("test") | |
| def test_provider_unavailable(self): | |
| """Test behavior when provider is unavailable.""" | |
| provider = ConcreteTTSProvider(available=False) | |
| assert provider.is_available() is False | |
| def test_no_voices_available(self): | |
| """Test behavior when no voices are available.""" | |
| provider = ConcreteTTSProvider(voices=[]) | |
| assert provider.get_available_voices() == [] |