Spaces:
Build error
Build error
"""Unit tests for TTSProviderBase abstract class.""" | |
import pytest | |
from unittest.mock import Mock, patch, MagicMock | |
import tempfile | |
from pathlib import Path | |
import time | |
from src.infrastructure.base.tts_provider_base import TTSProviderBase | |
from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
from src.domain.models.text_content import TextContent | |
from src.domain.models.voice_settings import VoiceSettings | |
from src.domain.models.audio_content import AudioContent | |
from src.domain.models.audio_chunk import AudioChunk | |
from src.domain.exceptions import SpeechSynthesisException | |
class ConcreteTTSProvider(TTSProviderBase): | |
"""Concrete implementation for testing.""" | |
def __init__(self, provider_name="test", supported_languages=None, available=True, voices=None): | |
super().__init__(provider_name, supported_languages) | |
self._available = available | |
self._voices = voices or ["voice1", "voice2"] | |
self._should_fail = False | |
def _generate_audio(self, request): | |
if self._should_fail: | |
raise Exception("Test error") | |
return b"fake_audio_data", 44100 | |
def _generate_audio_stream(self, request): | |
if self._should_fail: | |
raise Exception("Test stream error") | |
chunks = [ | |
(b"chunk1", 44100, False), | |
(b"chunk2", 44100, False), | |
(b"chunk3", 44100, True) | |
] | |
for chunk in chunks: | |
yield chunk | |
def is_available(self): | |
return self._available | |
def get_available_voices(self): | |
return self._voices | |
def set_should_fail(self, should_fail): | |
self._should_fail = should_fail | |
class TestTTSProviderBase: | |
"""Test cases for TTSProviderBase abstract class.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
self.provider = ConcreteTTSProvider() | |
self.text_content = TextContent(text="Hello world", language="en") | |
self.voice_settings = VoiceSettings(voice_id="voice1", speed=1.0, pitch=1.0) | |
self.request = SpeechSynthesisRequest( | |
text_content=self.text_content, | |
voice_settings=self.voice_settings | |
) | |
def test_provider_initialization(self): | |
"""Test provider initialization with default values.""" | |
provider = ConcreteTTSProvider("test_provider", ["en", "es"]) | |
assert provider.provider_name == "test_provider" | |
assert provider.supported_languages == ["en", "es"] | |
assert isinstance(provider._output_dir, Path) | |
assert provider._output_dir.exists() | |
def test_provider_initialization_no_languages(self): | |
"""Test provider initialization without supported languages.""" | |
provider = ConcreteTTSProvider("test_provider") | |
assert provider.provider_name == "test_provider" | |
assert provider.supported_languages == [] | |
def test_synthesize_success(self): | |
"""Test successful speech synthesis.""" | |
result = self.provider.synthesize(self.request) | |
assert isinstance(result, AudioContent) | |
assert result.data == b"fake_audio_data" | |
assert result.format == "wav" | |
assert result.sample_rate == 44100 | |
assert result.duration > 0 | |
assert "test_" in result.filename | |
def test_synthesize_with_language_validation(self): | |
"""Test synthesis with language validation.""" | |
provider = ConcreteTTSProvider("test", ["en", "es"]) | |
# Valid language should work | |
result = provider.synthesize(self.request) | |
assert isinstance(result, AudioContent) | |
# Invalid language should fail | |
invalid_request = SpeechSynthesisRequest( | |
text_content=TextContent(text="Hola", language="fr"), | |
voice_settings=self.voice_settings | |
) | |
with pytest.raises(SpeechSynthesisException, match="Language fr not supported"): | |
provider.synthesize(invalid_request) | |
def test_synthesize_with_voice_validation(self): | |
"""Test synthesis with voice validation.""" | |
provider = ConcreteTTSProvider("test", voices=["voice1", "voice2"]) | |
# Valid voice should work | |
result = provider.synthesize(self.request) | |
assert isinstance(result, AudioContent) | |
# Invalid voice should fail | |
invalid_request = SpeechSynthesisRequest( | |
text_content=self.text_content, | |
voice_settings=VoiceSettings(voice_id="invalid_voice", speed=1.0, pitch=1.0) | |
) | |
with pytest.raises(SpeechSynthesisException, match="Voice invalid_voice not available"): | |
provider.synthesize(invalid_request) | |
def test_synthesize_empty_text_fails(self): | |
"""Test that empty text raises exception.""" | |
empty_request = SpeechSynthesisRequest( | |
text_content=TextContent(text="", language="en"), | |
voice_settings=self.voice_settings | |
) | |
with pytest.raises(SpeechSynthesisException, match="Text content cannot be empty"): | |
self.provider.synthesize(empty_request) | |
def test_synthesize_whitespace_text_fails(self): | |
"""Test that whitespace-only text raises exception.""" | |
whitespace_request = SpeechSynthesisRequest( | |
text_content=TextContent(text=" ", language="en"), | |
voice_settings=self.voice_settings | |
) | |
with pytest.raises(SpeechSynthesisException, match="Text content cannot be empty"): | |
self.provider.synthesize(whitespace_request) | |
def test_synthesize_provider_error(self): | |
"""Test handling of provider-specific errors.""" | |
self.provider.set_should_fail(True) | |
with pytest.raises(SpeechSynthesisException, match="TTS synthesis failed"): | |
self.provider.synthesize(self.request) | |
def test_synthesize_stream_success(self): | |
"""Test successful streaming synthesis.""" | |
chunks = list(self.provider.synthesize_stream(self.request)) | |
assert len(chunks) == 3 | |
for i, chunk in enumerate(chunks): | |
assert isinstance(chunk, AudioChunk) | |
assert chunk.data == f"chunk{i+1}".encode() | |
assert chunk.format == "wav" | |
assert chunk.sample_rate == 44100 | |
assert chunk.chunk_index == i | |
assert chunk.timestamp > 0 | |
# Last chunk should be final | |
assert chunks[-1].is_final is True | |
assert chunks[0].is_final is False | |
assert chunks[1].is_final is False | |
def test_synthesize_stream_provider_error(self): | |
"""Test handling of provider errors in streaming.""" | |
self.provider.set_should_fail(True) | |
with pytest.raises(SpeechSynthesisException, match="TTS streaming synthesis failed"): | |
list(self.provider.synthesize_stream(self.request)) | |
def test_calculate_duration(self): | |
"""Test audio duration calculation.""" | |
# Test with standard parameters | |
audio_data = b"x" * 88200 # 1 second at 44100 Hz, 16-bit, mono | |
duration = self.provider._calculate_duration(audio_data, 44100) | |
assert duration == 1.0 | |
# Test with different sample rate | |
duration = self.provider._calculate_duration(audio_data, 22050) | |
assert duration == 2.0 | |
# Test with stereo | |
duration = self.provider._calculate_duration(audio_data, 44100, channels=2) | |
assert duration == 0.5 | |
# Test with empty data | |
duration = self.provider._calculate_duration(b"", 44100) | |
assert duration == 0.0 | |
# Test with zero sample rate | |
duration = self.provider._calculate_duration(audio_data, 0) | |
assert duration == 0.0 | |
def test_ensure_output_directory(self): | |
"""Test output directory creation.""" | |
output_dir = self.provider._ensure_output_directory() | |
assert isinstance(output_dir, Path) | |
assert output_dir.exists() | |
assert output_dir.is_dir() | |
assert "tts_output" in str(output_dir) | |
def test_generate_output_path(self): | |
"""Test output path generation.""" | |
path1 = self.provider._generate_output_path() | |
path2 = self.provider._generate_output_path() | |
# Paths should be different (due to timestamp) | |
assert path1 != path2 | |
assert path1.suffix == ".wav" | |
assert path2.suffix == ".wav" | |
assert "test_" in path1.name | |
assert "test_" in path2.name | |
# Test with custom prefix and extension | |
path3 = self.provider._generate_output_path("custom", "mp3") | |
assert path3.suffix == ".mp3" | |
assert "custom_" in path3.name | |
def test_cleanup_temp_files(self, mock_unlink, mock_stat, mock_glob, mock_time): | |
"""Test temporary file cleanup.""" | |
# Mock current time | |
mock_time.return_value = 1000000 | |
# Mock old file | |
old_file = Mock() | |
old_file.is_file.return_value = True | |
old_file.stat.return_value.st_mtime = 900000 # 100000 seconds old | |
# Mock recent file | |
recent_file = Mock() | |
recent_file.is_file.return_value = True | |
recent_file.stat.return_value.st_mtime = 999000 # 1000 seconds old | |
mock_glob.return_value = [old_file, recent_file] | |
# Cleanup with 24 hour limit (86400 seconds) | |
self.provider._cleanup_temp_files(24) | |
# Old file should be deleted, recent file should not | |
old_file.unlink.assert_called_once() | |
recent_file.unlink.assert_not_called() | |
def test_cleanup_temp_files_error_handling(self): | |
"""Test cleanup error handling.""" | |
# Should not raise exception even if cleanup fails | |
with patch.object(self.provider._output_dir, 'glob', side_effect=Exception("Test error")): | |
self.provider._cleanup_temp_files() # Should not raise | |
def test_handle_provider_error(self): | |
"""Test provider error handling.""" | |
original_error = ValueError("Original error") | |
with pytest.raises(SpeechSynthesisException) as exc_info: | |
self.provider._handle_provider_error(original_error, "testing") | |
assert "test error during testing: Original error" in str(exc_info.value) | |
assert exc_info.value.__cause__ is original_error | |
def test_handle_provider_error_no_context(self): | |
"""Test provider error handling without context.""" | |
original_error = ValueError("Original error") | |
with pytest.raises(SpeechSynthesisException) as exc_info: | |
self.provider._handle_provider_error(original_error) | |
assert "test error: Original error" in str(exc_info.value) | |
assert exc_info.value.__cause__ is original_error | |
def test_abstract_methods_not_implemented(self): | |
"""Test that abstract methods raise NotImplementedError.""" | |
# Create instance of base class directly (should fail) | |
with pytest.raises(TypeError): | |
TTSProviderBase("test") | |
def test_provider_unavailable(self): | |
"""Test behavior when provider is unavailable.""" | |
provider = ConcreteTTSProvider(available=False) | |
assert provider.is_available() is False | |
def test_no_voices_available(self): | |
"""Test behavior when no voices are available.""" | |
provider = ConcreteTTSProvider(voices=[]) | |
assert provider.get_available_voices() == [] |