"""Unit tests for STTProviderBase abstract class.""" import pytest from unittest.mock import Mock, patch, MagicMock import tempfile from pathlib import Path from src.infrastructure.base.stt_provider_base import STTProviderBase from src.domain.models.audio_content import AudioContent from src.domain.models.text_content import TextContent from src.domain.exceptions import SpeechRecognitionException class ConcreteSTTProvider(STTProviderBase): """Concrete implementation for testing.""" def __init__(self, provider_name="test", supported_languages=None, available=True, models=None): super().__init__(provider_name, supported_languages) self._available = available self._models = models or ["model1", "model2"] self._should_fail = False self._transcription_result = "Hello world" def _perform_transcription(self, audio_path, model): if self._should_fail: raise Exception("Test transcription error") return self._transcription_result def is_available(self): return self._available def get_available_models(self): return self._models def get_default_model(self): return self._models[0] if self._models else "default" def set_should_fail(self, should_fail): self._should_fail = should_fail def set_transcription_result(self, result): self._transcription_result = result class TestSTTProviderBase: """Test cases for STTProviderBase abstract class.""" def setup_method(self): """Set up test fixtures.""" self.provider = ConcreteSTTProvider() self.audio_content = AudioContent( data=b"fake_audio_data", format="wav", sample_rate=16000, duration=5.0, filename="test.wav" ) def test_provider_initialization(self): """Test provider initialization with default values.""" provider = ConcreteSTTProvider("test_provider", ["en", "es"]) assert provider.provider_name == "test_provider" assert provider.supported_languages == ["en", "es"] assert isinstance(provider._temp_dir, Path) assert provider._temp_dir.exists() def test_provider_initialization_no_languages(self): """Test provider initialization without supported languages.""" provider = ConcreteSTTProvider("test_provider") assert provider.provider_name == "test_provider" assert provider.supported_languages == [] @patch('builtins.open', create=True) def test_transcribe_success(self, mock_open): """Test successful transcription.""" mock_file = MagicMock() mock_open.return_value.__enter__.return_value = mock_file result = self.provider.transcribe(self.audio_content, "model1") assert isinstance(result, TextContent) assert result.text == "Hello world" assert result.language == "en" assert result.encoding == "utf-8" def test_transcribe_empty_audio_fails(self): """Test that empty audio data raises exception.""" empty_audio = AudioContent( data=b"", format="wav", sample_rate=16000, duration=0.1 ) with pytest.raises(SpeechRecognitionException, match="Audio data cannot be empty"): self.provider.transcribe(empty_audio, "model1") def test_transcribe_audio_too_long_fails(self): """Test that audio longer than 1 hour raises exception.""" long_audio = AudioContent( data=b"fake_audio_data", format="wav", sample_rate=16000, duration=3601.0 # Over 1 hour ) with pytest.raises(SpeechRecognitionException, match="Audio duration exceeds maximum limit"): self.provider.transcribe(long_audio, "model1") def test_transcribe_audio_too_short_fails(self): """Test that audio shorter than 100ms raises exception.""" short_audio = AudioContent( data=b"fake_audio_data", format="wav", sample_rate=16000, duration=0.05 # 50ms ) with pytest.raises(SpeechRecognitionException, match="Audio duration too short"): self.provider.transcribe(short_audio, "model1") def test_transcribe_invalid_format_fails(self): """Test that invalid audio format raises exception.""" # Create audio with invalid format by mocking is_valid_format invalid_audio = AudioContent( data=b"fake_audio_data", format="wav", sample_rate=16000, duration=5.0 ) with patch.object(invalid_audio, 'is_valid_format', False): with pytest.raises(SpeechRecognitionException, match="Unsupported audio format"): self.provider.transcribe(invalid_audio, "model1") @patch('builtins.open', create=True) def test_transcribe_provider_error(self, mock_open): """Test handling of provider-specific errors.""" mock_file = MagicMock() mock_open.return_value.__enter__.return_value = mock_file self.provider.set_should_fail(True) with pytest.raises(SpeechRecognitionException, match="STT transcription failed"): self.provider.transcribe(self.audio_content, "model1") @patch('builtins.open', create=True) @patch('pathlib.Path.unlink') def test_transcribe_cleanup_temp_file(self, mock_unlink, mock_open): """Test that temporary files are cleaned up.""" mock_file = MagicMock() mock_open.return_value.__enter__.return_value = mock_file self.provider.transcribe(self.audio_content, "model1") # Verify cleanup was attempted mock_unlink.assert_called() @patch('builtins.open', create=True) def test_preprocess_audio(self, mock_open): """Test audio preprocessing.""" mock_file = MagicMock() mock_open.return_value.__enter__.return_value = mock_file processed_path = self.provider._preprocess_audio(self.audio_content) assert isinstance(processed_path, Path) assert processed_path.suffix == ".wav" mock_file.write.assert_called_once_with(self.audio_content.data) def test_preprocess_audio_error(self): """Test audio preprocessing error handling.""" with patch('builtins.open', side_effect=IOError("Test error")): with pytest.raises(SpeechRecognitionException, match="Audio preprocessing failed"): self.provider._preprocess_audio(self.audio_content) @patch('pydub.AudioSegment.from_wav') @patch('pydub.AudioSegment.export') def test_convert_audio_format_wav(self, mock_export, mock_from_wav): """Test audio format conversion for WAV.""" mock_audio = Mock() mock_audio.set_frame_rate.return_value.set_channels.return_value = mock_audio mock_from_wav.return_value = mock_audio test_path = Path("/tmp/test.wav") result_path = self.provider._convert_audio_format(test_path, self.audio_content) mock_from_wav.assert_called_once_with(test_path) mock_audio.set_frame_rate.assert_called_once_with(16000) mock_audio.set_channels.assert_called_once_with(1) mock_export.assert_called_once() @patch('pydub.AudioSegment.from_mp3') def test_convert_audio_format_mp3(self, mock_from_mp3): """Test audio format conversion for MP3.""" mp3_audio = AudioContent( data=b"fake_mp3_data", format="mp3", sample_rate=44100, duration=5.0 ) mock_audio = Mock() mock_audio.set_frame_rate.return_value.set_channels.return_value = mock_audio mock_from_mp3.return_value = mock_audio test_path = Path("/tmp/test.mp3") self.provider._convert_audio_format(test_path, mp3_audio) mock_from_mp3.assert_called_once_with(test_path) def test_convert_audio_format_no_pydub(self): """Test audio format conversion when pydub is not available.""" test_path = Path("/tmp/test.wav") with patch('pydub.AudioSegment', side_effect=ImportError("pydub not available")): result_path = self.provider._convert_audio_format(test_path, self.audio_content) # Should return original path when pydub is not available assert result_path == test_path def test_convert_audio_format_error(self): """Test audio format conversion error handling.""" test_path = Path("/tmp/test.wav") with patch('pydub.AudioSegment.from_wav', side_effect=Exception("Conversion error")): result_path = self.provider._convert_audio_format(test_path, self.audio_content) # Should return original path on error assert result_path == test_path def test_detect_language_english(self): """Test language detection for English text.""" english_text = "The quick brown fox jumps over the lazy dog and it is very nice" language = self.provider._detect_language(english_text) assert language == "en" def test_detect_language_few_indicators(self): """Test language detection with few English indicators.""" text = "Hello world" language = self.provider._detect_language(text) assert language == "en" def test_detect_language_no_indicators(self): """Test language detection with no clear indicators.""" text = "xyz abc def" language = self.provider._detect_language(text) assert language == "en" # Should default to English def test_detect_language_error(self): """Test language detection error handling.""" with patch.object(self.provider, '_detect_language', side_effect=Exception("Detection error")): language = self.provider._detect_language("test") assert language is None def test_ensure_temp_directory(self): """Test temporary directory creation.""" temp_dir = self.provider._ensure_temp_directory() assert isinstance(temp_dir, Path) assert temp_dir.exists() assert temp_dir.is_dir() assert "stt_temp" in str(temp_dir) def test_cleanup_temp_file(self): """Test temporary file cleanup.""" # Create a temporary file temp_file = self.provider._temp_dir / "test_file.wav" temp_file.touch() assert temp_file.exists() self.provider._cleanup_temp_file(temp_file) assert not temp_file.exists() def test_cleanup_temp_file_not_exists(self): """Test cleanup of non-existent file.""" non_existent = Path("/tmp/non_existent_file.wav") # Should not raise exception self.provider._cleanup_temp_file(non_existent) def test_cleanup_temp_file_error(self): """Test cleanup error handling.""" with patch('pathlib.Path.unlink', side_effect=OSError("Permission denied")): temp_file = Path("/tmp/test.wav") # Should not raise exception self.provider._cleanup_temp_file(temp_file) @patch('time.time') @patch('pathlib.Path.glob') def test_cleanup_old_temp_files(self, mock_glob, mock_time): """Test cleanup of old temporary files.""" mock_time.return_value = 1000000 # Mock old file old_file = Mock() old_file.is_file.return_value = True old_file.stat.return_value.st_mtime = 900000 # Old file # Mock recent file recent_file = Mock() recent_file.is_file.return_value = True recent_file.stat.return_value.st_mtime = 999000 # Recent file mock_glob.return_value = [old_file, recent_file] self.provider._cleanup_old_temp_files(24) # Old file should be deleted old_file.unlink.assert_called_once() recent_file.unlink.assert_not_called() def test_cleanup_old_temp_files_error(self): """Test cleanup error handling.""" with patch.object(self.provider._temp_dir, 'glob', side_effect=Exception("Test error")): # Should not raise exception self.provider._cleanup_old_temp_files() def test_handle_provider_error(self): """Test provider error handling.""" original_error = ValueError("Original error") with pytest.raises(SpeechRecognitionException) as exc_info: self.provider._handle_provider_error(original_error, "testing") assert "test error during testing: Original error" in str(exc_info.value) assert exc_info.value.__cause__ is original_error def test_handle_provider_error_no_context(self): """Test provider error handling without context.""" original_error = ValueError("Original error") with pytest.raises(SpeechRecognitionException) as exc_info: self.provider._handle_provider_error(original_error) assert "test error: Original error" in str(exc_info.value) assert exc_info.value.__cause__ is original_error def test_abstract_methods_not_implemented(self): """Test that abstract methods raise NotImplementedError.""" # Create instance of base class directly (should fail) with pytest.raises(TypeError): STTProviderBase("test") def test_provider_unavailable(self): """Test behavior when provider is unavailable.""" provider = ConcreteSTTProvider(available=False) assert provider.is_available() is False def test_no_models_available(self): """Test behavior when no models are available.""" provider = ConcreteSTTProvider(models=[]) assert provider.get_available_models() == [] assert provider.get_default_model() == "default"