File size: 7,596 Bytes
0f99c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""Chatterbox TTS provider implementation."""

import logging
import numpy as np
import soundfile as sf
import io
from typing import Iterator, Optional, TYPE_CHECKING

if TYPE_CHECKING:
    from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest

from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException

logger = logging.getLogger(__name__)

# Flag to track Chatterbox availability
CHATTERBOX_AVAILABLE = False

# Try to import Chatterbox
try:
    import torch
    import torchaudio as ta
    from chatterbox.tts import ChatterboxTTS
    CHATTERBOX_AVAILABLE = True
    logger.info("Chatterbox TTS engine is available")
except ImportError as e:
    logger.warning(f"Chatterbox TTS engine is not available: {e}")
except Exception as e:
    logger.error(f"Chatterbox import failed with unexpected error: {str(e)}")
    CHATTERBOX_AVAILABLE = False


class ChatterboxTTSProvider(TTSProviderBase):
    """Chatterbox TTS provider implementation."""

    def __init__(self, lang_code: str = 'en'):
        """Initialize the Chatterbox TTS provider."""
        super().__init__(
            provider_name="Chatterbox",
            supported_languages=['en']  # Chatterbox primarily supports English
        )
        self.lang_code = lang_code
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def _ensure_model(self):
        """Ensure the model is loaded."""
        if self.model is None and CHATTERBOX_AVAILABLE:
            try:
                logger.info(f"Loading Chatterbox model on device: {self.device}")
                self.model = ChatterboxTTS.from_pretrained(device=self.device)
                logger.info("Chatterbox model successfully loaded")
            except Exception as e:
                logger.error(f"Failed to initialize Chatterbox model: {str(e)}")
                self.model = None
        return self.model is not None

    def is_available(self) -> bool:
        """Check if Chatterbox TTS is available."""
        return CHATTERBOX_AVAILABLE and self._ensure_model()

    def get_available_voices(self) -> list[str]:
        """Get available voices for Chatterbox."""
        # Chatterbox supports voice cloning with audio prompts
        # Default voice is the base model voice
        return ['default', 'custom']

    def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
        """Generate audio using Chatterbox TTS."""
        if not self.is_available():
            raise SpeechSynthesisException("Chatterbox TTS engine is not available")

        try:
            # Extract parameters from request
            text = request.text_content.text
            voice = request.voice_settings.voice_id

            # Generate speech using Chatterbox
            if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
                # Use custom voice with audio prompt
                audio_prompt_path = request.voice_settings.audio_prompt_path
                wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
            else:
                # Use default voice
                wav = self.model.generate(text)

            # Convert tensor to numpy array if needed
            if hasattr(wav, 'cpu'):
                wav = wav.cpu().numpy()
            elif hasattr(wav, 'detach'):
                wav = wav.detach().numpy()

            # Get sample rate from model
            sample_rate = self.model.sr

            # Convert numpy array to bytes
            audio_bytes = self._numpy_to_bytes(wav, sample_rate)
            return audio_bytes, sample_rate

        except Exception as e:
            self._handle_provider_error(e, "audio generation")

    def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
        """Generate audio stream using Chatterbox TTS."""
        if not self.is_available():
            raise SpeechSynthesisException("Chatterbox TTS engine is not available")

        try:
            # Chatterbox doesn't natively support streaming, so we'll generate the full audio
            # and split it into chunks for streaming
            text = request.text_content.text
            voice = request.voice_settings.voice_id

            # Generate full audio
            if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
                audio_prompt_path = request.voice_settings.audio_prompt_path
                wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
            else:
                wav = self.model.generate(text)

            # Convert tensor to numpy array if needed
            if hasattr(wav, 'cpu'):
                wav = wav.cpu().numpy()
            elif hasattr(wav, 'detach'):
                wav = wav.detach().numpy()

            sample_rate = self.model.sr

            # Split audio into chunks for streaming
            chunk_size = int(sample_rate * 1.0)  # 1 second chunks
            total_samples = len(wav)

            for start_idx in range(0, total_samples, chunk_size):
                end_idx = min(start_idx + chunk_size, total_samples)
                chunk = wav[start_idx:end_idx]

                # Convert chunk to bytes
                audio_bytes = self._numpy_to_bytes(chunk, sample_rate)

                # Check if this is the final chunk
                is_final = (end_idx >= total_samples)

                yield audio_bytes, sample_rate, is_final

        except Exception as e:
            self._handle_provider_error(e, "streaming audio generation")

    def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
        """Convert numpy audio array to bytes."""
        try:
            # Ensure audio is in the right format
            if audio_array.dtype != np.float32:
                audio_array = audio_array.astype(np.float32)

            # Normalize if needed
            if np.max(np.abs(audio_array)) > 1.0:
                audio_array = audio_array / np.max(np.abs(audio_array))

            # Create an in-memory buffer
            buffer = io.BytesIO()

            # Write audio data to buffer as WAV
            sf.write(buffer, audio_array, sample_rate, format='WAV')

            # Get bytes from buffer
            buffer.seek(0)
            return buffer.read()

        except Exception as e:
            raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e

    def generate_with_voice_prompt(self, text: str, audio_prompt_path: str) -> tuple[bytes, int]:
        """
        Generate audio with a custom voice prompt.

        Args:
            text: Text to synthesize
            audio_prompt_path: Path to audio file for voice cloning

        Returns:
            tuple: (audio_bytes, sample_rate)
        """
        if not self.is_available():
            raise SpeechSynthesisException("Chatterbox TTS engine is not available")

        try:
            wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)

            # Convert tensor to numpy array if needed
            if hasattr(wav, 'cpu'):
                wav = wav.cpu().numpy()
            elif hasattr(wav, 'detach'):
                wav = wav.detach().numpy()

            sample_rate = self.model.sr
            audio_bytes = self._numpy_to_bytes(wav, sample_rate)
            return audio_bytes, sample_rate

        except Exception as e:
            self._handle_provider_error(e, "voice prompt audio generation")