Spaces:
Build error
Build error
| """CosyVoice2 TTS provider implementation.""" | |
| import logging | |
| import numpy as np | |
| import soundfile as sf | |
| import io | |
| from typing import Iterator, TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
| from ..base.tts_provider_base import TTSProviderBase | |
| from ...domain.exceptions import SpeechSynthesisException | |
| logger = logging.getLogger(__name__) | |
| # Flag to track CosyVoice2 availability | |
| COSYVOICE2_AVAILABLE = False | |
| DEFAULT_SAMPLE_RATE = 24000 | |
| # Try to import CosyVoice2 dependencies | |
| try: | |
| import torch | |
| # Import CosyVoice2 - assuming it's installed and has a similar API to Dia | |
| # since they're both from nari-labs according to the GitHub link | |
| from cosyvoice2.model import CosyVoice2 | |
| COSYVOICE2_AVAILABLE = True | |
| logger.info("CosyVoice2 TTS engine is available") | |
| except ImportError: | |
| logger.warning("CosyVoice2 TTS engine is not available") | |
| except ModuleNotFoundError as e: | |
| logger.warning(f"CosyVoice2 TTS engine is not available: {str(e)}") | |
| COSYVOICE2_AVAILABLE = False | |
| class CosyVoice2TTSProvider(TTSProviderBase): | |
| """CosyVoice2 TTS provider implementation.""" | |
| def __init__(self, lang_code: str = 'z'): | |
| """Initialize the CosyVoice2 TTS provider.""" | |
| super().__init__( | |
| provider_name="CosyVoice2", | |
| supported_languages=['en', 'z'] # CosyVoice2 supports English and multilingual | |
| ) | |
| self.lang_code = lang_code | |
| self.model = None | |
| def _ensure_model(self): | |
| """Ensure the model is loaded.""" | |
| if self.model is None and COSYVOICE2_AVAILABLE: | |
| try: | |
| import torch | |
| from cosyvoice2.model import CosyVoice2 | |
| self.model = CosyVoice2.from_pretrained() | |
| logger.info("CosyVoice2 model successfully loaded") | |
| except ImportError as e: | |
| logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}") | |
| self.model = None | |
| except FileNotFoundError as e: | |
| logger.error(f"Failed to load CosyVoice2 model files: {str(e)}") | |
| self.model = None | |
| except Exception as e: | |
| logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}") | |
| self.model = None | |
| return self.model is not None | |
| def is_available(self) -> bool: | |
| """Check if CosyVoice2 TTS is available.""" | |
| return COSYVOICE2_AVAILABLE and self._ensure_model() | |
| def get_available_voices(self) -> list[str]: | |
| """Get available voices for CosyVoice2.""" | |
| # CosyVoice2 typically uses a default voice | |
| return ['default'] | |
| def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: | |
| """Generate audio using CosyVoice2 TTS.""" | |
| if not self.is_available(): | |
| raise SpeechSynthesisException("CosyVoice2 TTS engine is not available") | |
| try: | |
| import torch | |
| # Extract parameters from request | |
| text = request.text_content.text | |
| # Generate audio using CosyVoice2 | |
| with torch.inference_mode(): | |
| # Assuming CosyVoice2 has a similar API to Dia | |
| output_audio_np = self.model.generate( | |
| text, | |
| max_tokens=None, | |
| cfg_scale=3.0, | |
| temperature=1.3, | |
| top_p=0.95, | |
| use_torch_compile=False, | |
| verbose=False | |
| ) | |
| if output_audio_np is None: | |
| raise SpeechSynthesisException("CosyVoice2 model returned None for audio output") | |
| # Convert numpy array to bytes | |
| audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE) | |
| return audio_bytes, DEFAULT_SAMPLE_RATE | |
| except Exception as e: | |
| self._handle_provider_error(e, "audio generation") | |
| def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: | |
| """Generate audio stream using CosyVoice2 TTS.""" | |
| if not self.is_available(): | |
| raise SpeechSynthesisException("CosyVoice2 TTS engine is not available") | |
| try: | |
| import torch | |
| # Extract parameters from request | |
| text = request.text_content.text | |
| # Generate audio using CosyVoice2 | |
| with torch.inference_mode(): | |
| # Assuming CosyVoice2 has a similar API to Dia | |
| output_audio_np = self.model.generate( | |
| text, | |
| max_tokens=None, | |
| cfg_scale=3.0, | |
| temperature=1.3, | |
| top_p=0.95, | |
| use_torch_compile=False, | |
| verbose=False | |
| ) | |
| if output_audio_np is None: | |
| raise SpeechSynthesisException("CosyVoice2 model returned None for audio output") | |
| # Convert numpy array to bytes | |
| audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE) | |
| # CosyVoice2 generates complete audio in one go | |
| yield audio_bytes, DEFAULT_SAMPLE_RATE, True | |
| except Exception as e: | |
| self._handle_provider_error(e, "streaming audio generation") | |
| def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes: | |
| """Convert numpy audio array to bytes.""" | |
| try: | |
| # Create an in-memory buffer | |
| buffer = io.BytesIO() | |
| # Write audio data to buffer as WAV | |
| sf.write(buffer, audio_array, sample_rate, format='WAV') | |
| # Get bytes from buffer | |
| buffer.seek(0) | |
| return buffer.read() | |
| except Exception as e: | |
| raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e |