Spaces:
Build error
Build error
File size: 12,059 Bytes
4e4961e e3cb97b 4e4961e e3cb97b 4e4961e e3cb97b 4e4961e e3cb97b fdc056d e3cb97b 6514731 e3cb97b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
"""
Base class for TTS provider implementations.
This module provides the abstract base class for all Text-to-Speech provider
implementations in the infrastructure layer. It implements common functionality
and defines the contract that all TTS providers must follow.
The base class handles:
- Common validation logic
- File management and cleanup
- Error handling and logging
- Audio format processing
- Provider lifecycle management
Example implementation:
```python
from src.infrastructure.base.tts_provider_base import TTSProviderBase
class MyTTSProvider(TTSProviderBase):
def __init__(self):
super().__init__("my_tts", ["en", "es"])
def _generate_audio(self, request):
# Implement TTS-specific logic
audio_data = my_tts_engine.synthesize(request.text_content.text)
return audio_data, 22050 # audio_bytes, sample_rate
def is_available(self):
return my_tts_engine.is_loaded()
def get_available_voices(self):
return ["voice1", "voice2"]
```
"""
import logging
import os
import time
import tempfile
from abc import ABC, abstractmethod
from typing import Iterator, Optional, TYPE_CHECKING
from pathlib import Path
if TYPE_CHECKING:
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
from ...domain.models.audio_content import AudioContent
from ...domain.models.audio_chunk import AudioChunk
from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService
from ...domain.exceptions import SpeechSynthesisException
logger = logging.getLogger(__name__)
class TTSProviderBase(ISpeechSynthesisService, ABC):
"""
Abstract base class for TTS provider implementations.
This class provides a foundation for implementing Text-to-Speech providers
in the infrastructure layer. It handles common concerns like validation,
file management, error handling, and audio processing while allowing
concrete implementations to focus on provider-specific logic.
Key features:
- Automatic validation of synthesis requests
- Temporary file management with cleanup
- Consistent error handling and logging
- Support for both batch and streaming synthesis
- Audio format standardization
- Provider availability checking
Subclasses must implement:
- _generate_audio(): Core synthesis logic
- _generate_audio_stream(): Streaming synthesis (optional)
- is_available(): Provider availability check
- get_available_voices(): Voice enumeration
The base class ensures that all providers follow the same patterns
for error handling, logging, and resource management, making the
system more maintainable and predictable.
"""
def __init__(self, provider_name: str, supported_languages: list[str] = None):
"""
Initialize the TTS provider.
Sets up the provider with basic configuration and creates necessary
directories for temporary file storage. This constructor should be
called by all subclass implementations.
Args:
provider_name: Unique identifier for this TTS provider (e.g., "kokoro", "dia").
Used for logging, error messages, and provider selection.
supported_languages: List of ISO language codes supported by this provider
(e.g., ["en", "zh", "es"]). If None, no language validation
will be performed.
Example:
```python
class MyTTSProvider(TTSProviderBase):
def __init__(self):
super().__init__(
provider_name="my_tts",
supported_languages=["en", "es", "fr"]
)
```
"""
self.provider_name = provider_name
self.supported_languages = supported_languages or []
self._output_dir = self._ensure_output_directory()
def synthesize(self, request: 'SpeechSynthesisRequest') -> 'AudioContent':
"""
Synthesize speech from text.
Args:
request: The speech synthesis request
Returns:
AudioContent: The synthesized audio
Raises:
SpeechSynthesisException: If synthesis fails
"""
try:
logger.info(f"Starting synthesis with {self.provider_name} provider")
self._validate_request(request)
# Generate audio using provider-specific implementation
audio_data, sample_rate = self._generate_audio(request)
# Create AudioContent from the generated data
from ...domain.models.audio_content import AudioContent
audio_content = AudioContent(
data=audio_data,
format='wav', # Most providers output WAV
sample_rate=sample_rate,
duration=self._calculate_duration(audio_data, sample_rate),
filename=f"{self.provider_name}_{int(time.time())}.wav"
)
logger.info(f"Synthesis completed successfully with {self.provider_name}")
return audio_content
except Exception as e:
logger.error(f"Synthesis failed with {self.provider_name}: {str(e)}")
raise SpeechSynthesisException(f"TTS synthesis failed: {str(e)}") from e
def synthesize_stream(self, request: 'SpeechSynthesisRequest') -> Iterator['AudioChunk']:
"""
Synthesize speech from text as a stream.
Args:
request: The speech synthesis request
Returns:
Iterator[AudioChunk]: Stream of audio chunks
Raises:
SpeechSynthesisException: If synthesis fails
"""
try:
logger.info(f"Starting streaming synthesis with {self.provider_name} provider")
self._validate_request(request)
# Generate audio stream using provider-specific implementation
chunk_index = 0
for audio_data, sample_rate, is_final in self._generate_audio_stream(request):
from ...domain.models.audio_chunk import AudioChunk
chunk = AudioChunk(
data=audio_data,
format='wav',
sample_rate=sample_rate,
chunk_index=chunk_index,
is_final=is_final,
timestamp=time.time()
)
yield chunk
chunk_index += 1
logger.info(f"Streaming synthesis completed with {self.provider_name}")
except Exception as e:
logger.error(f"Streaming synthesis failed with {self.provider_name}: {str(e)}")
raise SpeechSynthesisException(f"TTS streaming synthesis failed: {str(e)}") from e
@abstractmethod
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
"""
Generate audio data from synthesis request.
Args:
request: The speech synthesis request
Returns:
tuple: (audio_data_bytes, sample_rate)
"""
pass
@abstractmethod
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
"""
Generate audio data stream from synthesis request.
Args:
request: The speech synthesis request
Returns:
Iterator: (audio_data_bytes, sample_rate, is_final) tuples
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""
Check if the TTS provider is available and ready to use.
Returns:
bool: True if provider is available, False otherwise
"""
pass
@abstractmethod
def get_available_voices(self) -> list[str]:
"""
Get list of available voices for this provider.
Returns:
list[str]: List of voice identifiers
"""
pass
def _validate_request(self, request: 'SpeechSynthesisRequest') -> None:
"""
Validate the synthesis request.
Args:
request: The synthesis request to validate
Raises:
SpeechSynthesisException: If request is invalid
"""
if not request.text_content.text.strip():
raise SpeechSynthesisException("Text content cannot be empty")
if self.supported_languages and request.text_content.language not in self.supported_languages:
raise SpeechSynthesisException(
f"Language {request.text_content.language} not supported by {self.provider_name}. "
f"Supported languages: {self.supported_languages}"
)
available_voices = self.get_available_voices()
if available_voices and request.voice_settings.voice_id not in available_voices:
raise SpeechSynthesisException(
f"Voice {request.voice_settings.voice_id} not available for {self.provider_name}. "
f"Available voices: {available_voices}"
)
def _ensure_output_directory(self) -> Path:
"""
Ensure output directory exists and return its path.
Returns:
Path: Path to the output directory
"""
output_dir = Path(tempfile.gettempdir()) / "tts_output"
output_dir.mkdir(exist_ok=True)
return output_dir
def _generate_output_path(self, prefix: str = None, extension: str = "wav") -> Path:
"""
Generate a unique output path for audio files.
Args:
prefix: Optional prefix for the filename
extension: File extension (default: wav)
Returns:
Path: Unique file path
"""
prefix = prefix or self.provider_name
timestamp = int(time.time() * 1000)
filename = f"{prefix}_{timestamp}.{extension}"
return self._output_dir / filename
def _calculate_duration(self, audio_data: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> float:
"""
Calculate audio duration from raw audio data.
Args:
audio_data: Raw audio data in bytes
sample_rate: Sample rate in Hz
channels: Number of audio channels (default: 1)
sample_width: Sample width in bytes (default: 2 for 16-bit)
Returns:
float: Duration in seconds
"""
if not audio_data or sample_rate <= 0:
return 0.0
bytes_per_sample = channels * sample_width
total_samples = len(audio_data) // bytes_per_sample
return total_samples / sample_rate
def _cleanup_temp_files(self, max_age_hours: int = 24) -> None:
"""
Clean up old temporary files.
Args:
max_age_hours: Maximum age of files to keep in hours
"""
try:
current_time = time.time()
max_age_seconds = max_age_hours * 3600
for file_path in self._output_dir.glob("*"):
if file_path.is_file():
file_age = current_time - file_path.stat().st_mtime
if file_age > max_age_seconds:
file_path.unlink()
logger.info(f"Cleaned up old temp file: {file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup temp files: {str(e)}")
def _handle_provider_error(self, error: Exception, context: str = "") -> None:
"""
Handle provider-specific errors and convert to domain exceptions.
Args:
error: The original error
context: Additional context about when the error occurred
"""
error_msg = f"{self.provider_name} error"
if context:
error_msg += f" during {context}"
error_msg += f": {str(error)}"
logger.error(error_msg, exception=error)
raise SpeechSynthesisException(error_msg) from error |