Michael Hu
use dia tts as fallback model if kokoro is not available
7b25fdd
raw
history blame
9.85 kB
import os
import logging
import time
import soundfile as sf
logger = logging.getLogger(__name__)
# Flag to track TTS engine availability
KOKORO_AVAILABLE = False
DIA_AVAILABLE = False
# Try to import Kokoro first
try:
from kokoro import KPipeline
KOKORO_AVAILABLE = True
logger.info("Kokoro TTS engine is available")
except AttributeError as e:
# Specifically catch the EspeakWrapper.set_data_path error
if "EspeakWrapper" in str(e) and "set_data_path" in str(e):
logger.warning("Kokoro import failed due to EspeakWrapper.set_data_path issue")
else:
# Re-raise if it's a different error
logger.error(f"Kokoro import failed with unexpected error: {str(e)}")
raise
except ImportError:
logger.warning("Kokoro TTS engine is not available")
# Try to import Dia as fallback
if not KOKORO_AVAILABLE:
try:
from utils.tts_dia import _get_model as get_dia_model
DIA_AVAILABLE = True
logger.info("Dia TTS engine is available as fallback")
except ImportError as e:
logger.warning(f"Dia TTS engine is not available: {str(e)}")
logger.warning("Will use dummy TTS implementation as fallback")
class TTSEngine:
def __init__(self, lang_code='z'):
"""Initialize TTS Engine with Kokoro or Dia as fallback
Args:
lang_code (str): Language code ('a' for US English, 'b' for British English,
'j' for Japanese, 'z' for Mandarin Chinese)
Note: lang_code is only used for Kokoro, not for Dia
"""
logger.info("Initializing TTS Engine")
self.engine_type = None
if KOKORO_AVAILABLE:
self.pipeline = KPipeline(lang_code=lang_code)
self.engine_type = "kokoro"
logger.info("TTS engine initialized with Kokoro")
elif DIA_AVAILABLE:
# For Dia, we don't need to initialize anything here
# The model will be lazy-loaded when needed
self.pipeline = None
self.engine_type = "dia"
logger.info("TTS engine initialized with Dia (lazy loading)")
else:
logger.warning("Using dummy TTS implementation as no TTS engines are available")
self.pipeline = None
self.engine_type = "dummy"
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
"""Generate speech from text using available TTS engine
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
Note: voice parameter is only used for Kokoro, not for Dia
speed (float): Speech speed multiplier (0.5 to 2.0)
Note: speed parameter is only used for Kokoro, not for Dia
Returns:
str: Path to the generated audio file
"""
logger.info(f"Generating speech for text length: {len(text)}")
try:
# Create output directory if it doesn't exist
os.makedirs("temp/outputs", exist_ok=True)
# Generate unique output path
output_path = f"temp/outputs/output_{int(time.time())}.wav"
# Use the appropriate TTS engine based on availability
if self.engine_type == "kokoro":
# Use Kokoro for TTS generation
generator = self.pipeline(text, voice=voice, speed=speed)
for _, _, audio in generator:
logger.info(f"Saving Kokoro audio to {output_path}")
sf.write(output_path, audio, 24000)
break
elif self.engine_type == "dia":
# Use Dia for TTS generation
try:
# Import here to avoid circular imports
from utils.tts_dia import generate_speech as dia_generate_speech
# Call Dia's generate_speech function
output_path = dia_generate_speech(text)
logger.info(f"Generated audio with Dia: {output_path}")
except Exception as dia_error:
logger.error(f"Dia TTS generation failed: {str(dia_error)}", exc_info=True)
# Fall back to dummy audio if Dia fails
return self._generate_dummy_audio(output_path)
else:
# Generate dummy audio as fallback
return self._generate_dummy_audio(output_path)
logger.info(f"Audio generation complete: {output_path}")
return output_path
except Exception as e:
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
raise
def _generate_dummy_audio(self, output_path):
"""Generate a dummy audio file with a simple sine wave
Args:
output_path (str): Path to save the dummy audio file
Returns:
str: Path to the generated dummy audio file
"""
import numpy as np
sample_rate = 24000
duration = 3.0 # seconds
t = np.linspace(0, duration, int(sample_rate * duration), False)
tone = np.sin(2 * np.pi * 440 * t) * 0.3
logger.info(f"Saving dummy audio to {output_path}")
sf.write(output_path, tone, sample_rate)
logger.info(f"Dummy audio generation complete: {output_path}")
return output_path
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0):
"""Generate speech from text and yield each segment
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
speed (float): Speech speed multiplier (0.5 to 2.0)
Yields:
tuple: (sample_rate, audio_data) pairs for each segment
"""
try:
# Use the appropriate TTS engine based on availability
if self.engine_type == "kokoro":
# Use Kokoro for streaming TTS
generator = self.pipeline(text, voice=voice, speed=speed)
for _, _, audio in generator:
yield 24000, audio
elif self.engine_type == "dia":
# Dia doesn't support streaming natively, so we generate the full audio
# and then yield it as a single chunk
try:
# Import here to avoid circular imports
import torch
from utils.tts_dia import _get_model, DEFAULT_SAMPLE_RATE
# Get the Dia model
model = _get_model()
# Generate audio
with torch.inference_mode():
output_audio_np = model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is not None:
yield DEFAULT_SAMPLE_RATE, output_audio_np
else:
# Fall back to dummy audio if Dia fails
yield from self._generate_dummy_audio_stream()
except Exception as dia_error:
logger.error(f"Dia TTS streaming failed: {str(dia_error)}", exc_info=True)
# Fall back to dummy audio if Dia fails
yield from self._generate_dummy_audio_stream()
else:
# Generate dummy audio chunks as fallback
yield from self._generate_dummy_audio_stream()
except Exception as e:
logger.error(f"TTS streaming failed: {str(e)}", exc_info=True)
raise
def _generate_dummy_audio_stream(self):
"""Generate dummy audio chunks with simple sine waves
Yields:
tuple: (sample_rate, audio_data) pairs for each dummy segment
"""
import numpy as np
sample_rate = 24000
duration = 1.0 # seconds per chunk
# Create 3 chunks of dummy audio
for i in range(3):
t = np.linspace(0, duration, int(sample_rate * duration), False)
freq = 440 + (i * 220) # Different frequency for each chunk
tone = np.sin(2 * np.pi * freq * t) * 0.3
yield sample_rate, tone
# Initialize TTS engine with cache decorator if using Streamlit
def get_tts_engine(lang_code='a'):
"""Get or create TTS engine instance
Args:
lang_code (str): Language code for the pipeline
Returns:
TTSEngine: Initialized TTS engine instance
"""
try:
import streamlit as st
@st.cache_resource
def _get_engine():
return TTSEngine(lang_code)
return _get_engine()
except ImportError:
return TTSEngine(lang_code)
def generate_speech(text: str, voice: str = 'af_heart', speed: float = 1.0) -> str:
"""Public interface for TTS generation
Args:
text (str): Input text to synthesize
voice (str): Voice ID to use
speed (float): Speech speed multiplier
Returns:
str: Path to generated audio file
"""
engine = get_tts_engine()
return engine.generate_speech(text, voice, speed)