File size: 9,044 Bytes
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c2d9e7
 
 
 
1f9c751
 
0c2d9e7
 
 
1f9c751
0c2d9e7
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c2d9e7
1f9c751
0c2d9e7
 
 
 
 
 
1f9c751
 
6514731
1f9c751
 
6514731
1f9c751
 
6514731
1f9c751
0c2d9e7
 
 
 
1f9c751
 
 
 
 
 
 
 
 
 
 
 
0c2d9e7
 
1f9c751
0c2d9e7
1f9c751
 
 
 
0c2d9e7
1f9c751
 
0c2d9e7
 
1f9c751
 
0c2d9e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f9c751
 
0c2d9e7
1f9c751
 
0c2d9e7
 
1f9c751
0c2d9e7
1f9c751
0c2d9e7
 
1f9c751
 
 
6514731
1f9c751
 
 
 
 
 
 
 
 
0c2d9e7
1f9c751
 
 
 
0c2d9e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f9c751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c2d9e7
1f9c751
 
0c2d9e7
1f9c751
 
 
0c2d9e7
1f9c751
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""CosyVoice2 TTS provider implementation."""

import logging
import numpy as np
import soundfile as sf
import io
from typing import Iterator, TYPE_CHECKING

if TYPE_CHECKING:
    from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest

from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException

logger = logging.getLogger(__name__)

# Flag to track CosyVoice2 availability
COSYVOICE2_AVAILABLE = False
DEFAULT_SAMPLE_RATE = 24000

# Try to import CosyVoice2 dependencies
try:
    import torch
    import torchaudio
    # Import CosyVoice2 from the correct package
    # Based on https://github.com/FunAudioLLM/CosyVoice
    from cosyvoice.cli.cosyvoice import CosyVoice
    COSYVOICE2_AVAILABLE = True
    logger.info("CosyVoice2 TTS engine is available")
except ImportError as e:
    logger.warning(f"CosyVoice2 TTS engine is not available - ImportError: {str(e)}")
    COSYVOICE2_AVAILABLE = False
except ModuleNotFoundError as e:
    logger.warning(f"CosyVoice2 TTS engine is not available - ModuleNotFoundError: {str(e)}")
    COSYVOICE2_AVAILABLE = False


class CosyVoice2TTSProvider(TTSProviderBase):
    """CosyVoice2 TTS provider implementation."""

    def __init__(self, lang_code: str = 'z'):
        """Initialize the CosyVoice2 TTS provider."""
        super().__init__(
            provider_name="CosyVoice2",
            supported_languages=['en', 'z']  # CosyVoice2 supports English and multilingual
        )
        self.lang_code = lang_code
        self.model = None

    def _ensure_model(self):
        """Ensure the model is loaded."""
        if self.model is None and COSYVOICE2_AVAILABLE:
            try:
                logger.info("Loading CosyVoice2 model...")
                import torch
                import torchaudio
                from cosyvoice.cli.cosyvoice import CosyVoice

                # Initialize CosyVoice with the correct model path
                # You may need to adjust the model path based on your installation
                self.model = CosyVoice('pretrained_models/CosyVoice-300M')
                logger.info("CosyVoice2 model successfully loaded")
            except ImportError as e:
                logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}", exception=e)
                self.model = None
            except FileNotFoundError as e:
                logger.error(f"Failed to load CosyVoice2 model files: {str(e)}", exception=e)
                self.model = None
            except Exception as e:
                logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}", exception=e)
                self.model = None

        model_available = self.model is not None
        logger.info(f"CosyVoice2 model availability check: {model_available}")
        return model_available

    def is_available(self) -> bool:
        """Check if CosyVoice2 TTS is available."""
        return COSYVOICE2_AVAILABLE and self._ensure_model()

    def get_available_voices(self) -> list[str]:
        """Get available voices for CosyVoice2."""
        # CosyVoice2 typically uses a default voice
        return ['default']

    def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
        """Generate audio using CosyVoice2 TTS."""
        logger.info("Starting CosyVoice2 audio generation")

        if not self.is_available():
            logger.error("CosyVoice2 TTS engine is not available")
            raise SpeechSynthesisException("CosyVoice2 TTS engine is not available")

        try:
            import torch

            # Extract parameters from request
            text = request.text_content.text
            logger.info(f"CosyVoice2 generating audio for text length: {len(text)}")
            logger.info(f"Voice settings: voice_id={request.voice_settings.voice_id}, speed={request.voice_settings.speed}")

            # Generate audio using CosyVoice2
            logger.info("Starting CosyVoice2 model inference")

            # CosyVoice API - using inference method
            # The model expects text and returns audio tensor
            try:
                # Use the inference method from CosyVoice
                output_audio_tensor = self.model.inference_sft(text, '中文女')

                # Convert tensor to numpy array
                if isinstance(output_audio_tensor, torch.Tensor):
                    output_audio_np = output_audio_tensor.cpu().numpy()
                else:
                    output_audio_np = output_audio_tensor

                logger.info("CosyVoice2 model inference completed")
            except Exception as api_error:
                logger.error(f"CosyVoice2 API error: {str(api_error)}")
                # Try alternative API if the first one fails
                try:
                    logger.info("Trying alternative CosyVoice2 API")
                    output_audio_tensor = self.model.inference_zero_shot(text, '请输入提示文本', '中文女')
                    if isinstance(output_audio_tensor, torch.Tensor):
                        output_audio_np = output_audio_tensor.cpu().numpy()
                    else:
                        output_audio_np = output_audio_tensor
                    logger.info("CosyVoice2 alternative API succeeded")
                except Exception as alt_error:
                    logger.error(f"CosyVoice2 alternative API also failed: {str(alt_error)}")
                    raise SpeechSynthesisException(f"CosyVoice2 inference failed: {str(api_error)}")

            if output_audio_np is None:
                logger.error("CosyVoice2 model returned None for audio output")
                raise SpeechSynthesisException("CosyVoice2 model returned None for audio output")

            logger.info(f"CosyVoice2 generated audio array shape: {output_audio_np.shape if hasattr(output_audio_np, 'shape') else 'unknown'}")

            # Convert numpy array to bytes
            logger.info("Converting CosyVoice2 audio to bytes")
            audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
            logger.info(f"CosyVoice2 audio conversion completed, bytes length: {len(audio_bytes)}")

            return audio_bytes, DEFAULT_SAMPLE_RATE

        except Exception as e:
            logger.error(f"CosyVoice2 audio generation failed: {str(e)}", exception=e)
            self._handle_provider_error(e, "audio generation")

    def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
        """Generate audio stream using CosyVoice2 TTS."""
        if not self.is_available():
            raise SpeechSynthesisException("CosyVoice2 TTS engine is not available")

        try:
            import torch

            # Extract parameters from request
            text = request.text_content.text

            # Generate audio using CosyVoice2
            try:
                # Use the inference method from CosyVoice
                output_audio_tensor = self.model.inference_sft(text, '中文女')

                # Convert tensor to numpy array
                if isinstance(output_audio_tensor, torch.Tensor):
                    output_audio_np = output_audio_tensor.cpu().numpy()
                else:
                    output_audio_np = output_audio_tensor
            except Exception as api_error:
                # Try alternative API if the first one fails
                try:
                    output_audio_tensor = self.model.inference_zero_shot(text, '请输入提示文本', '中文女')
                    if isinstance(output_audio_tensor, torch.Tensor):
                        output_audio_np = output_audio_tensor.cpu().numpy()
                    else:
                        output_audio_np = output_audio_tensor
                except Exception as alt_error:
                    raise SpeechSynthesisException(f"CosyVoice2 inference failed: {str(api_error)}")

            if output_audio_np is None:
                raise SpeechSynthesisException("CosyVoice2 model returned None for audio output")

            # Convert numpy array to bytes
            audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
            # CosyVoice2 generates complete audio in one go
            yield audio_bytes, DEFAULT_SAMPLE_RATE, True

        except Exception as e:
            self._handle_provider_error(e, "streaming audio generation")

    def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
        """Convert numpy audio array to bytes."""
        try:
            # Create an in-memory buffer
            buffer = io.BytesIO()

            # Write audio data to buffer as WAV
            sf.write(buffer, audio_array, sample_rate, format='WAV')

            # Get bytes from buffer
            buffer.seek(0)
            return buffer.read()

        except Exception as e:
            raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e