Michael Hu commited on
Commit
1f9c751
Β·
1 Parent(s): e3cb97b

Migrate existing TTS providers to infrastructure layer

Browse files
src/domain/models/speech_synthesis_request.py CHANGED
@@ -10,7 +10,7 @@ from .voice_settings import VoiceSettings
10
  class SpeechSynthesisRequest:
11
  """Value object representing a speech synthesis request."""
12
 
13
- text: TextContent
14
  voice_settings: VoiceSettings
15
  output_format: str = 'wav'
16
  sample_rate: Optional[int] = None
@@ -21,7 +21,7 @@ class SpeechSynthesisRequest:
21
 
22
  def _validate(self):
23
  """Validate speech synthesis request properties."""
24
- if not isinstance(self.text, TextContent):
25
  raise TypeError("Text must be a TextContent instance")
26
 
27
  if not isinstance(self.voice_settings, VoiceSettings):
@@ -44,8 +44,8 @@ class SpeechSynthesisRequest:
44
  raise ValueError("Sample rate must be between 8000 and 192000 Hz")
45
 
46
  # Validate that text and voice settings have compatible languages
47
- if self.text.language != self.voice_settings.language:
48
- raise ValueError(f"Text language ({self.text.language}) must match voice language ({self.voice_settings.language})")
49
 
50
  @property
51
  def estimated_duration_seconds(self) -> float:
@@ -53,12 +53,12 @@ class SpeechSynthesisRequest:
53
  # Rough estimation: average speaking rate is about 150-200 words per minute
54
  # Adjusted by speed setting
55
  words_per_minute = 175 / self.voice_settings.speed
56
- return (self.text.word_count / words_per_minute) * 60
57
 
58
  @property
59
  def is_long_text(self) -> bool:
60
  """Check if the text is considered long for TTS processing."""
61
- return self.text.character_count > 5000
62
 
63
  @property
64
  def effective_sample_rate(self) -> int:
@@ -68,7 +68,7 @@ class SpeechSynthesisRequest:
68
  def with_output_format(self, output_format: str) -> 'SpeechSynthesisRequest':
69
  """Create a new SpeechSynthesisRequest with different output format."""
70
  return SpeechSynthesisRequest(
71
- text=self.text,
72
  voice_settings=self.voice_settings,
73
  output_format=output_format,
74
  sample_rate=self.sample_rate
@@ -77,7 +77,7 @@ class SpeechSynthesisRequest:
77
  def with_sample_rate(self, sample_rate: Optional[int]) -> 'SpeechSynthesisRequest':
78
  """Create a new SpeechSynthesisRequest with different sample rate."""
79
  return SpeechSynthesisRequest(
80
- text=self.text,
81
  voice_settings=self.voice_settings,
82
  output_format=self.output_format,
83
  sample_rate=sample_rate
@@ -86,7 +86,7 @@ class SpeechSynthesisRequest:
86
  def with_voice_settings(self, voice_settings: VoiceSettings) -> 'SpeechSynthesisRequest':
87
  """Create a new SpeechSynthesisRequest with different voice settings."""
88
  return SpeechSynthesisRequest(
89
- text=self.text,
90
  voice_settings=voice_settings,
91
  output_format=self.output_format,
92
  sample_rate=self.sample_rate
 
10
  class SpeechSynthesisRequest:
11
  """Value object representing a speech synthesis request."""
12
 
13
+ text_content: TextContent
14
  voice_settings: VoiceSettings
15
  output_format: str = 'wav'
16
  sample_rate: Optional[int] = None
 
21
 
22
  def _validate(self):
23
  """Validate speech synthesis request properties."""
24
+ if not isinstance(self.text_content, TextContent):
25
  raise TypeError("Text must be a TextContent instance")
26
 
27
  if not isinstance(self.voice_settings, VoiceSettings):
 
44
  raise ValueError("Sample rate must be between 8000 and 192000 Hz")
45
 
46
  # Validate that text and voice settings have compatible languages
47
+ if self.text_content.language != self.voice_settings.language:
48
+ raise ValueError(f"Text language ({self.text_content.language}) must match voice language ({self.voice_settings.language})")
49
 
50
  @property
51
  def estimated_duration_seconds(self) -> float:
 
53
  # Rough estimation: average speaking rate is about 150-200 words per minute
54
  # Adjusted by speed setting
55
  words_per_minute = 175 / self.voice_settings.speed
56
+ return (self.text_content.word_count / words_per_minute) * 60
57
 
58
  @property
59
  def is_long_text(self) -> bool:
60
  """Check if the text is considered long for TTS processing."""
61
+ return self.text_content.character_count > 5000
62
 
63
  @property
64
  def effective_sample_rate(self) -> int:
 
68
  def with_output_format(self, output_format: str) -> 'SpeechSynthesisRequest':
69
  """Create a new SpeechSynthesisRequest with different output format."""
70
  return SpeechSynthesisRequest(
71
+ text_content=self.text_content,
72
  voice_settings=self.voice_settings,
73
  output_format=output_format,
74
  sample_rate=self.sample_rate
 
77
  def with_sample_rate(self, sample_rate: Optional[int]) -> 'SpeechSynthesisRequest':
78
  """Create a new SpeechSynthesisRequest with different sample rate."""
79
  return SpeechSynthesisRequest(
80
+ text_content=self.text_content,
81
  voice_settings=self.voice_settings,
82
  output_format=self.output_format,
83
  sample_rate=sample_rate
 
86
  def with_voice_settings(self, voice_settings: VoiceSettings) -> 'SpeechSynthesisRequest':
87
  """Create a new SpeechSynthesisRequest with different voice settings."""
88
  return SpeechSynthesisRequest(
89
+ text_content=self.text_content,
90
  voice_settings=voice_settings,
91
  output_format=self.output_format,
92
  sample_rate=self.sample_rate
src/infrastructure/tts/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TTS provider implementations."""
2
+
3
+ from .provider_factory import TTSProviderFactory
4
+ from .dummy_provider import DummyTTSProvider
5
+
6
+ # Try to import optional providers
7
+ try:
8
+ from .kokoro_provider import KokoroTTSProvider
9
+ except ImportError:
10
+ KokoroTTSProvider = None
11
+
12
+ try:
13
+ from .dia_provider import DiaTTSProvider
14
+ except ImportError:
15
+ DiaTTSProvider = None
16
+
17
+ try:
18
+ from .cosyvoice2_provider import CosyVoice2TTSProvider
19
+ except ImportError:
20
+ CosyVoice2TTSProvider = None
21
+
22
+ __all__ = [
23
+ 'TTSProviderFactory',
24
+ 'DummyTTSProvider',
25
+ 'KokoroTTSProvider',
26
+ 'DiaTTSProvider',
27
+ 'CosyVoice2TTSProvider'
28
+ ]
src/infrastructure/tts/cosyvoice2_provider.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CosyVoice2 TTS provider implementation."""
2
+
3
+ import logging
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import io
7
+ from typing import Iterator, TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
11
+
12
+ from ..base.tts_provider_base import TTSProviderBase
13
+ from ...domain.exceptions import SpeechSynthesisException
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Flag to track CosyVoice2 availability
18
+ COSYVOICE2_AVAILABLE = False
19
+ DEFAULT_SAMPLE_RATE = 24000
20
+
21
+ # Try to import CosyVoice2 dependencies
22
+ try:
23
+ import torch
24
+ # Import CosyVoice2 - assuming it's installed and has a similar API to Dia
25
+ # since they're both from nari-labs according to the GitHub link
26
+ from cosyvoice2.model import CosyVoice2
27
+ COSYVOICE2_AVAILABLE = True
28
+ logger.info("CosyVoice2 TTS engine is available")
29
+ except ImportError:
30
+ logger.warning("CosyVoice2 TTS engine is not available")
31
+ except ModuleNotFoundError as e:
32
+ logger.warning(f"CosyVoice2 TTS engine is not available: {str(e)}")
33
+ COSYVOICE2_AVAILABLE = False
34
+
35
+
36
+ class CosyVoice2TTSProvider(TTSProviderBase):
37
+ """CosyVoice2 TTS provider implementation."""
38
+
39
+ def __init__(self, lang_code: str = 'z'):
40
+ """Initialize the CosyVoice2 TTS provider."""
41
+ super().__init__(
42
+ provider_name="CosyVoice2",
43
+ supported_languages=['en', 'z'] # CosyVoice2 supports English and multilingual
44
+ )
45
+ self.lang_code = lang_code
46
+ self.model = None
47
+
48
+ def _ensure_model(self):
49
+ """Ensure the model is loaded."""
50
+ if self.model is None and COSYVOICE2_AVAILABLE:
51
+ try:
52
+ import torch
53
+ from cosyvoice2.model import CosyVoice2
54
+ self.model = CosyVoice2.from_pretrained()
55
+ logger.info("CosyVoice2 model successfully loaded")
56
+ except ImportError as e:
57
+ logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}")
58
+ self.model = None
59
+ except FileNotFoundError as e:
60
+ logger.error(f"Failed to load CosyVoice2 model files: {str(e)}")
61
+ self.model = None
62
+ except Exception as e:
63
+ logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}")
64
+ self.model = None
65
+ return self.model is not None
66
+
67
+ def is_available(self) -> bool:
68
+ """Check if CosyVoice2 TTS is available."""
69
+ return COSYVOICE2_AVAILABLE and self._ensure_model()
70
+
71
+ def get_available_voices(self) -> list[str]:
72
+ """Get available voices for CosyVoice2."""
73
+ # CosyVoice2 typically uses a default voice
74
+ return ['default']
75
+
76
+ def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
77
+ """Generate audio using CosyVoice2 TTS."""
78
+ if not self.is_available():
79
+ raise SpeechSynthesisException("CosyVoice2 TTS engine is not available")
80
+
81
+ try:
82
+ import torch
83
+
84
+ # Extract parameters from request
85
+ text = request.text_content.text
86
+
87
+ # Generate audio using CosyVoice2
88
+ with torch.inference_mode():
89
+ # Assuming CosyVoice2 has a similar API to Dia
90
+ output_audio_np = self.model.generate(
91
+ text,
92
+ max_tokens=None,
93
+ cfg_scale=3.0,
94
+ temperature=1.3,
95
+ top_p=0.95,
96
+ use_torch_compile=False,
97
+ verbose=False
98
+ )
99
+
100
+ if output_audio_np is None:
101
+ raise SpeechSynthesisException("CosyVoice2 model returned None for audio output")
102
+
103
+ # Convert numpy array to bytes
104
+ audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
105
+ return audio_bytes, DEFAULT_SAMPLE_RATE
106
+
107
+ except Exception as e:
108
+ self._handle_provider_error(e, "audio generation")
109
+
110
+ def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
111
+ """Generate audio stream using CosyVoice2 TTS."""
112
+ if not self.is_available():
113
+ raise SpeechSynthesisException("CosyVoice2 TTS engine is not available")
114
+
115
+ try:
116
+ import torch
117
+
118
+ # Extract parameters from request
119
+ text = request.text_content.text
120
+
121
+ # Generate audio using CosyVoice2
122
+ with torch.inference_mode():
123
+ # Assuming CosyVoice2 has a similar API to Dia
124
+ output_audio_np = self.model.generate(
125
+ text,
126
+ max_tokens=None,
127
+ cfg_scale=3.0,
128
+ temperature=1.3,
129
+ top_p=0.95,
130
+ use_torch_compile=False,
131
+ verbose=False
132
+ )
133
+
134
+ if output_audio_np is None:
135
+ raise SpeechSynthesisException("CosyVoice2 model returned None for audio output")
136
+
137
+ # Convert numpy array to bytes
138
+ audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
139
+ # CosyVoice2 generates complete audio in one go
140
+ yield audio_bytes, DEFAULT_SAMPLE_RATE, True
141
+
142
+ except Exception as e:
143
+ self._handle_provider_error(e, "streaming audio generation")
144
+
145
+ def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
146
+ """Convert numpy audio array to bytes."""
147
+ try:
148
+ # Create an in-memory buffer
149
+ buffer = io.BytesIO()
150
+
151
+ # Write audio data to buffer as WAV
152
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
153
+
154
+ # Get bytes from buffer
155
+ buffer.seek(0)
156
+ return buffer.read()
157
+
158
+ except Exception as e:
159
+ raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
src/infrastructure/tts/dia_provider.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dia TTS provider implementation."""
2
+
3
+ import logging
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import io
7
+ from typing import Iterator, TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
11
+
12
+ from ..base.tts_provider_base import TTSProviderBase
13
+ from ...domain.exceptions import SpeechSynthesisException
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Flag to track Dia availability
18
+ DIA_AVAILABLE = False
19
+ DEFAULT_SAMPLE_RATE = 24000
20
+
21
+ # Try to import Dia dependencies
22
+ try:
23
+ import torch
24
+ from dia.model import Dia
25
+ DIA_AVAILABLE = True
26
+ logger.info("Dia TTS engine is available")
27
+ except ImportError:
28
+ logger.warning("Dia TTS engine is not available")
29
+ except ModuleNotFoundError as e:
30
+ if "dac" in str(e):
31
+ logger.warning("Dia TTS engine is not available due to missing 'dac' module")
32
+ else:
33
+ logger.warning(f"Dia TTS engine is not available: {str(e)}")
34
+ DIA_AVAILABLE = False
35
+
36
+
37
+ class DiaTTSProvider(TTSProviderBase):
38
+ """Dia TTS provider implementation."""
39
+
40
+ def __init__(self, lang_code: str = 'z'):
41
+ """Initialize the Dia TTS provider."""
42
+ super().__init__(
43
+ provider_name="Dia",
44
+ supported_languages=['en', 'z'] # Dia supports English and multilingual
45
+ )
46
+ self.lang_code = lang_code
47
+ self.model = None
48
+
49
+ def _ensure_model(self):
50
+ """Ensure the model is loaded."""
51
+ if self.model is None and DIA_AVAILABLE:
52
+ try:
53
+ import torch
54
+ from dia.model import Dia
55
+ self.model = Dia.from_pretrained()
56
+ logger.info("Dia model successfully loaded")
57
+ except ImportError as e:
58
+ logger.error(f"Failed to import Dia dependencies: {str(e)}")
59
+ self.model = None
60
+ except FileNotFoundError as e:
61
+ logger.error(f"Failed to load Dia model files: {str(e)}")
62
+ self.model = None
63
+ except Exception as e:
64
+ logger.error(f"Failed to initialize Dia model: {str(e)}")
65
+ self.model = None
66
+ return self.model is not None
67
+
68
+ def is_available(self) -> bool:
69
+ """Check if Dia TTS is available."""
70
+ return DIA_AVAILABLE and self._ensure_model()
71
+
72
+ def get_available_voices(self) -> list[str]:
73
+ """Get available voices for Dia."""
74
+ # Dia typically uses a default voice
75
+ return ['default']
76
+
77
+ def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
78
+ """Generate audio using Dia TTS."""
79
+ if not self.is_available():
80
+ raise SpeechSynthesisException("Dia TTS engine is not available")
81
+
82
+ try:
83
+ import torch
84
+
85
+ # Extract parameters from request
86
+ text = request.text_content.text
87
+
88
+ # Generate audio using Dia
89
+ with torch.inference_mode():
90
+ output_audio_np = self.model.generate(
91
+ text,
92
+ max_tokens=None,
93
+ cfg_scale=3.0,
94
+ temperature=1.3,
95
+ top_p=0.95,
96
+ cfg_filter_top_k=35,
97
+ use_torch_compile=False,
98
+ verbose=False
99
+ )
100
+
101
+ if output_audio_np is None:
102
+ raise SpeechSynthesisException("Dia model returned None for audio output")
103
+
104
+ # Convert numpy array to bytes
105
+ audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
106
+ return audio_bytes, DEFAULT_SAMPLE_RATE
107
+
108
+ except ModuleNotFoundError as e:
109
+ if "dac" in str(e):
110
+ raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
111
+ else:
112
+ self._handle_provider_error(e, "audio generation")
113
+ except Exception as e:
114
+ self._handle_provider_error(e, "audio generation")
115
+
116
+ def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
117
+ """Generate audio stream using Dia TTS."""
118
+ if not self.is_available():
119
+ raise SpeechSynthesisException("Dia TTS engine is not available")
120
+
121
+ try:
122
+ import torch
123
+
124
+ # Extract parameters from request
125
+ text = request.text_content.text
126
+
127
+ # Generate audio using Dia
128
+ with torch.inference_mode():
129
+ output_audio_np = self.model.generate(
130
+ text,
131
+ max_tokens=None,
132
+ cfg_scale=3.0,
133
+ temperature=1.3,
134
+ top_p=0.95,
135
+ cfg_filter_top_k=35,
136
+ use_torch_compile=False,
137
+ verbose=False
138
+ )
139
+
140
+ if output_audio_np is None:
141
+ raise SpeechSynthesisException("Dia model returned None for audio output")
142
+
143
+ # Convert numpy array to bytes
144
+ audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
145
+ # Dia generates complete audio in one go
146
+ yield audio_bytes, DEFAULT_SAMPLE_RATE, True
147
+
148
+ except ModuleNotFoundError as e:
149
+ if "dac" in str(e):
150
+ raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
151
+ else:
152
+ self._handle_provider_error(e, "streaming audio generation")
153
+ except Exception as e:
154
+ self._handle_provider_error(e, "streaming audio generation")
155
+
156
+ def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
157
+ """Convert numpy audio array to bytes."""
158
+ try:
159
+ # Create an in-memory buffer
160
+ buffer = io.BytesIO()
161
+
162
+ # Write audio data to buffer as WAV
163
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
164
+
165
+ # Get bytes from buffer
166
+ buffer.seek(0)
167
+ return buffer.read()
168
+
169
+ except Exception as e:
170
+ raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
src/infrastructure/tts/dummy_provider.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dummy TTS provider implementation for testing and fallback."""
2
+
3
+ import logging
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import io
7
+ from typing import Iterator, TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
11
+
12
+ from ..base.tts_provider_base import TTSProviderBase
13
+ from ...domain.exceptions import SpeechSynthesisException
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class DummyTTSProvider(TTSProviderBase):
19
+ """Dummy TTS provider that generates sine wave audio for testing."""
20
+
21
+ def __init__(self):
22
+ """Initialize the Dummy TTS provider."""
23
+ super().__init__(
24
+ provider_name="Dummy",
25
+ supported_languages=['en', 'es', 'fr', 'de', 'it', 'pt', 'ru', 'ja', 'ko', 'zh']
26
+ )
27
+
28
+ def is_available(self) -> bool:
29
+ """Dummy TTS is always available."""
30
+ return True
31
+
32
+ def get_available_voices(self) -> list[str]:
33
+ """Get available voices for Dummy TTS."""
34
+ return ['default', 'male', 'female', 'robot']
35
+
36
+ def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
37
+ """Generate dummy sine wave audio."""
38
+ try:
39
+ # Extract parameters from request
40
+ text = request.text_content.text
41
+ speed = request.voice_settings.speed
42
+
43
+ # Generate a simple sine wave based on text length and speed
44
+ sample_rate = 24000
45
+ # Rough approximation of speech duration adjusted by speed
46
+ duration = min(len(text) / (20 * speed), 10)
47
+
48
+ # Create time array
49
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
50
+
51
+ # Generate sine wave (440 Hz base frequency)
52
+ frequency = 440
53
+ audio = 0.5 * np.sin(2 * np.pi * frequency * t)
54
+
55
+ # Add some variation based on voice setting
56
+ voice = request.voice_settings.voice_id
57
+ if voice == 'male':
58
+ # Lower frequency for male voice
59
+ audio = 0.5 * np.sin(2 * np.pi * 220 * t)
60
+ elif voice == 'female':
61
+ # Higher frequency for female voice
62
+ audio = 0.5 * np.sin(2 * np.pi * 660 * t)
63
+ elif voice == 'robot':
64
+ # Square wave for robot voice
65
+ audio = 0.5 * np.sign(np.sin(2 * np.pi * 440 * t))
66
+
67
+ # Convert to bytes
68
+ audio_bytes = self._numpy_to_bytes(audio, sample_rate)
69
+
70
+ logger.info(f"Generated dummy audio: duration={duration:.2f}s, voice={voice}")
71
+ return audio_bytes, sample_rate
72
+
73
+ except Exception as e:
74
+ self._handle_provider_error(e, "dummy audio generation")
75
+
76
+ def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
77
+ """Generate dummy sine wave audio stream."""
78
+ try:
79
+ # Extract parameters from request
80
+ text = request.text_content.text
81
+ speed = request.voice_settings.speed
82
+
83
+ # Generate audio in chunks
84
+ sample_rate = 24000
85
+ chunk_duration = 1.0 # 1 second chunks
86
+ total_duration = min(len(text) / (20 * speed), 10)
87
+
88
+ chunks_count = int(np.ceil(total_duration / chunk_duration))
89
+
90
+ for chunk_idx in range(chunks_count):
91
+ start_time = chunk_idx * chunk_duration
92
+ end_time = min((chunk_idx + 1) * chunk_duration, total_duration)
93
+ actual_duration = end_time - start_time
94
+
95
+ if actual_duration <= 0:
96
+ break
97
+
98
+ # Create time array for this chunk
99
+ t = np.linspace(0, actual_duration, int(sample_rate * actual_duration), endpoint=False)
100
+
101
+ # Generate sine wave
102
+ frequency = 440
103
+ audio = 0.5 * np.sin(2 * np.pi * frequency * t)
104
+
105
+ # Apply voice variations
106
+ voice = request.voice_settings.voice_id
107
+ if voice == 'male':
108
+ audio = 0.5 * np.sin(2 * np.pi * 220 * t)
109
+ elif voice == 'female':
110
+ audio = 0.5 * np.sin(2 * np.pi * 660 * t)
111
+ elif voice == 'robot':
112
+ audio = 0.5 * np.sign(np.sin(2 * np.pi * 440 * t))
113
+
114
+ # Convert to bytes
115
+ audio_bytes = self._numpy_to_bytes(audio, sample_rate)
116
+
117
+ # Check if this is the final chunk
118
+ is_final = (chunk_idx == chunks_count - 1)
119
+
120
+ yield audio_bytes, sample_rate, is_final
121
+
122
+ except Exception as e:
123
+ self._handle_provider_error(e, "dummy streaming audio generation")
124
+
125
+ def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
126
+ """Convert numpy audio array to bytes."""
127
+ try:
128
+ # Create an in-memory buffer
129
+ buffer = io.BytesIO()
130
+
131
+ # Write audio data to buffer as WAV
132
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
133
+
134
+ # Get bytes from buffer
135
+ buffer.seek(0)
136
+ return buffer.read()
137
+
138
+ except Exception as e:
139
+ raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
src/infrastructure/tts/kokoro_provider.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kokoro TTS provider implementation."""
2
+
3
+ import logging
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import io
7
+ from typing import Iterator, TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
11
+
12
+ from ..base.tts_provider_base import TTSProviderBase
13
+ from ...domain.exceptions import SpeechSynthesisException
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Flag to track Kokoro availability
18
+ KOKORO_AVAILABLE = False
19
+
20
+ # Try to import Kokoro
21
+ try:
22
+ from kokoro import KPipeline
23
+ KOKORO_AVAILABLE = True
24
+ logger.info("Kokoro TTS engine is available")
25
+ except ImportError:
26
+ logger.warning("Kokoro TTS engine is not available")
27
+ except Exception as e:
28
+ logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
29
+ KOKORO_AVAILABLE = False
30
+
31
+
32
+ class KokoroTTSProvider(TTSProviderBase):
33
+ """Kokoro TTS provider implementation."""
34
+
35
+ def __init__(self, lang_code: str = 'z'):
36
+ """Initialize the Kokoro TTS provider."""
37
+ super().__init__(
38
+ provider_name="Kokoro",
39
+ supported_languages=['en', 'z'] # Kokoro supports English and multilingual
40
+ )
41
+ self.lang_code = lang_code
42
+ self.pipeline = None
43
+
44
+ def _ensure_pipeline(self):
45
+ """Ensure the pipeline is loaded."""
46
+ if self.pipeline is None and KOKORO_AVAILABLE:
47
+ try:
48
+ self.pipeline = KPipeline(lang_code=self.lang_code)
49
+ logger.info("Kokoro pipeline successfully loaded")
50
+ except Exception as e:
51
+ logger.error(f"Failed to initialize Kokoro pipeline: {str(e)}")
52
+ self.pipeline = None
53
+ return self.pipeline is not None
54
+
55
+ def is_available(self) -> bool:
56
+ """Check if Kokoro TTS is available."""
57
+ return KOKORO_AVAILABLE and self._ensure_pipeline()
58
+
59
+ def get_available_voices(self) -> list[str]:
60
+ """Get available voices for Kokoro."""
61
+ # Common Kokoro voices based on the original implementation
62
+ return [
63
+ 'af_heart', 'af_bella', 'af_sarah', 'af_nicole',
64
+ 'am_adam', 'am_michael', 'bf_emma', 'bf_isabella'
65
+ ]
66
+
67
+ def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
68
+ """Generate audio using Kokoro TTS."""
69
+ if not self.is_available():
70
+ raise SpeechSynthesisException("Kokoro TTS engine is not available")
71
+
72
+ try:
73
+ # Extract parameters from request
74
+ text = request.text_content.text
75
+ voice = request.voice_settings.voice_id
76
+ speed = request.voice_settings.speed
77
+
78
+ # Generate speech using Kokoro
79
+ generator = self.pipeline(text, voice=voice, speed=speed)
80
+
81
+ for _, _, audio in generator:
82
+ # Convert numpy array to bytes
83
+ audio_bytes = self._numpy_to_bytes(audio, sample_rate=24000)
84
+ return audio_bytes, 24000
85
+
86
+ raise SpeechSynthesisException("Kokoro failed to generate audio")
87
+
88
+ except Exception as e:
89
+ self._handle_provider_error(e, "audio generation")
90
+
91
+ def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
92
+ """Generate audio stream using Kokoro TTS."""
93
+ if not self.is_available():
94
+ raise SpeechSynthesisException("Kokoro TTS engine is not available")
95
+
96
+ try:
97
+ # Extract parameters from request
98
+ text = request.text_content.text
99
+ voice = request.voice_settings.voice_id
100
+ speed = request.voice_settings.speed
101
+
102
+ # Generate speech stream using Kokoro
103
+ generator = self.pipeline(text, voice=voice, speed=speed)
104
+
105
+ chunk_count = 0
106
+ for _, _, audio in generator:
107
+ chunk_count += 1
108
+ # Convert numpy array to bytes
109
+ audio_bytes = self._numpy_to_bytes(audio, sample_rate=24000)
110
+ # Assume this is the final chunk for now (Kokoro typically generates one chunk)
111
+ is_final = True
112
+ yield audio_bytes, 24000, is_final
113
+
114
+ except Exception as e:
115
+ self._handle_provider_error(e, "streaming audio generation")
116
+
117
+ def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
118
+ """Convert numpy audio array to bytes."""
119
+ try:
120
+ # Create an in-memory buffer
121
+ buffer = io.BytesIO()
122
+
123
+ # Write audio data to buffer as WAV
124
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
125
+
126
+ # Get bytes from buffer
127
+ buffer.seek(0)
128
+ return buffer.read()
129
+
130
+ except Exception as e:
131
+ raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
src/infrastructure/tts/provider_factory.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TTS provider factory for creating and managing TTS providers."""
2
+
3
+ import logging
4
+ from typing import Dict, List, Optional, Type
5
+ from ..base.tts_provider_base import TTSProviderBase
6
+ from ...domain.exceptions import SpeechSynthesisException
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class TTSProviderFactory:
12
+ """Factory for creating and managing TTS providers."""
13
+
14
+ def __init__(self):
15
+ """Initialize the TTS provider factory."""
16
+ self._providers: Dict[str, Type[TTSProviderBase]] = {}
17
+ self._provider_instances: Dict[str, TTSProviderBase] = {}
18
+ self._register_default_providers()
19
+
20
+ def _register_default_providers(self):
21
+ """Register all available TTS providers."""
22
+ # Import providers dynamically to avoid import errors if dependencies are missing
23
+
24
+ # Always register dummy provider as fallback
25
+ from .dummy_provider import DummyTTSProvider
26
+ self._providers['dummy'] = DummyTTSProvider
27
+
28
+ # Try to register Kokoro provider
29
+ try:
30
+ from .kokoro_provider import KokoroTTSProvider
31
+ self._providers['kokoro'] = KokoroTTSProvider
32
+ logger.info("Registered Kokoro TTS provider")
33
+ except ImportError as e:
34
+ logger.debug(f"Kokoro TTS provider not available: {e}")
35
+
36
+ # Try to register Dia provider
37
+ try:
38
+ from .dia_provider import DiaTTSProvider
39
+ self._providers['dia'] = DiaTTSProvider
40
+ logger.info("Registered Dia TTS provider")
41
+ except ImportError as e:
42
+ logger.debug(f"Dia TTS provider not available: {e}")
43
+
44
+ # Try to register CosyVoice2 provider
45
+ try:
46
+ from .cosyvoice2_provider import CosyVoice2TTSProvider
47
+ self._providers['cosyvoice2'] = CosyVoice2TTSProvider
48
+ logger.info("Registered CosyVoice2 TTS provider")
49
+ except ImportError as e:
50
+ logger.debug(f"CosyVoice2 TTS provider not available: {e}")
51
+
52
+ def get_available_providers(self) -> List[str]:
53
+ """Get list of available TTS providers."""
54
+ available = []
55
+ for name, provider_class in self._providers.items():
56
+ try:
57
+ # Create instance if not cached
58
+ if name not in self._provider_instances:
59
+ if name == 'kokoro':
60
+ self._provider_instances[name] = provider_class()
61
+ elif name == 'dia':
62
+ self._provider_instances[name] = provider_class()
63
+ elif name == 'cosyvoice2':
64
+ self._provider_instances[name] = provider_class()
65
+ else:
66
+ self._provider_instances[name] = provider_class()
67
+
68
+ # Check if provider is available
69
+ if self._provider_instances[name].is_available():
70
+ available.append(name)
71
+
72
+ except Exception as e:
73
+ logger.warning(f"Failed to check availability of {name} provider: {e}")
74
+
75
+ return available
76
+
77
+ def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase:
78
+ """
79
+ Create a TTS provider instance.
80
+
81
+ Args:
82
+ provider_name: Name of the provider to create
83
+ **kwargs: Additional arguments for provider initialization
84
+
85
+ Returns:
86
+ TTSProviderBase: The created provider instance
87
+
88
+ Raises:
89
+ SpeechSynthesisException: If provider is not available or creation fails
90
+ """
91
+ if provider_name not in self._providers:
92
+ available = list(self._providers.keys())
93
+ raise SpeechSynthesisException(
94
+ f"Unknown TTS provider: {provider_name}. Available providers: {available}"
95
+ )
96
+
97
+ try:
98
+ provider_class = self._providers[provider_name]
99
+
100
+ # Create instance with appropriate parameters
101
+ if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
102
+ lang_code = kwargs.get('lang_code', 'z')
103
+ provider = provider_class(lang_code=lang_code)
104
+ else:
105
+ provider = provider_class(**kwargs)
106
+
107
+ # Verify the provider is available
108
+ if not provider.is_available():
109
+ raise SpeechSynthesisException(f"TTS provider {provider_name} is not available")
110
+
111
+ logger.info(f"Created TTS provider: {provider_name}")
112
+ return provider
113
+
114
+ except Exception as e:
115
+ logger.error(f"Failed to create TTS provider {provider_name}: {e}")
116
+ raise SpeechSynthesisException(f"Failed to create TTS provider {provider_name}: {e}") from e
117
+
118
+ def get_provider_with_fallback(self, preferred_providers: List[str] = None, **kwargs) -> TTSProviderBase:
119
+ """
120
+ Get a TTS provider with fallback logic.
121
+
122
+ Args:
123
+ preferred_providers: List of preferred providers in order of preference
124
+ **kwargs: Additional arguments for provider initialization
125
+
126
+ Returns:
127
+ TTSProviderBase: The first available provider
128
+
129
+ Raises:
130
+ SpeechSynthesisException: If no providers are available
131
+ """
132
+ if preferred_providers is None:
133
+ preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
134
+
135
+ available_providers = self.get_available_providers()
136
+
137
+ # Try preferred providers in order
138
+ for provider_name in preferred_providers:
139
+ if provider_name in available_providers:
140
+ try:
141
+ return self.create_provider(provider_name, **kwargs)
142
+ except Exception as e:
143
+ logger.warning(f"Failed to create preferred provider {provider_name}: {e}")
144
+ continue
145
+
146
+ # If no preferred providers work, try any available provider
147
+ for provider_name in available_providers:
148
+ if provider_name not in preferred_providers:
149
+ try:
150
+ return self.create_provider(provider_name, **kwargs)
151
+ except Exception as e:
152
+ logger.warning(f"Failed to create fallback provider {provider_name}: {e}")
153
+ continue
154
+
155
+ raise SpeechSynthesisException("No TTS providers are available")
156
+
157
+ def get_provider_info(self, provider_name: str) -> Dict:
158
+ """
159
+ Get information about a specific provider.
160
+
161
+ Args:
162
+ provider_name: Name of the provider
163
+
164
+ Returns:
165
+ Dict: Provider information including availability and supported features
166
+ """
167
+ if provider_name not in self._providers:
168
+ return {"available": False, "error": "Provider not registered"}
169
+
170
+ try:
171
+ # Create instance if not cached
172
+ if provider_name not in self._provider_instances:
173
+ provider_class = self._providers[provider_name]
174
+ if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
175
+ self._provider_instances[provider_name] = provider_class()
176
+ else:
177
+ self._provider_instances[provider_name] = provider_class()
178
+
179
+ provider = self._provider_instances[provider_name]
180
+
181
+ return {
182
+ "available": provider.is_available(),
183
+ "name": provider.provider_name,
184
+ "supported_languages": provider.supported_languages,
185
+ "available_voices": provider.get_available_voices() if provider.is_available() else []
186
+ }
187
+
188
+ except Exception as e:
189
+ return {
190
+ "available": False,
191
+ "error": str(e)
192
+ }
193
+
194
+ def cleanup_providers(self):
195
+ """Clean up provider instances and resources."""
196
+ for provider in self._provider_instances.values():
197
+ try:
198
+ if hasattr(provider, '_cleanup_temp_files'):
199
+ provider._cleanup_temp_files()
200
+ except Exception as e:
201
+ logger.warning(f"Failed to cleanup provider {provider.provider_name}: {e}")
202
+
203
+ self._provider_instances.clear()
204
+ logger.info("Cleaned up TTS provider instances")
test_tts_migration.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test script to verify TTS provider migration."""
3
+
4
+ import sys
5
+ import os
6
+
7
+ # Add src to path
8
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
9
+
10
+ def test_provider_imports():
11
+ """Test that all providers can be imported."""
12
+ print("Testing provider imports...")
13
+
14
+ try:
15
+ from src.infrastructure.tts import TTSProviderFactory, DummyTTSProvider
16
+ print("βœ“ Core TTS components imported successfully")
17
+ except Exception as e:
18
+ print(f"βœ— Failed to import core TTS components: {e}")
19
+ return False
20
+
21
+ try:
22
+ from src.domain.models.text_content import TextContent
23
+ from src.domain.models.voice_settings import VoiceSettings
24
+ from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest
25
+ print("βœ“ Domain models imported successfully")
26
+ except Exception as e:
27
+ print(f"βœ— Failed to import domain models: {e}")
28
+ return False
29
+
30
+ return True
31
+
32
+ def test_dummy_provider():
33
+ """Test the dummy provider functionality."""
34
+ print("\nTesting dummy provider...")
35
+
36
+ try:
37
+ from src.infrastructure.tts import DummyTTSProvider
38
+ from src.domain.models.text_content import TextContent
39
+ from src.domain.models.voice_settings import VoiceSettings
40
+ from src.domain.models.speech_synthesis_request import SpeechSynthesisRequest
41
+
42
+ # Create provider
43
+ provider = DummyTTSProvider()
44
+ print(f"βœ“ Created dummy provider: {provider.provider_name}")
45
+
46
+ # Check availability
47
+ if provider.is_available():
48
+ print("βœ“ Dummy provider is available")
49
+ else:
50
+ print("βœ— Dummy provider is not available")
51
+ return False
52
+
53
+ # Check voices
54
+ voices = provider.get_available_voices()
55
+ print(f"βœ“ Available voices: {voices}")
56
+
57
+ # Create a synthesis request
58
+ text_content = TextContent(text="Hello, world!", language="en")
59
+ voice_settings = VoiceSettings(voice_id="default", speed=1.0, language="en")
60
+ request = SpeechSynthesisRequest(
61
+ text_content=text_content,
62
+ voice_settings=voice_settings
63
+ )
64
+ print("βœ“ Created synthesis request")
65
+
66
+ # Test synthesis
67
+ audio_content = provider.synthesize(request)
68
+ print(f"βœ“ Generated audio: {len(audio_content.data)} bytes, {audio_content.duration:.2f}s")
69
+
70
+ return True
71
+
72
+ except Exception as e:
73
+ print(f"βœ— Dummy provider test failed: {e}")
74
+ import traceback
75
+ traceback.print_exc()
76
+ return False
77
+
78
+ def test_provider_factory():
79
+ """Test the provider factory."""
80
+ print("\nTesting provider factory...")
81
+
82
+ try:
83
+ from src.infrastructure.tts import TTSProviderFactory
84
+
85
+ factory = TTSProviderFactory()
86
+ print("βœ“ Created provider factory")
87
+
88
+ available = factory.get_available_providers()
89
+ print(f"βœ“ Available providers: {available}")
90
+
91
+ if 'dummy' not in available:
92
+ print("βœ— Dummy provider should always be available")
93
+ return False
94
+
95
+ # Test creating dummy provider
96
+ provider = factory.create_provider('dummy')
97
+ print(f"βœ“ Created provider via factory: {provider.provider_name}")
98
+
99
+ # Test fallback logic
100
+ provider = factory.get_provider_with_fallback()
101
+ print(f"βœ“ Got provider with fallback: {provider.provider_name}")
102
+
103
+ return True
104
+
105
+ except Exception as e:
106
+ print(f"βœ— Provider factory test failed: {e}")
107
+ import traceback
108
+ traceback.print_exc()
109
+ return False
110
+
111
+ def main():
112
+ """Run all tests."""
113
+ print("=== TTS Provider Migration Test ===\n")
114
+
115
+ tests = [
116
+ test_provider_imports,
117
+ test_dummy_provider,
118
+ test_provider_factory
119
+ ]
120
+
121
+ passed = 0
122
+ for test in tests:
123
+ if test():
124
+ passed += 1
125
+ print()
126
+
127
+ print(f"=== Results: {passed}/{len(tests)} tests passed ===")
128
+
129
+ if passed == len(tests):
130
+ print("πŸŽ‰ All tests passed! TTS provider migration successful.")
131
+ return 0
132
+ else:
133
+ print("❌ Some tests failed. Check the output above.")
134
+ return 1
135
+
136
+ if __name__ == "__main__":
137
+ sys.exit(main())