Spaces:
Running
Running
import os | |
import time | |
import logging | |
import requests | |
import numpy as np | |
import soundfile as sf | |
from typing import Optional, Tuple, Generator | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
DEFAULT_SAMPLE_RATE = 44100 | |
DEFAULT_API_URL = "https://droolingpanda-dia-tts-server.hf.space" | |
DEFAULT_MODEL = "dia-1.6b" | |
# Global client instance (lazy loaded) | |
_client = None | |
def _get_client(): | |
"""Lazy-load the Dia Space client to avoid loading it until needed""" | |
global _client | |
if _client is None: | |
logger.info("Loading Dia Space client...") | |
try: | |
# Import requests if not already imported | |
import requests | |
# Initialize the client (just a session for now) | |
logger.info("Initializing Dia Space client") | |
_client = requests.Session() | |
# Test connection to the API | |
response = _client.get(f"{DEFAULT_API_URL}/docs") | |
if response.status_code == 200: | |
logger.info("Dia Space client loaded successfully") | |
logger.info(f"Client type: {type(_client).__name__}") | |
else: | |
logger.warning(f"Dia Space API returned status code {response.status_code}") | |
except ImportError as import_err: | |
logger.error(f"Import error loading Dia Space client: {import_err}") | |
logger.error("This may indicate missing dependencies") | |
raise | |
except Exception as e: | |
logger.error(f"Error loading Dia Space client: {e}", exc_info=True) | |
logger.error(f"Error type: {type(e).__name__}") | |
raise | |
return _client | |
def generate_speech(text: str, language: str = "zh", voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> str: | |
"""Public interface for TTS generation using Dia Space API | |
This is a legacy function maintained for backward compatibility. | |
New code should use the factory pattern implementation directly. | |
Args: | |
text (str): Input text to synthesize | |
language (str): Language code (not used in Dia Space, kept for API compatibility) | |
voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone) | |
response_format (str): Audio format ('wav', 'mp3', 'opus') | |
speed (float): Speech speed multiplier | |
Returns: | |
str: Path to the generated audio file | |
""" | |
logger.info(f"Legacy Dia Space generate_speech called with text length: {len(text)}") | |
# Use the new implementation via factory pattern | |
from utils.tts_engines import DiaSpaceTTSEngine | |
try: | |
# Create a Dia Space engine and generate speech | |
dia_space_engine = DiaSpaceTTSEngine(language) | |
return dia_space_engine.generate_speech(text, voice, speed, response_format) | |
except Exception as e: | |
logger.error(f"Error in legacy Dia Space generate_speech: {str(e)}", exc_info=True) | |
# Fall back to dummy TTS | |
from utils.tts_base import DummyTTSEngine | |
dummy_engine = DummyTTSEngine() | |
return dummy_engine.generate_speech(text) | |
def _create_output_dir() -> str: | |
"""Create output directory for audio files | |
Returns: | |
str: Path to the output directory | |
""" | |
output_dir = "temp/outputs" | |
os.makedirs(output_dir, exist_ok=True) | |
return output_dir | |
def _generate_output_path(prefix: str = "output", extension: str = "wav") -> str: | |
"""Generate a unique output path for audio files | |
Args: | |
prefix (str): Prefix for the output filename | |
extension (str): File extension for the output file | |
Returns: | |
str: Path to the output file | |
""" | |
output_dir = _create_output_dir() | |
timestamp = int(time.time()) | |
return f"{output_dir}/{prefix}_{timestamp}.{extension}" | |
def _call_dia_api(text: str, voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> bytes: | |
"""Call the Dia Space API to generate speech | |
Args: | |
text (str): Input text to synthesize | |
voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone) | |
response_format (str): Audio format ('wav', 'mp3', 'opus') | |
speed (float): Speech speed multiplier | |
Returns: | |
bytes: Audio data | |
""" | |
client = _get_client() | |
# Prepare the request payload | |
payload = { | |
"model": DEFAULT_MODEL, | |
"input": text, | |
"voice": voice, | |
"response_format": response_format, | |
"speed": speed | |
} | |
# Make the API request | |
logger.info(f"Calling Dia Space API with voice: {voice}, format: {response_format}, speed: {speed}") | |
try: | |
response = client.post( | |
f"{DEFAULT_API_URL}/v1/audio/speech", | |
json=payload, | |
headers={"Content-Type": "application/json"} | |
) | |
# Check for successful response | |
if response.status_code == 200: | |
logger.info("Dia Space API call successful") | |
return response.content | |
else: | |
logger.error(f"Dia Space API returned error: {response.status_code}") | |
logger.error(f"Response: {response.text}") | |
raise Exception(f"Dia Space API error: {response.status_code}") | |
except Exception as e: | |
logger.error(f"Error calling Dia Space API: {str(e)}", exc_info=True) | |
raise |