File size: 3,647 Bytes
030c851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91223c9
 
 
 
 
 
 
 
 
 
 
 
030c851
91223c9
 
 
 
2d176f4
 
 
 
 
91223c9
 
 
 
 
 
 
 
030c851
 
91223c9
 
030c851
 
 
c549dab
030c851
 
 
3ed3b5a
 
 
030c851
 
 
 
 
 
 
3ed3b5a
4a9bb1a
3ed3b5a
 
4a9bb1a
 
3ed3b5a
 
 
 
 
4a9bb1a
3ed3b5a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import time
import logging
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional

from dia.model import Dia

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"

# Global model instance (lazy loaded)
_model = None


def _get_model() -> Dia:
    """Lazy-load the Dia model to avoid loading it until needed"""
    global _model
    if _model is None:
        logger.info("Loading Dia model...")
        try:
            # Check if torch is available with correct version
            logger.info(f"PyTorch version: {torch.__version__}")
            logger.info(f"CUDA available: {torch.cuda.is_available()}")
            if torch.cuda.is_available():
                logger.info(f"CUDA version: {torch.version.cuda}")
                logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
            
            # Check if model path exists
            logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
            
            # Load the model with detailed logging
            logger.info("Initializing Dia model...")
            _model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
            
            # Log model details
            logger.info(f"Dia model loaded successfully")
            logger.info(f"Model type: {type(_model).__name__}")
            # Check if model has parameters method (PyTorch models do, but Dia might not)
            if hasattr(_model, 'parameters'):
                logger.info(f"Model device: {next(_model.parameters()).device}")
            else:
                logger.info("Model device: Device information not available for Dia model")
        except ImportError as import_err:
            logger.error(f"Import error loading Dia model: {import_err}")
            logger.error(f"This may indicate missing dependencies")
            raise
        except FileNotFoundError as file_err:
            logger.error(f"File not found error loading Dia model: {file_err}")
            logger.error(f"Model path may be incorrect or inaccessible")
            raise
        except Exception as e:
            logger.error(f"Error loading Dia model: {e}", exc_info=True)
            logger.error(f"Error type: {type(e).__name__}")
            logger.error(f"This may indicate incompatible versions or missing CUDA support")
            raise
    return _model


def generate_speech(text: str, language: str = "zh") -> str:
    """Public interface for TTS generation using Dia model
    
    This is a legacy function maintained for backward compatibility.
    New code should use the factory pattern implementation directly.
    
    Args:
        text (str): Input text to synthesize
        language (str): Language code (not used in Dia model, kept for API compatibility)
        
    Returns:
        str: Path to the generated audio file
    """
    logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
    
    # Use the new implementation via factory pattern
    from utils.tts_engines import DiaTTSEngine
    
    try:
        # Create a Dia engine and generate speech
        dia_engine = DiaTTSEngine(language)
        return dia_engine.generate_speech(text)
    except Exception as e:
        logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
        # Fall back to dummy TTS
        from utils.tts_base import DummyTTSEngine
        dummy_engine = DummyTTSEngine()
        return dummy_engine.generate_speech(text)