teachingAssistant / src /infrastructure /tts /chatterbox_provider.py
Michael Hu
add chatterbox
0f99c8d
"""Chatterbox TTS provider implementation."""
import logging
import numpy as np
import soundfile as sf
import io
from typing import Iterator, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException
logger = logging.getLogger(__name__)
# Flag to track Chatterbox availability
CHATTERBOX_AVAILABLE = False
# Try to import Chatterbox
try:
import torch
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
CHATTERBOX_AVAILABLE = True
logger.info("Chatterbox TTS engine is available")
except ImportError as e:
logger.warning(f"Chatterbox TTS engine is not available: {e}")
except Exception as e:
logger.error(f"Chatterbox import failed with unexpected error: {str(e)}")
CHATTERBOX_AVAILABLE = False
class ChatterboxTTSProvider(TTSProviderBase):
"""Chatterbox TTS provider implementation."""
def __init__(self, lang_code: str = 'en'):
"""Initialize the Chatterbox TTS provider."""
super().__init__(
provider_name="Chatterbox",
supported_languages=['en'] # Chatterbox primarily supports English
)
self.lang_code = lang_code
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def _ensure_model(self):
"""Ensure the model is loaded."""
if self.model is None and CHATTERBOX_AVAILABLE:
try:
logger.info(f"Loading Chatterbox model on device: {self.device}")
self.model = ChatterboxTTS.from_pretrained(device=self.device)
logger.info("Chatterbox model successfully loaded")
except Exception as e:
logger.error(f"Failed to initialize Chatterbox model: {str(e)}")
self.model = None
return self.model is not None
def is_available(self) -> bool:
"""Check if Chatterbox TTS is available."""
return CHATTERBOX_AVAILABLE and self._ensure_model()
def get_available_voices(self) -> list[str]:
"""Get available voices for Chatterbox."""
# Chatterbox supports voice cloning with audio prompts
# Default voice is the base model voice
return ['default', 'custom']
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
"""Generate audio using Chatterbox TTS."""
if not self.is_available():
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
try:
# Extract parameters from request
text = request.text_content.text
voice = request.voice_settings.voice_id
# Generate speech using Chatterbox
if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
# Use custom voice with audio prompt
audio_prompt_path = request.voice_settings.audio_prompt_path
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
else:
# Use default voice
wav = self.model.generate(text)
# Convert tensor to numpy array if needed
if hasattr(wav, 'cpu'):
wav = wav.cpu().numpy()
elif hasattr(wav, 'detach'):
wav = wav.detach().numpy()
# Get sample rate from model
sample_rate = self.model.sr
# Convert numpy array to bytes
audio_bytes = self._numpy_to_bytes(wav, sample_rate)
return audio_bytes, sample_rate
except Exception as e:
self._handle_provider_error(e, "audio generation")
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
"""Generate audio stream using Chatterbox TTS."""
if not self.is_available():
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
try:
# Chatterbox doesn't natively support streaming, so we'll generate the full audio
# and split it into chunks for streaming
text = request.text_content.text
voice = request.voice_settings.voice_id
# Generate full audio
if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
audio_prompt_path = request.voice_settings.audio_prompt_path
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
else:
wav = self.model.generate(text)
# Convert tensor to numpy array if needed
if hasattr(wav, 'cpu'):
wav = wav.cpu().numpy()
elif hasattr(wav, 'detach'):
wav = wav.detach().numpy()
sample_rate = self.model.sr
# Split audio into chunks for streaming
chunk_size = int(sample_rate * 1.0) # 1 second chunks
total_samples = len(wav)
for start_idx in range(0, total_samples, chunk_size):
end_idx = min(start_idx + chunk_size, total_samples)
chunk = wav[start_idx:end_idx]
# Convert chunk to bytes
audio_bytes = self._numpy_to_bytes(chunk, sample_rate)
# Check if this is the final chunk
is_final = (end_idx >= total_samples)
yield audio_bytes, sample_rate, is_final
except Exception as e:
self._handle_provider_error(e, "streaming audio generation")
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
"""Convert numpy audio array to bytes."""
try:
# Ensure audio is in the right format
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32)
# Normalize if needed
if np.max(np.abs(audio_array)) > 1.0:
audio_array = audio_array / np.max(np.abs(audio_array))
# Create an in-memory buffer
buffer = io.BytesIO()
# Write audio data to buffer as WAV
sf.write(buffer, audio_array, sample_rate, format='WAV')
# Get bytes from buffer
buffer.seek(0)
return buffer.read()
except Exception as e:
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
def generate_with_voice_prompt(self, text: str, audio_prompt_path: str) -> tuple[bytes, int]:
"""
Generate audio with a custom voice prompt.
Args:
text: Text to synthesize
audio_prompt_path: Path to audio file for voice cloning
Returns:
tuple: (audio_bytes, sample_rate)
"""
if not self.is_available():
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
try:
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
# Convert tensor to numpy array if needed
if hasattr(wav, 'cpu'):
wav = wav.cpu().numpy()
elif hasattr(wav, 'detach'):
wav = wav.detach().numpy()
sample_rate = self.model.sr
audio_bytes = self._numpy_to_bytes(wav, sample_rate)
return audio_bytes, sample_rate
except Exception as e:
self._handle_provider_error(e, "voice prompt audio generation")