teachingAssistant / utils /tts_kokoro.py
Michael Hu
refactor tts
60bd17d
import os
import time
import logging
import numpy as np
import soundfile as sf
from typing import Optional, Tuple, Generator
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DEFAULT_SAMPLE_RATE = 24000
# Global model instance (lazy loaded)
_pipeline = None
def _get_pipeline(lang_code: str = 'z'):
"""Lazy-load the Kokoro pipeline to avoid loading it until needed"""
global _pipeline
if _pipeline is None:
logger.info("Loading Kokoro pipeline...")
try:
# Import Kokoro
from kokoro import KPipeline
# Initialize the pipeline
logger.info(f"Initializing Kokoro pipeline with language code: {lang_code}")
_pipeline = KPipeline(lang_code=lang_code)
# Log pipeline details
logger.info(f"Kokoro pipeline loaded successfully")
logger.info(f"Pipeline type: {type(_pipeline).__name__}")
except ImportError as import_err:
logger.error(f"Import error loading Kokoro pipeline: {import_err}")
logger.error(f"This may indicate missing dependencies")
raise
except Exception as e:
logger.error(f"Error loading Kokoro pipeline: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
raise
return _pipeline
def generate_speech(text: str, language: str = "z", voice: str = "af_heart", speed: float = 1.0) -> str:
"""Public interface for TTS generation using Kokoro model
This is a legacy function maintained for backward compatibility.
New code should use the factory pattern implementation directly.
Args:
text (str): Input text to synthesize
language (str): Language code ('a' for US English, 'b' for British English,
'j' for Japanese, 'z' for Mandarin Chinese)
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
speed (float): Speech speed multiplier (0.5 to 2.0)
Returns:
str: Path to the generated audio file
"""
logger.info(f"Legacy Kokoro generate_speech called with text length: {len(text)}")
# Use the new implementation via factory pattern
from utils.tts_engines import KokoroTTSEngine
try:
# Create a Kokoro engine and generate speech
kokoro_engine = KokoroTTSEngine(language)
return kokoro_engine.generate_speech(text, voice, speed)
except Exception as e:
logger.error(f"Error in legacy Kokoro generate_speech: {str(e)}", exc_info=True)
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine()
return dummy_engine.generate_speech(text)
def generate_speech_stream(text: str, language: str = "z", voice: str = "af_heart", speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
"""Generate speech stream using Kokoro TTS engine
Args:
text (str): Input text to synthesize
language (str): Language code
voice (str): Voice ID to use
speed (float): Speech speed multiplier
Yields:
tuple: (sample_rate, audio_data) pairs for each segment
"""
logger.info(f"Generating speech stream with Kokoro for text length: {len(text)}")
try:
# Get the Kokoro pipeline
pipeline = _get_pipeline(language)
# Generate speech stream
generator = pipeline(text, voice=voice, speed=speed)
for _, _, audio in generator:
yield DEFAULT_SAMPLE_RATE, audio
except Exception as e:
logger.error(f"Error in Kokoro generate_speech_stream: {str(e)}", exc_info=True)
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine()
yield from dummy_engine.generate_speech_stream(text)