Spaces:
Sleeping
Sleeping
| import logging | |
| import time | |
| import os | |
| import numpy as np | |
| import soundfile as sf | |
| from typing import Dict, List, Optional, Tuple, Generator, Any, Union | |
| from utils.tts_base import TTSEngineBase, DummyTTSEngine | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| # Flag to track TTS engine availability | |
| KOKORO_AVAILABLE = False | |
| KOKORO_SPACE_AVAILABLE = True | |
| DIA_AVAILABLE = False | |
| DIA_SPACE_AVAILABLE = True | |
| # Try to import Kokoro | |
| try: | |
| from kokoro import KPipeline | |
| KOKORO_AVAILABLE = True | |
| logger.info("Kokoro TTS engine is available") | |
| except AttributeError as e: | |
| # Specifically catch the EspeakWrapper.set_data_path error | |
| if "EspeakWrapper" in str(e) and "set_data_path" in str(e): | |
| logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue, falling back to Kokoro FastAPI server") | |
| else: | |
| # Re-raise if it's a different error | |
| logger.error(f"Kokoro import failed with unexpected error: {str(e)}") | |
| raise | |
| except ImportError: | |
| logger.warning("Kokoro TTS engine is not available") | |
| # Try to import Dia dependencies to check availability | |
| try: | |
| import torch | |
| from dia.model import Dia | |
| DIA_AVAILABLE = True | |
| logger.info("Dia TTS engine is available") | |
| except ImportError: | |
| logger.warning("Dia TTS engine is not available") | |
| except ModuleNotFoundError as e: | |
| if "dac" in str(e): | |
| logger.warning("Dia TTS engine is not available due to missing 'dac' module") | |
| else: | |
| logger.warning(f"Dia TTS engine is not available: {str(e)}") | |
| DIA_AVAILABLE = False | |
| class KokoroTTSEngine(TTSEngineBase): | |
| """Kokoro TTS engine implementation | |
| This engine uses the Kokoro library for TTS generation. | |
| """ | |
| def __init__(self, lang_code: str = 'z'): | |
| super().__init__(lang_code) | |
| try: | |
| self.pipeline = KPipeline(lang_code=lang_code) | |
| logger.info("Kokoro TTS engine successfully initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Kokoro pipeline: {str(e)}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| raise | |
| def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]: | |
| """Generate speech using Kokoro TTS engine | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.) | |
| speed (float): Speech speed multiplier (0.5 to 2.0) | |
| Returns: | |
| Optional[str]: Path to the generated audio file or None if generation fails | |
| """ | |
| logger.info(f"Generating speech with Kokoro for text length: {len(text)}") | |
| # Generate unique output path | |
| output_path = self._generate_output_path() | |
| # Generate speech | |
| generator = self.pipeline(text, voice=voice, speed=speed) | |
| for _, _, audio in generator: | |
| logger.info(f"Saving Kokoro audio to {output_path}") | |
| sf.write(output_path, audio, 24000) | |
| break | |
| logger.info(f"Kokoro audio generation complete: {output_path}") | |
| return output_path | |
| def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]: | |
| """Generate speech stream using Kokoro TTS engine | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice ID to use | |
| speed (float): Speech speed multiplier | |
| Yields: | |
| tuple: (sample_rate, audio_data) pairs for each segment | |
| """ | |
| logger.info(f"Generating speech stream with Kokoro for text length: {len(text)}") | |
| # Generate speech stream | |
| generator = self.pipeline(text, voice=voice, speed=speed) | |
| for _, _, audio in generator: | |
| yield 24000, audio | |
| class KokoroSpaceTTSEngine(TTSEngineBase): | |
| """Kokoro Space TTS engine implementation | |
| This engine uses the Kokoro FastAPI server for TTS generation. | |
| """ | |
| def __init__(self, lang_code: str = 'z'): | |
| super().__init__(lang_code) | |
| try: | |
| from gradio_client import Client | |
| self.client = Client("Remsky/Kokoro-TTS-Zero") | |
| logger.info("Kokoro Space TTS engine successfully initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Kokoro Space client: {str(e)}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| raise | |
| def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]: | |
| """Generate speech using Kokoro Space TTS engine | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.) | |
| speed (float): Speech speed multiplier (0.5 to 2.0) | |
| Returns: | |
| Optional[str]: Path to the generated audio file or None if generation fails | |
| """ | |
| logger.info(f"Generating speech with Kokoro Space for text length: {len(text)}") | |
| logger.info(f"Text to generate speech on is: {text[:50]}..." if len(text) > 50 else f"Text to generate speech on is: {text}") | |
| # Generate unique output path | |
| output_path = self._generate_output_path() | |
| try: | |
| # Use af_nova as the default voice for Kokoro Space | |
| voice_to_use = 'af_nova' if voice == 'af_heart' else voice | |
| # Generate speech | |
| result = self.client.predict( | |
| text=text, | |
| voice_names=voice_to_use, | |
| speed=speed, | |
| api_name="/generate_speech_from_ui" | |
| ) | |
| logger.info(f"Received audio from Kokoro FastAPI server: {result}") | |
| # Process the result and save to output_path | |
| # Return the result path directly if it's a string | |
| if isinstance(result, str) and os.path.exists(result): | |
| return result | |
| else: | |
| logger.warning("Unexpected result from Kokoro Space") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Failed to generate speech from Kokoro FastAPI server: {str(e)}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| logger.info("Kokoro Space TTS engine failed") | |
| return None | |
| class DiaTTSEngine(TTSEngineBase): | |
| """Dia TTS engine implementation | |
| This engine uses the Dia model for TTS generation. | |
| """ | |
| def __init__(self, lang_code: str = 'z'): | |
| super().__init__(lang_code) | |
| # Dia doesn't need initialization here, it will be lazy-loaded when needed | |
| logger.info("Dia TTS engine initialized (lazy loading)") | |
| def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]: | |
| """Generate speech using Dia TTS engine | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice ID (not used in Dia) | |
| speed (float): Speech speed multiplier (not used in Dia) | |
| Returns: | |
| Optional[str]: Path to the generated audio file or None if generation fails | |
| """ | |
| logger.info(f"Generating speech with Dia for text length: {len(text)}") | |
| try: | |
| # Import here to avoid circular imports | |
| from utils.tts_dia import generate_speech as dia_generate_speech, DIA_AVAILABLE | |
| # Check if Dia is available | |
| if not DIA_AVAILABLE: | |
| logger.warning("Dia TTS engine is not available") | |
| return None | |
| logger.info("Successfully imported Dia speech generation function") | |
| # Call Dia's generate_speech function | |
| # Note: Dia's function expects a language parameter, not voice or speed | |
| output_path = dia_generate_speech(text, language=self.lang_code) | |
| logger.info(f"Generated audio with Dia: {output_path}") | |
| return output_path | |
| except ModuleNotFoundError as e: | |
| if "dac" in str(e): | |
| logger.warning("Dia TTS engine failed due to missing 'dac' module") | |
| return None | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True) | |
| logger.warning("Dia TTS engine failed") | |
| return None | |
| class DiaSpaceTTSEngine(TTSEngineBase): | |
| """Dia Space TTS engine implementation | |
| This engine uses the Dia TTS Server API for speech generation. | |
| """ | |
| def __init__(self, lang_code: str = 'z'): | |
| super().__init__(lang_code) | |
| try: | |
| # Import here to avoid circular imports | |
| from utils.tts_dia_space import _get_client | |
| self.client = _get_client() | |
| logger.info("Dia Space TTS engine successfully initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Dia Space client: {str(e)}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| raise | |
| def generate_speech(self, text: str, voice: str = 'S1', speed: float = 1.0, response_format: str = 'wav') -> Optional[str]: | |
| """Generate speech using Dia Space TTS engine | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone) | |
| speed (float): Speech speed multiplier | |
| response_format (str): Audio format ('wav', 'mp3', 'opus') | |
| Returns: | |
| Optional[str]: Path to the generated audio file or None if generation fails | |
| """ | |
| logger.info(f"Generating speech with Dia Space for text length: {len(text)}") | |
| try: | |
| # Import here to avoid circular imports | |
| from utils.tts_dia_space import _call_dia_api, _generate_output_path | |
| # Call the Dia Space API | |
| audio_data = _call_dia_api(text, voice, response_format, speed) | |
| # Save the audio data to a file | |
| output_path = _generate_output_path(prefix="dia_space", extension=response_format) | |
| with open(output_path, 'wb') as f: | |
| f.write(audio_data) | |
| logger.info(f"Generated audio with Dia Space: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Failed to generate speech from Dia Space API: {str(e)}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| logger.info("Dia Space TTS engine failed") | |
| return None | |
| except ImportError as import_err: | |
| logger.error(f"Dia TTS generation failed due to import error: {str(import_err)}") | |
| logger.error("Dia Space TTS engine failed") | |
| return None | |
| except Exception as dia_error: | |
| logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True) | |
| logger.error(f"Error type: {type(dia_error).__name__}") | |
| logger.error("Dia Space TTS engine failed") | |
| return None | |
| def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]: | |
| """Generate speech stream using Dia TTS engine | |
| Args: | |
| text (str): Input text to synthesize | |
| voice (str): Voice ID (not used in Dia) | |
| speed (float): Speech speed multiplier (not used in Dia) | |
| Yields: | |
| tuple: (sample_rate, audio_data) pairs for each segment | |
| """ | |
| logger.info(f"Generating speech stream with Dia for text length: {len(text)}") | |
| try: | |
| # Import required modules | |
| from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE, DIA_AVAILABLE | |
| # Check if Dia is available | |
| if not DIA_AVAILABLE: | |
| logger.warning("Dia TTS engine is not available, falling back to dummy audio stream") | |
| yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed) | |
| return | |
| import torch | |
| # Get the Dia model | |
| model = _get_model() | |
| # Generate audio | |
| with torch.inference_mode(): | |
| output_audio_np = model.generate( | |
| text, | |
| max_tokens=None, | |
| cfg_scale=3.0, | |
| temperature=1.3, | |
| top_p=0.95, | |
| cfg_filter_top_k=35, | |
| use_torch_compile=False, | |
| verbose=False | |
| ) | |
| if output_audio_np is not None: | |
| logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})") | |
| yield DEFAULT_SAMPLE_RATE, output_audio_np | |
| else: | |
| logger.warning("Dia model returned None for audio output") | |
| logger.warning("Falling back to dummy audio stream") | |
| yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed) | |
| except ModuleNotFoundError as e: | |
| if "dac" in str(e): | |
| logger.warning("Dia TTS streaming failed due to missing 'dac' module, falling back to dummy audio stream") | |
| else: | |
| logger.error(f"Module not found error in Dia TTS streaming: {str(e)}") | |
| yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed) | |
| except ImportError as import_err: | |
| logger.error(f"Dia TTS streaming failed due to import error: {str(import_err)}") | |
| logger.error("Falling back to dummy audio stream") | |
| yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed) | |
| except Exception as dia_error: | |
| logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True) | |
| logger.error(f"Error type: {type(dia_error).__name__}") | |
| logger.error("Falling back to dummy audio stream") | |
| yield from DummyTTSEngine(self.lang_code).generate_speech_stream(text, voice, speed) | |
| def get_available_engines() -> List[str]: | |
| """Get a list of available TTS engines | |
| Returns: | |
| List[str]: List of available engine names | |
| """ | |
| available = [] | |
| if KOKORO_AVAILABLE: | |
| available.append('kokoro') | |
| if KOKORO_SPACE_AVAILABLE: | |
| available.append('kokoro_space') | |
| if DIA_AVAILABLE: | |
| available.append('dia') | |
| # Dummy is always available | |
| available.append('dummy') | |
| return available | |
| def create_engine(engine_type: str, lang_code: str = 'z') -> TTSEngineBase: | |
| """Create a specific TTS engine | |
| Args: | |
| engine_type (str): Type of engine to create ('kokoro', 'kokoro_space', 'dia', 'dummy') | |
| lang_code (str): Language code for the engine | |
| Returns: | |
| TTSEngineBase: An instance of the requested TTS engine | |
| Raises: | |
| ValueError: If the requested engine type is not supported | |
| """ | |
| if engine_type == 'kokoro': | |
| if not KOKORO_AVAILABLE: | |
| raise ValueError("Kokoro TTS engine is not available") | |
| return KokoroTTSEngine(lang_code) | |
| elif engine_type == 'kokoro_space': | |
| if not KOKORO_SPACE_AVAILABLE: | |
| raise ValueError("Kokoro Space TTS engine is not available") | |
| return KokoroSpaceTTSEngine(lang_code) | |
| elif engine_type == 'dia': | |
| if not DIA_AVAILABLE: | |
| raise ValueError("Dia TTS engine is not available") | |
| return DiaTTSEngine(lang_code) | |
| elif engine_type == 'dummy': | |
| return DummyTTSEngine(lang_code) | |
| else: | |
| raise ValueError(f"Unsupported TTS engine type: {engine_type}") |