teachingAssistant / tests /unit /domain /services /test_audio_processing_service.py
Michael Hu
Implement domain services
6aea21a
"""Tests for AudioProcessingService."""
import pytest
from unittest.mock import Mock, MagicMock
from src.domain.services.audio_processing_service import AudioProcessingService
from src.domain.models.audio_content import AudioContent
from src.domain.models.text_content import TextContent
from src.domain.models.voice_settings import VoiceSettings
from src.domain.models.translation_request import TranslationRequest
from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest
from src.domain.exceptions import (
AudioProcessingException,
SpeechRecognitionException,
TranslationFailedException,
SpeechSynthesisException
)
class TestAudioProcessingService:
"""Test cases for AudioProcessingService."""
@pytest.fixture
def mock_stt_service(self):
"""Mock speech recognition service."""
return Mock()
@pytest.fixture
def mock_translation_service(self):
"""Mock translation service."""
return Mock()
@pytest.fixture
def mock_tts_service(self):
"""Mock speech synthesis service."""
return Mock()
@pytest.fixture
def audio_processing_service(self, mock_stt_service, mock_translation_service, mock_tts_service):
"""AudioProcessingService instance with mocked dependencies."""
return AudioProcessingService(
speech_recognition_service=mock_stt_service,
translation_service=mock_translation_service,
speech_synthesis_service=mock_tts_service
)
@pytest.fixture
def sample_audio(self):
"""Sample audio content for testing."""
return AudioContent(
data=b"fake_audio_data",
format="wav",
sample_rate=22050,
duration=10.0,
filename="test.wav"
)
@pytest.fixture
def sample_voice_settings(self):
"""Sample voice settings for testing."""
return VoiceSettings(
voice_id="test_voice",
speed=1.0,
language="es"
)
@pytest.fixture
def sample_text_content(self):
"""Sample text content for testing."""
return TextContent(
text="Hello world",
language="en"
)
def test_successful_pipeline_processing(
self,
audio_processing_service,
mock_stt_service,
mock_translation_service,
mock_tts_service,
sample_audio,
sample_voice_settings,
sample_text_content
):
"""Test successful processing through the complete pipeline."""
# Arrange
original_text = TextContent(text="Hello world", language="en")
translated_text = TextContent(text="Hola mundo", language="es")
output_audio = AudioContent(
data=b"synthesized_audio",
format="wav",
sample_rate=22050,
duration=5.0
)
mock_stt_service.transcribe.return_value = original_text
mock_translation_service.translate.return_value = translated_text
mock_tts_service.synthesize.return_value = output_audio
# Act
result = audio_processing_service.process_audio_pipeline(
audio=sample_audio,
target_language="es",
voice_settings=sample_voice_settings
)
# Assert
assert result.success is True
assert result.original_text == original_text
assert result.translated_text == translated_text
assert result.audio_output == output_audio
assert result.error_message is None
assert result.processing_time >= 0
# Verify service calls
mock_stt_service.transcribe.assert_called_once_with(sample_audio, "whisper-base")
mock_translation_service.translate.assert_called_once()
mock_tts_service.synthesize.assert_called_once()
def test_no_translation_needed_same_language(
self,
audio_processing_service,
mock_stt_service,
mock_translation_service,
mock_tts_service,
sample_audio
):
"""Test pipeline when no translation is needed (same language)."""
# Arrange
original_text = TextContent(text="Hola mundo", language="es")
voice_settings = VoiceSettings(voice_id="test_voice", speed=1.0, language="es")
output_audio = AudioContent(
data=b"synthesized_audio",
format="wav",
sample_rate=22050,
duration=5.0
)
mock_stt_service.transcribe.return_value = original_text
mock_tts_service.synthesize.return_value = output_audio
# Act
result = audio_processing_service.process_audio_pipeline(
audio=sample_audio,
target_language="es",
voice_settings=voice_settings
)
# Assert
assert result.success is True
assert result.original_text == original_text
assert result.translated_text == original_text # Same as original
assert result.audio_output == output_audio
# Translation service should not be called
mock_translation_service.translate.assert_not_called()
def test_validation_error_none_audio(self, audio_processing_service, sample_voice_settings):
"""Test validation error when audio is None."""
# Act
result = audio_processing_service.process_audio_pipeline(
audio=None,
target_language="es",
voice_settings=sample_voice_settings
)
# Assert
assert result.success is False
assert "Audio content cannot be None" in result.error_message
def test_validation_error_empty_target_language(self, audio_processing_service, sample_audio, sample_voice_settings):
"""Test validation error when target language is empty."""
# Act
result = audio_processing_service.process_audio_pipeline(
audio=sample_audio,
target_language="",
voice_settings=sample_voice_settings
)
# Assert
assert result.success is False
assert "Target language cannot be empty" in result.error_message
def test_validation_error_language_mismatch(self, audio_processing_service, sample_audio):
"""Test validation error when voice settings language doesn't match target language."""
# Arrange
voice_settings = VoiceSettings(voice_id="test_voice", speed=1.0, language="en")
# Act
result = audio_processing_service.process_audio_pipeline(
audio=sample_audio,
target_language="es",
voice_settings=voice_settings
)
# Assert
assert result.success is False
assert "Voice settings language (en) must match target language (es)" in result.error_message
def test_validation_error_audio_too_long(self, audio_processing_service, sample_voice_settings):
"""Test validation error when audio is too long."""
# Arrange
long_audio = AudioContent(
data=b"fake_audio_data",
format="wav",
sample_rate=22050,
duration=400.0 # Exceeds 300s limit
)
# Act
result = audio_processing_service.process_audio_pipeline(
audio=long_audio,
target_language="es",
voice_settings=sample_voice_settings
)
# Assert
assert result.success is False
assert "exceeds maximum allowed duration" in result.error_message
def test_stt_failure_handling(
self,
audio_processing_service,
mock_stt_service,
sample_audio,
sample_voice_settings
):
"""Test handling of STT service failure."""
# Arrange
mock_stt_service.transcribe.side_effect = Exception("STT service unavailable")
# Act
result = audio_processing_service.process_audio_pipeline(
audio=sample_audio,
target_language="es",
voice_settings=sample_voice_settings
)
# Assert
assert result.success is False
assert "Speech recognition failed" in result.error_message
assert result.processing_time >= 0
def test_translation_failure_handling(
self,
audio_processing_service,
mock_stt_service,
mock_translation_service,
sample_audio,
sample_voice_settings
):
"""Test handling of translation service failure."""
# Arrange
original_text = TextContent(text="Hello world", language="en")
mock_stt_service.transcribe.return_value = original_text
mock_translation_service.translate.side_effect = Exception("Translation service unavailable")
# Act
result = audio_processing_service.process_audio_pipeline(
audio=sample_audio,
target_language="es",
voice_settings=sample_voice_settings
)
# Assert
assert result.success is False
assert "Translation failed" in result.error_message
assert result.processing_time >= 0
def test_tts_failure_handling(
self,
audio_processing_service,
mock_stt_service,
mock_translation_service,
mock_tts_service,
sample_audio,
sample_voice_settings
):
"""Test handling of TTS service failure."""
# Arrange
original_text = TextContent(text="Hello world", language="en")
translated_text = TextContent(text="Hola mundo", language="es")
mock_stt_service.transcribe.return_value = original_text
mock_translation_service.translate.return_value = translated_text
mock_tts_service.synthesize.side_effect = Exception("TTS service unavailable")
# Act
result = audio_processing_service.process_audio_pipeline(
audio=sample_audio,
target_language="es",
voice_settings=sample_voice_settings
)
# Assert
assert result.success is False
assert "Speech synthesis failed" in result.error_message
assert result.processing_time >= 0