Michael Hu
fix import issue
cb90410
import os
import time
import logging
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Flag to track Dia availability
DIA_AVAILABLE = False
# Try to import required dependencies
try:
import torch
# Try to import Dia, which will try to import dac
try:
from dia.model import Dia
DIA_AVAILABLE = True
logger.info("Dia TTS engine is available")
except ModuleNotFoundError as e:
if "dac" in str(e):
logger.warning("Dia TTS engine is not available due to missing 'dac' module")
else:
logger.warning(f"Dia TTS engine is not available: {str(e)}")
DIA_AVAILABLE = False
except ImportError:
logger.warning("Torch not available, Dia TTS engine cannot be used")
DIA_AVAILABLE = False
# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
# Global model instance (lazy loaded)
_model = None
def _get_model():
"""Lazy-load the Dia model to avoid loading it until needed"""
global _model
# Check if Dia is available before attempting to load
if not DIA_AVAILABLE:
logger.warning("Dia is not available, cannot load model")
raise ImportError("Dia module is not available")
if _model is None:
logger.info("Loading Dia model...")
try:
# Check if torch is available with correct version
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
# Check if model path exists
logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
# Load the model with detailed logging
logger.info("Initializing Dia model...")
_model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
# Log model details
logger.info(f"Dia model loaded successfully")
logger.info(f"Model type: {type(_model).__name__}")
# Check if model has parameters method (PyTorch models do, but Dia might not)
if hasattr(_model, 'parameters'):
logger.info(f"Model device: {next(_model.parameters()).device}")
else:
logger.info("Model device: Device information not available for Dia model")
except ImportError as import_err:
logger.error(f"Import error loading Dia model: {import_err}")
logger.error(f"This may indicate missing dependencies")
raise
except FileNotFoundError as file_err:
logger.error(f"File not found error loading Dia model: {file_err}")
logger.error(f"Model path may be incorrect or inaccessible")
raise
except Exception as e:
logger.error(f"Error loading Dia model: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"This may indicate incompatible versions or missing CUDA support")
raise
return _model
def generate_speech(text: str, language: str = "zh") -> str:
"""Public interface for TTS generation using Dia model
This is a legacy function maintained for backward compatibility.
New code should use the factory pattern implementation directly.
Args:
text (str): Input text to synthesize
language (str): Language code (not used in Dia model, kept for API compatibility)
Returns:
str: Path to the generated audio file
"""
logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
# Check if Dia is available
if not DIA_AVAILABLE:
logger.warning("Dia is not available, falling back to dummy TTS engine")
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine(language)
return dummy_engine.generate_speech(text)
# Use the new implementation via factory pattern
try:
# Import here to avoid circular imports
from utils.tts_engines import DiaTTSEngine
# Create a Dia engine and generate speech
dia_engine = DiaTTSEngine(language)
return dia_engine.generate_speech(text)
except ModuleNotFoundError as e:
logger.error(f"Module not found error in Dia generate_speech: {str(e)}")
if "dac" in str(e):
logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS")
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine(language)
return dummy_engine.generate_speech(text)
except Exception as e:
logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine(language)
return dummy_engine.generate_speech(text)