Michael Hu
refator tts part
3ed3b5a
raw
history blame
5.06 kB
import logging
# Configure logging
logger = logging.getLogger(__name__)
# Import from the new factory pattern implementation
from utils.tts_factory import get_tts_engine, generate_speech, TTSFactory
from utils.tts_engines import get_available_engines
# For backward compatibility
from utils.tts_engines import KOKORO_AVAILABLE, KOKORO_SPACE_AVAILABLE, DIA_AVAILABLE
# Backward compatibility class
class TTSEngine:
"""Legacy TTSEngine class for backward compatibility
This class is maintained for backward compatibility with existing code.
New code should use the factory pattern implementation directly.
"""
def __init__(self, lang_code='z'):
"""Initialize TTS Engine using the factory pattern
Args:
lang_code (str): Language code ('a' for US English, 'b' for British English,
'j' for Japanese, 'z' for Mandarin Chinese)
"""
logger.info("Initializing legacy TTSEngine wrapper")
logger.info(f"Available engines - Kokoro: {KOKORO_AVAILABLE}, Dia: {DIA_AVAILABLE}")
# Create the appropriate engine using the factory
self._engine = TTSFactory.create_engine(lang_code=lang_code)
# Set engine_type for backward compatibility
engine_class = self._engine.__class__.__name__
if 'Kokoro' in engine_class and 'Space' in engine_class:
self.engine_type = "kokoro_space"
elif 'Kokoro' in engine_class:
self.engine_type = "kokoro"
elif 'Dia' in engine_class:
self.engine_type = "dia"
else:
self.engine_type = "dummy"
# Set pipeline and client attributes for backward compatibility
self.pipeline = getattr(self._engine, 'pipeline', None)
self.client = getattr(self._engine, 'client', None)
logger.info(f"Legacy TTSEngine wrapper initialized with engine type: {self.engine_type}")
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
"""Generate speech from text using available TTS engine
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
speed (float): Speech speed multiplier (0.5 to 2.0)
Returns:
str: Path to the generated audio file
"""
logger.info(f"Legacy TTSEngine wrapper calling generate_speech for text length: {len(text)}")
return self._engine.generate_speech(text, voice, speed)
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
"""Generate speech from text and yield each segment
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use
speed (float): Speech speed multiplier
Yields:
tuple: (sample_rate, audio_data) pairs for each segment
"""
logger.info(f"Legacy TTSEngine wrapper calling generate_speech_stream for text length: {len(text)}")
yield from self._engine.generate_speech_stream(text, voice, speed)
# For backward compatibility
def _generate_dummy_audio(self, output_path):
"""Generate a dummy audio file with a simple sine wave (backward compatibility)
Args:
output_path (str): Path to save the dummy audio file
Returns:
str: Path to the generated dummy audio file
"""
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine()
return dummy_engine.generate_speech("", "", 1.0)
# For backward compatibility
def _generate_dummy_audio_stream(self):
"""Generate dummy audio chunks (backward compatibility)
Yields:
tuple: (sample_rate, audio_data) pairs for each dummy segment
"""
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine()
yield from dummy_engine.generate_speech_stream("", "", 1.0)
# Import the new implementations from tts_base
# These functions are already defined in tts_base.py and imported at the top of this file
# They are kept here as comments for reference
# def get_tts_engine(lang_code='a'):
# """Get or create TTS engine instance
#
# Args:
# lang_code (str): Language code for the pipeline
#
# Returns:
# TTSEngineBase: Initialized TTS engine instance
# """
# # Implementation moved to tts_base.py
# pass
# def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
# """Public interface for TTS generation
#
# Args:
# text (str): Input text to synthesize
# voice (str): Voice ID to use
# speed (float): Speech speed multiplier
#
# Returns:
# str: Path to generated audio file
# "\"""
# # Implementation moved to tts_base.py
# pass