Spaces:
Running
Running
File size: 5,444 Bytes
a316f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import time
import logging
import requests
import numpy as np
import soundfile as sf
from typing import Optional, Tuple, Generator
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_API_URL = "https://droolingpanda-dia-tts-server.hf.space"
DEFAULT_MODEL = "dia-1.6b"
# Global client instance (lazy loaded)
_client = None
def _get_client():
"""Lazy-load the Dia Space client to avoid loading it until needed"""
global _client
if _client is None:
logger.info("Loading Dia Space client...")
try:
# Import requests if not already imported
import requests
# Initialize the client (just a session for now)
logger.info("Initializing Dia Space client")
_client = requests.Session()
# Test connection to the API
response = _client.get(f"{DEFAULT_API_URL}/docs")
if response.status_code == 200:
logger.info("Dia Space client loaded successfully")
logger.info(f"Client type: {type(_client).__name__}")
else:
logger.warning(f"Dia Space API returned status code {response.status_code}")
except ImportError as import_err:
logger.error(f"Import error loading Dia Space client: {import_err}")
logger.error("This may indicate missing dependencies")
raise
except Exception as e:
logger.error(f"Error loading Dia Space client: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
raise
return _client
def generate_speech(text: str, language: str = "zh", voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> str:
"""Public interface for TTS generation using Dia Space API
This is a legacy function maintained for backward compatibility.
New code should use the factory pattern implementation directly.
Args:
text (str): Input text to synthesize
language (str): Language code (not used in Dia Space, kept for API compatibility)
voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
response_format (str): Audio format ('wav', 'mp3', 'opus')
speed (float): Speech speed multiplier
Returns:
str: Path to the generated audio file
"""
logger.info(f"Legacy Dia Space generate_speech called with text length: {len(text)}")
# Use the new implementation via factory pattern
from utils.tts_engines import DiaSpaceTTSEngine
try:
# Create a Dia Space engine and generate speech
dia_space_engine = DiaSpaceTTSEngine(language)
return dia_space_engine.generate_speech(text, voice, speed, response_format)
except Exception as e:
logger.error(f"Error in legacy Dia Space generate_speech: {str(e)}", exc_info=True)
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine()
return dummy_engine.generate_speech(text)
def _create_output_dir() -> str:
"""Create output directory for audio files
Returns:
str: Path to the output directory
"""
output_dir = "temp/outputs"
os.makedirs(output_dir, exist_ok=True)
return output_dir
def _generate_output_path(prefix: str = "output", extension: str = "wav") -> str:
"""Generate a unique output path for audio files
Args:
prefix (str): Prefix for the output filename
extension (str): File extension for the output file
Returns:
str: Path to the output file
"""
output_dir = _create_output_dir()
timestamp = int(time.time())
return f"{output_dir}/{prefix}_{timestamp}.{extension}"
def _call_dia_api(text: str, voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> bytes:
"""Call the Dia Space API to generate speech
Args:
text (str): Input text to synthesize
voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
response_format (str): Audio format ('wav', 'mp3', 'opus')
speed (float): Speech speed multiplier
Returns:
bytes: Audio data
"""
client = _get_client()
# Prepare the request payload
payload = {
"model": DEFAULT_MODEL,
"input": text,
"voice": voice,
"response_format": response_format,
"speed": speed
}
# Make the API request
logger.info(f"Calling Dia Space API with voice: {voice}, format: {response_format}, speed: {speed}")
try:
response = client.post(
f"{DEFAULT_API_URL}/v1/audio/speech",
json=payload,
headers={"Content-Type": "application/json"}
)
# Check for successful response
if response.status_code == 200:
logger.info("Dia Space API call successful")
return response.content
else:
logger.error(f"Dia Space API returned error: {response.status_code}")
logger.error(f"Response: {response.text}")
raise Exception(f"Dia Space API error: {response.status_code}")
except Exception as e:
logger.error(f"Error calling Dia Space API: {str(e)}", exc_info=True)
raise |