Michael Hu
refator tts part
3ed3b5a
raw
history blame
3.65 kB
import os
import time
import logging
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional
from dia.model import Dia
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
# Global model instance (lazy loaded)
_model = None
def _get_model() -> Dia:
"""Lazy-load the Dia model to avoid loading it until needed"""
global _model
if _model is None:
logger.info("Loading Dia model...")
try:
# Check if torch is available with correct version
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
# Check if model path exists
logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
# Load the model with detailed logging
logger.info("Initializing Dia model...")
_model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
# Log model details
logger.info(f"Dia model loaded successfully")
logger.info(f"Model type: {type(_model).__name__}")
# Check if model has parameters method (PyTorch models do, but Dia might not)
if hasattr(_model, 'parameters'):
logger.info(f"Model device: {next(_model.parameters()).device}")
else:
logger.info("Model device: Device information not available for Dia model")
except ImportError as import_err:
logger.error(f"Import error loading Dia model: {import_err}")
logger.error(f"This may indicate missing dependencies")
raise
except FileNotFoundError as file_err:
logger.error(f"File not found error loading Dia model: {file_err}")
logger.error(f"Model path may be incorrect or inaccessible")
raise
except Exception as e:
logger.error(f"Error loading Dia model: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"This may indicate incompatible versions or missing CUDA support")
raise
return _model
def generate_speech(text: str, language: str = "zh") -> str:
"""Public interface for TTS generation using Dia model
This is a legacy function maintained for backward compatibility.
New code should use the factory pattern implementation directly.
Args:
text (str): Input text to synthesize
language (str): Language code (not used in Dia model, kept for API compatibility)
Returns:
str: Path to the generated audio file
"""
logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
# Use the new implementation via factory pattern
from utils.tts_engines import DiaTTSEngine
try:
# Create a Dia engine and generate speech
dia_engine = DiaTTSEngine(language)
return dia_engine.generate_speech(text)
except Exception as e:
logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine()
return dummy_engine.generate_speech(text)