"""Unit tests for TTSProviderBase abstract class.""" import pytest from unittest.mock import Mock, patch, MagicMock import tempfile from pathlib import Path import time from src.infrastructure.base.tts_provider_base import TTSProviderBase from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest from src.domain.models.text_content import TextContent from src.domain.models.voice_settings import VoiceSettings from src.domain.models.audio_content import AudioContent from src.domain.models.audio_chunk import AudioChunk from src.domain.exceptions import SpeechSynthesisException class ConcreteTTSProvider(TTSProviderBase): """Concrete implementation for testing.""" def __init__(self, provider_name="test", supported_languages=None, available=True, voices=None): super().__init__(provider_name, supported_languages) self._available = available self._voices = voices or ["voice1", "voice2"] self._should_fail = False def _generate_audio(self, request): if self._should_fail: raise Exception("Test error") return b"fake_audio_data", 44100 def _generate_audio_stream(self, request): if self._should_fail: raise Exception("Test stream error") chunks = [ (b"chunk1", 44100, False), (b"chunk2", 44100, False), (b"chunk3", 44100, True) ] for chunk in chunks: yield chunk def is_available(self): return self._available def get_available_voices(self): return self._voices def set_should_fail(self, should_fail): self._should_fail = should_fail class TestTTSProviderBase: """Test cases for TTSProviderBase abstract class.""" def setup_method(self): """Set up test fixtures.""" self.provider = ConcreteTTSProvider() self.text_content = TextContent(text="Hello world", language="en") self.voice_settings = VoiceSettings(voice_id="voice1", speed=1.0, pitch=1.0) self.request = SpeechSynthesisRequest( text_content=self.text_content, voice_settings=self.voice_settings ) def test_provider_initialization(self): """Test provider initialization with default values.""" provider = ConcreteTTSProvider("test_provider", ["en", "es"]) assert provider.provider_name == "test_provider" assert provider.supported_languages == ["en", "es"] assert isinstance(provider._output_dir, Path) assert provider._output_dir.exists() def test_provider_initialization_no_languages(self): """Test provider initialization without supported languages.""" provider = ConcreteTTSProvider("test_provider") assert provider.provider_name == "test_provider" assert provider.supported_languages == [] def test_synthesize_success(self): """Test successful speech synthesis.""" result = self.provider.synthesize(self.request) assert isinstance(result, AudioContent) assert result.data == b"fake_audio_data" assert result.format == "wav" assert result.sample_rate == 44100 assert result.duration > 0 assert "test_" in result.filename def test_synthesize_with_language_validation(self): """Test synthesis with language validation.""" provider = ConcreteTTSProvider("test", ["en", "es"]) # Valid language should work result = provider.synthesize(self.request) assert isinstance(result, AudioContent) # Invalid language should fail invalid_request = SpeechSynthesisRequest( text_content=TextContent(text="Hola", language="fr"), voice_settings=self.voice_settings ) with pytest.raises(SpeechSynthesisException, match="Language fr not supported"): provider.synthesize(invalid_request) def test_synthesize_with_voice_validation(self): """Test synthesis with voice validation.""" provider = ConcreteTTSProvider("test", voices=["voice1", "voice2"]) # Valid voice should work result = provider.synthesize(self.request) assert isinstance(result, AudioContent) # Invalid voice should fail invalid_request = SpeechSynthesisRequest( text_content=self.text_content, voice_settings=VoiceSettings(voice_id="invalid_voice", speed=1.0, pitch=1.0) ) with pytest.raises(SpeechSynthesisException, match="Voice invalid_voice not available"): provider.synthesize(invalid_request) def test_synthesize_empty_text_fails(self): """Test that empty text raises exception.""" empty_request = SpeechSynthesisRequest( text_content=TextContent(text="", language="en"), voice_settings=self.voice_settings ) with pytest.raises(SpeechSynthesisException, match="Text content cannot be empty"): self.provider.synthesize(empty_request) def test_synthesize_whitespace_text_fails(self): """Test that whitespace-only text raises exception.""" whitespace_request = SpeechSynthesisRequest( text_content=TextContent(text=" ", language="en"), voice_settings=self.voice_settings ) with pytest.raises(SpeechSynthesisException, match="Text content cannot be empty"): self.provider.synthesize(whitespace_request) def test_synthesize_provider_error(self): """Test handling of provider-specific errors.""" self.provider.set_should_fail(True) with pytest.raises(SpeechSynthesisException, match="TTS synthesis failed"): self.provider.synthesize(self.request) def test_synthesize_stream_success(self): """Test successful streaming synthesis.""" chunks = list(self.provider.synthesize_stream(self.request)) assert len(chunks) == 3 for i, chunk in enumerate(chunks): assert isinstance(chunk, AudioChunk) assert chunk.data == f"chunk{i+1}".encode() assert chunk.format == "wav" assert chunk.sample_rate == 44100 assert chunk.chunk_index == i assert chunk.timestamp > 0 # Last chunk should be final assert chunks[-1].is_final is True assert chunks[0].is_final is False assert chunks[1].is_final is False def test_synthesize_stream_provider_error(self): """Test handling of provider errors in streaming.""" self.provider.set_should_fail(True) with pytest.raises(SpeechSynthesisException, match="TTS streaming synthesis failed"): list(self.provider.synthesize_stream(self.request)) def test_calculate_duration(self): """Test audio duration calculation.""" # Test with standard parameters audio_data = b"x" * 88200 # 1 second at 44100 Hz, 16-bit, mono duration = self.provider._calculate_duration(audio_data, 44100) assert duration == 1.0 # Test with different sample rate duration = self.provider._calculate_duration(audio_data, 22050) assert duration == 2.0 # Test with stereo duration = self.provider._calculate_duration(audio_data, 44100, channels=2) assert duration == 0.5 # Test with empty data duration = self.provider._calculate_duration(b"", 44100) assert duration == 0.0 # Test with zero sample rate duration = self.provider._calculate_duration(audio_data, 0) assert duration == 0.0 def test_ensure_output_directory(self): """Test output directory creation.""" output_dir = self.provider._ensure_output_directory() assert isinstance(output_dir, Path) assert output_dir.exists() assert output_dir.is_dir() assert "tts_output" in str(output_dir) def test_generate_output_path(self): """Test output path generation.""" path1 = self.provider._generate_output_path() path2 = self.provider._generate_output_path() # Paths should be different (due to timestamp) assert path1 != path2 assert path1.suffix == ".wav" assert path2.suffix == ".wav" assert "test_" in path1.name assert "test_" in path2.name # Test with custom prefix and extension path3 = self.provider._generate_output_path("custom", "mp3") assert path3.suffix == ".mp3" assert "custom_" in path3.name @patch('time.time') @patch('pathlib.Path.glob') @patch('pathlib.Path.stat') @patch('pathlib.Path.unlink') def test_cleanup_temp_files(self, mock_unlink, mock_stat, mock_glob, mock_time): """Test temporary file cleanup.""" # Mock current time mock_time.return_value = 1000000 # Mock old file old_file = Mock() old_file.is_file.return_value = True old_file.stat.return_value.st_mtime = 900000 # 100000 seconds old # Mock recent file recent_file = Mock() recent_file.is_file.return_value = True recent_file.stat.return_value.st_mtime = 999000 # 1000 seconds old mock_glob.return_value = [old_file, recent_file] # Cleanup with 24 hour limit (86400 seconds) self.provider._cleanup_temp_files(24) # Old file should be deleted, recent file should not old_file.unlink.assert_called_once() recent_file.unlink.assert_not_called() def test_cleanup_temp_files_error_handling(self): """Test cleanup error handling.""" # Should not raise exception even if cleanup fails with patch.object(self.provider._output_dir, 'glob', side_effect=Exception("Test error")): self.provider._cleanup_temp_files() # Should not raise def test_handle_provider_error(self): """Test provider error handling.""" original_error = ValueError("Original error") with pytest.raises(SpeechSynthesisException) as exc_info: self.provider._handle_provider_error(original_error, "testing") assert "test error during testing: Original error" in str(exc_info.value) assert exc_info.value.__cause__ is original_error def test_handle_provider_error_no_context(self): """Test provider error handling without context.""" original_error = ValueError("Original error") with pytest.raises(SpeechSynthesisException) as exc_info: self.provider._handle_provider_error(original_error) assert "test error: Original error" in str(exc_info.value) assert exc_info.value.__cause__ is original_error def test_abstract_methods_not_implemented(self): """Test that abstract methods raise NotImplementedError.""" # Create instance of base class directly (should fail) with pytest.raises(TypeError): TTSProviderBase("test") def test_provider_unavailable(self): """Test behavior when provider is unavailable.""" provider = ConcreteTTSProvider(available=False) assert provider.is_available() is False def test_no_voices_available(self): """Test behavior when no voices are available.""" provider = ConcreteTTSProvider(voices=[]) assert provider.get_available_voices() == []