Spaces:
Build error
Build error
"""Unit tests for STTProviderBase abstract class.""" | |
import pytest | |
from unittest.mock import Mock, patch, MagicMock | |
import tempfile | |
from pathlib import Path | |
from src.infrastructure.base.stt_provider_base import STTProviderBase | |
from src.domain.models.audio_content import AudioContent | |
from src.domain.models.text_content import TextContent | |
from src.domain.exceptions import SpeechRecognitionException | |
class ConcreteSTTProvider(STTProviderBase): | |
"""Concrete implementation for testing.""" | |
def __init__(self, provider_name="test", supported_languages=None, available=True, models=None): | |
super().__init__(provider_name, supported_languages) | |
self._available = available | |
self._models = models or ["model1", "model2"] | |
self._should_fail = False | |
self._transcription_result = "Hello world" | |
def _perform_transcription(self, audio_path, model): | |
if self._should_fail: | |
raise Exception("Test transcription error") | |
return self._transcription_result | |
def is_available(self): | |
return self._available | |
def get_available_models(self): | |
return self._models | |
def get_default_model(self): | |
return self._models[0] if self._models else "default" | |
def set_should_fail(self, should_fail): | |
self._should_fail = should_fail | |
def set_transcription_result(self, result): | |
self._transcription_result = result | |
class TestSTTProviderBase: | |
"""Test cases for STTProviderBase abstract class.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
self.provider = ConcreteSTTProvider() | |
self.audio_content = AudioContent( | |
data=b"fake_audio_data", | |
format="wav", | |
sample_rate=16000, | |
duration=5.0, | |
filename="test.wav" | |
) | |
def test_provider_initialization(self): | |
"""Test provider initialization with default values.""" | |
provider = ConcreteSTTProvider("test_provider", ["en", "es"]) | |
assert provider.provider_name == "test_provider" | |
assert provider.supported_languages == ["en", "es"] | |
assert isinstance(provider._temp_dir, Path) | |
assert provider._temp_dir.exists() | |
def test_provider_initialization_no_languages(self): | |
"""Test provider initialization without supported languages.""" | |
provider = ConcreteSTTProvider("test_provider") | |
assert provider.provider_name == "test_provider" | |
assert provider.supported_languages == [] | |
def test_transcribe_success(self, mock_open): | |
"""Test successful transcription.""" | |
mock_file = MagicMock() | |
mock_open.return_value.__enter__.return_value = mock_file | |
result = self.provider.transcribe(self.audio_content, "model1") | |
assert isinstance(result, TextContent) | |
assert result.text == "Hello world" | |
assert result.language == "en" | |
assert result.encoding == "utf-8" | |
def test_transcribe_empty_audio_fails(self): | |
"""Test that empty audio data raises exception.""" | |
empty_audio = AudioContent( | |
data=b"", | |
format="wav", | |
sample_rate=16000, | |
duration=0.1 | |
) | |
with pytest.raises(SpeechRecognitionException, match="Audio data cannot be empty"): | |
self.provider.transcribe(empty_audio, "model1") | |
def test_transcribe_audio_too_long_fails(self): | |
"""Test that audio longer than 1 hour raises exception.""" | |
long_audio = AudioContent( | |
data=b"fake_audio_data", | |
format="wav", | |
sample_rate=16000, | |
duration=3601.0 # Over 1 hour | |
) | |
with pytest.raises(SpeechRecognitionException, match="Audio duration exceeds maximum limit"): | |
self.provider.transcribe(long_audio, "model1") | |
def test_transcribe_audio_too_short_fails(self): | |
"""Test that audio shorter than 100ms raises exception.""" | |
short_audio = AudioContent( | |
data=b"fake_audio_data", | |
format="wav", | |
sample_rate=16000, | |
duration=0.05 # 50ms | |
) | |
with pytest.raises(SpeechRecognitionException, match="Audio duration too short"): | |
self.provider.transcribe(short_audio, "model1") | |
def test_transcribe_invalid_format_fails(self): | |
"""Test that invalid audio format raises exception.""" | |
# Create audio with invalid format by mocking is_valid_format | |
invalid_audio = AudioContent( | |
data=b"fake_audio_data", | |
format="wav", | |
sample_rate=16000, | |
duration=5.0 | |
) | |
with patch.object(invalid_audio, 'is_valid_format', False): | |
with pytest.raises(SpeechRecognitionException, match="Unsupported audio format"): | |
self.provider.transcribe(invalid_audio, "model1") | |
def test_transcribe_provider_error(self, mock_open): | |
"""Test handling of provider-specific errors.""" | |
mock_file = MagicMock() | |
mock_open.return_value.__enter__.return_value = mock_file | |
self.provider.set_should_fail(True) | |
with pytest.raises(SpeechRecognitionException, match="STT transcription failed"): | |
self.provider.transcribe(self.audio_content, "model1") | |
def test_transcribe_cleanup_temp_file(self, mock_unlink, mock_open): | |
"""Test that temporary files are cleaned up.""" | |
mock_file = MagicMock() | |
mock_open.return_value.__enter__.return_value = mock_file | |
self.provider.transcribe(self.audio_content, "model1") | |
# Verify cleanup was attempted | |
mock_unlink.assert_called() | |
def test_preprocess_audio(self, mock_open): | |
"""Test audio preprocessing.""" | |
mock_file = MagicMock() | |
mock_open.return_value.__enter__.return_value = mock_file | |
processed_path = self.provider._preprocess_audio(self.audio_content) | |
assert isinstance(processed_path, Path) | |
assert processed_path.suffix == ".wav" | |
mock_file.write.assert_called_once_with(self.audio_content.data) | |
def test_preprocess_audio_error(self): | |
"""Test audio preprocessing error handling.""" | |
with patch('builtins.open', side_effect=IOError("Test error")): | |
with pytest.raises(SpeechRecognitionException, match="Audio preprocessing failed"): | |
self.provider._preprocess_audio(self.audio_content) | |
def test_convert_audio_format_wav(self, mock_export, mock_from_wav): | |
"""Test audio format conversion for WAV.""" | |
mock_audio = Mock() | |
mock_audio.set_frame_rate.return_value.set_channels.return_value = mock_audio | |
mock_from_wav.return_value = mock_audio | |
test_path = Path("/tmp/test.wav") | |
result_path = self.provider._convert_audio_format(test_path, self.audio_content) | |
mock_from_wav.assert_called_once_with(test_path) | |
mock_audio.set_frame_rate.assert_called_once_with(16000) | |
mock_audio.set_channels.assert_called_once_with(1) | |
mock_export.assert_called_once() | |
def test_convert_audio_format_mp3(self, mock_from_mp3): | |
"""Test audio format conversion for MP3.""" | |
mp3_audio = AudioContent( | |
data=b"fake_mp3_data", | |
format="mp3", | |
sample_rate=44100, | |
duration=5.0 | |
) | |
mock_audio = Mock() | |
mock_audio.set_frame_rate.return_value.set_channels.return_value = mock_audio | |
mock_from_mp3.return_value = mock_audio | |
test_path = Path("/tmp/test.mp3") | |
self.provider._convert_audio_format(test_path, mp3_audio) | |
mock_from_mp3.assert_called_once_with(test_path) | |
def test_convert_audio_format_no_pydub(self): | |
"""Test audio format conversion when pydub is not available.""" | |
test_path = Path("/tmp/test.wav") | |
with patch('pydub.AudioSegment', side_effect=ImportError("pydub not available")): | |
result_path = self.provider._convert_audio_format(test_path, self.audio_content) | |
# Should return original path when pydub is not available | |
assert result_path == test_path | |
def test_convert_audio_format_error(self): | |
"""Test audio format conversion error handling.""" | |
test_path = Path("/tmp/test.wav") | |
with patch('pydub.AudioSegment.from_wav', side_effect=Exception("Conversion error")): | |
result_path = self.provider._convert_audio_format(test_path, self.audio_content) | |
# Should return original path on error | |
assert result_path == test_path | |
def test_detect_language_english(self): | |
"""Test language detection for English text.""" | |
english_text = "The quick brown fox jumps over the lazy dog and it is very nice" | |
language = self.provider._detect_language(english_text) | |
assert language == "en" | |
def test_detect_language_few_indicators(self): | |
"""Test language detection with few English indicators.""" | |
text = "Hello world" | |
language = self.provider._detect_language(text) | |
assert language == "en" | |
def test_detect_language_no_indicators(self): | |
"""Test language detection with no clear indicators.""" | |
text = "xyz abc def" | |
language = self.provider._detect_language(text) | |
assert language == "en" # Should default to English | |
def test_detect_language_error(self): | |
"""Test language detection error handling.""" | |
with patch.object(self.provider, '_detect_language', side_effect=Exception("Detection error")): | |
language = self.provider._detect_language("test") | |
assert language is None | |
def test_ensure_temp_directory(self): | |
"""Test temporary directory creation.""" | |
temp_dir = self.provider._ensure_temp_directory() | |
assert isinstance(temp_dir, Path) | |
assert temp_dir.exists() | |
assert temp_dir.is_dir() | |
assert "stt_temp" in str(temp_dir) | |
def test_cleanup_temp_file(self): | |
"""Test temporary file cleanup.""" | |
# Create a temporary file | |
temp_file = self.provider._temp_dir / "test_file.wav" | |
temp_file.touch() | |
assert temp_file.exists() | |
self.provider._cleanup_temp_file(temp_file) | |
assert not temp_file.exists() | |
def test_cleanup_temp_file_not_exists(self): | |
"""Test cleanup of non-existent file.""" | |
non_existent = Path("/tmp/non_existent_file.wav") | |
# Should not raise exception | |
self.provider._cleanup_temp_file(non_existent) | |
def test_cleanup_temp_file_error(self): | |
"""Test cleanup error handling.""" | |
with patch('pathlib.Path.unlink', side_effect=OSError("Permission denied")): | |
temp_file = Path("/tmp/test.wav") | |
# Should not raise exception | |
self.provider._cleanup_temp_file(temp_file) | |
def test_cleanup_old_temp_files(self, mock_glob, mock_time): | |
"""Test cleanup of old temporary files.""" | |
mock_time.return_value = 1000000 | |
# Mock old file | |
old_file = Mock() | |
old_file.is_file.return_value = True | |
old_file.stat.return_value.st_mtime = 900000 # Old file | |
# Mock recent file | |
recent_file = Mock() | |
recent_file.is_file.return_value = True | |
recent_file.stat.return_value.st_mtime = 999000 # Recent file | |
mock_glob.return_value = [old_file, recent_file] | |
self.provider._cleanup_old_temp_files(24) | |
# Old file should be deleted | |
old_file.unlink.assert_called_once() | |
recent_file.unlink.assert_not_called() | |
def test_cleanup_old_temp_files_error(self): | |
"""Test cleanup error handling.""" | |
with patch.object(self.provider._temp_dir, 'glob', side_effect=Exception("Test error")): | |
# Should not raise exception | |
self.provider._cleanup_old_temp_files() | |
def test_handle_provider_error(self): | |
"""Test provider error handling.""" | |
original_error = ValueError("Original error") | |
with pytest.raises(SpeechRecognitionException) as exc_info: | |
self.provider._handle_provider_error(original_error, "testing") | |
assert "test error during testing: Original error" in str(exc_info.value) | |
assert exc_info.value.__cause__ is original_error | |
def test_handle_provider_error_no_context(self): | |
"""Test provider error handling without context.""" | |
original_error = ValueError("Original error") | |
with pytest.raises(SpeechRecognitionException) as exc_info: | |
self.provider._handle_provider_error(original_error) | |
assert "test error: Original error" in str(exc_info.value) | |
assert exc_info.value.__cause__ is original_error | |
def test_abstract_methods_not_implemented(self): | |
"""Test that abstract methods raise NotImplementedError.""" | |
# Create instance of base class directly (should fail) | |
with pytest.raises(TypeError): | |
STTProviderBase("test") | |
def test_provider_unavailable(self): | |
"""Test behavior when provider is unavailable.""" | |
provider = ConcreteSTTProvider(available=False) | |
assert provider.is_available() is False | |
def test_no_models_available(self): | |
"""Test behavior when no models are available.""" | |
provider = ConcreteSTTProvider(models=[]) | |
assert provider.get_available_models() == [] | |
assert provider.get_default_model() == "default" |