Spaces:
Build error
Build error
import logging | |
from typing import Optional, Generator, Tuple, List, Dict, Any | |
import numpy as np | |
# Import the base class and dummy implementation | |
from utils.tts_base import TTSBase | |
from utils.tts_dummy import DummyTTS | |
# Import the specific TTS implementations | |
from utils.tts_kokoro import KokoroTTS, KOKORO_AVAILABLE | |
from utils.tts_dia import DiaTTS, DIA_AVAILABLE | |
from utils.tts_cosyvoice2 import CosyVoice2TTS, COSYVOICE2_AVAILABLE | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
def get_available_engines() -> List[str]: | |
"""Get a list of available TTS engines | |
Returns: | |
List[str]: List of available engine names | |
""" | |
available = [] | |
if KOKORO_AVAILABLE: | |
available.append('kokoro') | |
if DIA_AVAILABLE: | |
available.append('dia') | |
if COSYVOICE2_AVAILABLE: | |
available.append('cosyvoice2') | |
# Dummy is always available | |
available.append('dummy') | |
return available | |
def get_tts_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSBase: | |
"""Get a TTS engine instance | |
Args: | |
engine_type (str, optional): Type of engine to create ('kokoro', 'dia', 'cosyvoice2', 'dummy') | |
If None, the best available engine will be used | |
lang_code (str): Language code for the engine | |
Returns: | |
TTSBase: An instance of a TTS engine | |
""" | |
# Get available engines | |
available_engines = get_available_engines() | |
logger.info(f"Available TTS engines: {available_engines}") | |
# If engine_type is specified, try to create that specific engine | |
if engine_type is not None: | |
if engine_type == 'kokoro' and KOKORO_AVAILABLE: | |
logger.info("Creating Kokoro TTS engine") | |
return KokoroTTS(lang_code) | |
elif engine_type == 'dia' and DIA_AVAILABLE: | |
logger.info("Creating Dia TTS engine") | |
return DiaTTS(lang_code) | |
elif engine_type == 'cosyvoice2' and COSYVOICE2_AVAILABLE: | |
logger.info("Creating CosyVoice2 TTS engine") | |
return CosyVoice2TTS(lang_code) | |
elif engine_type == 'dummy': | |
logger.info("Creating Dummy TTS engine") | |
return DummyTTS(lang_code) | |
else: | |
logger.warning(f"Requested engine '{engine_type}' is not available") | |
# If no specific engine is requested or the requested engine is not available, | |
# use the best available engine based on priority | |
priority_order = ['cosyvoice2', 'kokoro', 'dia', 'dummy'] | |
for engine in priority_order: | |
if engine in available_engines: | |
logger.info(f"Using best available engine: {engine}") | |
if engine == 'kokoro': | |
return KokoroTTS(lang_code) | |
elif engine == 'dia': | |
return DiaTTS(lang_code) | |
elif engine == 'cosyvoice2': | |
return CosyVoice2TTS(lang_code) | |
elif engine == 'dummy': | |
return DummyTTS(lang_code) | |
# Fallback to dummy engine if no engines are available | |
logger.warning("No TTS engines available, falling back to dummy engine") | |
return DummyTTS(lang_code) | |
def generate_speech(text: str, engine_type: Optional[str] = None, lang_code: str = 'z', | |
voice: str = 'default', speed: float = 1.0) -> Optional[str]: | |
"""Generate speech using the specified or best available TTS engine | |
Args: | |
text (str): Input text to synthesize | |
engine_type (str, optional): Type of engine to use | |
lang_code (str): Language code | |
voice (str): Voice ID to use | |
speed (float): Speech speed multiplier | |
Returns: | |
Optional[str]: Path to the generated audio file or None if generation fails | |
""" | |
engine = get_tts_engine(engine_type, lang_code) | |
return engine.generate_speech(text, voice, speed) | |
def generate_speech_stream(text: str, engine_type: Optional[str] = None, lang_code: str = 'z', | |
voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]: | |
"""Generate speech stream using the specified or best available TTS engine | |
Args: | |
text (str): Input text to synthesize | |
engine_type (str, optional): Type of engine to use | |
lang_code (str): Language code | |
voice (str): Voice ID to use | |
speed (float): Speech speed multiplier | |
Yields: | |
tuple: (sample_rate, audio_data) pairs for each segment | |
""" | |
engine = get_tts_engine(engine_type, lang_code) | |
yield from engine.generate_speech_stream(text, voice, speed) |