teachingAssistant / src /infrastructure /tts /chatterbox_provider.py
Michael Hu
add chatterbox
0f99c8d
raw
history blame
7.6 kB
"""Chatterbox TTS provider implementation."""
import logging
import numpy as np
import soundfile as sf
import io
from typing import Iterator, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException
logger = logging.getLogger(__name__)
# Flag to track Chatterbox availability
CHATTERBOX_AVAILABLE = False
# Try to import Chatterbox
try:
import torch
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
CHATTERBOX_AVAILABLE = True
logger.info("Chatterbox TTS engine is available")
except ImportError as e:
logger.warning(f"Chatterbox TTS engine is not available: {e}")
except Exception as e:
logger.error(f"Chatterbox import failed with unexpected error: {str(e)}")
CHATTERBOX_AVAILABLE = False
class ChatterboxTTSProvider(TTSProviderBase):
"""Chatterbox TTS provider implementation."""
def __init__(self, lang_code: str = 'en'):
"""Initialize the Chatterbox TTS provider."""
super().__init__(
provider_name="Chatterbox",
supported_languages=['en'] # Chatterbox primarily supports English
)
self.lang_code = lang_code
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def _ensure_model(self):
"""Ensure the model is loaded."""
if self.model is None and CHATTERBOX_AVAILABLE:
try:
logger.info(f"Loading Chatterbox model on device: {self.device}")
self.model = ChatterboxTTS.from_pretrained(device=self.device)
logger.info("Chatterbox model successfully loaded")
except Exception as e:
logger.error(f"Failed to initialize Chatterbox model: {str(e)}")
self.model = None
return self.model is not None
def is_available(self) -> bool:
"""Check if Chatterbox TTS is available."""
return CHATTERBOX_AVAILABLE and self._ensure_model()
def get_available_voices(self) -> list[str]:
"""Get available voices for Chatterbox."""
# Chatterbox supports voice cloning with audio prompts
# Default voice is the base model voice
return ['default', 'custom']
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
"""Generate audio using Chatterbox TTS."""
if not self.is_available():
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
try:
# Extract parameters from request
text = request.text_content.text
voice = request.voice_settings.voice_id
# Generate speech using Chatterbox
if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
# Use custom voice with audio prompt
audio_prompt_path = request.voice_settings.audio_prompt_path
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
else:
# Use default voice
wav = self.model.generate(text)
# Convert tensor to numpy array if needed
if hasattr(wav, 'cpu'):
wav = wav.cpu().numpy()
elif hasattr(wav, 'detach'):
wav = wav.detach().numpy()
# Get sample rate from model
sample_rate = self.model.sr
# Convert numpy array to bytes
audio_bytes = self._numpy_to_bytes(wav, sample_rate)
return audio_bytes, sample_rate
except Exception as e:
self._handle_provider_error(e, "audio generation")
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
"""Generate audio stream using Chatterbox TTS."""
if not self.is_available():
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
try:
# Chatterbox doesn't natively support streaming, so we'll generate the full audio
# and split it into chunks for streaming
text = request.text_content.text
voice = request.voice_settings.voice_id
# Generate full audio
if voice == 'custom' and hasattr(request.voice_settings, 'audio_prompt_path'):
audio_prompt_path = request.voice_settings.audio_prompt_path
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
else:
wav = self.model.generate(text)
# Convert tensor to numpy array if needed
if hasattr(wav, 'cpu'):
wav = wav.cpu().numpy()
elif hasattr(wav, 'detach'):
wav = wav.detach().numpy()
sample_rate = self.model.sr
# Split audio into chunks for streaming
chunk_size = int(sample_rate * 1.0) # 1 second chunks
total_samples = len(wav)
for start_idx in range(0, total_samples, chunk_size):
end_idx = min(start_idx + chunk_size, total_samples)
chunk = wav[start_idx:end_idx]
# Convert chunk to bytes
audio_bytes = self._numpy_to_bytes(chunk, sample_rate)
# Check if this is the final chunk
is_final = (end_idx >= total_samples)
yield audio_bytes, sample_rate, is_final
except Exception as e:
self._handle_provider_error(e, "streaming audio generation")
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
"""Convert numpy audio array to bytes."""
try:
# Ensure audio is in the right format
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32)
# Normalize if needed
if np.max(np.abs(audio_array)) > 1.0:
audio_array = audio_array / np.max(np.abs(audio_array))
# Create an in-memory buffer
buffer = io.BytesIO()
# Write audio data to buffer as WAV
sf.write(buffer, audio_array, sample_rate, format='WAV')
# Get bytes from buffer
buffer.seek(0)
return buffer.read()
except Exception as e:
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
def generate_with_voice_prompt(self, text: str, audio_prompt_path: str) -> tuple[bytes, int]:
"""
Generate audio with a custom voice prompt.
Args:
text: Text to synthesize
audio_prompt_path: Path to audio file for voice cloning
Returns:
tuple: (audio_bytes, sample_rate)
"""
if not self.is_available():
raise SpeechSynthesisException("Chatterbox TTS engine is not available")
try:
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
# Convert tensor to numpy array if needed
if hasattr(wav, 'cpu'):
wav = wav.cpu().numpy()
elif hasattr(wav, 'detach'):
wav = wav.detach().numpy()
sample_rate = self.model.sr
audio_bytes = self._numpy_to_bytes(wav, sample_rate)
return audio_bytes, sample_rate
except Exception as e:
self._handle_provider_error(e, "voice prompt audio generation")