Michael Hu commited on
Commit
0f99c8d
Β·
1 Parent(s): c7f7521

add chatterbox

Browse files
pyproject.toml CHANGED
@@ -25,7 +25,7 @@ dependencies = [
25
  "phonemizer-fork>=3.3.2",
26
  "nemo_toolkit[asr]",
27
  "faster-whisper>=1.1.1",
28
- "descript-audio-codec>=1.0.0"
29
  ]
30
 
31
  [project.optional-dependencies]
 
25
  "phonemizer-fork>=3.3.2",
26
  "nemo_toolkit[asr]",
27
  "faster-whisper>=1.1.1",
28
+ "chatterbox-tts"
29
  ]
30
 
31
  [project.optional-dependencies]
requirements.txt CHANGED
@@ -14,4 +14,4 @@ ordered-set>=4.1.0
14
  phonemizer-fork>=3.3.2
15
  nemo_toolkit[asr]
16
  faster-whisper>=1.1.1
17
- descript-audio-codec>=1.0.0
 
14
  phonemizer-fork>=3.3.2
15
  nemo_toolkit[asr]
16
  faster-whisper>=1.1.1
17
+ chatterbox-tts
src/domain/models/text_content.py CHANGED
@@ -8,71 +8,71 @@ import re
8
  @dataclass(frozen=True)
9
  class TextContent:
10
  """Value object representing text content with language and encoding information."""
11
-
12
  text: str
13
  language: str
14
  encoding: str = 'utf-8'
15
-
16
  def __post_init__(self):
17
  """Validate text content after initialization."""
18
  self._validate()
19
-
20
  def _validate(self):
21
  """Validate text content properties."""
22
  if not isinstance(self.text, str):
23
  raise TypeError("Text must be a string")
24
-
25
  if not self.text.strip():
26
  raise ValueError("Text content cannot be empty or whitespace only")
27
-
28
  if len(self.text) > 50000: # Reasonable limit for TTS processing
29
  raise ValueError("Text content too long (maximum 50,000 characters)")
30
-
31
  if not isinstance(self.language, str):
32
  raise TypeError("Language must be a string")
33
-
34
  if not self.language.strip():
35
  raise ValueError("Language cannot be empty")
36
-
37
  # Validate language code format (ISO 639-1 or ISO 639-3)
38
  if not re.match(r'^[a-z]{2,3}(-[A-Z]{2})?$', self.language):
39
  raise ValueError(f"Invalid language code format: {self.language}. Expected format: 'en', 'en-US', etc.")
40
-
41
  if not isinstance(self.encoding, str):
42
  raise TypeError("Encoding must be a string")
43
-
44
  if self.encoding not in ['utf-8', 'utf-16', 'ascii', 'latin-1']:
45
  raise ValueError(f"Unsupported encoding: {self.encoding}. Supported: utf-8, utf-16, ascii, latin-1")
46
-
47
  # Validate that text can be encoded with specified encoding
48
  try:
49
  self.text.encode(self.encoding)
50
  except UnicodeEncodeError:
51
  raise ValueError(f"Text cannot be encoded with {self.encoding} encoding")
52
-
53
  @property
54
  def word_count(self) -> int:
55
  """Get the approximate word count of the text."""
56
  return len(self.text.split())
57
-
58
  @property
59
  def character_count(self) -> int:
60
  """Get the character count of the text."""
61
  return len(self.text)
62
-
63
  @property
64
  def is_empty(self) -> bool:
65
  """Check if the text content is effectively empty."""
66
  return not self.text.strip()
67
-
68
  def truncate(self, max_length: int) -> 'TextContent':
69
  """Create a new TextContent with truncated text."""
70
  if max_length <= 0:
71
  raise ValueError("Max length must be positive")
72
-
73
  if len(self.text) <= max_length:
74
  return self
75
-
76
  truncated_text = self.text[:max_length].rstrip()
77
  return TextContent(
78
  text=truncated_text,
 
8
  @dataclass(frozen=True)
9
  class TextContent:
10
  """Value object representing text content with language and encoding information."""
11
+
12
  text: str
13
  language: str
14
  encoding: str = 'utf-8'
15
+
16
  def __post_init__(self):
17
  """Validate text content after initialization."""
18
  self._validate()
19
+
20
  def _validate(self):
21
  """Validate text content properties."""
22
  if not isinstance(self.text, str):
23
  raise TypeError("Text must be a string")
24
+
25
  if not self.text.strip():
26
  raise ValueError("Text content cannot be empty or whitespace only")
27
+
28
  if len(self.text) > 50000: # Reasonable limit for TTS processing
29
  raise ValueError("Text content too long (maximum 50,000 characters)")
30
+
31
  if not isinstance(self.language, str):
32
  raise TypeError("Language must be a string")
33
+
34
  if not self.language.strip():
35
  raise ValueError("Language cannot be empty")
36
+
37
  # Validate language code format (ISO 639-1 or ISO 639-3)
38
  if not re.match(r'^[a-z]{2,3}(-[A-Z]{2})?$', self.language):
39
  raise ValueError(f"Invalid language code format: {self.language}. Expected format: 'en', 'en-US', etc.")
40
+
41
  if not isinstance(self.encoding, str):
42
  raise TypeError("Encoding must be a string")
43
+
44
  if self.encoding not in ['utf-8', 'utf-16', 'ascii', 'latin-1']:
45
  raise ValueError(f"Unsupported encoding: {self.encoding}. Supported: utf-8, utf-16, ascii, latin-1")
46
+
47
  # Validate that text can be encoded with specified encoding
48
  try:
49
  self.text.encode(self.encoding)
50
  except UnicodeEncodeError:
51
  raise ValueError(f"Text cannot be encoded with {self.encoding} encoding")
52
+
53
  @property
54
  def word_count(self) -> int:
55
  """Get the approximate word count of the text."""
56
  return len(self.text.split())
57
+
58
  @property
59
  def character_count(self) -> int:
60
  """Get the character count of the text."""
61
  return len(self.text)
62
+
63
  @property
64
  def is_empty(self) -> bool:
65
  """Check if the text content is effectively empty."""
66
  return not self.text.strip()
67
+
68
  def truncate(self, max_length: int) -> 'TextContent':
69
  """Create a new TextContent with truncated text."""
70
  if max_length <= 0:
71
  raise ValueError("Max length must be positive")
72
+
73
  if len(self.text) <= max_length:
74
  return self
75
+
76
  truncated_text = self.text[:max_length].rstrip()
77
  return TextContent(
78
  text=truncated_text,
src/domain/models/voice_settings.py CHANGED
@@ -8,74 +8,82 @@ import re
8
  @dataclass(frozen=True)
9
  class VoiceSettings:
10
  """Value object representing voice settings for text-to-speech synthesis."""
11
-
12
  voice_id: str
13
  speed: float
14
  language: str
15
  pitch: Optional[float] = None
16
  volume: Optional[float] = None
17
-
 
18
  def __post_init__(self):
19
  """Validate voice settings after initialization."""
20
  self._validate()
21
-
22
  def _validate(self):
23
  """Validate voice settings properties."""
24
  if not isinstance(self.voice_id, str):
25
  raise TypeError("Voice ID must be a string")
26
-
27
  if not self.voice_id.strip():
28
  raise ValueError("Voice ID cannot be empty")
29
-
30
  # Voice ID should be alphanumeric with possible underscores/hyphens
31
  if not re.match(r'^[a-zA-Z0-9_-]+$', self.voice_id):
32
  raise ValueError(f"Invalid voice ID format: {self.voice_id}. Must contain only letters, numbers, underscores, and hyphens")
33
-
34
  if not isinstance(self.speed, (int, float)):
35
  raise TypeError("Speed must be a number")
36
-
37
  if not 0.1 <= self.speed <= 3.0:
38
  raise ValueError(f"Speed must be between 0.1 and 3.0, got {self.speed}")
39
-
40
  if not isinstance(self.language, str):
41
  raise TypeError("Language must be a string")
42
-
43
  if not self.language.strip():
44
  raise ValueError("Language cannot be empty")
45
-
46
  # Validate language code format (ISO 639-1 or ISO 639-3)
47
  if not re.match(r'^[a-z]{2,3}(-[A-Z]{2})?$', self.language):
48
  raise ValueError(f"Invalid language code format: {self.language}. Expected format: 'en', 'en-US', etc.")
49
-
50
  if self.pitch is not None:
51
  if not isinstance(self.pitch, (int, float)):
52
  raise TypeError("Pitch must be a number")
53
-
54
  if not -2.0 <= self.pitch <= 2.0:
55
  raise ValueError(f"Pitch must be between -2.0 and 2.0, got {self.pitch}")
56
-
57
  if self.volume is not None:
58
  if not isinstance(self.volume, (int, float)):
59
  raise TypeError("Volume must be a number")
60
-
61
  if not 0.0 <= self.volume <= 2.0:
62
  raise ValueError(f"Volume must be between 0.0 and 2.0, got {self.volume}")
63
-
 
 
 
 
 
 
 
64
  @property
65
  def is_default_speed(self) -> bool:
66
  """Check if speed is at default value (1.0)."""
67
  return abs(self.speed - 1.0) < 0.01
68
-
69
  @property
70
  def is_default_pitch(self) -> bool:
71
  """Check if pitch is at default value (0.0 or None)."""
72
  return self.pitch is None or abs(self.pitch) < 0.01
73
-
74
  @property
75
  def is_default_volume(self) -> bool:
76
  """Check if volume is at default value (1.0 or None)."""
77
  return self.volume is None or abs(self.volume - 1.0) < 0.01
78
-
79
  def with_speed(self, speed: float) -> 'VoiceSettings':
80
  """Create a new VoiceSettings with different speed."""
81
  return VoiceSettings(
@@ -83,9 +91,10 @@ class VoiceSettings:
83
  speed=speed,
84
  language=self.language,
85
  pitch=self.pitch,
86
- volume=self.volume
 
87
  )
88
-
89
  def with_pitch(self, pitch: Optional[float]) -> 'VoiceSettings':
90
  """Create a new VoiceSettings with different pitch."""
91
  return VoiceSettings(
@@ -93,5 +102,17 @@ class VoiceSettings:
93
  speed=self.speed,
94
  language=self.language,
95
  pitch=pitch,
96
- volume=self.volume
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
 
8
  @dataclass(frozen=True)
9
  class VoiceSettings:
10
  """Value object representing voice settings for text-to-speech synthesis."""
11
+
12
  voice_id: str
13
  speed: float
14
  language: str
15
  pitch: Optional[float] = None
16
  volume: Optional[float] = None
17
+ audio_prompt_path: Optional[str] = None # For voice cloning (e.g., Chatterbox)
18
+
19
  def __post_init__(self):
20
  """Validate voice settings after initialization."""
21
  self._validate()
22
+
23
  def _validate(self):
24
  """Validate voice settings properties."""
25
  if not isinstance(self.voice_id, str):
26
  raise TypeError("Voice ID must be a string")
27
+
28
  if not self.voice_id.strip():
29
  raise ValueError("Voice ID cannot be empty")
30
+
31
  # Voice ID should be alphanumeric with possible underscores/hyphens
32
  if not re.match(r'^[a-zA-Z0-9_-]+$', self.voice_id):
33
  raise ValueError(f"Invalid voice ID format: {self.voice_id}. Must contain only letters, numbers, underscores, and hyphens")
34
+
35
  if not isinstance(self.speed, (int, float)):
36
  raise TypeError("Speed must be a number")
37
+
38
  if not 0.1 <= self.speed <= 3.0:
39
  raise ValueError(f"Speed must be between 0.1 and 3.0, got {self.speed}")
40
+
41
  if not isinstance(self.language, str):
42
  raise TypeError("Language must be a string")
43
+
44
  if not self.language.strip():
45
  raise ValueError("Language cannot be empty")
46
+
47
  # Validate language code format (ISO 639-1 or ISO 639-3)
48
  if not re.match(r'^[a-z]{2,3}(-[A-Z]{2})?$', self.language):
49
  raise ValueError(f"Invalid language code format: {self.language}. Expected format: 'en', 'en-US', etc.")
50
+
51
  if self.pitch is not None:
52
  if not isinstance(self.pitch, (int, float)):
53
  raise TypeError("Pitch must be a number")
54
+
55
  if not -2.0 <= self.pitch <= 2.0:
56
  raise ValueError(f"Pitch must be between -2.0 and 2.0, got {self.pitch}")
57
+
58
  if self.volume is not None:
59
  if not isinstance(self.volume, (int, float)):
60
  raise TypeError("Volume must be a number")
61
+
62
  if not 0.0 <= self.volume <= 2.0:
63
  raise ValueError(f"Volume must be between 0.0 and 2.0, got {self.volume}")
64
+
65
+ if self.audio_prompt_path is not None:
66
+ if not isinstance(self.audio_prompt_path, str):
67
+ raise TypeError("Audio prompt path must be a string")
68
+
69
+ if not self.audio_prompt_path.strip():
70
+ raise ValueError("Audio prompt path cannot be empty")
71
+
72
  @property
73
  def is_default_speed(self) -> bool:
74
  """Check if speed is at default value (1.0)."""
75
  return abs(self.speed - 1.0) < 0.01
76
+
77
  @property
78
  def is_default_pitch(self) -> bool:
79
  """Check if pitch is at default value (0.0 or None)."""
80
  return self.pitch is None or abs(self.pitch) < 0.01
81
+
82
  @property
83
  def is_default_volume(self) -> bool:
84
  """Check if volume is at default value (1.0 or None)."""
85
  return self.volume is None or abs(self.volume - 1.0) < 0.01
86
+
87
  def with_speed(self, speed: float) -> 'VoiceSettings':
88
  """Create a new VoiceSettings with different speed."""
89
  return VoiceSettings(
 
91
  speed=speed,
92
  language=self.language,
93
  pitch=self.pitch,
94
+ volume=self.volume,
95
+ audio_prompt_path=self.audio_prompt_path
96
  )
97
+
98
  def with_pitch(self, pitch: Optional[float]) -> 'VoiceSettings':
99
  """Create a new VoiceSettings with different pitch."""
100
  return VoiceSettings(
 
102
  speed=self.speed,
103
  language=self.language,
104
  pitch=pitch,
105
+ volume=self.volume,
106
+ audio_prompt_path=self.audio_prompt_path
107
+ )
108
+
109
+ def with_audio_prompt(self, audio_prompt_path: Optional[str]) -> 'VoiceSettings':
110
+ """Create a new VoiceSettings with different audio prompt path."""
111
+ return VoiceSettings(
112
+ voice_id=self.voice_id,
113
+ speed=self.speed,
114
+ language=self.language,
115
+ pitch=self.pitch,
116
+ volume=self.volume,
117
+ audio_prompt_path=audio_prompt_path
118
  )
src/infrastructure/tts/__init__.py CHANGED
@@ -19,10 +19,16 @@ try:
19
  except ImportError:
20
  CosyVoice2TTSProvider = None
21
 
 
 
 
 
 
22
  __all__ = [
23
  'TTSProviderFactory',
24
  'DummyTTSProvider',
25
  'KokoroTTSProvider',
26
- 'DiaTTSProvider',
27
- 'CosyVoice2TTSProvider'
 
28
  ]
 
19
  except ImportError:
20
  CosyVoice2TTSProvider = None
21
 
22
+ try:
23
+ from .chatterbox_provider import ChatterboxTTSProvider
24
+ except ImportError:
25
+ ChatterboxTTSProvider = None
26
+
27
  __all__ = [
28
  'TTSProviderFactory',
29
  'DummyTTSProvider',
30
  'KokoroTTSProvider',
31
+ 'DiaTTSProvider',
32
+ 'CosyVoice2TTSProvider',
33
+ 'ChatterboxTTSProvider'
34
  ]
src/infrastructure/tts/chatterbox_provider.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chatterbox TTS provider implementation."""
2
+
3
+ import logging
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import io
7
+ from typing import Iterator, Optional, TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
11
+
12
+ from ..base.tts_provider_base import TTSProviderBase
13
+ from ...domain.exceptions import SpeechSynthesisException
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Flag to track Chatterbox availability
18
+ CHATTERBOX_AVAILABLE = False
19
+
20
+ # Try to import Chatterbox
21
+ try:
22
+ import torch
23
+ import torchaudio as ta
24
+ from chatterbox.tts import ChatterboxTTS
25
+ CHATTERBOX_AVAILABLE = True
26
+ logger.info("Chatterbox TTS engine is available")
27
+ except ImportError as e:
28
+ logger.warning(f"Chatterbox TTS engine is not available: {e}")
29
+ except Exception as e:
30
+ logger.error(f"Chatterbox import failed with unexpected error: {str(e)}")
31
+ CHATTERBOX_AVAILABLE = False
32
+
33
+
34
+ class ChatterboxTTSProvider(TTSProviderBase):
35
+ """Chatterbox TTS provider implementation."""
36
+
37
+ def __init__(self, lang_code: str = 'en'):
38
+ """Initialize the Chatterbox TTS provider."""
39
+ super().__init__(
40
+ provider_name="Chatterbox",
41
+ supported_languages=['en'] # Chatterbox primarily supports English
42
+ )
43
+ self.lang_code = lang_code
44
+ self.model = None
45
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ def _ensure_model(self):
48
+ """Ensure the model is loaded."""
49
+ if self.model is None and CHATTERBOX_AVAILABLE:
50
+ try:
51
+ logger.info(f"Loading Chatterbox model on device: {self.device}")
52
+ self.model = ChatterboxTTS.from_pretrained(device=self.device)
53
+ logger.info("Chatterbox model successfully loaded")
54
+ except Exception as e:
55
+ logger.error(f"Failed to initialize Chatterbox model: {str(e)}")
56
+ self.model = None
57
+ return self.model is not None
58
+
59
+ def is_available(self) -> bool:
60
+ """Check if Chatterbox TTS is available."""
61
+ return CHATTERBOX_AVAILABLE and self._ensure_model()
62
+
63
+ def get_available_voices(self) -> list[str]:
64
+ """Get available voices for Chatterbox."""
65
+ # Chatterbox supports voice cloning with audio prompts
66
+ # Default voice is the base model voice
67
+ return ['default', 'custom']
68
+
69
+ def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
70
+ """Generate audio using Chatterbox TTS."""
71
+ if not self.is_available():
72
+ raise SpeechSynthesisException("Chatterbox TTS engine is not available")
73
+
74
+ try:
75
+ # Extract parameters from request
76
+ text = request.text_content.text
77
+ voice = request.voice_settings.voice_id
78
+
79
+ # Generate speech using Chatterbox
80
+ if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
81
+ # Use custom voice with audio prompt
82
+ audio_prompt_path = request.voice_settings.audio_prompt_path
83
+ wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
84
+ else:
85
+ # Use default voice
86
+ wav = self.model.generate(text)
87
+
88
+ # Convert tensor to numpy array if needed
89
+ if hasattr(wav, 'cpu'):
90
+ wav = wav.cpu().numpy()
91
+ elif hasattr(wav, 'detach'):
92
+ wav = wav.detach().numpy()
93
+
94
+ # Get sample rate from model
95
+ sample_rate = self.model.sr
96
+
97
+ # Convert numpy array to bytes
98
+ audio_bytes = self._numpy_to_bytes(wav, sample_rate)
99
+ return audio_bytes, sample_rate
100
+
101
+ except Exception as e:
102
+ self._handle_provider_error(e, "audio generation")
103
+
104
+ def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
105
+ """Generate audio stream using Chatterbox TTS."""
106
+ if not self.is_available():
107
+ raise SpeechSynthesisException("Chatterbox TTS engine is not available")
108
+
109
+ try:
110
+ # Chatterbox doesn't natively support streaming, so we'll generate the full audio
111
+ # and split it into chunks for streaming
112
+ text = request.text_content.text
113
+ voice = request.voice_settings.voice_id
114
+
115
+ # Generate full audio
116
+ if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
117
+ audio_prompt_path = request.voice_settings.audio_prompt_path
118
+ wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
119
+ else:
120
+ wav = self.model.generate(text)
121
+
122
+ # Convert tensor to numpy array if needed
123
+ if hasattr(wav, 'cpu'):
124
+ wav = wav.cpu().numpy()
125
+ elif hasattr(wav, 'detach'):
126
+ wav = wav.detach().numpy()
127
+
128
+ sample_rate = self.model.sr
129
+
130
+ # Split audio into chunks for streaming
131
+ chunk_size = int(sample_rate * 1.0) # 1 second chunks
132
+ total_samples = len(wav)
133
+
134
+ for start_idx in range(0, total_samples, chunk_size):
135
+ end_idx = min(start_idx + chunk_size, total_samples)
136
+ chunk = wav[start_idx:end_idx]
137
+
138
+ # Convert chunk to bytes
139
+ audio_bytes = self._numpy_to_bytes(chunk, sample_rate)
140
+
141
+ # Check if this is the final chunk
142
+ is_final = (end_idx >= total_samples)
143
+
144
+ yield audio_bytes, sample_rate, is_final
145
+
146
+ except Exception as e:
147
+ self._handle_provider_error(e, "streaming audio generation")
148
+
149
+ def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
150
+ """Convert numpy audio array to bytes."""
151
+ try:
152
+ # Ensure audio is in the right format
153
+ if audio_array.dtype != np.float32:
154
+ audio_array = audio_array.astype(np.float32)
155
+
156
+ # Normalize if needed
157
+ if np.max(np.abs(audio_array)) > 1.0:
158
+ audio_array = audio_array / np.max(np.abs(audio_array))
159
+
160
+ # Create an in-memory buffer
161
+ buffer = io.BytesIO()
162
+
163
+ # Write audio data to buffer as WAV
164
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
165
+
166
+ # Get bytes from buffer
167
+ buffer.seek(0)
168
+ return buffer.read()
169
+
170
+ except Exception as e:
171
+ raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
172
+
173
+ def generate_with_voice_prompt(self, text: str, audio_prompt_path: str) -> tuple[bytes, int]:
174
+ """
175
+ Generate audio with a custom voice prompt.
176
+
177
+ Args:
178
+ text: Text to synthesize
179
+ audio_prompt_path: Path to audio file for voice cloning
180
+
181
+ Returns:
182
+ tuple: (audio_bytes, sample_rate)
183
+ """
184
+ if not self.is_available():
185
+ raise SpeechSynthesisException("Chatterbox TTS engine is not available")
186
+
187
+ try:
188
+ wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
189
+
190
+ # Convert tensor to numpy array if needed
191
+ if hasattr(wav, 'cpu'):
192
+ wav = wav.cpu().numpy()
193
+ elif hasattr(wav, 'detach'):
194
+ wav = wav.detach().numpy()
195
+
196
+ sample_rate = self.model.sr
197
+ audio_bytes = self._numpy_to_bytes(wav, sample_rate)
198
+ return audio_bytes, sample_rate
199
+
200
+ except Exception as e:
201
+ self._handle_provider_error(e, "voice prompt audio generation")
src/infrastructure/tts/provider_factory.py CHANGED
@@ -58,6 +58,14 @@ class TTSProviderFactory:
58
  except ImportError as e:
59
  logger.info(f"CosyVoice2 TTS provider not available: {e}")
60
 
 
 
 
 
 
 
 
 
61
  def get_available_providers(self) -> List[str]:
62
  """Get list of available TTS providers."""
63
  logger.info("πŸ” Checking availability of TTS providers...")
@@ -76,6 +84,8 @@ class TTSProviderFactory:
76
  self._provider_instances[name] = provider_class()
77
  elif name == 'cosyvoice2':
78
  self._provider_instances[name] = provider_class()
 
 
79
  else:
80
  self._provider_instances[name] = provider_class()
81
 
@@ -124,8 +134,8 @@ class TTSProviderFactory:
124
  provider_class = self._providers[provider_name]
125
 
126
  # Create instance with appropriate parameters
127
- if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
128
- lang_code = kwargs.get('lang_code', 'z')
129
  provider = provider_class(lang_code=lang_code)
130
  else:
131
  provider = provider_class(**kwargs)
@@ -156,7 +166,7 @@ class TTSProviderFactory:
156
  SpeechSynthesisException: If no providers are available
157
  """
158
  if preferred_providers is None:
159
- preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
160
 
161
  logger.info(f"πŸ”„ Getting TTS provider with fallback, preferred order: {preferred_providers}")
162
  available_providers = self.get_available_providers()
@@ -204,7 +214,7 @@ class TTSProviderFactory:
204
  # Create instance if not cached
205
  if provider_name not in self._provider_instances:
206
  provider_class = self._providers[provider_name]
207
- if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
208
  self._provider_instances[provider_name] = provider_class()
209
  else:
210
  self._provider_instances[provider_name] = provider_class()
 
58
  except ImportError as e:
59
  logger.info(f"CosyVoice2 TTS provider not available: {e}")
60
 
61
+ # Try to register Chatterbox provider
62
+ try:
63
+ from .chatterbox_provider import ChatterboxTTSProvider
64
+ self._providers['chatterbox'] = ChatterboxTTSProvider
65
+ logger.info("Registered Chatterbox TTS provider")
66
+ except ImportError as e:
67
+ logger.info(f"Chatterbox TTS provider not available: {e}")
68
+
69
  def get_available_providers(self) -> List[str]:
70
  """Get list of available TTS providers."""
71
  logger.info("πŸ” Checking availability of TTS providers...")
 
84
  self._provider_instances[name] = provider_class()
85
  elif name == 'cosyvoice2':
86
  self._provider_instances[name] = provider_class()
87
+ elif name == 'chatterbox':
88
+ self._provider_instances[name] = provider_class()
89
  else:
90
  self._provider_instances[name] = provider_class()
91
 
 
134
  provider_class = self._providers[provider_name]
135
 
136
  # Create instance with appropriate parameters
137
+ if provider_name in ['kokoro', 'dia', 'cosyvoice2', 'chatterbox']:
138
+ lang_code = kwargs.get('lang_code', 'en' if provider_name == 'chatterbox' else 'z')
139
  provider = provider_class(lang_code=lang_code)
140
  else:
141
  provider = provider_class(**kwargs)
 
166
  SpeechSynthesisException: If no providers are available
167
  """
168
  if preferred_providers is None:
169
+ preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'chatterbox', 'dummy']
170
 
171
  logger.info(f"πŸ”„ Getting TTS provider with fallback, preferred order: {preferred_providers}")
172
  available_providers = self.get_available_providers()
 
214
  # Create instance if not cached
215
  if provider_name not in self._provider_instances:
216
  provider_class = self._providers[provider_name]
217
+ if provider_name in ['kokoro', 'dia', 'cosyvoice2', 'chatterbox']:
218
  self._provider_instances[provider_name] = provider_class()
219
  else:
220
  self._provider_instances[provider_name] = provider_class()