File size: 12,059 Bytes
4e4961e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e4961e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3cb97b
 
 
 
 
4e4961e
 
 
 
e3cb97b
4e4961e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6514731
e3cb97b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
"""
Base class for TTS provider implementations.

This module provides the abstract base class for all Text-to-Speech provider
implementations in the infrastructure layer. It implements common functionality
and defines the contract that all TTS providers must follow.

The base class handles:
- Common validation logic
- File management and cleanup
- Error handling and logging
- Audio format processing
- Provider lifecycle management

Example implementation:
    ```python
    from src.infrastructure.base.tts_provider_base import TTSProviderBase

    class MyTTSProvider(TTSProviderBase):
        def __init__(self):
            super().__init__("my_tts", ["en", "es"])

        def _generate_audio(self, request):
            # Implement TTS-specific logic
            audio_data = my_tts_engine.synthesize(request.text_content.text)
            return audio_data, 22050  # audio_bytes, sample_rate

        def is_available(self):
            return my_tts_engine.is_loaded()

        def get_available_voices(self):
            return ["voice1", "voice2"]
    ```
"""

import logging
import os
import time
import tempfile
from abc import ABC, abstractmethod
from typing import Iterator, Optional, TYPE_CHECKING
from pathlib import Path

if TYPE_CHECKING:
    from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
    from ...domain.models.audio_content import AudioContent
    from ...domain.models.audio_chunk import AudioChunk

from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService
from ...domain.exceptions import SpeechSynthesisException

logger = logging.getLogger(__name__)


class TTSProviderBase(ISpeechSynthesisService, ABC):
    """
    Abstract base class for TTS provider implementations.

    This class provides a foundation for implementing Text-to-Speech providers
    in the infrastructure layer. It handles common concerns like validation,
    file management, error handling, and audio processing while allowing
    concrete implementations to focus on provider-specific logic.

    Key features:
    - Automatic validation of synthesis requests
    - Temporary file management with cleanup
    - Consistent error handling and logging
    - Support for both batch and streaming synthesis
    - Audio format standardization
    - Provider availability checking

    Subclasses must implement:
    - _generate_audio(): Core synthesis logic
    - _generate_audio_stream(): Streaming synthesis (optional)
    - is_available(): Provider availability check
    - get_available_voices(): Voice enumeration

    The base class ensures that all providers follow the same patterns
    for error handling, logging, and resource management, making the
    system more maintainable and predictable.
    """

    def __init__(self, provider_name: str, supported_languages: list[str] = None):
        """
        Initialize the TTS provider.

        Sets up the provider with basic configuration and creates necessary
        directories for temporary file storage. This constructor should be
        called by all subclass implementations.

        Args:
            provider_name: Unique identifier for this TTS provider (e.g., "kokoro", "dia").
                         Used for logging, error messages, and provider selection.
            supported_languages: List of ISO language codes supported by this provider
                               (e.g., ["en", "zh", "es"]). If None, no language validation
                               will be performed.

        Example:
            ```python
            class MyTTSProvider(TTSProviderBase):
                def __init__(self):
                    super().__init__(
                        provider_name="my_tts",
                        supported_languages=["en", "es", "fr"]
                    )
            ```
        """
        self.provider_name = provider_name
        self.supported_languages = supported_languages or []
        self._output_dir = self._ensure_output_directory()

    def synthesize(self, request: 'SpeechSynthesisRequest') -> 'AudioContent':
        """
        Synthesize speech from text.

        Args:
            request: The speech synthesis request

        Returns:
            AudioContent: The synthesized audio

        Raises:
            SpeechSynthesisException: If synthesis fails
        """
        try:
            logger.info(f"Starting synthesis with {self.provider_name} provider")
            self._validate_request(request)

            # Generate audio using provider-specific implementation
            audio_data, sample_rate = self._generate_audio(request)

            # Create AudioContent from the generated data
            from ...domain.models.audio_content import AudioContent

            audio_content = AudioContent(
                data=audio_data,
                format='wav',  # Most providers output WAV
                sample_rate=sample_rate,
                duration=self._calculate_duration(audio_data, sample_rate),
                filename=f"{self.provider_name}_{int(time.time())}.wav"
            )

            logger.info(f"Synthesis completed successfully with {self.provider_name}")
            return audio_content

        except Exception as e:
            logger.error(f"Synthesis failed with {self.provider_name}: {str(e)}")
            raise SpeechSynthesisException(f"TTS synthesis failed: {str(e)}") from e

    def synthesize_stream(self, request: 'SpeechSynthesisRequest') -> Iterator['AudioChunk']:
        """
        Synthesize speech from text as a stream.

        Args:
            request: The speech synthesis request

        Returns:
            Iterator[AudioChunk]: Stream of audio chunks

        Raises:
            SpeechSynthesisException: If synthesis fails
        """
        try:
            logger.info(f"Starting streaming synthesis with {self.provider_name} provider")
            self._validate_request(request)

            # Generate audio stream using provider-specific implementation
            chunk_index = 0
            for audio_data, sample_rate, is_final in self._generate_audio_stream(request):
                from ...domain.models.audio_chunk import AudioChunk

                chunk = AudioChunk(
                    data=audio_data,
                    format='wav',
                    sample_rate=sample_rate,
                    chunk_index=chunk_index,
                    is_final=is_final,
                    timestamp=time.time()
                )

                yield chunk
                chunk_index += 1

            logger.info(f"Streaming synthesis completed with {self.provider_name}")

        except Exception as e:
            logger.error(f"Streaming synthesis failed with {self.provider_name}: {str(e)}")
            raise SpeechSynthesisException(f"TTS streaming synthesis failed: {str(e)}") from e

    @abstractmethod
    def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
        """
        Generate audio data from synthesis request.

        Args:
            request: The speech synthesis request

        Returns:
            tuple: (audio_data_bytes, sample_rate)
        """
        pass

    @abstractmethod
    def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
        """
        Generate audio data stream from synthesis request.

        Args:
            request: The speech synthesis request

        Returns:
            Iterator: (audio_data_bytes, sample_rate, is_final) tuples
        """
        pass

    @abstractmethod
    def is_available(self) -> bool:
        """
        Check if the TTS provider is available and ready to use.

        Returns:
            bool: True if provider is available, False otherwise
        """
        pass

    @abstractmethod
    def get_available_voices(self) -> list[str]:
        """
        Get list of available voices for this provider.

        Returns:
            list[str]: List of voice identifiers
        """
        pass

    def _validate_request(self, request: 'SpeechSynthesisRequest') -> None:
        """
        Validate the synthesis request.

        Args:
            request: The synthesis request to validate

        Raises:
            SpeechSynthesisException: If request is invalid
        """
        if not request.text_content.text.strip():
            raise SpeechSynthesisException("Text content cannot be empty")

        if self.supported_languages and request.text_content.language not in self.supported_languages:
            raise SpeechSynthesisException(
                f"Language {request.text_content.language} not supported by {self.provider_name}. "
                f"Supported languages: {self.supported_languages}"
            )

        available_voices = self.get_available_voices()
        if available_voices and request.voice_settings.voice_id not in available_voices:
            raise SpeechSynthesisException(
                f"Voice {request.voice_settings.voice_id} not available for {self.provider_name}. "
                f"Available voices: {available_voices}"
            )

    def _ensure_output_directory(self) -> Path:
        """
        Ensure output directory exists and return its path.

        Returns:
            Path: Path to the output directory
        """
        output_dir = Path(tempfile.gettempdir()) / "tts_output"
        output_dir.mkdir(exist_ok=True)
        return output_dir

    def _generate_output_path(self, prefix: str = None, extension: str = "wav") -> Path:
        """
        Generate a unique output path for audio files.

        Args:
            prefix: Optional prefix for the filename
            extension: File extension (default: wav)

        Returns:
            Path: Unique file path
        """
        prefix = prefix or self.provider_name
        timestamp = int(time.time() * 1000)
        filename = f"{prefix}_{timestamp}.{extension}"
        return self._output_dir / filename

    def _calculate_duration(self, audio_data: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> float:
        """
        Calculate audio duration from raw audio data.

        Args:
            audio_data: Raw audio data in bytes
            sample_rate: Sample rate in Hz
            channels: Number of audio channels (default: 1)
            sample_width: Sample width in bytes (default: 2 for 16-bit)

        Returns:
            float: Duration in seconds
        """
        if not audio_data or sample_rate <= 0:
            return 0.0

        bytes_per_sample = channels * sample_width
        total_samples = len(audio_data) // bytes_per_sample
        return total_samples / sample_rate

    def _cleanup_temp_files(self, max_age_hours: int = 24) -> None:
        """
        Clean up old temporary files.

        Args:
            max_age_hours: Maximum age of files to keep in hours
        """
        try:
            current_time = time.time()
            max_age_seconds = max_age_hours * 3600

            for file_path in self._output_dir.glob("*"):
                if file_path.is_file():
                    file_age = current_time - file_path.stat().st_mtime
                    if file_age > max_age_seconds:
                        file_path.unlink()
                        logger.info(f"Cleaned up old temp file: {file_path}")

        except Exception as e:
            logger.warning(f"Failed to cleanup temp files: {str(e)}")

    def _handle_provider_error(self, error: Exception, context: str = "") -> None:
        """
        Handle provider-specific errors and convert to domain exceptions.

        Args:
            error: The original error
            context: Additional context about when the error occurred
        """
        error_msg = f"{self.provider_name} error"
        if context:
            error_msg += f" during {context}"
        error_msg += f": {str(error)}"

        logger.error(error_msg, exception=error)
        raise SpeechSynthesisException(error_msg) from error