File size: 6,774 Bytes
030c851
 
 
7495571
 
aaa0814
030c851
 
 
 
cb90410
 
7495571
cb90410
7495571
cb90410
 
7495571
 
 
cb90410
7495571
 
 
 
 
 
cb90410
 
030c851
cb90410
7495571
cb90410
7495571
 
 
cb90410
7495571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e22e786
 
7495571
 
 
e22e786
 
cb90410
030c851
7495571
91223c9
7495571
 
91223c9
7495571
 
 
 
 
 
 
 
 
 
 
 
91223c9
7495571
 
 
 
 
2d176f4
e22e786
 
7495571
 
 
e22e786
7495571
 
e22e786
030c851
7495571
e22e786
030c851
7495571
 
030c851
7495571
 
 
 
 
 
 
 
 
cb90410
7495571
 
e22e786
7495571
 
 
 
e22e786
7495571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e22e786
 
7495571
 
 
e22e786
7495571
 
e22e786
7495571
 
e22e786
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import logging
import numpy as np
import soundfile as sf
from typing import Optional, Generator, Tuple

from utils.tts_base import TTSBase

# Configure logging
logger = logging.getLogger(__name__)

# Flag to track Dia availability
DIA_AVAILABLE = False
DEFAULT_SAMPLE_RATE = 24000

# Try to import Dia dependencies
try:
    import torch
    from dia.model import Dia
    DIA_AVAILABLE = True
    logger.info("Dia TTS engine is available")
except ImportError:
    logger.warning("Dia TTS engine is not available")
except ModuleNotFoundError as e:
    if "dac" in str(e):
        logger.warning("Dia TTS engine is not available due to missing 'dac' module")
    else:
        logger.warning(f"Dia TTS engine is not available: {str(e)}")
    DIA_AVAILABLE = False


def _get_model():
    """Lazy-load the Dia model
    
    Returns:
        Dia or None: The Dia model or None if not available
    """
    if not DIA_AVAILABLE:
        logger.warning("Dia TTS engine is not available")
        return None
    
    try:
        import torch
        from dia.model import Dia
        
        # Initialize the model
        model = Dia.from_pretrained()
        logger.info("Dia model successfully loaded")
        return model
    except ImportError as e:
        logger.error(f"Failed to import Dia dependencies: {str(e)}")
        return None
    except FileNotFoundError as e:
        logger.error(f"Failed to load Dia model files: {str(e)}")
        return None
    except Exception as e:
        logger.error(f"Failed to initialize Dia model: {str(e)}")
        return None


class DiaTTS(TTSBase):
    """Dia TTS engine implementation
    
    This engine uses the Dia model for TTS generation.
    """
    
    def __init__(self, lang_code: str = 'z'):
        """Initialize the Dia TTS engine
        
        Args:
            lang_code (str): Language code for the engine
        """
        super().__init__(lang_code)
        self.model = None
    
    def _ensure_model(self):
        """Ensure the model is loaded
        
        Returns:
            bool: True if model is available, False otherwise
        """
        if self.model is None:
            self.model = _get_model()
        
        return self.model is not None
    
    def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]:
        """Generate speech using Dia TTS engine
        
        Args:
            text (str): Input text to synthesize
            voice (str): Voice ID (not used in Dia)
            speed (float): Speech speed multiplier (not used in Dia)
            
        Returns:
            Optional[str]: Path to the generated audio file or None if generation fails
        """
        logger.info(f"Generating speech with Dia for text length: {len(text)}")
        
        # Check if Dia is available
        if not DIA_AVAILABLE:
            logger.error("Dia TTS engine is not available")
            return None
        
        # Ensure model is loaded
        if not self._ensure_model():
            logger.error("Failed to load Dia model")
            return None
        
        try:
            import torch
            
            # Generate unique output path
            output_path = self._generate_output_path(prefix="dia")
            
            # Generate audio
            with torch.inference_mode():
                output_audio_np = self.model.generate(
                    text,
                    max_tokens=None,
                    cfg_scale=3.0,
                    temperature=1.3,
                    top_p=0.95,
                    cfg_filter_top_k=35,
                    use_torch_compile=False,
                    verbose=False
                )
            
            if output_audio_np is not None:
                logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
                sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
                logger.info(f"Dia audio generation complete: {output_path}")
                return output_path
            else:
                logger.error("Dia model returned None for audio output")
                return None
                
        except ModuleNotFoundError as e:
            if "dac" in str(e):
                logger.error("Dia TTS engine failed due to missing 'dac' module")
            else:
                logger.error(f"Module not found error in Dia TTS: {str(e)}")
            return None
        except Exception as e:
            logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
            return None
    
    def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
        """Generate speech stream using Dia TTS engine
        
        Args:
            text (str): Input text to synthesize
            voice (str): Voice ID (not used in Dia)
            speed (float): Speech speed multiplier (not used in Dia)
            
        Yields:
            tuple: (sample_rate, audio_data) pairs for each segment
        """
        logger.info(f"Generating speech stream with Dia for text length: {len(text)}")
        
        # Check if Dia is available
        if not DIA_AVAILABLE:
            logger.error("Dia TTS engine is not available")
            return
        
        # Ensure model is loaded
        if not self._ensure_model():
            logger.error("Failed to load Dia model")
            return
        
        try:
            import torch
            
            # Generate audio
            with torch.inference_mode():
                output_audio_np = self.model.generate(
                    text,
                    max_tokens=None,
                    cfg_scale=3.0,
                    temperature=1.3,
                    top_p=0.95,
                    cfg_filter_top_k=35,
                    use_torch_compile=False,
                    verbose=False
                )
            
            if output_audio_np is not None:
                logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
                yield DEFAULT_SAMPLE_RATE, output_audio_np
            else:
                logger.error("Dia model returned None for audio output")
                return
                
        except ModuleNotFoundError as e:
            if "dac" in str(e):
                logger.error("Dia TTS engine failed due to missing 'dac' module")
            else:
                logger.error(f"Module not found error in Dia TTS: {str(e)}")
            return
        except Exception as e:
            logger.error(f"Error generating speech stream with Dia: {str(e)}", exc_info=True)
            return