teachingAssistant / utils /tts_factory.py
Michael Hu
refator tts part
3ed3b5a
raw
history blame
4.73 kB
import logging
from typing import Optional, List
# Configure logging
logger = logging.getLogger(__name__)
# Import the base class
from utils.tts_base import TTSEngineBase, DummyTTSEngine
class TTSFactory:
"""Factory class for creating TTS engines
This class is responsible for creating the appropriate TTS engine based on
availability and configuration.
"""
@staticmethod
def create_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSEngineBase:
"""Create a TTS engine instance
Args:
engine_type (str, optional): Type of engine to create ('kokoro', 'kokoro_space', 'dia', 'dummy')
If None, the best available engine will be used
lang_code (str): Language code for the engine
Returns:
TTSEngineBase: An instance of a TTS engine
"""
from utils.tts_engines import get_available_engines, create_engine
# Get available engines
available_engines = get_available_engines()
logger.info(f"Available TTS engines: {available_engines}")
# If engine_type is specified, try to create that specific engine
if engine_type is not None:
if engine_type in available_engines:
logger.info(f"Creating requested engine: {engine_type}")
return create_engine(engine_type, lang_code)
else:
logger.warning(f"Requested engine '{engine_type}' is not available")
# Try to create the best available engine
# Priority: kokoro > kokoro_space > dia > dummy
for engine in ['kokoro', 'kokoro_space', 'dia']:
if engine in available_engines:
logger.info(f"Creating best available engine: {engine}")
return create_engine(engine, lang_code)
# Fall back to dummy engine
logger.warning("No TTS engines available, falling back to dummy engine")
return DummyTTSEngine(lang_code)
# Backward compatibility function
def get_tts_engine(lang_code: str = 'a') -> TTSEngineBase:
"""Get or create TTS engine instance (backward compatibility function)
Args:
lang_code (str): Language code for the pipeline
Returns:
TTSEngineBase: Initialized TTS engine instance
"""
logger.info(f"Requesting TTS engine with language code: {lang_code}")
try:
import streamlit as st
logger.info("Streamlit detected, using cached TTS engine")
@st.cache_resource
def _get_engine():
logger.info("Creating cached TTS engine instance")
engine = TTSFactory.create_engine(lang_code=lang_code)
logger.info(f"Cached TTS engine created with type: {engine.__class__.__name__}")
return engine
engine = _get_engine()
logger.info(f"Retrieved TTS engine from cache with type: {engine.__class__.__name__}")
return engine
except ImportError:
logger.info("Streamlit not available, creating direct TTS engine instance")
engine = TTSFactory.create_engine(lang_code=lang_code)
logger.info(f"Direct TTS engine created with type: {engine.__class__.__name__}")
return engine
# Backward compatibility function
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
"""Public interface for TTS generation (backward compatibility function)
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use
speed (float): Speech speed multiplier
Returns:
str: Path to generated audio file
"""
logger.info(f"Public generate_speech called with text length: {len(text)}, voice: {voice}, speed: {speed}")
try:
# Get the TTS engine
logger.info("Getting TTS engine instance")
engine = get_tts_engine()
logger.info(f"Using TTS engine type: {engine.__class__.__name__}")
# Generate speech
logger.info("Calling engine.generate_speech")
output_path = engine.generate_speech(text, voice, speed)
logger.info(f"Speech generation complete, output path: {output_path}")
return output_path
except Exception as e:
logger.error(f"Error in public generate_speech function: {str(e)}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
if hasattr(e, '__traceback__'):
tb = e.__traceback__
while tb.tb_next:
tb = tb.tb_next
logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
raise