Michael Hu
fix path
b2b15db
raw
history blame
7.91 kB
import logging
import numpy as np
import soundfile as sf
from typing import Optional, Generator, Tuple
from utils.tts import TTSBase, DummyTTS
# Configure logging
logger = logging.getLogger(__name__)
# Flag to track Dia availability
DIA_AVAILABLE = False
DEFAULT_SAMPLE_RATE = 24000
# Try to import Dia dependencies
try:
import torch
from dia.model import Dia
DIA_AVAILABLE = True
logger.info("Dia TTS engine is available")
except ImportError:
logger.warning("Dia TTS engine is not available")
except ModuleNotFoundError as e:
if "dac" in str(e):
logger.warning("Dia TTS engine is not available due to missing 'dac' module")
else:
logger.warning(f"Dia TTS engine is not available: {str(e)}")
DIA_AVAILABLE = False
def _get_model():
"""Lazy-load the Dia model
Returns:
Dia or None: The Dia model or None if not available
"""
if not DIA_AVAILABLE:
logger.warning("Dia TTS engine is not available")
return None
try:
import torch
from dia.model import Dia
# Initialize the model
model = Dia.from_pretrained()
logger.info("Dia model successfully loaded")
return model
except ImportError as e:
logger.error(f"Failed to import Dia dependencies: {str(e)}")
return None
except FileNotFoundError as e:
logger.error(f"Failed to load Dia model files: {str(e)}")
return None
except Exception as e:
logger.error(f"Failed to initialize Dia model: {str(e)}")
return None
class DiaTTS(TTSBase):
"""Dia TTS engine implementation
This engine uses the Dia model for TTS generation.
"""
def __init__(self, lang_code: str = 'z'):
"""Initialize the Dia TTS engine
Args:
lang_code (str): Language code for the engine
"""
super().__init__(lang_code)
self.model = None
def _ensure_model(self):
"""Ensure the model is loaded
Returns:
bool: True if model is available, False otherwise
"""
if self.model is None:
self.model = _get_model()
return self.model is not None
def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]:
"""Generate speech using Dia TTS engine
Args:
text (str): Input text to synthesize
voice (str): Voice ID (not used in Dia)
speed (float): Speech speed multiplier (not used in Dia)
Returns:
Optional[str]: Path to the generated audio file or None if generation fails
"""
logger.info(f"Generating speech with Dia for text length: {len(text)}")
# Check if Dia is available
if not DIA_AVAILABLE:
logger.warning("Dia TTS engine is not available, falling back to dummy TTS")
return DummyTTS(self.lang_code).generate_speech(text, voice, speed)
# Ensure model is loaded
if not self._ensure_model():
logger.warning("Failed to load Dia model, falling back to dummy TTS")
return DummyTTS(self.lang_code).generate_speech(text, voice, speed)
try:
import torch
# Generate unique output path
output_path = self._generate_output_path(prefix="dia")
# Generate audio
with torch.inference_mode():
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is not None:
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
logger.info(f"Dia audio generation complete: {output_path}")
return output_path
else:
logger.warning("Dia model returned None for audio output")
logger.warning("Falling back to dummy TTS")
return DummyTTS(self.lang_code).generate_speech(text, voice, speed)
except ModuleNotFoundError as e:
if "dac" in str(e):
logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS")
else:
logger.error(f"Module not found error in Dia TTS: {str(e)}")
return DummyTTS(self.lang_code).generate_speech(text, voice, speed)
except Exception as e:
logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
logger.warning("Dia TTS engine failed, falling back to dummy TTS")
return DummyTTS(self.lang_code).generate_speech(text, voice, speed)
def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
"""Generate speech stream using Dia TTS engine
Args:
text (str): Input text to synthesize
voice (str): Voice ID (not used in Dia)
speed (float): Speech speed multiplier (not used in Dia)
Yields:
tuple: (sample_rate, audio_data) pairs for each segment
"""
logger.info(f"Generating speech stream with Dia for text length: {len(text)}")
# Check if Dia is available
if not DIA_AVAILABLE:
logger.warning("Dia TTS engine is not available, falling back to dummy TTS")
yield from DummyTTS(self.lang_code).generate_speech_stream(text, voice, speed)
return
# Ensure model is loaded
if not self._ensure_model():
logger.warning("Failed to load Dia model, falling back to dummy TTS")
yield from DummyTTS(self.lang_code).generate_speech_stream(text, voice, speed)
return
try:
import torch
# Generate audio
with torch.inference_mode():
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is not None:
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
yield DEFAULT_SAMPLE_RATE, output_audio_np
else:
logger.warning("Dia model returned None for audio output")
logger.warning("Falling back to dummy TTS")
yield from DummyTTS(self.lang_code).generate_speech_stream(text, voice, speed)
except ModuleNotFoundError as e:
if "dac" in str(e):
logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS")
else:
logger.error(f"Module not found error in Dia TTS: {str(e)}")
yield from DummyTTS(self.lang_code).generate_speech_stream(text, voice, speed)
except Exception as e:
logger.error(f"Error generating speech stream with Dia: {str(e)}", exc_info=True)
logger.warning("Dia TTS engine failed, falling back to dummy TTS")
yield from DummyTTS(self.lang_code).generate_speech_stream(text, voice, speed)