Michael Hu commited on
Commit
93dc283
·
1 Parent(s): acd758a

Create unit tests for infrastructure layer

Browse files
tests/unit/infrastructure/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Infrastructure layer unit tests."""
tests/unit/infrastructure/base/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Base provider unit tests."""
tests/unit/infrastructure/base/test_stt_provider_base.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for STTProviderBase abstract class."""
2
+
3
+ import pytest
4
+ from unittest.mock import Mock, patch, MagicMock
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ from src.infrastructure.base.stt_provider_base import STTProviderBase
9
+ from src.domain.models.audio_content import AudioContent
10
+ from src.domain.models.text_content import TextContent
11
+ from src.domain.exceptions import SpeechRecognitionException
12
+
13
+
14
+ class ConcreteSTTProvider(STTProviderBase):
15
+ """Concrete implementation for testing."""
16
+
17
+ def __init__(self, provider_name="test", supported_languages=None, available=True, models=None):
18
+ super().__init__(provider_name, supported_languages)
19
+ self._available = available
20
+ self._models = models or ["model1", "model2"]
21
+ self._should_fail = False
22
+ self._transcription_result = "Hello world"
23
+
24
+ def _perform_transcription(self, audio_path, model):
25
+ if self._should_fail:
26
+ raise Exception("Test transcription error")
27
+ return self._transcription_result
28
+
29
+ def is_available(self):
30
+ return self._available
31
+
32
+ def get_available_models(self):
33
+ return self._models
34
+
35
+ def get_default_model(self):
36
+ return self._models[0] if self._models else "default"
37
+
38
+ def set_should_fail(self, should_fail):
39
+ self._should_fail = should_fail
40
+
41
+ def set_transcription_result(self, result):
42
+ self._transcription_result = result
43
+
44
+
45
+ class TestSTTProviderBase:
46
+ """Test cases for STTProviderBase abstract class."""
47
+
48
+ def setup_method(self):
49
+ """Set up test fixtures."""
50
+ self.provider = ConcreteSTTProvider()
51
+ self.audio_content = AudioContent(
52
+ data=b"fake_audio_data",
53
+ format="wav",
54
+ sample_rate=16000,
55
+ duration=5.0,
56
+ filename="test.wav"
57
+ )
58
+
59
+ def test_provider_initialization(self):
60
+ """Test provider initialization with default values."""
61
+ provider = ConcreteSTTProvider("test_provider", ["en", "es"])
62
+
63
+ assert provider.provider_name == "test_provider"
64
+ assert provider.supported_languages == ["en", "es"]
65
+ assert isinstance(provider._temp_dir, Path)
66
+ assert provider._temp_dir.exists()
67
+
68
+ def test_provider_initialization_no_languages(self):
69
+ """Test provider initialization without supported languages."""
70
+ provider = ConcreteSTTProvider("test_provider")
71
+
72
+ assert provider.provider_name == "test_provider"
73
+ assert provider.supported_languages == []
74
+
75
+ @patch('builtins.open', create=True)
76
+ def test_transcribe_success(self, mock_open):
77
+ """Test successful transcription."""
78
+ mock_file = MagicMock()
79
+ mock_open.return_value.__enter__.return_value = mock_file
80
+
81
+ result = self.provider.transcribe(self.audio_content, "model1")
82
+
83
+ assert isinstance(result, TextContent)
84
+ assert result.text == "Hello world"
85
+ assert result.language == "en"
86
+ assert result.encoding == "utf-8"
87
+
88
+ def test_transcribe_empty_audio_fails(self):
89
+ """Test that empty audio data raises exception."""
90
+ empty_audio = AudioContent(
91
+ data=b"",
92
+ format="wav",
93
+ sample_rate=16000,
94
+ duration=0.1
95
+ )
96
+
97
+ with pytest.raises(SpeechRecognitionException, match="Audio data cannot be empty"):
98
+ self.provider.transcribe(empty_audio, "model1")
99
+
100
+ def test_transcribe_audio_too_long_fails(self):
101
+ """Test that audio longer than 1 hour raises exception."""
102
+ long_audio = AudioContent(
103
+ data=b"fake_audio_data",
104
+ format="wav",
105
+ sample_rate=16000,
106
+ duration=3601.0 # Over 1 hour
107
+ )
108
+
109
+ with pytest.raises(SpeechRecognitionException, match="Audio duration exceeds maximum limit"):
110
+ self.provider.transcribe(long_audio, "model1")
111
+
112
+ def test_transcribe_audio_too_short_fails(self):
113
+ """Test that audio shorter than 100ms raises exception."""
114
+ short_audio = AudioContent(
115
+ data=b"fake_audio_data",
116
+ format="wav",
117
+ sample_rate=16000,
118
+ duration=0.05 # 50ms
119
+ )
120
+
121
+ with pytest.raises(SpeechRecognitionException, match="Audio duration too short"):
122
+ self.provider.transcribe(short_audio, "model1")
123
+
124
+ def test_transcribe_invalid_format_fails(self):
125
+ """Test that invalid audio format raises exception."""
126
+ # Create audio with invalid format by mocking is_valid_format
127
+ invalid_audio = AudioContent(
128
+ data=b"fake_audio_data",
129
+ format="wav",
130
+ sample_rate=16000,
131
+ duration=5.0
132
+ )
133
+
134
+ with patch.object(invalid_audio, 'is_valid_format', False):
135
+ with pytest.raises(SpeechRecognitionException, match="Unsupported audio format"):
136
+ self.provider.transcribe(invalid_audio, "model1")
137
+
138
+ @patch('builtins.open', create=True)
139
+ def test_transcribe_provider_error(self, mock_open):
140
+ """Test handling of provider-specific errors."""
141
+ mock_file = MagicMock()
142
+ mock_open.return_value.__enter__.return_value = mock_file
143
+
144
+ self.provider.set_should_fail(True)
145
+
146
+ with pytest.raises(SpeechRecognitionException, match="STT transcription failed"):
147
+ self.provider.transcribe(self.audio_content, "model1")
148
+
149
+ @patch('builtins.open', create=True)
150
+ @patch('pathlib.Path.unlink')
151
+ def test_transcribe_cleanup_temp_file(self, mock_unlink, mock_open):
152
+ """Test that temporary files are cleaned up."""
153
+ mock_file = MagicMock()
154
+ mock_open.return_value.__enter__.return_value = mock_file
155
+
156
+ self.provider.transcribe(self.audio_content, "model1")
157
+
158
+ # Verify cleanup was attempted
159
+ mock_unlink.assert_called()
160
+
161
+ @patch('builtins.open', create=True)
162
+ def test_preprocess_audio(self, mock_open):
163
+ """Test audio preprocessing."""
164
+ mock_file = MagicMock()
165
+ mock_open.return_value.__enter__.return_value = mock_file
166
+
167
+ processed_path = self.provider._preprocess_audio(self.audio_content)
168
+
169
+ assert isinstance(processed_path, Path)
170
+ assert processed_path.suffix == ".wav"
171
+ mock_file.write.assert_called_once_with(self.audio_content.data)
172
+
173
+ def test_preprocess_audio_error(self):
174
+ """Test audio preprocessing error handling."""
175
+ with patch('builtins.open', side_effect=IOError("Test error")):
176
+ with pytest.raises(SpeechRecognitionException, match="Audio preprocessing failed"):
177
+ self.provider._preprocess_audio(self.audio_content)
178
+
179
+ @patch('pydub.AudioSegment.from_wav')
180
+ @patch('pydub.AudioSegment.export')
181
+ def test_convert_audio_format_wav(self, mock_export, mock_from_wav):
182
+ """Test audio format conversion for WAV."""
183
+ mock_audio = Mock()
184
+ mock_audio.set_frame_rate.return_value.set_channels.return_value = mock_audio
185
+ mock_from_wav.return_value = mock_audio
186
+
187
+ test_path = Path("/tmp/test.wav")
188
+ result_path = self.provider._convert_audio_format(test_path, self.audio_content)
189
+
190
+ mock_from_wav.assert_called_once_with(test_path)
191
+ mock_audio.set_frame_rate.assert_called_once_with(16000)
192
+ mock_audio.set_channels.assert_called_once_with(1)
193
+ mock_export.assert_called_once()
194
+
195
+ @patch('pydub.AudioSegment.from_mp3')
196
+ def test_convert_audio_format_mp3(self, mock_from_mp3):
197
+ """Test audio format conversion for MP3."""
198
+ mp3_audio = AudioContent(
199
+ data=b"fake_mp3_data",
200
+ format="mp3",
201
+ sample_rate=44100,
202
+ duration=5.0
203
+ )
204
+
205
+ mock_audio = Mock()
206
+ mock_audio.set_frame_rate.return_value.set_channels.return_value = mock_audio
207
+ mock_from_mp3.return_value = mock_audio
208
+
209
+ test_path = Path("/tmp/test.mp3")
210
+ self.provider._convert_audio_format(test_path, mp3_audio)
211
+
212
+ mock_from_mp3.assert_called_once_with(test_path)
213
+
214
+ def test_convert_audio_format_no_pydub(self):
215
+ """Test audio format conversion when pydub is not available."""
216
+ test_path = Path("/tmp/test.wav")
217
+
218
+ with patch('pydub.AudioSegment', side_effect=ImportError("pydub not available")):
219
+ result_path = self.provider._convert_audio_format(test_path, self.audio_content)
220
+
221
+ # Should return original path when pydub is not available
222
+ assert result_path == test_path
223
+
224
+ def test_convert_audio_format_error(self):
225
+ """Test audio format conversion error handling."""
226
+ test_path = Path("/tmp/test.wav")
227
+
228
+ with patch('pydub.AudioSegment.from_wav', side_effect=Exception("Conversion error")):
229
+ result_path = self.provider._convert_audio_format(test_path, self.audio_content)
230
+
231
+ # Should return original path on error
232
+ assert result_path == test_path
233
+
234
+ def test_detect_language_english(self):
235
+ """Test language detection for English text."""
236
+ english_text = "The quick brown fox jumps over the lazy dog and it is very nice"
237
+ language = self.provider._detect_language(english_text)
238
+ assert language == "en"
239
+
240
+ def test_detect_language_few_indicators(self):
241
+ """Test language detection with few English indicators."""
242
+ text = "Hello world"
243
+ language = self.provider._detect_language(text)
244
+ assert language == "en"
245
+
246
+ def test_detect_language_no_indicators(self):
247
+ """Test language detection with no clear indicators."""
248
+ text = "xyz abc def"
249
+ language = self.provider._detect_language(text)
250
+ assert language == "en" # Should default to English
251
+
252
+ def test_detect_language_error(self):
253
+ """Test language detection error handling."""
254
+ with patch.object(self.provider, '_detect_language', side_effect=Exception("Detection error")):
255
+ language = self.provider._detect_language("test")
256
+ assert language is None
257
+
258
+ def test_ensure_temp_directory(self):
259
+ """Test temporary directory creation."""
260
+ temp_dir = self.provider._ensure_temp_directory()
261
+
262
+ assert isinstance(temp_dir, Path)
263
+ assert temp_dir.exists()
264
+ assert temp_dir.is_dir()
265
+ assert "stt_temp" in str(temp_dir)
266
+
267
+ def test_cleanup_temp_file(self):
268
+ """Test temporary file cleanup."""
269
+ # Create a temporary file
270
+ temp_file = self.provider._temp_dir / "test_file.wav"
271
+ temp_file.touch()
272
+
273
+ assert temp_file.exists()
274
+
275
+ self.provider._cleanup_temp_file(temp_file)
276
+
277
+ assert not temp_file.exists()
278
+
279
+ def test_cleanup_temp_file_not_exists(self):
280
+ """Test cleanup of non-existent file."""
281
+ non_existent = Path("/tmp/non_existent_file.wav")
282
+
283
+ # Should not raise exception
284
+ self.provider._cleanup_temp_file(non_existent)
285
+
286
+ def test_cleanup_temp_file_error(self):
287
+ """Test cleanup error handling."""
288
+ with patch('pathlib.Path.unlink', side_effect=OSError("Permission denied")):
289
+ temp_file = Path("/tmp/test.wav")
290
+
291
+ # Should not raise exception
292
+ self.provider._cleanup_temp_file(temp_file)
293
+
294
+ @patch('time.time')
295
+ @patch('pathlib.Path.glob')
296
+ def test_cleanup_old_temp_files(self, mock_glob, mock_time):
297
+ """Test cleanup of old temporary files."""
298
+ mock_time.return_value = 1000000
299
+
300
+ # Mock old file
301
+ old_file = Mock()
302
+ old_file.is_file.return_value = True
303
+ old_file.stat.return_value.st_mtime = 900000 # Old file
304
+
305
+ # Mock recent file
306
+ recent_file = Mock()
307
+ recent_file.is_file.return_value = True
308
+ recent_file.stat.return_value.st_mtime = 999000 # Recent file
309
+
310
+ mock_glob.return_value = [old_file, recent_file]
311
+
312
+ self.provider._cleanup_old_temp_files(24)
313
+
314
+ # Old file should be deleted
315
+ old_file.unlink.assert_called_once()
316
+ recent_file.unlink.assert_not_called()
317
+
318
+ def test_cleanup_old_temp_files_error(self):
319
+ """Test cleanup error handling."""
320
+ with patch.object(self.provider._temp_dir, 'glob', side_effect=Exception("Test error")):
321
+ # Should not raise exception
322
+ self.provider._cleanup_old_temp_files()
323
+
324
+ def test_handle_provider_error(self):
325
+ """Test provider error handling."""
326
+ original_error = ValueError("Original error")
327
+
328
+ with pytest.raises(SpeechRecognitionException) as exc_info:
329
+ self.provider._handle_provider_error(original_error, "testing")
330
+
331
+ assert "test error during testing: Original error" in str(exc_info.value)
332
+ assert exc_info.value.__cause__ is original_error
333
+
334
+ def test_handle_provider_error_no_context(self):
335
+ """Test provider error handling without context."""
336
+ original_error = ValueError("Original error")
337
+
338
+ with pytest.raises(SpeechRecognitionException) as exc_info:
339
+ self.provider._handle_provider_error(original_error)
340
+
341
+ assert "test error: Original error" in str(exc_info.value)
342
+ assert exc_info.value.__cause__ is original_error
343
+
344
+ def test_abstract_methods_not_implemented(self):
345
+ """Test that abstract methods raise NotImplementedError."""
346
+ # Create instance of base class directly (should fail)
347
+ with pytest.raises(TypeError):
348
+ STTProviderBase("test")
349
+
350
+ def test_provider_unavailable(self):
351
+ """Test behavior when provider is unavailable."""
352
+ provider = ConcreteSTTProvider(available=False)
353
+ assert provider.is_available() is False
354
+
355
+ def test_no_models_available(self):
356
+ """Test behavior when no models are available."""
357
+ provider = ConcreteSTTProvider(models=[])
358
+ assert provider.get_available_models() == []
359
+ assert provider.get_default_model() == "default"
tests/unit/infrastructure/base/test_translation_provider_base.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for TranslationProviderBase abstract class."""
2
+
3
+ import pytest
4
+ from unittest.mock import Mock, patch
5
+
6
+ from src.infrastructure.base.translation_provider_base import TranslationProviderBase
7
+ from src.domain.models.translation_request import TranslationRequest
8
+ from src.domain.models.text_content import TextContent
9
+ from src.domain.exceptions import TranslationFailedException
10
+
11
+
12
+ class ConcreteTranslationProvider(TranslationProviderBase):
13
+ """Concrete implementation for testing."""
14
+
15
+ def __init__(self, provider_name="test", supported_languages=None, available=True):
16
+ super().__init__(provider_name, supported_languages)
17
+ self._available = available
18
+ self._should_fail = False
19
+ self._translation_result = "Translated text"
20
+
21
+ def _translate_chunk(self, text, source_language, target_language):
22
+ if self._should_fail:
23
+ raise Exception("Test translation error")
24
+ return f"{self._translation_result} ({source_language}->{target_language})"
25
+
26
+ def is_available(self):
27
+ return self._available
28
+
29
+ def get_supported_languages(self):
30
+ return self.supported_languages
31
+
32
+ def set_should_fail(self, should_fail):
33
+ self._should_fail = should_fail
34
+
35
+ def set_translation_result(self, result):
36
+ self._translation_result = result
37
+
38
+
39
+ class TestTranslationProviderBase:
40
+ """Test cases for TranslationProviderBase abstract class."""
41
+
42
+ def setup_method(self):
43
+ """Set up test fixtures."""
44
+ self.provider = ConcreteTranslationProvider()
45
+ self.source_text = TextContent(text="Hello world", language="en")
46
+ self.request = TranslationRequest(
47
+ source_text=self.source_text,
48
+ target_language="es"
49
+ )
50
+
51
+ def test_provider_initialization(self):
52
+ """Test provider initialization with default values."""
53
+ supported_langs = {"en": ["es", "fr"], "es": ["en"]}
54
+ provider = ConcreteTranslationProvider("test_provider", supported_langs)
55
+
56
+ assert provider.provider_name == "test_provider"
57
+ assert provider.supported_languages == supported_langs
58
+ assert provider.max_chunk_length == 1000
59
+
60
+ def test_provider_initialization_no_languages(self):
61
+ """Test provider initialization without supported languages."""
62
+ provider = ConcreteTranslationProvider("test_provider")
63
+
64
+ assert provider.provider_name == "test_provider"
65
+ assert provider.supported_languages == {}
66
+
67
+ def test_translate_success(self):
68
+ """Test successful translation."""
69
+ result = self.provider.translate(self.request)
70
+
71
+ assert isinstance(result, TextContent)
72
+ assert result.text == "Translated text (en->es)"
73
+ assert result.language == "es"
74
+ assert result.encoding == "utf-8"
75
+
76
+ def test_translate_with_language_validation(self):
77
+ """Test translation with language validation."""
78
+ supported_langs = {"en": ["es", "fr"], "es": ["en"]}
79
+ provider = ConcreteTranslationProvider("test", supported_langs)
80
+
81
+ # Valid language pair should work
82
+ result = provider.translate(self.request)
83
+ assert isinstance(result, TextContent)
84
+
85
+ # Invalid source language should fail
86
+ invalid_request = TranslationRequest(
87
+ source_text=TextContent(text="Hello", language="de"),
88
+ target_language="es"
89
+ )
90
+
91
+ with pytest.raises(TranslationFailedException, match="Source language de not supported"):
92
+ provider.translate(invalid_request)
93
+
94
+ # Invalid target language should fail
95
+ invalid_request2 = TranslationRequest(
96
+ source_text=self.source_text,
97
+ target_language="de"
98
+ )
99
+
100
+ with pytest.raises(TranslationFailedException, match="Translation from en to de not supported"):
101
+ provider.translate(invalid_request2)
102
+
103
+ def test_translate_empty_text_fails(self):
104
+ """Test that empty text raises exception."""
105
+ empty_request = TranslationRequest(
106
+ source_text=TextContent(text="", language="en"),
107
+ target_language="es"
108
+ )
109
+
110
+ with pytest.raises(TranslationFailedException, match="Source text cannot be empty"):
111
+ self.provider.translate(empty_request)
112
+
113
+ def test_translate_whitespace_text_fails(self):
114
+ """Test that whitespace-only text raises exception."""
115
+ whitespace_request = TranslationRequest(
116
+ source_text=TextContent(text=" ", language="en"),
117
+ target_language="es"
118
+ )
119
+
120
+ with pytest.raises(TranslationFailedException, match="Source text cannot be empty"):
121
+ self.provider.translate(whitespace_request)
122
+
123
+ def test_translate_same_language_fails(self):
124
+ """Test that same source and target language raises exception."""
125
+ same_lang_request = TranslationRequest(
126
+ source_text=self.source_text,
127
+ target_language="en"
128
+ )
129
+
130
+ with pytest.raises(TranslationFailedException, match="Source and target languages cannot be the same"):
131
+ self.provider.translate(same_lang_request)
132
+
133
+ def test_translate_provider_error(self):
134
+ """Test handling of provider-specific errors."""
135
+ self.provider.set_should_fail(True)
136
+
137
+ with pytest.raises(TranslationFailedException, match="Translation failed"):
138
+ self.provider.translate(self.request)
139
+
140
+ def test_translate_long_text_chunking(self):
141
+ """Test translation of long text with chunking."""
142
+ # Create long text that will be chunked
143
+ long_text = "This is a sentence. " * 100 # Much longer than default chunk size
144
+ long_request = TranslationRequest(
145
+ source_text=TextContent(text=long_text, language="en"),
146
+ target_language="es"
147
+ )
148
+
149
+ result = self.provider.translate(long_request)
150
+
151
+ assert isinstance(result, TextContent)
152
+ # Should contain multiple translated chunks
153
+ assert "Translated text (en->es)" in result.text
154
+
155
+ def test_chunk_text_short_text(self):
156
+ """Test text chunking with short text."""
157
+ short_text = "Hello world"
158
+ chunks = self.provider._chunk_text(short_text)
159
+
160
+ assert len(chunks) == 1
161
+ assert chunks[0] == short_text
162
+
163
+ def test_chunk_text_long_text(self):
164
+ """Test text chunking with long text."""
165
+ # Create text longer than chunk size
166
+ long_text = "This is a sentence. " * 100
167
+ self.provider.max_chunk_length = 50 # Small chunk size for testing
168
+
169
+ chunks = self.provider._chunk_text(long_text)
170
+
171
+ assert len(chunks) > 1
172
+ for chunk in chunks:
173
+ assert len(chunk) <= self.provider.max_chunk_length
174
+
175
+ def test_split_into_sentences(self):
176
+ """Test sentence splitting."""
177
+ text = "First sentence. Second sentence! Third sentence? Fourth sentence."
178
+ sentences = self.provider._split_into_sentences(text)
179
+
180
+ assert len(sentences) == 4
181
+ assert "First sentence" in sentences[0]
182
+ assert "Second sentence" in sentences[1]
183
+ assert "Third sentence" in sentences[2]
184
+ assert "Fourth sentence" in sentences[3]
185
+
186
+ def test_split_into_sentences_no_punctuation(self):
187
+ """Test sentence splitting with no punctuation."""
188
+ text = "Just one long sentence without proper punctuation"
189
+ sentences = self.provider._split_into_sentences(text)
190
+
191
+ assert len(sentences) == 1
192
+ assert sentences[0] == text
193
+
194
+ def test_split_long_sentence(self):
195
+ """Test splitting of long sentences by words."""
196
+ long_sentence = "word " * 100 # Very long sentence
197
+ self.provider.max_chunk_length = 20 # Small chunk size
198
+
199
+ chunks = self.provider._split_long_sentence(long_sentence)
200
+
201
+ assert len(chunks) > 1
202
+ for chunk in chunks:
203
+ assert len(chunk) <= self.provider.max_chunk_length
204
+
205
+ def test_split_long_sentence_single_long_word(self):
206
+ """Test splitting with a single very long word."""
207
+ long_word = "a" * 100
208
+ self.provider.max_chunk_length = 20
209
+
210
+ chunks = self.provider._split_long_sentence(long_word)
211
+
212
+ assert len(chunks) == 1
213
+ assert chunks[0] == long_word # Should include the long word as-is
214
+
215
+ def test_reassemble_chunks(self):
216
+ """Test reassembling translated chunks."""
217
+ chunks = ["First chunk", "Second chunk", "Third chunk"]
218
+ result = self.provider._reassemble_chunks(chunks)
219
+
220
+ assert result == "First chunk Second chunk Third chunk"
221
+
222
+ def test_reassemble_chunks_with_empty(self):
223
+ """Test reassembling chunks with empty strings."""
224
+ chunks = ["First chunk", "", "Third chunk", " "]
225
+ result = self.provider._reassemble_chunks(chunks)
226
+
227
+ assert result == "First chunk Third chunk"
228
+
229
+ def test_preprocess_text(self):
230
+ """Test text preprocessing."""
231
+ messy_text = " Hello world \n\n with extra spaces "
232
+ processed = self.provider._preprocess_text(messy_text)
233
+
234
+ assert processed == "Hello world with extra spaces"
235
+
236
+ def test_postprocess_text(self):
237
+ """Test text postprocessing."""
238
+ messy_text = "Hello world . This is a test ! Another sentence ?"
239
+ processed = self.provider._postprocess_text(messy_text)
240
+
241
+ assert processed == "Hello world. This is a test! Another sentence?"
242
+
243
+ def test_postprocess_text_sentence_spacing(self):
244
+ """Test postprocessing fixes sentence spacing."""
245
+ text = "First sentence.Second sentence!Third sentence?"
246
+ processed = self.provider._postprocess_text(text)
247
+
248
+ assert processed == "First sentence. Second sentence! Third sentence?"
249
+
250
+ def test_handle_provider_error(self):
251
+ """Test provider error handling."""
252
+ original_error = ValueError("Original error")
253
+
254
+ with pytest.raises(TranslationFailedException) as exc_info:
255
+ self.provider._handle_provider_error(original_error, "testing")
256
+
257
+ assert "test error during testing: Original error" in str(exc_info.value)
258
+ assert exc_info.value.__cause__ is original_error
259
+
260
+ def test_handle_provider_error_no_context(self):
261
+ """Test provider error handling without context."""
262
+ original_error = ValueError("Original error")
263
+
264
+ with pytest.raises(TranslationFailedException) as exc_info:
265
+ self.provider._handle_provider_error(original_error)
266
+
267
+ assert "test error: Original error" in str(exc_info.value)
268
+ assert exc_info.value.__cause__ is original_error
269
+
270
+ def test_set_chunk_size(self):
271
+ """Test setting chunk size."""
272
+ self.provider.set_chunk_size(500)
273
+ assert self.provider.max_chunk_length == 500
274
+
275
+ def test_set_chunk_size_invalid(self):
276
+ """Test setting invalid chunk size."""
277
+ with pytest.raises(ValueError, match="Chunk size must be positive"):
278
+ self.provider.set_chunk_size(0)
279
+
280
+ with pytest.raises(ValueError, match="Chunk size must be positive"):
281
+ self.provider.set_chunk_size(-1)
282
+
283
+ def test_get_translation_stats(self):
284
+ """Test getting translation statistics."""
285
+ stats = self.provider.get_translation_stats(self.request)
286
+
287
+ assert stats['provider'] == 'test'
288
+ assert stats['source_language'] == 'en'
289
+ assert stats['target_language'] == 'es'
290
+ assert stats['text_length'] == len(self.request.source_text.text)
291
+ assert stats['word_count'] == len(self.request.source_text.text.split())
292
+ assert stats['chunk_count'] >= 1
293
+ assert 'max_chunk_length' in stats
294
+ assert 'avg_chunk_length' in stats
295
+
296
+ def test_get_translation_stats_empty_text(self):
297
+ """Test getting translation statistics for empty text."""
298
+ empty_request = TranslationRequest(
299
+ source_text=TextContent(text="", language="en"),
300
+ target_language="es"
301
+ )
302
+
303
+ stats = self.provider.get_translation_stats(empty_request)
304
+
305
+ assert stats['text_length'] == 0
306
+ assert stats['word_count'] == 0
307
+ assert stats['chunk_count'] == 0
308
+ assert stats['max_chunk_length'] == 0
309
+ assert stats['avg_chunk_length'] == 0
310
+
311
+ def test_abstract_methods_not_implemented(self):
312
+ """Test that abstract methods raise NotImplementedError."""
313
+ # Create instance of base class directly (should fail)
314
+ with pytest.raises(TypeError):
315
+ TranslationProviderBase("test")
316
+
317
+ def test_provider_unavailable(self):
318
+ """Test behavior when provider is unavailable."""
319
+ provider = ConcreteTranslationProvider(available=False)
320
+ assert provider.is_available() is False
321
+
322
+ def test_no_supported_languages(self):
323
+ """Test behavior when no languages are supported."""
324
+ provider = ConcreteTranslationProvider(supported_languages={})
325
+ assert provider.get_supported_languages() == {}
tests/unit/infrastructure/base/test_tts_provider_base.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for TTSProviderBase abstract class."""
2
+
3
+ import pytest
4
+ from unittest.mock import Mock, patch, MagicMock
5
+ import tempfile
6
+ from pathlib import Path
7
+ import time
8
+
9
+ from src.infrastructure.base.tts_provider_base import TTSProviderBase
10
+ from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest
11
+ from src.domain.models.text_content import TextContent
12
+ from src.domain.models.voice_settings import VoiceSettings
13
+ from src.domain.models.audio_content import AudioContent
14
+ from src.domain.models.audio_chunk import AudioChunk
15
+ from src.domain.exceptions import SpeechSynthesisException
16
+
17
+
18
+ class ConcreteTTSProvider(TTSProviderBase):
19
+ """Concrete implementation for testing."""
20
+
21
+ def __init__(self, provider_name="test", supported_languages=None, available=True, voices=None):
22
+ super().__init__(provider_name, supported_languages)
23
+ self._available = available
24
+ self._voices = voices or ["voice1", "voice2"]
25
+ self._should_fail = False
26
+
27
+ def _generate_audio(self, request):
28
+ if self._should_fail:
29
+ raise Exception("Test error")
30
+ return b"fake_audio_data", 44100
31
+
32
+ def _generate_audio_stream(self, request):
33
+ if self._should_fail:
34
+ raise Exception("Test stream error")
35
+ chunks = [
36
+ (b"chunk1", 44100, False),
37
+ (b"chunk2", 44100, False),
38
+ (b"chunk3", 44100, True)
39
+ ]
40
+ for chunk in chunks:
41
+ yield chunk
42
+
43
+ def is_available(self):
44
+ return self._available
45
+
46
+ def get_available_voices(self):
47
+ return self._voices
48
+
49
+ def set_should_fail(self, should_fail):
50
+ self._should_fail = should_fail
51
+
52
+
53
+ class TestTTSProviderBase:
54
+ """Test cases for TTSProviderBase abstract class."""
55
+
56
+ def setup_method(self):
57
+ """Set up test fixtures."""
58
+ self.provider = ConcreteTTSProvider()
59
+ self.text_content = TextContent(text="Hello world", language="en")
60
+ self.voice_settings = VoiceSettings(voice_id="voice1", speed=1.0, pitch=1.0)
61
+ self.request = SpeechSynthesisRequest(
62
+ text_content=self.text_content,
63
+ voice_settings=self.voice_settings
64
+ )
65
+
66
+ def test_provider_initialization(self):
67
+ """Test provider initialization with default values."""
68
+ provider = ConcreteTTSProvider("test_provider", ["en", "es"])
69
+
70
+ assert provider.provider_name == "test_provider"
71
+ assert provider.supported_languages == ["en", "es"]
72
+ assert isinstance(provider._output_dir, Path)
73
+ assert provider._output_dir.exists()
74
+
75
+ def test_provider_initialization_no_languages(self):
76
+ """Test provider initialization without supported languages."""
77
+ provider = ConcreteTTSProvider("test_provider")
78
+
79
+ assert provider.provider_name == "test_provider"
80
+ assert provider.supported_languages == []
81
+
82
+ def test_synthesize_success(self):
83
+ """Test successful speech synthesis."""
84
+ result = self.provider.synthesize(self.request)
85
+
86
+ assert isinstance(result, AudioContent)
87
+ assert result.data == b"fake_audio_data"
88
+ assert result.format == "wav"
89
+ assert result.sample_rate == 44100
90
+ assert result.duration > 0
91
+ assert "test_" in result.filename
92
+
93
+ def test_synthesize_with_language_validation(self):
94
+ """Test synthesis with language validation."""
95
+ provider = ConcreteTTSProvider("test", ["en", "es"])
96
+
97
+ # Valid language should work
98
+ result = provider.synthesize(self.request)
99
+ assert isinstance(result, AudioContent)
100
+
101
+ # Invalid language should fail
102
+ invalid_request = SpeechSynthesisRequest(
103
+ text_content=TextContent(text="Hola", language="fr"),
104
+ voice_settings=self.voice_settings
105
+ )
106
+
107
+ with pytest.raises(SpeechSynthesisException, match="Language fr not supported"):
108
+ provider.synthesize(invalid_request)
109
+
110
+ def test_synthesize_with_voice_validation(self):
111
+ """Test synthesis with voice validation."""
112
+ provider = ConcreteTTSProvider("test", voices=["voice1", "voice2"])
113
+
114
+ # Valid voice should work
115
+ result = provider.synthesize(self.request)
116
+ assert isinstance(result, AudioContent)
117
+
118
+ # Invalid voice should fail
119
+ invalid_request = SpeechSynthesisRequest(
120
+ text_content=self.text_content,
121
+ voice_settings=VoiceSettings(voice_id="invalid_voice", speed=1.0, pitch=1.0)
122
+ )
123
+
124
+ with pytest.raises(SpeechSynthesisException, match="Voice invalid_voice not available"):
125
+ provider.synthesize(invalid_request)
126
+
127
+ def test_synthesize_empty_text_fails(self):
128
+ """Test that empty text raises exception."""
129
+ empty_request = SpeechSynthesisRequest(
130
+ text_content=TextContent(text="", language="en"),
131
+ voice_settings=self.voice_settings
132
+ )
133
+
134
+ with pytest.raises(SpeechSynthesisException, match="Text content cannot be empty"):
135
+ self.provider.synthesize(empty_request)
136
+
137
+ def test_synthesize_whitespace_text_fails(self):
138
+ """Test that whitespace-only text raises exception."""
139
+ whitespace_request = SpeechSynthesisRequest(
140
+ text_content=TextContent(text=" ", language="en"),
141
+ voice_settings=self.voice_settings
142
+ )
143
+
144
+ with pytest.raises(SpeechSynthesisException, match="Text content cannot be empty"):
145
+ self.provider.synthesize(whitespace_request)
146
+
147
+ def test_synthesize_provider_error(self):
148
+ """Test handling of provider-specific errors."""
149
+ self.provider.set_should_fail(True)
150
+
151
+ with pytest.raises(SpeechSynthesisException, match="TTS synthesis failed"):
152
+ self.provider.synthesize(self.request)
153
+
154
+ def test_synthesize_stream_success(self):
155
+ """Test successful streaming synthesis."""
156
+ chunks = list(self.provider.synthesize_stream(self.request))
157
+
158
+ assert len(chunks) == 3
159
+
160
+ for i, chunk in enumerate(chunks):
161
+ assert isinstance(chunk, AudioChunk)
162
+ assert chunk.data == f"chunk{i+1}".encode()
163
+ assert chunk.format == "wav"
164
+ assert chunk.sample_rate == 44100
165
+ assert chunk.chunk_index == i
166
+ assert chunk.timestamp > 0
167
+
168
+ # Last chunk should be final
169
+ assert chunks[-1].is_final is True
170
+ assert chunks[0].is_final is False
171
+ assert chunks[1].is_final is False
172
+
173
+ def test_synthesize_stream_provider_error(self):
174
+ """Test handling of provider errors in streaming."""
175
+ self.provider.set_should_fail(True)
176
+
177
+ with pytest.raises(SpeechSynthesisException, match="TTS streaming synthesis failed"):
178
+ list(self.provider.synthesize_stream(self.request))
179
+
180
+ def test_calculate_duration(self):
181
+ """Test audio duration calculation."""
182
+ # Test with standard parameters
183
+ audio_data = b"x" * 88200 # 1 second at 44100 Hz, 16-bit, mono
184
+ duration = self.provider._calculate_duration(audio_data, 44100)
185
+ assert duration == 1.0
186
+
187
+ # Test with different sample rate
188
+ duration = self.provider._calculate_duration(audio_data, 22050)
189
+ assert duration == 2.0
190
+
191
+ # Test with stereo
192
+ duration = self.provider._calculate_duration(audio_data, 44100, channels=2)
193
+ assert duration == 0.5
194
+
195
+ # Test with empty data
196
+ duration = self.provider._calculate_duration(b"", 44100)
197
+ assert duration == 0.0
198
+
199
+ # Test with zero sample rate
200
+ duration = self.provider._calculate_duration(audio_data, 0)
201
+ assert duration == 0.0
202
+
203
+ def test_ensure_output_directory(self):
204
+ """Test output directory creation."""
205
+ output_dir = self.provider._ensure_output_directory()
206
+
207
+ assert isinstance(output_dir, Path)
208
+ assert output_dir.exists()
209
+ assert output_dir.is_dir()
210
+ assert "tts_output" in str(output_dir)
211
+
212
+ def test_generate_output_path(self):
213
+ """Test output path generation."""
214
+ path1 = self.provider._generate_output_path()
215
+ path2 = self.provider._generate_output_path()
216
+
217
+ # Paths should be different (due to timestamp)
218
+ assert path1 != path2
219
+ assert path1.suffix == ".wav"
220
+ assert path2.suffix == ".wav"
221
+ assert "test_" in path1.name
222
+ assert "test_" in path2.name
223
+
224
+ # Test with custom prefix and extension
225
+ path3 = self.provider._generate_output_path("custom", "mp3")
226
+ assert path3.suffix == ".mp3"
227
+ assert "custom_" in path3.name
228
+
229
+ @patch('time.time')
230
+ @patch('pathlib.Path.glob')
231
+ @patch('pathlib.Path.stat')
232
+ @patch('pathlib.Path.unlink')
233
+ def test_cleanup_temp_files(self, mock_unlink, mock_stat, mock_glob, mock_time):
234
+ """Test temporary file cleanup."""
235
+ # Mock current time
236
+ mock_time.return_value = 1000000
237
+
238
+ # Mock old file
239
+ old_file = Mock()
240
+ old_file.is_file.return_value = True
241
+ old_file.stat.return_value.st_mtime = 900000 # 100000 seconds old
242
+
243
+ # Mock recent file
244
+ recent_file = Mock()
245
+ recent_file.is_file.return_value = True
246
+ recent_file.stat.return_value.st_mtime = 999000 # 1000 seconds old
247
+
248
+ mock_glob.return_value = [old_file, recent_file]
249
+
250
+ # Cleanup with 24 hour limit (86400 seconds)
251
+ self.provider._cleanup_temp_files(24)
252
+
253
+ # Old file should be deleted, recent file should not
254
+ old_file.unlink.assert_called_once()
255
+ recent_file.unlink.assert_not_called()
256
+
257
+ def test_cleanup_temp_files_error_handling(self):
258
+ """Test cleanup error handling."""
259
+ # Should not raise exception even if cleanup fails
260
+ with patch.object(self.provider._output_dir, 'glob', side_effect=Exception("Test error")):
261
+ self.provider._cleanup_temp_files() # Should not raise
262
+
263
+ def test_handle_provider_error(self):
264
+ """Test provider error handling."""
265
+ original_error = ValueError("Original error")
266
+
267
+ with pytest.raises(SpeechSynthesisException) as exc_info:
268
+ self.provider._handle_provider_error(original_error, "testing")
269
+
270
+ assert "test error during testing: Original error" in str(exc_info.value)
271
+ assert exc_info.value.__cause__ is original_error
272
+
273
+ def test_handle_provider_error_no_context(self):
274
+ """Test provider error handling without context."""
275
+ original_error = ValueError("Original error")
276
+
277
+ with pytest.raises(SpeechSynthesisException) as exc_info:
278
+ self.provider._handle_provider_error(original_error)
279
+
280
+ assert "test error: Original error" in str(exc_info.value)
281
+ assert exc_info.value.__cause__ is original_error
282
+
283
+ def test_abstract_methods_not_implemented(self):
284
+ """Test that abstract methods raise NotImplementedError."""
285
+ # Create instance of base class directly (should fail)
286
+ with pytest.raises(TypeError):
287
+ TTSProviderBase("test")
288
+
289
+ def test_provider_unavailable(self):
290
+ """Test behavior when provider is unavailable."""
291
+ provider = ConcreteTTSProvider(available=False)
292
+ assert provider.is_available() is False
293
+
294
+ def test_no_voices_available(self):
295
+ """Test behavior when no voices are available."""
296
+ provider = ConcreteTTSProvider(voices=[])
297
+ assert provider.get_available_voices() == []
tests/unit/infrastructure/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Configuration unit tests."""
tests/unit/infrastructure/config/test_dependency_container.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for DependencyContainer."""
2
+
3
+ import pytest
4
+ from unittest.mock import Mock, patch, MagicMock
5
+ from threading import Thread
6
+ import time
7
+
8
+ from src.infrastructure.config.dependency_container import (
9
+ DependencyContainer,
10
+ DependencyScope,
11
+ ServiceDescriptor,
12
+ ServiceLifetime,
13
+ get_container,
14
+ set_container,
15
+ cleanup_container
16
+ )
17
+ from src.infrastructure.config.app_config import AppConfig
18
+ from src.infrastructure.tts.provider_factory import TTSProviderFactory
19
+ from src.infrastructure.stt.provider_factory import STTProviderFactory
20
+ from src.infrastructure.translation.provider_factory import TranslationProviderFactory, TranslationProviderType
21
+ from src.domain.interfaces.speech_synthesis import ISpeechSynthesisService
22
+ from src.domain.interfaces.speech_recognition import ISpeechRecognitionService
23
+ from src.domain.interfaces.translation import ITranslationService
24
+
25
+
26
+ class MockService:
27
+ """Mock service for testing."""
28
+
29
+ def __init__(self, name="mock", **kwargs):
30
+ self.name = name
31
+ self.kwargs = kwargs
32
+ self.cleanup_called = False
33
+
34
+ def cleanup(self):
35
+ self.cleanup_called = True
36
+
37
+
38
+ class MockServiceWithDispose:
39
+ """Mock service with dispose method."""
40
+
41
+ def __init__(self, name="mock"):
42
+ self.name = name
43
+ self.dispose_called = False
44
+
45
+ def dispose(self):
46
+ self.dispose_called = True
47
+
48
+
49
+ def mock_factory(**kwargs):
50
+ """Mock factory function."""
51
+ return MockService("factory_created", **kwargs)
52
+
53
+
54
+ class TestServiceDescriptor:
55
+ """Test cases for ServiceDescriptor."""
56
+
57
+ def test_service_descriptor_creation(self):
58
+ """Test service descriptor creation."""
59
+ descriptor = ServiceDescriptor(
60
+ service_type=MockService,
61
+ implementation=MockService,
62
+ lifetime=ServiceLifetime.SINGLETON,
63
+ factory_args={'name': 'test'}
64
+ )
65
+
66
+ assert descriptor.service_type == MockService
67
+ assert descriptor.implementation == MockService
68
+ assert descriptor.lifetime == ServiceLifetime.SINGLETON
69
+ assert descriptor.factory_args == {'name': 'test'}
70
+
71
+ def test_service_descriptor_defaults(self):
72
+ """Test service descriptor with default values."""
73
+ descriptor = ServiceDescriptor(
74
+ service_type=MockService,
75
+ implementation=MockService
76
+ )
77
+
78
+ assert descriptor.lifetime == ServiceLifetime.TRANSIENT
79
+ assert descriptor.factory_args == {}
80
+
81
+
82
+ class TestDependencyContainer:
83
+ """Test cases for DependencyContainer."""
84
+
85
+ def setup_method(self):
86
+ """Set up test fixtures."""
87
+ self.container = DependencyContainer()
88
+
89
+ def teardown_method(self):
90
+ """Clean up after tests."""
91
+ self.container.cleanup()
92
+
93
+ def test_container_initialization(self):
94
+ """Test container initialization."""
95
+ assert isinstance(self.container._config, AppConfig)
96
+ assert isinstance(self.container._services, dict)
97
+ assert isinstance(self.container._singletons, dict)
98
+ assert isinstance(self.container._scoped_instances, dict)
99
+
100
+ # Should have default services registered
101
+ assert AppConfig in self.container._singletons
102
+
103
+ def test_container_initialization_with_config(self):
104
+ """Test container initialization with custom config."""
105
+ config = AppConfig()
106
+ container = DependencyContainer(config)
107
+
108
+ assert container._config is config
109
+ assert AppConfig in container._singletons
110
+ assert container._singletons[AppConfig] is config
111
+
112
+ def test_register_singleton_class(self):
113
+ """Test registering singleton service with class."""
114
+ self.container.register_singleton(MockService, MockService, {'name': 'test'})
115
+
116
+ assert MockService in self.container._services
117
+ descriptor = self.container._services[MockService]
118
+ assert descriptor.lifetime == ServiceLifetime.SINGLETON
119
+ assert descriptor.factory_args == {'name': 'test'}
120
+
121
+ def test_register_singleton_instance(self):
122
+ """Test registering singleton service with instance."""
123
+ instance = MockService("test_instance")
124
+ self.container.register_singleton(MockService, instance)
125
+
126
+ assert MockService in self.container._singletons
127
+ assert self.container._singletons[MockService] is instance
128
+
129
+ def test_register_singleton_factory(self):
130
+ """Test registering singleton service with factory function."""
131
+ self.container.register_singleton(MockService, mock_factory, {'name': 'factory_test'})
132
+
133
+ service = self.container.resolve(MockService)
134
+ assert isinstance(service, MockService)
135
+ assert service.name == "factory_created"
136
+ assert service.kwargs == {'name': 'factory_test'}
137
+
138
+ def test_register_transient(self):
139
+ """Test registering transient service."""
140
+ self.container.register_transient(MockService, MockService, {'name': 'transient'})
141
+
142
+ assert MockService in self.container._services
143
+ descriptor = self.container._services[MockService]
144
+ assert descriptor.lifetime == ServiceLifetime.TRANSIENT
145
+
146
+ def test_register_scoped(self):
147
+ """Test registering scoped service."""
148
+ self.container.register_scoped(MockService, MockService, {'name': 'scoped'})
149
+
150
+ assert MockService in self.container._services
151
+ descriptor = self.container._services[MockService]
152
+ assert descriptor.lifetime == ServiceLifetime.SCOPED
153
+
154
+ def test_resolve_singleton(self):
155
+ """Test resolving singleton service."""
156
+ self.container.register_singleton(MockService, MockService, {'name': 'singleton'})
157
+
158
+ service1 = self.container.resolve(MockService)
159
+ service2 = self.container.resolve(MockService)
160
+
161
+ assert service1 is service2
162
+ assert service1.name == 'singleton'
163
+
164
+ def test_resolve_transient(self):
165
+ """Test resolving transient service."""
166
+ self.container.register_transient(MockService, MockService, {'name': 'transient'})
167
+
168
+ service1 = self.container.resolve(MockService)
169
+ service2 = self.container.resolve(MockService)
170
+
171
+ assert service1 is not service2
172
+ assert service1.name == 'transient'
173
+ assert service2.name == 'transient'
174
+
175
+ def test_resolve_scoped(self):
176
+ """Test resolving scoped service."""
177
+ self.container.register_scoped(MockService, MockService, {'name': 'scoped'})
178
+
179
+ service1 = self.container.resolve(MockService)
180
+ service2 = self.container.resolve(MockService)
181
+
182
+ assert service1 is service2 # Same instance within scope
183
+ assert service1.name == 'scoped'
184
+
185
+ def test_resolve_unregistered_service(self):
186
+ """Test resolving unregistered service raises error."""
187
+ class UnregisteredService:
188
+ pass
189
+
190
+ with pytest.raises(ValueError, match="Service UnregisteredService is not registered"):
191
+ self.container.resolve(UnregisteredService)
192
+
193
+ def test_resolve_service_creation_error(self):
194
+ """Test handling service creation errors."""
195
+ def failing_factory():
196
+ raise Exception("Creation failed")
197
+
198
+ self.container.register_singleton(MockService, failing_factory)
199
+
200
+ with pytest.raises(Exception, match="Creation failed"):
201
+ self.container.resolve(MockService)
202
+
203
+ def test_thread_safety(self):
204
+ """Test container thread safety."""
205
+ self.container.register_singleton(MockService, MockService, {'name': 'thread_test'})
206
+
207
+ results = []
208
+
209
+ def resolve_service():
210
+ service = self.container.resolve(MockService)
211
+ results.append(service)
212
+
213
+ threads = [Thread(target=resolve_service) for _ in range(10)]
214
+
215
+ for thread in threads:
216
+ thread.start()
217
+
218
+ for thread in threads:
219
+ thread.join()
220
+
221
+ # All threads should get the same singleton instance
222
+ assert len(results) == 10
223
+ assert all(service is results[0] for service in results)
224
+
225
+ def test_get_tts_provider_default(self):
226
+ """Test getting TTS provider with default settings."""
227
+ with patch.object(TTSProviderFactory, 'get_provider_with_fallback') as mock_get:
228
+ mock_provider = Mock()
229
+ mock_get.return_value = mock_provider
230
+
231
+ provider = self.container.get_tts_provider()
232
+
233
+ assert provider is mock_provider
234
+ mock_get.assert_called_once()
235
+
236
+ def test_get_tts_provider_specific(self):
237
+ """Test getting specific TTS provider."""
238
+ with patch.object(TTSProviderFactory, 'create_provider') as mock_create:
239
+ mock_provider = Mock()
240
+ mock_create.return_value = mock_provider
241
+
242
+ provider = self.container.get_tts_provider('kokoro', lang_code='en')
243
+
244
+ assert provider is mock_provider
245
+ mock_create.assert_called_once_with('kokoro', lang_code='en')
246
+
247
+ def test_get_stt_provider_default(self):
248
+ """Test getting STT provider with default settings."""
249
+ with patch.object(STTProviderFactory, 'create_provider_with_fallback') as mock_get:
250
+ mock_provider = Mock()
251
+ mock_get.return_value = mock_provider
252
+
253
+ provider = self.container.get_stt_provider()
254
+
255
+ assert provider is mock_provider
256
+ mock_get.assert_called_once()
257
+
258
+ def test_get_stt_provider_specific(self):
259
+ """Test getting specific STT provider."""
260
+ with patch.object(STTProviderFactory, 'create_provider') as mock_create:
261
+ mock_provider = Mock()
262
+ mock_create.return_value = mock_provider
263
+
264
+ provider = self.container.get_stt_provider('whisper')
265
+
266
+ assert provider is mock_provider
267
+ mock_create.assert_called_once_with('whisper')
268
+
269
+ def test_get_translation_provider_default(self):
270
+ """Test getting translation provider with default settings."""
271
+ with patch.object(TranslationProviderFactory, 'get_default_provider') as mock_get:
272
+ mock_provider = Mock()
273
+ mock_get.return_value = mock_provider
274
+
275
+ provider = self.container.get_translation_provider()
276
+
277
+ assert provider is mock_provider
278
+ mock_get.assert_called_once_with(None)
279
+
280
+ def test_get_translation_provider_specific(self):
281
+ """Test getting specific translation provider."""
282
+ with patch.object(TranslationProviderFactory, 'create_provider') as mock_create:
283
+ mock_provider = Mock()
284
+ mock_create.return_value = mock_provider
285
+
286
+ config = {'model': 'test'}
287
+ provider = self.container.get_translation_provider(TranslationProviderType.NLLB, config)
288
+
289
+ assert provider is mock_provider
290
+ mock_create.assert_called_once_with(TranslationProviderType.NLLB, config)
291
+
292
+ def test_clear_scoped_instances(self):
293
+ """Test clearing scoped instances."""
294
+ self.container.register_scoped(MockService, MockService)
295
+
296
+ # Create scoped instance
297
+ service = self.container.resolve(MockService)
298
+ assert MockService in self.container._scoped_instances
299
+
300
+ self.container.clear_scoped_instances()
301
+
302
+ assert len(self.container._scoped_instances) == 0
303
+ assert service.cleanup_called is True
304
+
305
+ def test_cleanup_instance_with_cleanup_method(self):
306
+ """Test cleanup of instance with cleanup method."""
307
+ instance = MockService()
308
+ self.container._cleanup_instance(instance)
309
+
310
+ assert instance.cleanup_called is True
311
+
312
+ def test_cleanup_instance_with_dispose_method(self):
313
+ """Test cleanup of instance with dispose method."""
314
+ instance = MockServiceWithDispose()
315
+ self.container._cleanup_instance(instance)
316
+
317
+ assert instance.dispose_called is True
318
+
319
+ def test_cleanup_instance_no_cleanup_method(self):
320
+ """Test cleanup of instance without cleanup method."""
321
+ instance = object()
322
+
323
+ # Should not raise exception
324
+ self.container._cleanup_instance(instance)
325
+
326
+ def test_cleanup_instance_error_handling(self):
327
+ """Test cleanup error handling."""
328
+ instance = Mock()
329
+ instance.cleanup.side_effect = Exception("Cleanup error")
330
+
331
+ # Should not raise exception
332
+ self.container._cleanup_instance(instance)
333
+
334
+ def test_cleanup_container(self):
335
+ """Test full container cleanup."""
336
+ # Register services
337
+ self.container.register_singleton(MockService, MockService)
338
+ self.container.register_scoped(MockServiceWithDispose, MockServiceWithDispose)
339
+
340
+ # Create instances
341
+ singleton = self.container.resolve(MockService)
342
+ scoped = self.container.resolve(MockServiceWithDispose)
343
+
344
+ # Mock factories
345
+ mock_tts_factory = Mock()
346
+ mock_translation_factory = Mock()
347
+ self.container._tts_factory = mock_tts_factory
348
+ self.container._translation_factory = mock_translation_factory
349
+
350
+ self.container.cleanup()
351
+
352
+ # Check cleanup was called
353
+ assert singleton.cleanup_called is True
354
+ assert scoped.dispose_called is True
355
+ mock_tts_factory.cleanup_providers.assert_called_once()
356
+ mock_translation_factory.clear_cache.assert_called_once()
357
+
358
+ # Check instances were cleared
359
+ assert len(self.container._singletons) == 0
360
+ assert len(self.container._scoped_instances) == 0
361
+ assert self.container._tts_factory is None
362
+ assert self.container._translation_factory is None
363
+
364
+ def test_cleanup_factory_error_handling(self):
365
+ """Test cleanup error handling for factories."""
366
+ mock_tts_factory = Mock()
367
+ mock_tts_factory.cleanup_providers.side_effect = Exception("TTS cleanup error")
368
+ self.container._tts_factory = mock_tts_factory
369
+
370
+ # Should not raise exception
371
+ self.container.cleanup()
372
+
373
+ def test_is_registered(self):
374
+ """Test checking if service is registered."""
375
+ assert self.container.is_registered(AppConfig) is True # Default registration
376
+ assert self.container.is_registered(MockService) is False
377
+
378
+ self.container.register_singleton(MockService, MockService)
379
+ assert self.container.is_registered(MockService) is True
380
+
381
+ def test_get_registered_services(self):
382
+ """Test getting registered services info."""
383
+ self.container.register_singleton(MockService, MockService)
384
+ self.container.register_transient(MockServiceWithDispose, MockServiceWithDispose)
385
+
386
+ services = self.container.get_registered_services()
387
+
388
+ assert 'AppConfig' in services
389
+ assert 'MockService' in services
390
+ assert 'MockServiceWithDispose' in services
391
+ assert services['MockService'] == 'singleton'
392
+ assert services['MockServiceWithDispose'] == 'transient'
393
+
394
+ def test_create_scope(self):
395
+ """Test creating dependency scope."""
396
+ scope = self.container.create_scope()
397
+
398
+ assert isinstance(scope, DependencyScope)
399
+ assert scope._parent is self.container
400
+
401
+ def test_context_manager(self):
402
+ """Test container as context manager."""
403
+ with DependencyContainer() as container:
404
+ container.register_singleton(MockService, MockService)
405
+ service = container.resolve(MockService)
406
+
407
+ assert isinstance(service, MockService)
408
+
409
+ # Cleanup should have been called
410
+ assert service.cleanup_called is True
411
+
412
+
413
+ class TestDependencyScope:
414
+ """Test cases for DependencyScope."""
415
+
416
+ def setup_method(self):
417
+ """Set up test fixtures."""
418
+ self.container = DependencyContainer()
419
+ self.scope = DependencyScope(self.container)
420
+
421
+ def teardown_method(self):
422
+ """Clean up after tests."""
423
+ self.scope.cleanup()
424
+ self.container.cleanup()
425
+
426
+ def test_scope_initialization(self):
427
+ """Test scope initialization."""
428
+ assert self.scope._parent is self.container
429
+ assert isinstance(self.scope._scoped_instances, dict)
430
+
431
+ def test_resolve_singleton_from_parent(self):
432
+ """Test resolving singleton from parent container."""
433
+ self.container.register_singleton(MockService, MockService)
434
+
435
+ service1 = self.scope.resolve(MockService)
436
+ service2 = self.scope.resolve(MockService)
437
+
438
+ assert service1 is service2
439
+ assert isinstance(service1, MockService)
440
+
441
+ def test_resolve_scoped_service(self):
442
+ """Test resolving scoped service within scope."""
443
+ self.container.register_scoped(MockService, MockService)
444
+
445
+ service1 = self.scope.resolve(MockService)
446
+ service2 = self.scope.resolve(MockService)
447
+
448
+ assert service1 is service2 # Same within scope
449
+ assert MockService in self.scope._scoped_instances
450
+
451
+ def test_resolve_transient_service(self):
452
+ """Test resolving transient service."""
453
+ self.container.register_transient(MockService, MockService)
454
+
455
+ service1 = self.scope.resolve(MockService)
456
+ service2 = self.scope.resolve(MockService)
457
+
458
+ assert service1 is not service2 # Different instances
459
+
460
+ def test_scope_cleanup(self):
461
+ """Test scope cleanup."""
462
+ self.container.register_scoped(MockService, MockService)
463
+
464
+ service = self.scope.resolve(MockService)
465
+ assert MockService in self.scope._scoped_instances
466
+
467
+ self.scope.cleanup()
468
+
469
+ assert len(self.scope._scoped_instances) == 0
470
+ assert service.cleanup_called is True
471
+
472
+ def test_scope_context_manager(self):
473
+ """Test scope as context manager."""
474
+ self.container.register_scoped(MockService, MockService)
475
+
476
+ with self.container.create_scope() as scope:
477
+ service = scope.resolve(MockService)
478
+ assert isinstance(service, MockService)
479
+
480
+ # Cleanup should have been called
481
+ assert service.cleanup_called is True
482
+
483
+
484
+ class TestGlobalContainer:
485
+ """Test cases for global container functions."""
486
+
487
+ def teardown_method(self):
488
+ """Clean up after tests."""
489
+ cleanup_container()
490
+
491
+ def test_get_container_creates_global(self):
492
+ """Test getting global container creates it if not exists."""
493
+ container = get_container()
494
+
495
+ assert isinstance(container, DependencyContainer)
496
+
497
+ # Second call should return same instance
498
+ container2 = get_container()
499
+ assert container is container2
500
+
501
+ def test_set_container(self):
502
+ """Test setting global container."""
503
+ custom_container = DependencyContainer()
504
+ set_container(custom_container)
505
+
506
+ container = get_container()
507
+ assert container is custom_container
508
+
509
+ def test_set_container_cleans_up_previous(self):
510
+ """Test setting container cleans up previous one."""
511
+ # Get initial container and register service
512
+ container1 = get_container()
513
+ container1.register_singleton(MockService, MockService)
514
+ service = container1.resolve(MockService)
515
+
516
+ # Set new container
517
+ container2 = DependencyContainer()
518
+ set_container(container2)
519
+
520
+ # Previous container should be cleaned up
521
+ assert service.cleanup_called is True
522
+
523
+ # New container should be active
524
+ assert get_container() is container2
525
+
526
+ def test_cleanup_container(self):
527
+ """Test cleaning up global container."""
528
+ container = get_container()
529
+ container.register_singleton(MockService, MockService)
530
+ service = container.resolve(MockService)
531
+
532
+ cleanup_container()
533
+
534
+ # Service should be cleaned up
535
+ assert service.cleanup_called is True
536
+
537
+ # New container should be created on next get
538
+ new_container = get_container()
539
+ assert new_container is not container
tests/unit/infrastructure/factories/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Factory unit tests."""
tests/unit/infrastructure/factories/test_stt_provider_factory.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for STTProviderFactory."""
2
+
3
+ import pytest
4
+ from unittest.mock import Mock, patch
5
+
6
+ from src.infrastructure.stt.provider_factory import STTProviderFactory, ASRFactory
7
+ from src.infrastructure.base.stt_provider_base import STTProviderBase
8
+ from src.domain.exceptions import SpeechRecognitionException
9
+
10
+
11
+ class MockSTTProvider(STTProviderBase):
12
+ """Mock STT provider for testing."""
13
+
14
+ def __init__(self, provider_name="mock", available=True, models=None):
15
+ super().__init__(provider_name)
16
+ self._available = available
17
+ self._models = models or ["model1", "model2"]
18
+
19
+ def _perform_transcription(self, audio_path, model):
20
+ return "Mock transcription"
21
+
22
+ def is_available(self):
23
+ return self._available
24
+
25
+ def get_available_models(self):
26
+ return self._models
27
+
28
+ def get_default_model(self):
29
+ return self._models[0] if self._models else "default"
30
+
31
+
32
+ class TestSTTProviderFactory:
33
+ """Test cases for STTProviderFactory."""
34
+
35
+ def setup_method(self):
36
+ """Set up test fixtures."""
37
+ # Patch the providers registry for testing
38
+ self.original_providers = STTProviderFactory._providers.copy()
39
+ STTProviderFactory._providers = {'mock': MockSTTProvider}
40
+
41
+ def teardown_method(self):
42
+ """Clean up after tests."""
43
+ STTProviderFactory._providers = self.original_providers
44
+
45
+ def test_create_provider_success(self):
46
+ """Test successful provider creation."""
47
+ with patch.object(MockSTTProvider, 'is_available', return_value=True):
48
+ provider = STTProviderFactory.create_provider('mock')
49
+
50
+ assert isinstance(provider, MockSTTProvider)
51
+ assert provider.provider_name == 'mock'
52
+
53
+ def test_create_provider_case_insensitive(self):
54
+ """Test provider creation is case insensitive."""
55
+ with patch.object(MockSTTProvider, 'is_available', return_value=True):
56
+ provider = STTProviderFactory.create_provider('MOCK')
57
+
58
+ assert isinstance(provider, MockSTTProvider)
59
+
60
+ def test_create_provider_unknown(self):
61
+ """Test creating unknown provider raises exception."""
62
+ with pytest.raises(SpeechRecognitionException, match="Unknown STT provider: unknown"):
63
+ STTProviderFactory.create_provider('unknown')
64
+
65
+ def test_create_provider_unavailable(self):
66
+ """Test creating unavailable provider raises exception."""
67
+ with patch.object(MockSTTProvider, 'is_available', return_value=False):
68
+ with pytest.raises(SpeechRecognitionException, match="STT provider mock is not available"):
69
+ STTProviderFactory.create_provider('mock')
70
+
71
+ def test_create_provider_creation_error(self):
72
+ """Test handling provider creation errors."""
73
+ with patch.object(MockSTTProvider, '__init__', side_effect=Exception("Creation error")):
74
+ with pytest.raises(SpeechRecognitionException, match="Failed to create STT provider mock"):
75
+ STTProviderFactory.create_provider('mock')
76
+
77
+ def test_create_provider_with_fallback_success(self):
78
+ """Test creating provider with fallback logic."""
79
+ STTProviderFactory._providers = {
80
+ 'available': MockSTTProvider,
81
+ 'unavailable': MockSTTProvider
82
+ }
83
+
84
+ def mock_is_available(self):
85
+ return self.provider_name == 'available'
86
+
87
+ with patch.object(MockSTTProvider, 'is_available', mock_is_available):
88
+ provider = STTProviderFactory.create_provider_with_fallback('unavailable')
89
+ assert provider.provider_name == 'available'
90
+
91
+ def test_create_provider_with_fallback_preferred_works(self):
92
+ """Test fallback when preferred provider works."""
93
+ with patch.object(MockSTTProvider, 'is_available', return_value=True):
94
+ provider = STTProviderFactory.create_provider_with_fallback('mock')
95
+ assert provider.provider_name == 'mock'
96
+
97
+ def test_create_provider_with_fallback_none_available(self):
98
+ """Test fallback when no providers are available."""
99
+ with patch.object(MockSTTProvider, 'is_available', return_value=False):
100
+ with pytest.raises(SpeechRecognitionException, match="No STT providers are available"):
101
+ STTProviderFactory.create_provider_with_fallback('mock')
102
+
103
+ def test_create_provider_with_fallback_uses_fallback_order(self):
104
+ """Test that fallback uses the predefined fallback order."""
105
+ STTProviderFactory._providers = {
106
+ 'whisper': MockSTTProvider,
107
+ 'parakeet': MockSTTProvider
108
+ }
109
+ STTProviderFactory._fallback_order = ['whisper', 'parakeet']
110
+
111
+ def mock_is_available(self):
112
+ return self.provider_name == 'parakeet' # Only parakeet is available
113
+
114
+ with patch.object(MockSTTProvider, 'is_available', mock_is_available):
115
+ provider = STTProviderFactory.create_provider_with_fallback('unavailable')
116
+ assert provider.provider_name == 'parakeet'
117
+
118
+ def test_get_available_providers(self):
119
+ """Test getting list of available providers."""
120
+ STTProviderFactory._providers = {
121
+ 'available1': MockSTTProvider,
122
+ 'available2': MockSTTProvider,
123
+ 'unavailable': MockSTTProvider
124
+ }
125
+
126
+ def mock_is_available(self):
127
+ return self.provider_name in ['available1', 'available2']
128
+
129
+ with patch.object(MockSTTProvider, 'is_available', mock_is_available):
130
+ available = STTProviderFactory.get_available_providers()
131
+
132
+ assert 'available1' in available
133
+ assert 'available2' in available
134
+ assert 'unavailable' not in available
135
+
136
+ def test_get_available_providers_error_handling(self):
137
+ """Test error handling when checking provider availability."""
138
+ def mock_init_error(self):
139
+ if self.provider_name == 'error':
140
+ raise Exception("Test error")
141
+ super(MockSTTProvider, self).__init__(self.provider_name)
142
+
143
+ STTProviderFactory._providers = {
144
+ 'good': MockSTTProvider,
145
+ 'error': MockSTTProvider
146
+ }
147
+
148
+ with patch.object(MockSTTProvider, '__init__', mock_init_error):
149
+ available = STTProviderFactory.get_available_providers()
150
+
151
+ # Should handle error gracefully and not include error provider
152
+ assert 'error' not in available
153
+
154
+ def test_get_provider_info_success(self):
155
+ """Test getting provider information."""
156
+ with patch.object(MockSTTProvider, 'is_available', return_value=True):
157
+ info = STTProviderFactory.get_provider_info('mock')
158
+
159
+ assert info is not None
160
+ assert info['name'] == 'mock'
161
+ assert info['available'] is True
162
+ assert 'supported_languages' in info
163
+ assert 'available_models' in info
164
+ assert 'default_model' in info
165
+
166
+ def test_get_provider_info_unavailable(self):
167
+ """Test getting info for unavailable provider."""
168
+ with patch.object(MockSTTProvider, 'is_available', return_value=False):
169
+ info = STTProviderFactory.get_provider_info('mock')
170
+
171
+ assert info['available'] is False
172
+ assert info['available_models'] == []
173
+ assert info['default_model'] is None
174
+
175
+ def test_get_provider_info_unknown(self):
176
+ """Test getting info for unknown provider."""
177
+ info = STTProviderFactory.get_provider_info('unknown')
178
+ assert info is None
179
+
180
+ def test_get_provider_info_error(self):
181
+ """Test handling errors when getting provider info."""
182
+ with patch.object(MockSTTProvider, '__init__', side_effect=Exception("Test error")):
183
+ info = STTProviderFactory.get_provider_info('mock')
184
+
185
+ assert info['available'] is False
186
+ assert info['error'] == 'Test error'
187
+
188
+ def test_register_provider(self):
189
+ """Test registering a new provider."""
190
+ class NewProvider(STTProviderBase):
191
+ def _perform_transcription(self, audio_path, model):
192
+ return "New provider"
193
+ def is_available(self):
194
+ return True
195
+ def get_available_models(self):
196
+ return ["new_model"]
197
+ def get_default_model(self):
198
+ return "new_model"
199
+
200
+ STTProviderFactory.register_provider('new', NewProvider)
201
+
202
+ assert 'new' in STTProviderFactory._providers
203
+ assert STTProviderFactory._providers['new'] == NewProvider
204
+
205
+ def test_register_provider_case_insensitive(self):
206
+ """Test provider registration is case insensitive."""
207
+ class NewProvider(STTProviderBase):
208
+ def _perform_transcription(self, audio_path, model):
209
+ return "New provider"
210
+ def is_available(self):
211
+ return True
212
+ def get_available_models(self):
213
+ return ["new_model"]
214
+ def get_default_model(self):
215
+ return "new_model"
216
+
217
+ STTProviderFactory.register_provider('NEW', NewProvider)
218
+
219
+ assert 'new' in STTProviderFactory._providers
220
+
221
+
222
+ class TestASRFactory:
223
+ """Test cases for legacy ASRFactory."""
224
+
225
+ def setup_method(self):
226
+ """Set up test fixtures."""
227
+ self.original_providers = STTProviderFactory._providers.copy()
228
+ STTProviderFactory._providers = {'mock': MockSTTProvider}
229
+
230
+ def teardown_method(self):
231
+ """Clean up after tests."""
232
+ STTProviderFactory._providers = self.original_providers
233
+
234
+ def test_get_model_default(self):
235
+ """Test getting model with default name."""
236
+ STTProviderFactory._providers = {'parakeet': MockSTTProvider}
237
+
238
+ with patch.object(MockSTTProvider, 'is_available', return_value=True):
239
+ provider = ASRFactory.get_model()
240
+ assert provider.provider_name == 'parakeet'
241
+
242
+ def test_get_model_specific(self):
243
+ """Test getting specific model."""
244
+ STTProviderFactory._providers = {'whisper': MockSTTProvider}
245
+
246
+ with patch.object(MockSTTProvider, 'is_available', return_value=True):
247
+ provider = ASRFactory.get_model('whisper')
248
+ assert provider.provider_name == 'whisper'
249
+
250
+ def test_get_model_legacy_mapping(self):
251
+ """Test legacy model name mapping."""
252
+ STTProviderFactory._providers = {'whisper': MockSTTProvider}
253
+
254
+ with patch.object(MockSTTProvider, 'is_available', return_value=True):
255
+ # Test faster-whisper maps to whisper
256
+ provider = ASRFactory.get_model('faster-whisper')
257
+ assert provider.provider_name == 'whisper'
258
+
259
+ def test_get_model_fallback(self):
260
+ """Test fallback when requested model is unavailable."""
261
+ STTProviderFactory._providers = {
262
+ 'whisper': MockSTTProvider,
263
+ 'parakeet': MockSTTProvider
264
+ }
265
+
266
+ def mock_is_available(self):
267
+ return self.provider_name == 'parakeet' # Only parakeet available
268
+
269
+ with patch.object(MockSTTProvider, 'is_available', mock_is_available):
270
+ with patch.object(STTProviderFactory, 'create_provider_with_fallback') as mock_fallback:
271
+ mock_fallback.return_value = MockSTTProvider('parakeet')
272
+
273
+ provider = ASRFactory.get_model('whisper')
274
+ mock_fallback.assert_called_once_with('whisper')
275
+
276
+ def test_get_model_unknown_fallback(self):
277
+ """Test fallback for unknown model names."""
278
+ STTProviderFactory._providers = {'parakeet': MockSTTProvider}
279
+
280
+ with patch.object(STTProviderFactory, 'create_provider_with_fallback') as mock_fallback:
281
+ mock_fallback.return_value = MockSTTProvider('parakeet')
282
+
283
+ provider = ASRFactory.get_model('unknown_model')
284
+ mock_fallback.assert_called_once_with('unknown_model')
tests/unit/infrastructure/factories/test_translation_provider_factory.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for TranslationProviderFactory."""
2
+
3
+ import pytest
4
+ from unittest.mock import Mock, patch
5
+ from enum import Enum
6
+
7
+ from src.infrastructure.translation.provider_factory import (
8
+ TranslationProviderFactory,
9
+ TranslationProviderType,
10
+ create_translation_provider,
11
+ get_default_translation_provider,
12
+ translation_provider_factory
13
+ )
14
+ from src.infrastructure.base.translation_provider_base import TranslationProviderBase
15
+ from src.domain.exceptions import TranslationFailedException
16
+
17
+
18
+ class MockTranslationProvider(TranslationProviderBase):
19
+ """Mock translation provider for testing."""
20
+
21
+ def __init__(self, provider_name="mock", available=True, **kwargs):
22
+ super().__init__(provider_name)
23
+ self._available = available
24
+ self._config = kwargs
25
+
26
+ def _translate_chunk(self, text, source_language, target_language):
27
+ return f"Translated: {text}"
28
+
29
+ def is_available(self):
30
+ return self._available
31
+
32
+ def get_supported_languages(self):
33
+ return {"en": ["es", "fr"], "es": ["en"]}
34
+
35
+
36
+ class TestTranslationProviderFactory:
37
+ """Test cases for TranslationProviderFactory."""
38
+
39
+ def setup_method(self):
40
+ """Set up test fixtures."""
41
+ self.factory = TranslationProviderFactory()
42
+
43
+ # Patch the provider registry for testing
44
+ self.original_registry = TranslationProviderFactory._PROVIDER_REGISTRY.copy()
45
+ TranslationProviderFactory._PROVIDER_REGISTRY = {
46
+ TranslationProviderType.NLLB: MockTranslationProvider
47
+ }
48
+
49
+ def teardown_method(self):
50
+ """Clean up after tests."""
51
+ TranslationProviderFactory._PROVIDER_REGISTRY = self.original_registry
52
+
53
+ def test_factory_initialization(self):
54
+ """Test factory initialization."""
55
+ factory = TranslationProviderFactory()
56
+
57
+ assert isinstance(factory._provider_cache, dict)
58
+ assert isinstance(factory._availability_cache, dict)
59
+ assert len(factory._provider_cache) == 0
60
+ assert len(factory._availability_cache) == 0
61
+
62
+ def test_create_provider_success(self):
63
+ """Test successful provider creation."""
64
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
65
+ provider = self.factory.create_provider(TranslationProviderType.NLLB)
66
+
67
+ assert isinstance(provider, MockTranslationProvider)
68
+ assert provider.provider_name == 'mock'
69
+
70
+ def test_create_provider_with_config(self):
71
+ """Test provider creation with custom config."""
72
+ config = {'model_name': 'custom_model', 'max_chunk_length': 500}
73
+
74
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
75
+ provider = self.factory.create_provider(TranslationProviderType.NLLB, config)
76
+
77
+ assert isinstance(provider, MockTranslationProvider)
78
+ assert provider._config['model_name'] == 'custom_model'
79
+ assert provider._config['max_chunk_length'] == 500
80
+
81
+ def test_create_provider_unknown_type(self):
82
+ """Test creating provider with unknown type."""
83
+ # Create a new enum value that's not in registry
84
+ class UnknownType(Enum):
85
+ UNKNOWN = "unknown"
86
+
87
+ with pytest.raises(TranslationFailedException, match="Unknown translation provider type"):
88
+ self.factory.create_provider(UnknownType.UNKNOWN)
89
+
90
+ def test_create_provider_creation_error(self):
91
+ """Test handling provider creation errors."""
92
+ with patch.object(MockTranslationProvider, '__init__', side_effect=Exception("Creation error")):
93
+ with pytest.raises(TranslationFailedException, match="Failed to create nllb provider"):
94
+ self.factory.create_provider(TranslationProviderType.NLLB)
95
+
96
+ def test_create_provider_caching(self):
97
+ """Test provider instance caching."""
98
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
99
+ provider1 = self.factory.create_provider(TranslationProviderType.NLLB, use_cache=True)
100
+ provider2 = self.factory.create_provider(TranslationProviderType.NLLB, use_cache=True)
101
+
102
+ # Should return the same cached instance
103
+ assert provider1 is provider2
104
+
105
+ def test_create_provider_no_caching(self):
106
+ """Test provider creation without caching."""
107
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
108
+ provider1 = self.factory.create_provider(TranslationProviderType.NLLB, use_cache=False)
109
+ provider2 = self.factory.create_provider(TranslationProviderType.NLLB, use_cache=False)
110
+
111
+ # Should return different instances
112
+ assert provider1 is not provider2
113
+
114
+ def test_get_available_providers(self):
115
+ """Test getting available providers."""
116
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
117
+ available = self.factory.get_available_providers()
118
+
119
+ assert TranslationProviderType.NLLB in available
120
+
121
+ def test_get_available_providers_unavailable(self):
122
+ """Test getting available providers when provider is unavailable."""
123
+ with patch.object(MockTranslationProvider, 'is_available', return_value=False):
124
+ available = self.factory.get_available_providers()
125
+
126
+ assert TranslationProviderType.NLLB not in available
127
+
128
+ def test_get_available_providers_force_check(self):
129
+ """Test forcing availability check ignores cache."""
130
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
131
+ # First call caches result
132
+ available1 = self.factory.get_available_providers()
133
+
134
+ # Change availability
135
+ with patch.object(MockTranslationProvider, 'is_available', return_value=False):
136
+ # Without force_check, should use cached result
137
+ available2 = self.factory.get_available_providers(force_check=False)
138
+ assert available1 == available2
139
+
140
+ # With force_check, should get updated result
141
+ available3 = self.factory.get_available_providers(force_check=True)
142
+ assert available3 != available1
143
+
144
+ def test_get_default_provider(self):
145
+ """Test getting default provider."""
146
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
147
+ provider = self.factory.get_default_provider()
148
+
149
+ assert isinstance(provider, MockTranslationProvider)
150
+
151
+ def test_get_default_provider_no_available(self):
152
+ """Test getting default provider when none are available."""
153
+ with patch.object(MockTranslationProvider, 'is_available', return_value=False):
154
+ with pytest.raises(TranslationFailedException, match="No translation providers are available"):
155
+ self.factory.get_default_provider()
156
+
157
+ def test_get_provider_with_fallback_preferred_available(self):
158
+ """Test fallback when preferred provider is available."""
159
+ preferred = [TranslationProviderType.NLLB]
160
+
161
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
162
+ provider = self.factory.get_provider_with_fallback(preferred)
163
+
164
+ assert isinstance(provider, MockTranslationProvider)
165
+
166
+ def test_get_provider_with_fallback_to_any_available(self):
167
+ """Test fallback to any available provider."""
168
+ # Create mock enum for testing
169
+ class TestType(Enum):
170
+ TEST = "test"
171
+
172
+ preferred = [TestType.TEST] # Not in registry
173
+
174
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
175
+ provider = self.factory.get_provider_with_fallback(preferred)
176
+
177
+ # Should fallback to NLLB since it's available
178
+ assert isinstance(provider, MockTranslationProvider)
179
+
180
+ def test_get_provider_with_fallback_none_available(self):
181
+ """Test fallback when no providers are available."""
182
+ preferred = [TranslationProviderType.NLLB]
183
+
184
+ with patch.object(MockTranslationProvider, 'is_available', return_value=False):
185
+ with pytest.raises(TranslationFailedException, match="None of the preferred translation providers are available"):
186
+ self.factory.get_provider_with_fallback(preferred)
187
+
188
+ def test_clear_cache(self):
189
+ """Test clearing provider cache."""
190
+ # Create cached provider
191
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
192
+ self.factory.create_provider(TranslationProviderType.NLLB)
193
+
194
+ assert len(self.factory._provider_cache) > 0
195
+ assert len(self.factory._availability_cache) > 0
196
+
197
+ self.factory.clear_cache()
198
+
199
+ assert len(self.factory._provider_cache) == 0
200
+ assert len(self.factory._availability_cache) == 0
201
+
202
+ def test_get_provider_info(self):
203
+ """Test getting provider information."""
204
+ info = self.factory.get_provider_info(TranslationProviderType.NLLB)
205
+
206
+ assert info['type'] == 'nllb'
207
+ assert info['class_name'] == 'MockTranslationProvider'
208
+ assert 'available' in info
209
+ assert 'default_config' in info
210
+ assert 'description' in info
211
+
212
+ def test_get_provider_info_unknown_type(self):
213
+ """Test getting info for unknown provider type."""
214
+ class UnknownType(Enum):
215
+ UNKNOWN = "unknown"
216
+
217
+ with pytest.raises(TranslationFailedException, match="Unknown provider type"):
218
+ self.factory.get_provider_info(UnknownType.UNKNOWN)
219
+
220
+ def test_get_all_providers_info(self):
221
+ """Test getting information about all providers."""
222
+ info = self.factory.get_all_providers_info()
223
+
224
+ assert 'nllb' in info
225
+ assert info['nllb']['type'] == 'nllb'
226
+ assert info['nllb']['class_name'] == 'MockTranslationProvider'
227
+
228
+ def test_generate_cache_key(self):
229
+ """Test cache key generation."""
230
+ key1 = self.factory._generate_cache_key(TranslationProviderType.NLLB, None)
231
+ key2 = self.factory._generate_cache_key(TranslationProviderType.NLLB, {})
232
+ key3 = self.factory._generate_cache_key(TranslationProviderType.NLLB, {'a': 1, 'b': 2})
233
+ key4 = self.factory._generate_cache_key(TranslationProviderType.NLLB, {'b': 2, 'a': 1})
234
+
235
+ assert key1 == key2 # None and empty dict should be same
236
+ assert key3 == key4 # Order shouldn't matter
237
+ assert key1 != key3 # Different configs should be different
238
+
239
+ def test_register_provider(self):
240
+ """Test registering new provider type."""
241
+ class NewType(Enum):
242
+ NEW = "new"
243
+
244
+ class NewProvider(TranslationProviderBase):
245
+ def _translate_chunk(self, text, source_language, target_language):
246
+ return f"New: {text}"
247
+ def is_available(self):
248
+ return True
249
+ def get_supported_languages(self):
250
+ return {}
251
+
252
+ TranslationProviderFactory.register_provider(
253
+ NewType.NEW,
254
+ NewProvider,
255
+ {'default_config': 'value'}
256
+ )
257
+
258
+ assert NewType.NEW in TranslationProviderFactory._PROVIDER_REGISTRY
259
+ assert TranslationProviderFactory._PROVIDER_REGISTRY[NewType.NEW] == NewProvider
260
+ assert TranslationProviderFactory._DEFAULT_CONFIGS[NewType.NEW] == {'default_config': 'value'}
261
+
262
+ def test_get_supported_provider_types(self):
263
+ """Test getting supported provider types."""
264
+ types = TranslationProviderFactory.get_supported_provider_types()
265
+
266
+ assert TranslationProviderType.NLLB in types
267
+ assert isinstance(types, list)
268
+
269
+ def test_is_provider_available_caching(self):
270
+ """Test provider availability caching."""
271
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True) as mock_available:
272
+ # First call should check availability
273
+ available1 = self.factory._is_provider_available(TranslationProviderType.NLLB)
274
+
275
+ # Second call should use cache
276
+ available2 = self.factory._is_provider_available(TranslationProviderType.NLLB)
277
+
278
+ assert available1 is True
279
+ assert available2 is True
280
+ # Should only be called once due to caching
281
+ assert mock_available.call_count == 1
282
+
283
+ def test_is_provider_available_force_check(self):
284
+ """Test forcing provider availability check."""
285
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True) as mock_available:
286
+ # First call
287
+ self.factory._is_provider_available(TranslationProviderType.NLLB)
288
+
289
+ # Force check should ignore cache
290
+ self.factory._is_provider_available(TranslationProviderType.NLLB, force_check=True)
291
+
292
+ # Should be called twice
293
+ assert mock_available.call_count == 2
294
+
295
+ def test_is_provider_available_error_handling(self):
296
+ """Test error handling in availability check."""
297
+ with patch.object(MockTranslationProvider, '__init__', side_effect=Exception("Test error")):
298
+ available = self.factory._is_provider_available(TranslationProviderType.NLLB)
299
+
300
+ assert available is False
301
+ # Should cache the error result
302
+ assert self.factory._availability_cache[TranslationProviderType.NLLB] is False
303
+
304
+
305
+ class TestConvenienceFunctions:
306
+ """Test cases for convenience functions."""
307
+
308
+ def setup_method(self):
309
+ """Set up test fixtures."""
310
+ # Patch the provider registry for testing
311
+ self.original_registry = TranslationProviderFactory._PROVIDER_REGISTRY.copy()
312
+ TranslationProviderFactory._PROVIDER_REGISTRY = {
313
+ TranslationProviderType.NLLB: MockTranslationProvider
314
+ }
315
+
316
+ def teardown_method(self):
317
+ """Clean up after tests."""
318
+ TranslationProviderFactory._PROVIDER_REGISTRY = self.original_registry
319
+
320
+ def test_create_translation_provider(self):
321
+ """Test convenience function for creating provider."""
322
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
323
+ provider = create_translation_provider()
324
+
325
+ assert isinstance(provider, MockTranslationProvider)
326
+
327
+ def test_create_translation_provider_with_config(self):
328
+ """Test convenience function with config."""
329
+ config = {'test': 'value'}
330
+
331
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
332
+ provider = create_translation_provider(TranslationProviderType.NLLB, config)
333
+
334
+ assert isinstance(provider, MockTranslationProvider)
335
+ assert provider._config['test'] == 'value'
336
+
337
+ def test_get_default_translation_provider(self):
338
+ """Test convenience function for getting default provider."""
339
+ with patch.object(MockTranslationProvider, 'is_available', return_value=True):
340
+ provider = get_default_translation_provider()
341
+
342
+ assert isinstance(provider, MockTranslationProvider)
343
+
344
+ def test_global_factory_instance(self):
345
+ """Test global factory instance."""
346
+ assert isinstance(translation_provider_factory, TranslationProviderFactory)
tests/unit/infrastructure/factories/test_tts_provider_factory.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for TTSProviderFactory."""
2
+
3
+ import pytest
4
+ from unittest.mock import Mock, patch, MagicMock
5
+
6
+ from src.infrastructure.tts.provider_factory import TTSProviderFactory
7
+ from src.infrastructure.base.tts_provider_base import TTSProviderBase
8
+ from src.domain.exceptions import SpeechSynthesisException
9
+
10
+
11
+ class MockTTSProvider(TTSProviderBase):
12
+ """Mock TTS provider for testing."""
13
+
14
+ def __init__(self, provider_name="mock", available=True, voices=None):
15
+ super().__init__(provider_name)
16
+ self._available = available
17
+ self._voices = voices or ["voice1", "voice2"]
18
+
19
+ def _generate_audio(self, request):
20
+ return b"mock_audio", 44100
21
+
22
+ def _generate_audio_stream(self, request):
23
+ yield b"chunk1", 44100, False
24
+ yield b"chunk2", 44100, True
25
+
26
+ def is_available(self):
27
+ return self._available
28
+
29
+ def get_available_voices(self):
30
+ return self._voices
31
+
32
+
33
+ class TestTTSProviderFactory:
34
+ """Test cases for TTSProviderFactory."""
35
+
36
+ def setup_method(self):
37
+ """Set up test fixtures."""
38
+ self.factory = TTSProviderFactory()
39
+
40
+ def test_factory_initialization(self):
41
+ """Test factory initialization."""
42
+ assert isinstance(self.factory._providers, dict)
43
+ assert isinstance(self.factory._provider_instances, dict)
44
+ assert 'dummy' in self.factory._providers
45
+
46
+ @patch('src.infrastructure.tts.provider_factory.DummyTTSProvider')
47
+ def test_register_default_providers_dummy(self, mock_dummy):
48
+ """Test registration of dummy provider."""
49
+ factory = TTSProviderFactory()
50
+
51
+ assert 'dummy' in factory._providers
52
+ assert factory._providers['dummy'] == mock_dummy
53
+
54
+ @patch('src.infrastructure.tts.provider_factory.KokoroTTSProvider')
55
+ def test_register_default_providers_kokoro_available(self, mock_kokoro):
56
+ """Test registration of Kokoro provider when available."""
57
+ factory = TTSProviderFactory()
58
+
59
+ assert 'kokoro' in factory._providers
60
+ assert factory._providers['kokoro'] == mock_kokoro
61
+
62
+ @patch('src.infrastructure.tts.kokoro_provider.KokoroTTSProvider', side_effect=ImportError("Not available"))
63
+ def test_register_default_providers_kokoro_unavailable(self, mock_kokoro):
64
+ """Test handling when Kokoro provider is not available."""
65
+ factory = TTSProviderFactory()
66
+
67
+ # Should not crash, just not register the provider
68
+ assert 'kokoro' not in factory._providers or factory._providers.get('kokoro') is None
69
+
70
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
71
+ def test_get_available_providers(self):
72
+ """Test getting available providers."""
73
+ with patch.object(MockTTSProvider, 'is_available', return_value=True):
74
+ available = self.factory.get_available_providers()
75
+ assert 'mock' in available
76
+
77
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
78
+ def test_get_available_providers_unavailable(self):
79
+ """Test getting available providers when provider is unavailable."""
80
+ with patch.object(MockTTSProvider, 'is_available', return_value=False):
81
+ available = self.factory.get_available_providers()
82
+ assert 'mock' not in available
83
+
84
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
85
+ def test_get_available_providers_error(self):
86
+ """Test handling errors when checking provider availability."""
87
+ with patch.object(MockTTSProvider, '__init__', side_effect=Exception("Test error")):
88
+ available = self.factory.get_available_providers()
89
+ assert 'mock' not in available
90
+
91
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
92
+ def test_create_provider_success(self):
93
+ """Test successful provider creation."""
94
+ with patch.object(MockTTSProvider, 'is_available', return_value=True):
95
+ provider = self.factory.create_provider('mock')
96
+
97
+ assert isinstance(provider, MockTTSProvider)
98
+ assert provider.provider_name == 'mock'
99
+
100
+ def test_create_provider_unknown(self):
101
+ """Test creating unknown provider raises exception."""
102
+ with pytest.raises(SpeechSynthesisException, match="Unknown TTS provider: unknown"):
103
+ self.factory.create_provider('unknown')
104
+
105
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
106
+ def test_create_provider_unavailable(self):
107
+ """Test creating unavailable provider raises exception."""
108
+ with patch.object(MockTTSProvider, 'is_available', return_value=False):
109
+ with pytest.raises(SpeechSynthesisException, match="TTS provider mock is not available"):
110
+ self.factory.create_provider('mock')
111
+
112
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
113
+ def test_create_provider_creation_error(self):
114
+ """Test handling provider creation errors."""
115
+ with patch.object(MockTTSProvider, '__init__', side_effect=Exception("Creation error")):
116
+ with pytest.raises(SpeechSynthesisException, match="Failed to create TTS provider mock"):
117
+ self.factory.create_provider('mock')
118
+
119
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
120
+ def test_create_provider_with_lang_code(self):
121
+ """Test creating provider with language code."""
122
+ with patch.object(MockTTSProvider, 'is_available', return_value=True):
123
+ # Mock providers that accept lang_code
124
+ self.factory._providers['kokoro'] = MockTTSProvider
125
+
126
+ provider = self.factory.create_provider('kokoro', lang_code='en')
127
+ assert isinstance(provider, MockTTSProvider)
128
+
129
+ @patch.object(TTSProviderFactory, '_providers', {
130
+ 'available1': MockTTSProvider,
131
+ 'available2': MockTTSProvider,
132
+ 'unavailable': MockTTSProvider
133
+ })
134
+ def test_get_provider_with_fallback_success(self):
135
+ """Test getting provider with fallback logic."""
136
+ def mock_is_available(self):
137
+ return self.provider_name in ['available1', 'available2']
138
+
139
+ with patch.object(MockTTSProvider, 'is_available', mock_is_available):
140
+ provider = self.factory.get_provider_with_fallback(['unavailable', 'available1'])
141
+ assert provider.provider_name == 'available1'
142
+
143
+ @patch.object(TTSProviderFactory, '_providers', {
144
+ 'available': MockTTSProvider,
145
+ 'unavailable': MockTTSProvider
146
+ })
147
+ def test_get_provider_with_fallback_to_any_available(self):
148
+ """Test fallback to any available provider."""
149
+ def mock_is_available(self):
150
+ return self.provider_name == 'available'
151
+
152
+ with patch.object(MockTTSProvider, 'is_available', mock_is_available):
153
+ provider = self.factory.get_provider_with_fallback(['unavailable'])
154
+ assert provider.provider_name == 'available'
155
+
156
+ @patch.object(TTSProviderFactory, '_providers', {'unavailable': MockTTSProvider})
157
+ def test_get_provider_with_fallback_none_available(self):
158
+ """Test fallback when no providers are available."""
159
+ with patch.object(MockTTSProvider, 'is_available', return_value=False):
160
+ with pytest.raises(SpeechSynthesisException, match="No TTS providers are available"):
161
+ self.factory.get_provider_with_fallback(['unavailable'])
162
+
163
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
164
+ def test_get_provider_info_success(self):
165
+ """Test getting provider information."""
166
+ with patch.object(MockTTSProvider, 'is_available', return_value=True):
167
+ info = self.factory.get_provider_info('mock')
168
+
169
+ assert info['available'] is True
170
+ assert info['name'] == 'mock'
171
+ assert 'supported_languages' in info
172
+ assert 'available_voices' in info
173
+
174
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
175
+ def test_get_provider_info_unavailable(self):
176
+ """Test getting info for unavailable provider."""
177
+ with patch.object(MockTTSProvider, 'is_available', return_value=False):
178
+ info = self.factory.get_provider_info('mock')
179
+
180
+ assert info['available'] is False
181
+ assert info['available_voices'] == []
182
+
183
+ def test_get_provider_info_unknown(self):
184
+ """Test getting info for unknown provider."""
185
+ info = self.factory.get_provider_info('unknown')
186
+
187
+ assert info['available'] is False
188
+ assert 'error' in info
189
+
190
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
191
+ def test_get_provider_info_error(self):
192
+ """Test handling errors when getting provider info."""
193
+ with patch.object(MockTTSProvider, '__init__', side_effect=Exception("Test error")):
194
+ info = self.factory.get_provider_info('mock')
195
+
196
+ assert info['available'] is False
197
+ assert info['error'] == 'Test error'
198
+
199
+ def test_cleanup_providers(self):
200
+ """Test cleaning up provider instances."""
201
+ # Create mock provider with cleanup method
202
+ mock_provider = Mock()
203
+ mock_provider._cleanup_temp_files = Mock()
204
+
205
+ self.factory._provider_instances['test'] = mock_provider
206
+
207
+ self.factory.cleanup_providers()
208
+
209
+ mock_provider._cleanup_temp_files.assert_called_once()
210
+ assert len(self.factory._provider_instances) == 0
211
+
212
+ def test_cleanup_providers_no_cleanup_method(self):
213
+ """Test cleanup when provider has no cleanup method."""
214
+ mock_provider = Mock()
215
+ del mock_provider._cleanup_temp_files # Remove cleanup method
216
+
217
+ self.factory._provider_instances['test'] = mock_provider
218
+
219
+ # Should not raise exception
220
+ self.factory.cleanup_providers()
221
+ assert len(self.factory._provider_instances) == 0
222
+
223
+ def test_cleanup_providers_cleanup_error(self):
224
+ """Test handling cleanup errors."""
225
+ mock_provider = Mock()
226
+ mock_provider._cleanup_temp_files.side_effect = Exception("Cleanup error")
227
+
228
+ self.factory._provider_instances['test'] = mock_provider
229
+
230
+ # Should not raise exception
231
+ self.factory.cleanup_providers()
232
+ assert len(self.factory._provider_instances) == 0
233
+
234
+ @patch.object(TTSProviderFactory, '_providers', {'mock': MockTTSProvider})
235
+ def test_provider_instance_caching(self):
236
+ """Test that provider instances are cached."""
237
+ with patch.object(MockTTSProvider, 'is_available', return_value=True):
238
+ # First call to get_available_providers should create instance
239
+ available1 = self.factory.get_available_providers()
240
+
241
+ # Second call should use cached instance
242
+ available2 = self.factory.get_available_providers()
243
+
244
+ assert available1 == available2
245
+ assert 'mock' in self.factory._provider_instances