File size: 8,738 Bytes
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
462d6de
1f9c751
 
 
 
 
 
 
 
462d6de
 
fdc056d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462d6de
 
fdc056d
 
 
462d6de
fdc056d
 
462d6de
fdc056d
 
 
 
 
 
 
 
462d6de
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
 
 
 
 
462d6de
fdc056d
462d6de
 
fdc056d
462d6de
fdc056d
462d6de
fdc056d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f9c751
 
 
fdc056d
 
 
 
 
 
 
 
 
 
 
 
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
1f9c751
 
fdc056d
1f9c751
 
 
fdc056d
1f9c751
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""Dia TTS provider implementation."""

import logging
import numpy as np
import soundfile as sf
import io
from typing import Iterator, TYPE_CHECKING

if TYPE_CHECKING:
    from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest

from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException


logger = logging.getLogger(__name__)

# Flag to track Dia availability
DIA_AVAILABLE = False
DEFAULT_SAMPLE_RATE = 24000

# Try to import Dia dependencies
def _check_dia_dependencies():
    """Check if Dia dependencies are available."""
    global DIA_AVAILABLE

    logger.info("πŸ” Checking Dia TTS dependencies...")

    try:
        logger.info("Attempting to import torch...")
        import torch
        logger.info("βœ“ Successfully imported torch")

        logger.info("Attempting to import dia.model...")
        from dia.model import Dia
        logger.info("βœ“ Successfully imported dia.model")

        DIA_AVAILABLE = True
        logger.info("βœ… Dia TTS engine is available")
        return True
    except ImportError as e:
        logger.warning(f"⚠️ Dia TTS engine dependencies not available: {e}")
        logger.info(f"ImportError details: {type(e).__name__}: {e}")
        DIA_AVAILABLE = False
        return False
    except ModuleNotFoundError as e:
        if "dac" in str(e):
            logger.warning("❌ Dia TTS engine is not available due to missing 'dac' module")
            logger.info("Please install descript-audio-codec: pip install descript-audio-codec")
        elif "dia" in str(e):
            logger.warning("❌ Dia TTS engine is not available due to missing 'dia' module")
            logger.info("Please install dia: pip install git+https://github.com/nari-labs/dia.git")
        else:
            logger.warning(f"❌ Dia TTS engine is not available: {str(e)}")
        logger.info(f"ModuleNotFoundError details: {type(e).__name__}: {e}")
        DIA_AVAILABLE = False
        return False

# Initial check
logger.info("πŸš€ Initializing Dia TTS provider...")
_check_dia_dependencies()


class DiaTTSProvider(TTSProviderBase):
    """Dia TTS provider implementation."""

    def __init__(self, lang_code: str = 'z'):
        """Initialize the Dia TTS provider."""
        super().__init__(
            provider_name="Dia",
            supported_languages=['en', 'z']  # Dia supports English and multilingual
        )
        self.lang_code = lang_code
        self.model = None

    def _ensure_model(self):
        """Ensure the model is loaded."""
        global DIA_AVAILABLE

        if self.model is None:
            logger.info("πŸ”„ Ensuring Dia model is loaded...")

            # If Dia is not available, check dependencies again
            if not DIA_AVAILABLE:
                logger.info("⚠️ Dia not available, checking dependencies again...")
                if _check_dia_dependencies():
                    DIA_AVAILABLE = True
                    logger.info("βœ… Dependencies are now available")
                else:
                    logger.error("❌ Dependencies still not available")
                    return False

            if DIA_AVAILABLE:
                try:
                    logger.info("πŸ“₯ Loading Dia model from pretrained...")
                    import torch
                    from dia.model import Dia
                    self.model = Dia.from_pretrained()
                    logger.info("πŸŽ‰ Dia model successfully loaded")
                except ImportError as e:
                    logger.error(f"❌ Failed to import Dia dependencies: {str(e)}")
                    self.model = None
                except FileNotFoundError as e:
                    logger.error(f"❌ Failed to load Dia model files: {str(e)}")
                    logger.info("ℹ️ This might be the first time loading the model. It will be downloaded automatically.")
                    self.model = None
                except Exception as e:
                    logger.error(f"❌ Failed to initialize Dia model: {str(e)}")
                    logger.info(f"Model initialization error: {type(e).__name__}: {e}")
                    self.model = None

        is_available = self.model is not None
        logger.info(f"Model availability check result: {is_available}")
        return is_available

    def is_available(self) -> bool:
        """Check if Dia TTS is available."""
        logger.info(f"πŸ” Checking Dia availability: DIA_AVAILABLE={DIA_AVAILABLE}")

        if not DIA_AVAILABLE:
            logger.info("❌ Dia dependencies not available")
            return False

        model_available = self._ensure_model()
        logger.info(f"πŸ” Model availability: {model_available}")

        result = DIA_AVAILABLE and model_available
        logger.info(f"🎯 Dia TTS availability result: {result}")
        return result

    def get_available_voices(self) -> list[str]:
        """Get available voices for Dia."""
        # Dia typically uses a default voice
        return ['default']

    def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
        """Generate audio using Dia TTS."""
        if not self.is_available():
            raise SpeechSynthesisException("Dia TTS engine is not available")

        try:
            import torch

            # Extract parameters from request
            text = request.text_content.text

            # Generate audio using Dia
            with torch.inference_mode():
                output_audio_np = self.model.generate(
                    text,
                    max_tokens=None,
                    cfg_scale=3.0,
                    temperature=1.3,
                    top_p=0.95,
                    cfg_filter_top_k=35,
                    use_torch_compile=False,
                    verbose=False
                )

            if output_audio_np is None:
                raise SpeechSynthesisException("Dia model returned None for audio output")

            # Convert numpy array to bytes
            audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
            return audio_bytes, DEFAULT_SAMPLE_RATE

        except ModuleNotFoundError as e:
            if "dac" in str(e):
                raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
            else:
                self._handle_provider_error(e, "audio generation")
        except Exception as e:
            self._handle_provider_error(e, "audio generation")

    def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
        """Generate audio stream using Dia TTS."""
        if not self.is_available():
            raise SpeechSynthesisException("Dia TTS engine is not available")

        try:
            import torch

            # Extract parameters from request
            text = request.text_content.text

            # Generate audio using Dia
            with torch.inference_mode():
                output_audio_np = self.model.generate(
                    text,
                    max_tokens=None,
                    cfg_scale=3.0,
                    temperature=1.3,
                    top_p=0.95,
                    cfg_filter_top_k=35,
                    use_torch_compile=False,
                    verbose=False
                )

            if output_audio_np is None:
                raise SpeechSynthesisException("Dia model returned None for audio output")

            # Convert numpy array to bytes
            audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
            # Dia generates complete audio in one go
            yield audio_bytes, DEFAULT_SAMPLE_RATE, True

        except ModuleNotFoundError as e:
            if "dac" in str(e):
                raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
            else:
                self._handle_provider_error(e, "streaming audio generation")
        except Exception as e:
            self._handle_provider_error(e, "streaming audio generation")

    def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
        """Convert numpy audio array to bytes."""
        try:
            # Create an in-memory buffer
            buffer = io.BytesIO()

            # Write audio data to buffer as WAV
            sf.write(buffer, audio_array, sample_rate, format='WAV')

            # Get bytes from buffer
            buffer.seek(0)
            return buffer.read()

        except Exception as e:
            raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e