Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import time | |
import soundfile as sf | |
logger = logging.getLogger(__name__) | |
# Flag to track TTS engine availability | |
KOKORO_AVAILABLE = False | |
DIA_AVAILABLE = False | |
# Try to import Kokoro first | |
try: | |
from kokoro import KPipeline | |
KOKORO_AVAILABLE = True | |
logger.info("Kokoro TTS engine is available") | |
except AttributeError as e: | |
# Specifically catch the EspeakWrapper.set_data_path error | |
if "EspeakWrapper" in str(e) and "set_data_path" in str(e): | |
logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue") | |
else: | |
# Re-raise if it's a different error | |
logger.error(f"Kokoro import failed with unexpected error: {str(e)}") | |
raise | |
except ImportError: | |
logger.warning("Kokoro TTS engine is not available") | |
# Try to import Dia as fallback | |
if not KOKORO_AVAILABLE: | |
try: | |
from utils.tts_dia import _get_model as get_dia_model | |
DIA_AVAILABLE = True | |
logger.info("Dia TTS engine is available as fallback") | |
except ImportError as e: | |
logger.warning(f"Dia TTS engine is not available: {str(e)}") | |
logger.warning("Will use dummy TTS implementation as fallback") | |
class TTSEngine: | |
def __init__(self, lang_code='z'): | |
"""Initialize TTS Engine with Kokoro or Dia as fallback | |
Args: | |
lang_code (str): Language code ('a' for US English, 'b' for British English, | |
'j' for Japanese, 'z' for Mandarin Chinese) | |
Note: lang_code is only used for Kokoro, not for Dia | |
""" | |
logger.info("Initializing TTS Engine") | |
self.engine_type = None | |
if KOKORO_AVAILABLE: | |
self.pipeline = KPipeline(lang_code=lang_code) | |
self.engine_type = "kokoro" | |
logger.info("TTS engine initialized with Kokoro") | |
elif DIA_AVAILABLE: | |
# For Dia, we don't need to initialize anything here | |
# The model will be lazy-loaded when needed | |
self.pipeline = None | |
self.engine_type = "dia" | |
logger.info("TTS engine initialized with Dia (lazy loading)") | |
else: | |
logger.warning("Using dummy TTS implementation as no TTS engines are available") | |
self.pipeline = None | |
self.engine_type = "dummy" | |
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str: | |
"""Generate speech from text using available TTS engine | |
Args: | |
text (str): Input text to synthesize | |
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.) | |
Note: voice parameter is only used for Kokoro, not for Dia | |
speed (float): Speech speed multiplier (0.5 to 2.0) | |
Note: speed parameter is only used for Kokoro, not for Dia | |
Returns: | |
str: Path to the generated audio file | |
""" | |
logger.info(f"Generating speech for text length: {len(text)}") | |
try: | |
# Create output directory if it doesn't exist | |
os.makedirs("temp/outputs", exist_ok=True) | |
# Generate unique output path | |
output_path = f"temp/outputs/output_{int(time.time())}.wav" | |
# Use the appropriate TTS engine based on availability | |
if self.engine_type == "kokoro": | |
# Use Kokoro for TTS generation | |
generator = self.pipeline(text, voice=voice, speed=speed) | |
for _, _, audio in generator: | |
logger.info(f"Saving Kokoro audio to {output_path}") | |
sf.write(output_path, audio, 24000) | |
break | |
elif self.engine_type == "dia": | |
# Use Dia for TTS generation | |
try: | |
# Import here to avoid circular imports | |
from utils.tts_dia import generate_speech as dia_generate_speech | |
# Call Dia's generate_speech function | |
output_path = dia_generate_speech(text) | |
logger.info(f"Generated audio with Dia: {output_path}") | |
except Exception as dia_error: | |
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True) | |
# Fall back to dummy audio if Dia fails | |
return self._generate_dummy_audio(output_path) | |
else: | |
# Generate dummy audio as fallback | |
return self._generate_dummy_audio(output_path) | |
logger.info(f"Audio generation complete: {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"TTS generation failed: {str(e)}", exc_info=True) | |
raise | |
def _generate_dummy_audio(self, output_path): | |
"""Generate a dummy audio file with a simple sine wave | |
Args: | |
output_path (str): Path to save the dummy audio file | |
Returns: | |
str: Path to the generated dummy audio file | |
""" | |
import numpy as np | |
sample_rate = 24000 | |
duration = 3.0 # seconds | |
t = np.linspace(0, duration, int(sample_rate * duration), False) | |
tone = np.sin(2 * np.pi * 440 * t) * 0.3 | |
logger.info(f"Saving dummy audio to {output_path}") | |
sf.write(output_path, tone, sample_rate) | |
logger.info(f"Dummy audio generation complete: {output_path}") | |
return output_path | |
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0): | |
"""Generate speech from text and yield each segment | |
Args: | |
text (str): Input text to synthesize | |
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.) | |
speed (float): Speech speed multiplier (0.5 to 2.0) | |
Yields: | |
tuple: (sample_rate, audio_data) pairs for each segment | |
""" | |
try: | |
# Use the appropriate TTS engine based on availability | |
if self.engine_type == "kokoro": | |
# Use Kokoro for streaming TTS | |
generator = self.pipeline(text, voice=voice, speed=speed) | |
for _, _, audio in generator: | |
yield 24000, audio | |
elif self.engine_type == "dia": | |
# Dia doesn't support streaming natively, so we generate the full audio | |
# and then yield it as a single chunk | |
try: | |
# Import here to avoid circular imports | |
import torch | |
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE | |
# Get the Dia model | |
model = _get_model() | |
# Generate audio | |
with torch.inference_mode(): | |
output_audio_np = model.generate( | |
text, | |
max_tokens=None, | |
cfg_scale=3.0, | |
temperature=1.3, | |
top_p=0.95, | |
cfg_filter_top_k=35, | |
use_torch_compile=False, | |
verbose=False | |
) | |
if output_audio_np is not None: | |
yield DEFAULT_SAMPLE_RATE, output_audio_np | |
else: | |
# Fall back to dummy audio if Dia fails | |
yield from self._generate_dummy_audio_stream() | |
except Exception as dia_error: | |
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True) | |
# Fall back to dummy audio if Dia fails | |
yield from self._generate_dummy_audio_stream() | |
else: | |
# Generate dummy audio chunks as fallback | |
yield from self._generate_dummy_audio_stream() | |
except Exception as e: | |
logger.error(f"TTS streaming failed: {str(e)}", exc_info=True) | |
raise | |
def _generate_dummy_audio_stream(self): | |
"""Generate dummy audio chunks with simple sine waves | |
Yields: | |
tuple: (sample_rate, audio_data) pairs for each dummy segment | |
""" | |
import numpy as np | |
sample_rate = 24000 | |
duration = 1.0 # seconds per chunk | |
# Create 3 chunks of dummy audio | |
for i in range(3): | |
t = np.linspace(0, duration, int(sample_rate * duration), False) | |
freq = 440 + (i * 220) # Different frequency for each chunk | |
tone = np.sin(2 * np.pi * freq * t) * 0.3 | |
yield sample_rate, tone | |
# Initialize TTS engine with cache decorator if using Streamlit | |
def get_tts_engine(lang_code='a'): | |
"""Get or create TTS engine instance | |
Args: | |
lang_code (str): Language code for the pipeline | |
Returns: | |
TTSEngine: Initialized TTS engine instance | |
""" | |
try: | |
import streamlit as st | |
def _get_engine(): | |
return TTSEngine(lang_code) | |
return _get_engine() | |
except ImportError: | |
return TTSEngine(lang_code) | |
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str: | |
"""Public interface for TTS generation | |
Args: | |
text (str): Input text to synthesize | |
voice (str): Voice ID to use | |
speed (float): Speech speed multiplier | |
Returns: | |
str: Path to generated audio file | |
""" | |
engine = get_tts_engine() | |
return engine.generate_speech(text, voice, speed) |