Spaces:
Build error
Build error
"""Kokoro TTS provider implementation.""" | |
import logging | |
import numpy as np | |
import soundfile as sf | |
import io | |
from typing import Iterator, TYPE_CHECKING | |
if TYPE_CHECKING: | |
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
from ..base.tts_provider_base import TTSProviderBase | |
from ...domain.exceptions import SpeechSynthesisException | |
logger = logging.getLogger(__name__) | |
# Flag to track Kokoro availability | |
KOKORO_AVAILABLE = False | |
# Try to import Kokoro | |
try: | |
from kokoro import KPipeline | |
KOKORO_AVAILABLE = True | |
logger.info("Kokoro TTS engine is available") | |
except ImportError: | |
logger.warning("Kokoro TTS engine is not available") | |
except Exception as e: | |
logger.error(f"Kokoro import failed with unexpected error: {str(e)}") | |
KOKORO_AVAILABLE = False | |
class KokoroTTSProvider(TTSProviderBase): | |
"""Kokoro TTS provider implementation.""" | |
def __init__(self, lang_code: str = 'z'): | |
"""Initialize the Kokoro TTS provider.""" | |
super().__init__( | |
provider_name="Kokoro", | |
supported_languages=['en', 'z'] # Kokoro supports English and multilingual | |
) | |
self.lang_code = lang_code | |
self.pipeline = None | |
def _ensure_pipeline(self): | |
"""Ensure the pipeline is loaded.""" | |
if self.pipeline is None and KOKORO_AVAILABLE: | |
try: | |
self.pipeline = KPipeline(lang_code=self.lang_code) | |
logger.info("Kokoro pipeline successfully loaded") | |
except Exception as e: | |
logger.error(f"Failed to initialize Kokoro pipeline: {str(e)}") | |
self.pipeline = None | |
return self.pipeline is not None | |
def is_available(self) -> bool: | |
"""Check if Kokoro TTS is available.""" | |
return KOKORO_AVAILABLE and self._ensure_pipeline() | |
def get_available_voices(self) -> list[str]: | |
"""Get available voices for Kokoro.""" | |
# Common Kokoro voices based on the original implementation | |
return [ | |
'af_heart', 'af_bella', 'af_sarah', 'af_nicole', | |
'am_adam', 'am_michael', 'bf_emma', 'bf_isabella' | |
] | |
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: | |
"""Generate audio using Kokoro TTS.""" | |
if not self.is_available(): | |
raise SpeechSynthesisException("Kokoro TTS engine is not available") | |
try: | |
# Extract parameters from request | |
text = request.text_content.text | |
voice = request.voice_settings.voice_id | |
speed = request.voice_settings.speed | |
# Generate speech using Kokoro | |
generator = self.pipeline(text, voice=voice, speed=speed) | |
for _, _, audio in generator: | |
# Convert numpy array to bytes | |
audio_bytes = self._numpy_to_bytes(audio, sample_rate=24000) | |
return audio_bytes, 24000 | |
raise SpeechSynthesisException("Kokoro failed to generate audio") | |
except Exception as e: | |
self._handle_provider_error(e, "audio generation") | |
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: | |
"""Generate audio stream using Kokoro TTS.""" | |
if not self.is_available(): | |
raise SpeechSynthesisException("Kokoro TTS engine is not available") | |
try: | |
# Extract parameters from request | |
text = request.text_content.text | |
voice = request.voice_settings.voice_id | |
speed = request.voice_settings.speed | |
# Generate speech stream using Kokoro | |
generator = self.pipeline(text, voice=voice, speed=speed) | |
chunk_count = 0 | |
for _, _, audio in generator: | |
chunk_count += 1 | |
# Convert numpy array to bytes | |
audio_bytes = self._numpy_to_bytes(audio, sample_rate=24000) | |
# Assume this is the final chunk for now (Kokoro typically generates one chunk) | |
is_final = True | |
yield audio_bytes, 24000, is_final | |
except Exception as e: | |
self._handle_provider_error(e, "streaming audio generation") | |
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes: | |
"""Convert numpy audio array to bytes.""" | |
try: | |
# Create an in-memory buffer | |
buffer = io.BytesIO() | |
# Write audio data to buffer as WAV | |
sf.write(buffer, audio_array, sample_rate, format='WAV') | |
# Get bytes from buffer | |
buffer.seek(0) | |
return buffer.read() | |
except Exception as e: | |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e |