Spaces:
Build error
Build error
"""CosyVoice2 TTS provider implementation.""" | |
import logging | |
import numpy as np | |
import soundfile as sf | |
import io | |
from typing import Iterator, TYPE_CHECKING | |
if TYPE_CHECKING: | |
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
from ..base.tts_provider_base import TTSProviderBase | |
from ...domain.exceptions import SpeechSynthesisException | |
logger = logging.getLogger(__name__) | |
# Flag to track CosyVoice2 availability | |
COSYVOICE2_AVAILABLE = False | |
DEFAULT_SAMPLE_RATE = 24000 | |
# Try to import CosyVoice2 dependencies | |
try: | |
import torch | |
import torchaudio | |
# Import CosyVoice2 from the correct package | |
# Based on https://github.com/FunAudioLLM/CosyVoice | |
from cosyvoice.cli.cosyvoice import CosyVoice | |
COSYVOICE2_AVAILABLE = True | |
logger.info("CosyVoice2 TTS engine is available") | |
except ImportError as e: | |
logger.warning(f"CosyVoice2 TTS engine is not available - ImportError: {str(e)}") | |
COSYVOICE2_AVAILABLE = False | |
except ModuleNotFoundError as e: | |
logger.warning(f"CosyVoice2 TTS engine is not available - ModuleNotFoundError: {str(e)}") | |
COSYVOICE2_AVAILABLE = False | |
class CosyVoice2TTSProvider(TTSProviderBase): | |
"""CosyVoice2 TTS provider implementation.""" | |
def __init__(self, lang_code: str = 'z'): | |
"""Initialize the CosyVoice2 TTS provider.""" | |
super().__init__( | |
provider_name="CosyVoice2", | |
supported_languages=['en', 'z'] # CosyVoice2 supports English and multilingual | |
) | |
self.lang_code = lang_code | |
self.model = None | |
def _ensure_model(self): | |
"""Ensure the model is loaded.""" | |
if self.model is None and COSYVOICE2_AVAILABLE: | |
try: | |
logger.info("Loading CosyVoice2 model...") | |
import torch | |
import torchaudio | |
from cosyvoice.cli.cosyvoice import CosyVoice | |
# Initialize CosyVoice with the correct model path | |
# You may need to adjust the model path based on your installation | |
self.model = CosyVoice('pretrained_models/CosyVoice-300M') | |
logger.info("CosyVoice2 model successfully loaded") | |
except ImportError as e: | |
logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}", exception=e) | |
self.model = None | |
except FileNotFoundError as e: | |
logger.error(f"Failed to load CosyVoice2 model files: {str(e)}", exception=e) | |
self.model = None | |
except Exception as e: | |
logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}", exception=e) | |
self.model = None | |
model_available = self.model is not None | |
logger.info(f"CosyVoice2 model availability check: {model_available}") | |
return model_available | |
def is_available(self) -> bool: | |
"""Check if CosyVoice2 TTS is available.""" | |
return COSYVOICE2_AVAILABLE and self._ensure_model() | |
def get_available_voices(self) -> list[str]: | |
"""Get available voices for CosyVoice2.""" | |
# CosyVoice2 typically uses a default voice | |
return ['default'] | |
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: | |
"""Generate audio using CosyVoice2 TTS.""" | |
logger.info("Starting CosyVoice2 audio generation") | |
if not self.is_available(): | |
logger.error("CosyVoice2 TTS engine is not available") | |
raise SpeechSynthesisException("CosyVoice2 TTS engine is not available") | |
try: | |
import torch | |
# Extract parameters from request | |
text = request.text_content.text | |
logger.info(f"CosyVoice2 generating audio for text length: {len(text)}") | |
logger.info(f"Voice settings: voice_id={request.voice_settings.voice_id}, speed={request.voice_settings.speed}") | |
# Generate audio using CosyVoice2 | |
logger.info("Starting CosyVoice2 model inference") | |
# CosyVoice API - using inference method | |
# The model expects text and returns audio tensor | |
try: | |
# Use the inference method from CosyVoice | |
output_audio_tensor = self.model.inference_sft(text, '中文女') | |
# Convert tensor to numpy array | |
if isinstance(output_audio_tensor, torch.Tensor): | |
output_audio_np = output_audio_tensor.cpu().numpy() | |
else: | |
output_audio_np = output_audio_tensor | |
logger.info("CosyVoice2 model inference completed") | |
except Exception as api_error: | |
logger.error(f"CosyVoice2 API error: {str(api_error)}") | |
# Try alternative API if the first one fails | |
try: | |
logger.info("Trying alternative CosyVoice2 API") | |
output_audio_tensor = self.model.inference_zero_shot(text, '请输入提示文本', '中文女') | |
if isinstance(output_audio_tensor, torch.Tensor): | |
output_audio_np = output_audio_tensor.cpu().numpy() | |
else: | |
output_audio_np = output_audio_tensor | |
logger.info("CosyVoice2 alternative API succeeded") | |
except Exception as alt_error: | |
logger.error(f"CosyVoice2 alternative API also failed: {str(alt_error)}") | |
raise SpeechSynthesisException(f"CosyVoice2 inference failed: {str(api_error)}") | |
if output_audio_np is None: | |
logger.error("CosyVoice2 model returned None for audio output") | |
raise SpeechSynthesisException("CosyVoice2 model returned None for audio output") | |
logger.info(f"CosyVoice2 generated audio array shape: {output_audio_np.shape if hasattr(output_audio_np, 'shape') else 'unknown'}") | |
# Convert numpy array to bytes | |
logger.info("Converting CosyVoice2 audio to bytes") | |
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE) | |
logger.info(f"CosyVoice2 audio conversion completed, bytes length: {len(audio_bytes)}") | |
return audio_bytes, DEFAULT_SAMPLE_RATE | |
except Exception as e: | |
logger.error(f"CosyVoice2 audio generation failed: {str(e)}", exception=e) | |
self._handle_provider_error(e, "audio generation") | |
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: | |
"""Generate audio stream using CosyVoice2 TTS.""" | |
if not self.is_available(): | |
raise SpeechSynthesisException("CosyVoice2 TTS engine is not available") | |
try: | |
import torch | |
# Extract parameters from request | |
text = request.text_content.text | |
# Generate audio using CosyVoice2 | |
try: | |
# Use the inference method from CosyVoice | |
output_audio_tensor = self.model.inference_sft(text, '中文女') | |
# Convert tensor to numpy array | |
if isinstance(output_audio_tensor, torch.Tensor): | |
output_audio_np = output_audio_tensor.cpu().numpy() | |
else: | |
output_audio_np = output_audio_tensor | |
except Exception as api_error: | |
# Try alternative API if the first one fails | |
try: | |
output_audio_tensor = self.model.inference_zero_shot(text, '请输入提示文本', '中文女') | |
if isinstance(output_audio_tensor, torch.Tensor): | |
output_audio_np = output_audio_tensor.cpu().numpy() | |
else: | |
output_audio_np = output_audio_tensor | |
except Exception as alt_error: | |
raise SpeechSynthesisException(f"CosyVoice2 inference failed: {str(api_error)}") | |
if output_audio_np is None: | |
raise SpeechSynthesisException("CosyVoice2 model returned None for audio output") | |
# Convert numpy array to bytes | |
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE) | |
# CosyVoice2 generates complete audio in one go | |
yield audio_bytes, DEFAULT_SAMPLE_RATE, True | |
except Exception as e: | |
self._handle_provider_error(e, "streaming audio generation") | |
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes: | |
"""Convert numpy audio array to bytes.""" | |
try: | |
# Create an in-memory buffer | |
buffer = io.BytesIO() | |
# Write audio data to buffer as WAV | |
sf.write(buffer, audio_array, sample_rate, format='WAV') | |
# Get bytes from buffer | |
buffer.seek(0) | |
return buffer.read() | |
except Exception as e: | |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e |