Spaces:
Build error
Build error
Michael Hu
commited on
Commit
Β·
0f99c8d
1
Parent(s):
c7f7521
add chatterbox
Browse files- pyproject.toml +1 -1
- requirements.txt +1 -1
- src/domain/models/text_content.py +17 -17
- src/domain/models/voice_settings.py +42 -21
- src/infrastructure/tts/__init__.py +8 -2
- src/infrastructure/tts/chatterbox_provider.py +201 -0
- src/infrastructure/tts/provider_factory.py +14 -4
pyproject.toml
CHANGED
@@ -25,7 +25,7 @@ dependencies = [
|
|
25 |
"phonemizer-fork>=3.3.2",
|
26 |
"nemo_toolkit[asr]",
|
27 |
"faster-whisper>=1.1.1",
|
28 |
-
"
|
29 |
]
|
30 |
|
31 |
[project.optional-dependencies]
|
|
|
25 |
"phonemizer-fork>=3.3.2",
|
26 |
"nemo_toolkit[asr]",
|
27 |
"faster-whisper>=1.1.1",
|
28 |
+
"chatterbox-tts"
|
29 |
]
|
30 |
|
31 |
[project.optional-dependencies]
|
requirements.txt
CHANGED
@@ -14,4 +14,4 @@ ordered-set>=4.1.0
|
|
14 |
phonemizer-fork>=3.3.2
|
15 |
nemo_toolkit[asr]
|
16 |
faster-whisper>=1.1.1
|
17 |
-
|
|
|
14 |
phonemizer-fork>=3.3.2
|
15 |
nemo_toolkit[asr]
|
16 |
faster-whisper>=1.1.1
|
17 |
+
chatterbox-tts
|
src/domain/models/text_content.py
CHANGED
@@ -8,71 +8,71 @@ import re
|
|
8 |
@dataclass(frozen=True)
|
9 |
class TextContent:
|
10 |
"""Value object representing text content with language and encoding information."""
|
11 |
-
|
12 |
text: str
|
13 |
language: str
|
14 |
encoding: str = 'utf-8'
|
15 |
-
|
16 |
def __post_init__(self):
|
17 |
"""Validate text content after initialization."""
|
18 |
self._validate()
|
19 |
-
|
20 |
def _validate(self):
|
21 |
"""Validate text content properties."""
|
22 |
if not isinstance(self.text, str):
|
23 |
raise TypeError("Text must be a string")
|
24 |
-
|
25 |
if not self.text.strip():
|
26 |
raise ValueError("Text content cannot be empty or whitespace only")
|
27 |
-
|
28 |
if len(self.text) > 50000: # Reasonable limit for TTS processing
|
29 |
raise ValueError("Text content too long (maximum 50,000 characters)")
|
30 |
-
|
31 |
if not isinstance(self.language, str):
|
32 |
raise TypeError("Language must be a string")
|
33 |
-
|
34 |
if not self.language.strip():
|
35 |
raise ValueError("Language cannot be empty")
|
36 |
-
|
37 |
# Validate language code format (ISO 639-1 or ISO 639-3)
|
38 |
if not re.match(r'^[a-z]{2,3}(-[A-Z]{2})?$', self.language):
|
39 |
raise ValueError(f"Invalid language code format: {self.language}. Expected format: 'en', 'en-US', etc.")
|
40 |
-
|
41 |
if not isinstance(self.encoding, str):
|
42 |
raise TypeError("Encoding must be a string")
|
43 |
-
|
44 |
if self.encoding not in ['utf-8', 'utf-16', 'ascii', 'latin-1']:
|
45 |
raise ValueError(f"Unsupported encoding: {self.encoding}. Supported: utf-8, utf-16, ascii, latin-1")
|
46 |
-
|
47 |
# Validate that text can be encoded with specified encoding
|
48 |
try:
|
49 |
self.text.encode(self.encoding)
|
50 |
except UnicodeEncodeError:
|
51 |
raise ValueError(f"Text cannot be encoded with {self.encoding} encoding")
|
52 |
-
|
53 |
@property
|
54 |
def word_count(self) -> int:
|
55 |
"""Get the approximate word count of the text."""
|
56 |
return len(self.text.split())
|
57 |
-
|
58 |
@property
|
59 |
def character_count(self) -> int:
|
60 |
"""Get the character count of the text."""
|
61 |
return len(self.text)
|
62 |
-
|
63 |
@property
|
64 |
def is_empty(self) -> bool:
|
65 |
"""Check if the text content is effectively empty."""
|
66 |
return not self.text.strip()
|
67 |
-
|
68 |
def truncate(self, max_length: int) -> 'TextContent':
|
69 |
"""Create a new TextContent with truncated text."""
|
70 |
if max_length <= 0:
|
71 |
raise ValueError("Max length must be positive")
|
72 |
-
|
73 |
if len(self.text) <= max_length:
|
74 |
return self
|
75 |
-
|
76 |
truncated_text = self.text[:max_length].rstrip()
|
77 |
return TextContent(
|
78 |
text=truncated_text,
|
|
|
8 |
@dataclass(frozen=True)
|
9 |
class TextContent:
|
10 |
"""Value object representing text content with language and encoding information."""
|
11 |
+
|
12 |
text: str
|
13 |
language: str
|
14 |
encoding: str = 'utf-8'
|
15 |
+
|
16 |
def __post_init__(self):
|
17 |
"""Validate text content after initialization."""
|
18 |
self._validate()
|
19 |
+
|
20 |
def _validate(self):
|
21 |
"""Validate text content properties."""
|
22 |
if not isinstance(self.text, str):
|
23 |
raise TypeError("Text must be a string")
|
24 |
+
|
25 |
if not self.text.strip():
|
26 |
raise ValueError("Text content cannot be empty or whitespace only")
|
27 |
+
|
28 |
if len(self.text) > 50000: # Reasonable limit for TTS processing
|
29 |
raise ValueError("Text content too long (maximum 50,000 characters)")
|
30 |
+
|
31 |
if not isinstance(self.language, str):
|
32 |
raise TypeError("Language must be a string")
|
33 |
+
|
34 |
if not self.language.strip():
|
35 |
raise ValueError("Language cannot be empty")
|
36 |
+
|
37 |
# Validate language code format (ISO 639-1 or ISO 639-3)
|
38 |
if not re.match(r'^[a-z]{2,3}(-[A-Z]{2})?$', self.language):
|
39 |
raise ValueError(f"Invalid language code format: {self.language}. Expected format: 'en', 'en-US', etc.")
|
40 |
+
|
41 |
if not isinstance(self.encoding, str):
|
42 |
raise TypeError("Encoding must be a string")
|
43 |
+
|
44 |
if self.encoding not in ['utf-8', 'utf-16', 'ascii', 'latin-1']:
|
45 |
raise ValueError(f"Unsupported encoding: {self.encoding}. Supported: utf-8, utf-16, ascii, latin-1")
|
46 |
+
|
47 |
# Validate that text can be encoded with specified encoding
|
48 |
try:
|
49 |
self.text.encode(self.encoding)
|
50 |
except UnicodeEncodeError:
|
51 |
raise ValueError(f"Text cannot be encoded with {self.encoding} encoding")
|
52 |
+
|
53 |
@property
|
54 |
def word_count(self) -> int:
|
55 |
"""Get the approximate word count of the text."""
|
56 |
return len(self.text.split())
|
57 |
+
|
58 |
@property
|
59 |
def character_count(self) -> int:
|
60 |
"""Get the character count of the text."""
|
61 |
return len(self.text)
|
62 |
+
|
63 |
@property
|
64 |
def is_empty(self) -> bool:
|
65 |
"""Check if the text content is effectively empty."""
|
66 |
return not self.text.strip()
|
67 |
+
|
68 |
def truncate(self, max_length: int) -> 'TextContent':
|
69 |
"""Create a new TextContent with truncated text."""
|
70 |
if max_length <= 0:
|
71 |
raise ValueError("Max length must be positive")
|
72 |
+
|
73 |
if len(self.text) <= max_length:
|
74 |
return self
|
75 |
+
|
76 |
truncated_text = self.text[:max_length].rstrip()
|
77 |
return TextContent(
|
78 |
text=truncated_text,
|
src/domain/models/voice_settings.py
CHANGED
@@ -8,74 +8,82 @@ import re
|
|
8 |
@dataclass(frozen=True)
|
9 |
class VoiceSettings:
|
10 |
"""Value object representing voice settings for text-to-speech synthesis."""
|
11 |
-
|
12 |
voice_id: str
|
13 |
speed: float
|
14 |
language: str
|
15 |
pitch: Optional[float] = None
|
16 |
volume: Optional[float] = None
|
17 |
-
|
|
|
18 |
def __post_init__(self):
|
19 |
"""Validate voice settings after initialization."""
|
20 |
self._validate()
|
21 |
-
|
22 |
def _validate(self):
|
23 |
"""Validate voice settings properties."""
|
24 |
if not isinstance(self.voice_id, str):
|
25 |
raise TypeError("Voice ID must be a string")
|
26 |
-
|
27 |
if not self.voice_id.strip():
|
28 |
raise ValueError("Voice ID cannot be empty")
|
29 |
-
|
30 |
# Voice ID should be alphanumeric with possible underscores/hyphens
|
31 |
if not re.match(r'^[a-zA-Z0-9_-]+$', self.voice_id):
|
32 |
raise ValueError(f"Invalid voice ID format: {self.voice_id}. Must contain only letters, numbers, underscores, and hyphens")
|
33 |
-
|
34 |
if not isinstance(self.speed, (int, float)):
|
35 |
raise TypeError("Speed must be a number")
|
36 |
-
|
37 |
if not 0.1 <= self.speed <= 3.0:
|
38 |
raise ValueError(f"Speed must be between 0.1 and 3.0, got {self.speed}")
|
39 |
-
|
40 |
if not isinstance(self.language, str):
|
41 |
raise TypeError("Language must be a string")
|
42 |
-
|
43 |
if not self.language.strip():
|
44 |
raise ValueError("Language cannot be empty")
|
45 |
-
|
46 |
# Validate language code format (ISO 639-1 or ISO 639-3)
|
47 |
if not re.match(r'^[a-z]{2,3}(-[A-Z]{2})?$', self.language):
|
48 |
raise ValueError(f"Invalid language code format: {self.language}. Expected format: 'en', 'en-US', etc.")
|
49 |
-
|
50 |
if self.pitch is not None:
|
51 |
if not isinstance(self.pitch, (int, float)):
|
52 |
raise TypeError("Pitch must be a number")
|
53 |
-
|
54 |
if not -2.0 <= self.pitch <= 2.0:
|
55 |
raise ValueError(f"Pitch must be between -2.0 and 2.0, got {self.pitch}")
|
56 |
-
|
57 |
if self.volume is not None:
|
58 |
if not isinstance(self.volume, (int, float)):
|
59 |
raise TypeError("Volume must be a number")
|
60 |
-
|
61 |
if not 0.0 <= self.volume <= 2.0:
|
62 |
raise ValueError(f"Volume must be between 0.0 and 2.0, got {self.volume}")
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
@property
|
65 |
def is_default_speed(self) -> bool:
|
66 |
"""Check if speed is at default value (1.0)."""
|
67 |
return abs(self.speed - 1.0) < 0.01
|
68 |
-
|
69 |
@property
|
70 |
def is_default_pitch(self) -> bool:
|
71 |
"""Check if pitch is at default value (0.0 or None)."""
|
72 |
return self.pitch is None or abs(self.pitch) < 0.01
|
73 |
-
|
74 |
@property
|
75 |
def is_default_volume(self) -> bool:
|
76 |
"""Check if volume is at default value (1.0 or None)."""
|
77 |
return self.volume is None or abs(self.volume - 1.0) < 0.01
|
78 |
-
|
79 |
def with_speed(self, speed: float) -> 'VoiceSettings':
|
80 |
"""Create a new VoiceSettings with different speed."""
|
81 |
return VoiceSettings(
|
@@ -83,9 +91,10 @@ class VoiceSettings:
|
|
83 |
speed=speed,
|
84 |
language=self.language,
|
85 |
pitch=self.pitch,
|
86 |
-
volume=self.volume
|
|
|
87 |
)
|
88 |
-
|
89 |
def with_pitch(self, pitch: Optional[float]) -> 'VoiceSettings':
|
90 |
"""Create a new VoiceSettings with different pitch."""
|
91 |
return VoiceSettings(
|
@@ -93,5 +102,17 @@ class VoiceSettings:
|
|
93 |
speed=self.speed,
|
94 |
language=self.language,
|
95 |
pitch=pitch,
|
96 |
-
volume=self.volume
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
)
|
|
|
8 |
@dataclass(frozen=True)
|
9 |
class VoiceSettings:
|
10 |
"""Value object representing voice settings for text-to-speech synthesis."""
|
11 |
+
|
12 |
voice_id: str
|
13 |
speed: float
|
14 |
language: str
|
15 |
pitch: Optional[float] = None
|
16 |
volume: Optional[float] = None
|
17 |
+
audio_prompt_path: Optional[str] = None # For voice cloning (e.g., Chatterbox)
|
18 |
+
|
19 |
def __post_init__(self):
|
20 |
"""Validate voice settings after initialization."""
|
21 |
self._validate()
|
22 |
+
|
23 |
def _validate(self):
|
24 |
"""Validate voice settings properties."""
|
25 |
if not isinstance(self.voice_id, str):
|
26 |
raise TypeError("Voice ID must be a string")
|
27 |
+
|
28 |
if not self.voice_id.strip():
|
29 |
raise ValueError("Voice ID cannot be empty")
|
30 |
+
|
31 |
# Voice ID should be alphanumeric with possible underscores/hyphens
|
32 |
if not re.match(r'^[a-zA-Z0-9_-]+$', self.voice_id):
|
33 |
raise ValueError(f"Invalid voice ID format: {self.voice_id}. Must contain only letters, numbers, underscores, and hyphens")
|
34 |
+
|
35 |
if not isinstance(self.speed, (int, float)):
|
36 |
raise TypeError("Speed must be a number")
|
37 |
+
|
38 |
if not 0.1 <= self.speed <= 3.0:
|
39 |
raise ValueError(f"Speed must be between 0.1 and 3.0, got {self.speed}")
|
40 |
+
|
41 |
if not isinstance(self.language, str):
|
42 |
raise TypeError("Language must be a string")
|
43 |
+
|
44 |
if not self.language.strip():
|
45 |
raise ValueError("Language cannot be empty")
|
46 |
+
|
47 |
# Validate language code format (ISO 639-1 or ISO 639-3)
|
48 |
if not re.match(r'^[a-z]{2,3}(-[A-Z]{2})?$', self.language):
|
49 |
raise ValueError(f"Invalid language code format: {self.language}. Expected format: 'en', 'en-US', etc.")
|
50 |
+
|
51 |
if self.pitch is not None:
|
52 |
if not isinstance(self.pitch, (int, float)):
|
53 |
raise TypeError("Pitch must be a number")
|
54 |
+
|
55 |
if not -2.0 <= self.pitch <= 2.0:
|
56 |
raise ValueError(f"Pitch must be between -2.0 and 2.0, got {self.pitch}")
|
57 |
+
|
58 |
if self.volume is not None:
|
59 |
if not isinstance(self.volume, (int, float)):
|
60 |
raise TypeError("Volume must be a number")
|
61 |
+
|
62 |
if not 0.0 <= self.volume <= 2.0:
|
63 |
raise ValueError(f"Volume must be between 0.0 and 2.0, got {self.volume}")
|
64 |
+
|
65 |
+
if self.audio_prompt_path is not None:
|
66 |
+
if not isinstance(self.audio_prompt_path, str):
|
67 |
+
raise TypeError("Audio prompt path must be a string")
|
68 |
+
|
69 |
+
if not self.audio_prompt_path.strip():
|
70 |
+
raise ValueError("Audio prompt path cannot be empty")
|
71 |
+
|
72 |
@property
|
73 |
def is_default_speed(self) -> bool:
|
74 |
"""Check if speed is at default value (1.0)."""
|
75 |
return abs(self.speed - 1.0) < 0.01
|
76 |
+
|
77 |
@property
|
78 |
def is_default_pitch(self) -> bool:
|
79 |
"""Check if pitch is at default value (0.0 or None)."""
|
80 |
return self.pitch is None or abs(self.pitch) < 0.01
|
81 |
+
|
82 |
@property
|
83 |
def is_default_volume(self) -> bool:
|
84 |
"""Check if volume is at default value (1.0 or None)."""
|
85 |
return self.volume is None or abs(self.volume - 1.0) < 0.01
|
86 |
+
|
87 |
def with_speed(self, speed: float) -> 'VoiceSettings':
|
88 |
"""Create a new VoiceSettings with different speed."""
|
89 |
return VoiceSettings(
|
|
|
91 |
speed=speed,
|
92 |
language=self.language,
|
93 |
pitch=self.pitch,
|
94 |
+
volume=self.volume,
|
95 |
+
audio_prompt_path=self.audio_prompt_path
|
96 |
)
|
97 |
+
|
98 |
def with_pitch(self, pitch: Optional[float]) -> 'VoiceSettings':
|
99 |
"""Create a new VoiceSettings with different pitch."""
|
100 |
return VoiceSettings(
|
|
|
102 |
speed=self.speed,
|
103 |
language=self.language,
|
104 |
pitch=pitch,
|
105 |
+
volume=self.volume,
|
106 |
+
audio_prompt_path=self.audio_prompt_path
|
107 |
+
)
|
108 |
+
|
109 |
+
def with_audio_prompt(self, audio_prompt_path: Optional[str]) -> 'VoiceSettings':
|
110 |
+
"""Create a new VoiceSettings with different audio prompt path."""
|
111 |
+
return VoiceSettings(
|
112 |
+
voice_id=self.voice_id,
|
113 |
+
speed=self.speed,
|
114 |
+
language=self.language,
|
115 |
+
pitch=self.pitch,
|
116 |
+
volume=self.volume,
|
117 |
+
audio_prompt_path=audio_prompt_path
|
118 |
)
|
src/infrastructure/tts/__init__.py
CHANGED
@@ -19,10 +19,16 @@ try:
|
|
19 |
except ImportError:
|
20 |
CosyVoice2TTSProvider = None
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
__all__ = [
|
23 |
'TTSProviderFactory',
|
24 |
'DummyTTSProvider',
|
25 |
'KokoroTTSProvider',
|
26 |
-
'DiaTTSProvider',
|
27 |
-
'CosyVoice2TTSProvider'
|
|
|
28 |
]
|
|
|
19 |
except ImportError:
|
20 |
CosyVoice2TTSProvider = None
|
21 |
|
22 |
+
try:
|
23 |
+
from .chatterbox_provider import ChatterboxTTSProvider
|
24 |
+
except ImportError:
|
25 |
+
ChatterboxTTSProvider = None
|
26 |
+
|
27 |
__all__ = [
|
28 |
'TTSProviderFactory',
|
29 |
'DummyTTSProvider',
|
30 |
'KokoroTTSProvider',
|
31 |
+
'DiaTTSProvider',
|
32 |
+
'CosyVoice2TTSProvider',
|
33 |
+
'ChatterboxTTSProvider'
|
34 |
]
|
src/infrastructure/tts/chatterbox_provider.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Chatterbox TTS provider implementation."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import soundfile as sf
|
6 |
+
import io
|
7 |
+
from typing import Iterator, Optional, TYPE_CHECKING
|
8 |
+
|
9 |
+
if TYPE_CHECKING:
|
10 |
+
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
|
11 |
+
|
12 |
+
from ..base.tts_provider_base import TTSProviderBase
|
13 |
+
from ...domain.exceptions import SpeechSynthesisException
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
# Flag to track Chatterbox availability
|
18 |
+
CHATTERBOX_AVAILABLE = False
|
19 |
+
|
20 |
+
# Try to import Chatterbox
|
21 |
+
try:
|
22 |
+
import torch
|
23 |
+
import torchaudio as ta
|
24 |
+
from chatterbox.tts import ChatterboxTTS
|
25 |
+
CHATTERBOX_AVAILABLE = True
|
26 |
+
logger.info("Chatterbox TTS engine is available")
|
27 |
+
except ImportError as e:
|
28 |
+
logger.warning(f"Chatterbox TTS engine is not available: {e}")
|
29 |
+
except Exception as e:
|
30 |
+
logger.error(f"Chatterbox import failed with unexpected error: {str(e)}")
|
31 |
+
CHATTERBOX_AVAILABLE = False
|
32 |
+
|
33 |
+
|
34 |
+
class ChatterboxTTSProvider(TTSProviderBase):
|
35 |
+
"""Chatterbox TTS provider implementation."""
|
36 |
+
|
37 |
+
def __init__(self, lang_code: str = 'en'):
|
38 |
+
"""Initialize the Chatterbox TTS provider."""
|
39 |
+
super().__init__(
|
40 |
+
provider_name="Chatterbox",
|
41 |
+
supported_languages=['en'] # Chatterbox primarily supports English
|
42 |
+
)
|
43 |
+
self.lang_code = lang_code
|
44 |
+
self.model = None
|
45 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
46 |
+
|
47 |
+
def _ensure_model(self):
|
48 |
+
"""Ensure the model is loaded."""
|
49 |
+
if self.model is None and CHATTERBOX_AVAILABLE:
|
50 |
+
try:
|
51 |
+
logger.info(f"Loading Chatterbox model on device: {self.device}")
|
52 |
+
self.model = ChatterboxTTS.from_pretrained(device=self.device)
|
53 |
+
logger.info("Chatterbox model successfully loaded")
|
54 |
+
except Exception as e:
|
55 |
+
logger.error(f"Failed to initialize Chatterbox model: {str(e)}")
|
56 |
+
self.model = None
|
57 |
+
return self.model is not None
|
58 |
+
|
59 |
+
def is_available(self) -> bool:
|
60 |
+
"""Check if Chatterbox TTS is available."""
|
61 |
+
return CHATTERBOX_AVAILABLE and self._ensure_model()
|
62 |
+
|
63 |
+
def get_available_voices(self) -> list[str]:
|
64 |
+
"""Get available voices for Chatterbox."""
|
65 |
+
# Chatterbox supports voice cloning with audio prompts
|
66 |
+
# Default voice is the base model voice
|
67 |
+
return ['default', 'custom']
|
68 |
+
|
69 |
+
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
|
70 |
+
"""Generate audio using Chatterbox TTS."""
|
71 |
+
if not self.is_available():
|
72 |
+
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
|
73 |
+
|
74 |
+
try:
|
75 |
+
# Extract parameters from request
|
76 |
+
text = request.text_content.text
|
77 |
+
voice = request.voice_settings.voice_id
|
78 |
+
|
79 |
+
# Generate speech using Chatterbox
|
80 |
+
if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
|
81 |
+
# Use custom voice with audio prompt
|
82 |
+
audio_prompt_path = request.voice_settings.audio_prompt_path
|
83 |
+
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
|
84 |
+
else:
|
85 |
+
# Use default voice
|
86 |
+
wav = self.model.generate(text)
|
87 |
+
|
88 |
+
# Convert tensor to numpy array if needed
|
89 |
+
if hasattr(wav, 'cpu'):
|
90 |
+
wav = wav.cpu().numpy()
|
91 |
+
elif hasattr(wav, 'detach'):
|
92 |
+
wav = wav.detach().numpy()
|
93 |
+
|
94 |
+
# Get sample rate from model
|
95 |
+
sample_rate = self.model.sr
|
96 |
+
|
97 |
+
# Convert numpy array to bytes
|
98 |
+
audio_bytes = self._numpy_to_bytes(wav, sample_rate)
|
99 |
+
return audio_bytes, sample_rate
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
self._handle_provider_error(e, "audio generation")
|
103 |
+
|
104 |
+
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
|
105 |
+
"""Generate audio stream using Chatterbox TTS."""
|
106 |
+
if not self.is_available():
|
107 |
+
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
|
108 |
+
|
109 |
+
try:
|
110 |
+
# Chatterbox doesn't natively support streaming, so we'll generate the full audio
|
111 |
+
# and split it into chunks for streaming
|
112 |
+
text = request.text_content.text
|
113 |
+
voice = request.voice_settings.voice_id
|
114 |
+
|
115 |
+
# Generate full audio
|
116 |
+
if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
|
117 |
+
audio_prompt_path = request.voice_settings.audio_prompt_path
|
118 |
+
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
|
119 |
+
else:
|
120 |
+
wav = self.model.generate(text)
|
121 |
+
|
122 |
+
# Convert tensor to numpy array if needed
|
123 |
+
if hasattr(wav, 'cpu'):
|
124 |
+
wav = wav.cpu().numpy()
|
125 |
+
elif hasattr(wav, 'detach'):
|
126 |
+
wav = wav.detach().numpy()
|
127 |
+
|
128 |
+
sample_rate = self.model.sr
|
129 |
+
|
130 |
+
# Split audio into chunks for streaming
|
131 |
+
chunk_size = int(sample_rate * 1.0) # 1 second chunks
|
132 |
+
total_samples = len(wav)
|
133 |
+
|
134 |
+
for start_idx in range(0, total_samples, chunk_size):
|
135 |
+
end_idx = min(start_idx + chunk_size, total_samples)
|
136 |
+
chunk = wav[start_idx:end_idx]
|
137 |
+
|
138 |
+
# Convert chunk to bytes
|
139 |
+
audio_bytes = self._numpy_to_bytes(chunk, sample_rate)
|
140 |
+
|
141 |
+
# Check if this is the final chunk
|
142 |
+
is_final = (end_idx >= total_samples)
|
143 |
+
|
144 |
+
yield audio_bytes, sample_rate, is_final
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
self._handle_provider_error(e, "streaming audio generation")
|
148 |
+
|
149 |
+
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
|
150 |
+
"""Convert numpy audio array to bytes."""
|
151 |
+
try:
|
152 |
+
# Ensure audio is in the right format
|
153 |
+
if audio_array.dtype != np.float32:
|
154 |
+
audio_array = audio_array.astype(np.float32)
|
155 |
+
|
156 |
+
# Normalize if needed
|
157 |
+
if np.max(np.abs(audio_array)) > 1.0:
|
158 |
+
audio_array = audio_array / np.max(np.abs(audio_array))
|
159 |
+
|
160 |
+
# Create an in-memory buffer
|
161 |
+
buffer = io.BytesIO()
|
162 |
+
|
163 |
+
# Write audio data to buffer as WAV
|
164 |
+
sf.write(buffer, audio_array, sample_rate, format='WAV')
|
165 |
+
|
166 |
+
# Get bytes from buffer
|
167 |
+
buffer.seek(0)
|
168 |
+
return buffer.read()
|
169 |
+
|
170 |
+
except Exception as e:
|
171 |
+
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
|
172 |
+
|
173 |
+
def generate_with_voice_prompt(self, text: str, audio_prompt_path: str) -> tuple[bytes, int]:
|
174 |
+
"""
|
175 |
+
Generate audio with a custom voice prompt.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
text: Text to synthesize
|
179 |
+
audio_prompt_path: Path to audio file for voice cloning
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
tuple: (audio_bytes, sample_rate)
|
183 |
+
"""
|
184 |
+
if not self.is_available():
|
185 |
+
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
|
186 |
+
|
187 |
+
try:
|
188 |
+
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
|
189 |
+
|
190 |
+
# Convert tensor to numpy array if needed
|
191 |
+
if hasattr(wav, 'cpu'):
|
192 |
+
wav = wav.cpu().numpy()
|
193 |
+
elif hasattr(wav, 'detach'):
|
194 |
+
wav = wav.detach().numpy()
|
195 |
+
|
196 |
+
sample_rate = self.model.sr
|
197 |
+
audio_bytes = self._numpy_to_bytes(wav, sample_rate)
|
198 |
+
return audio_bytes, sample_rate
|
199 |
+
|
200 |
+
except Exception as e:
|
201 |
+
self._handle_provider_error(e, "voice prompt audio generation")
|
src/infrastructure/tts/provider_factory.py
CHANGED
@@ -58,6 +58,14 @@ class TTSProviderFactory:
|
|
58 |
except ImportError as e:
|
59 |
logger.info(f"CosyVoice2 TTS provider not available: {e}")
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def get_available_providers(self) -> List[str]:
|
62 |
"""Get list of available TTS providers."""
|
63 |
logger.info("π Checking availability of TTS providers...")
|
@@ -76,6 +84,8 @@ class TTSProviderFactory:
|
|
76 |
self._provider_instances[name] = provider_class()
|
77 |
elif name == 'cosyvoice2':
|
78 |
self._provider_instances[name] = provider_class()
|
|
|
|
|
79 |
else:
|
80 |
self._provider_instances[name] = provider_class()
|
81 |
|
@@ -124,8 +134,8 @@ class TTSProviderFactory:
|
|
124 |
provider_class = self._providers[provider_name]
|
125 |
|
126 |
# Create instance with appropriate parameters
|
127 |
-
if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
|
128 |
-
lang_code = kwargs.get('lang_code', 'z')
|
129 |
provider = provider_class(lang_code=lang_code)
|
130 |
else:
|
131 |
provider = provider_class(**kwargs)
|
@@ -156,7 +166,7 @@ class TTSProviderFactory:
|
|
156 |
SpeechSynthesisException: If no providers are available
|
157 |
"""
|
158 |
if preferred_providers is None:
|
159 |
-
preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
|
160 |
|
161 |
logger.info(f"π Getting TTS provider with fallback, preferred order: {preferred_providers}")
|
162 |
available_providers = self.get_available_providers()
|
@@ -204,7 +214,7 @@ class TTSProviderFactory:
|
|
204 |
# Create instance if not cached
|
205 |
if provider_name not in self._provider_instances:
|
206 |
provider_class = self._providers[provider_name]
|
207 |
-
if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
|
208 |
self._provider_instances[provider_name] = provider_class()
|
209 |
else:
|
210 |
self._provider_instances[provider_name] = provider_class()
|
|
|
58 |
except ImportError as e:
|
59 |
logger.info(f"CosyVoice2 TTS provider not available: {e}")
|
60 |
|
61 |
+
# Try to register Chatterbox provider
|
62 |
+
try:
|
63 |
+
from .chatterbox_provider import ChatterboxTTSProvider
|
64 |
+
self._providers['chatterbox'] = ChatterboxTTSProvider
|
65 |
+
logger.info("Registered Chatterbox TTS provider")
|
66 |
+
except ImportError as e:
|
67 |
+
logger.info(f"Chatterbox TTS provider not available: {e}")
|
68 |
+
|
69 |
def get_available_providers(self) -> List[str]:
|
70 |
"""Get list of available TTS providers."""
|
71 |
logger.info("π Checking availability of TTS providers...")
|
|
|
84 |
self._provider_instances[name] = provider_class()
|
85 |
elif name == 'cosyvoice2':
|
86 |
self._provider_instances[name] = provider_class()
|
87 |
+
elif name == 'chatterbox':
|
88 |
+
self._provider_instances[name] = provider_class()
|
89 |
else:
|
90 |
self._provider_instances[name] = provider_class()
|
91 |
|
|
|
134 |
provider_class = self._providers[provider_name]
|
135 |
|
136 |
# Create instance with appropriate parameters
|
137 |
+
if provider_name in ['kokoro', 'dia', 'cosyvoice2', 'chatterbox']:
|
138 |
+
lang_code = kwargs.get('lang_code', 'en' if provider_name == 'chatterbox' else 'z')
|
139 |
provider = provider_class(lang_code=lang_code)
|
140 |
else:
|
141 |
provider = provider_class(**kwargs)
|
|
|
166 |
SpeechSynthesisException: If no providers are available
|
167 |
"""
|
168 |
if preferred_providers is None:
|
169 |
+
preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'chatterbox', 'dummy']
|
170 |
|
171 |
logger.info(f"π Getting TTS provider with fallback, preferred order: {preferred_providers}")
|
172 |
available_providers = self.get_available_providers()
|
|
|
214 |
# Create instance if not cached
|
215 |
if provider_name not in self._provider_instances:
|
216 |
provider_class = self._providers[provider_name]
|
217 |
+
if provider_name in ['kokoro', 'dia', 'cosyvoice2', 'chatterbox']:
|
218 |
self._provider_instances[provider_name] = provider_class()
|
219 |
else:
|
220 |
self._provider_instances[provider_name] = provider_class()
|