teachingAssistant / utils /tts_dia_space.py
Michael Hu
handle dia model not available
a316f58
raw
history blame
5.44 kB
import os
import time
import logging
import requests
import numpy as np
import soundfile as sf
from typing import Optional, Tuple, Generator
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_API_URL = "https://droolingpanda-dia-tts-server.hf.space"
DEFAULT_MODEL = "dia-1.6b"
# Global client instance (lazy loaded)
_client = None
def _get_client():
"""Lazy-load the Dia Space client to avoid loading it until needed"""
global _client
if _client is None:
logger.info("Loading Dia Space client...")
try:
# Import requests if not already imported
import requests
# Initialize the client (just a session for now)
logger.info("Initializing Dia Space client")
_client = requests.Session()
# Test connection to the API
response = _client.get(f"{DEFAULT_API_URL}/docs")
if response.status_code == 200:
logger.info("Dia Space client loaded successfully")
logger.info(f"Client type: {type(_client).__name__}")
else:
logger.warning(f"Dia Space API returned status code {response.status_code}")
except ImportError as import_err:
logger.error(f"Import error loading Dia Space client: {import_err}")
logger.error("This may indicate missing dependencies")
raise
except Exception as e:
logger.error(f"Error loading Dia Space client: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
raise
return _client
def generate_speech(text: str, language: str = "zh", voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> str:
"""Public interface for TTS generation using Dia Space API
This is a legacy function maintained for backward compatibility.
New code should use the factory pattern implementation directly.
Args:
text (str): Input text to synthesize
language (str): Language code (not used in Dia Space, kept for API compatibility)
voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
response_format (str): Audio format ('wav', 'mp3', 'opus')
speed (float): Speech speed multiplier
Returns:
str: Path to the generated audio file
"""
logger.info(f"Legacy Dia Space generate_speech called with text length: {len(text)}")
# Use the new implementation via factory pattern
from utils.tts_engines import DiaSpaceTTSEngine
try:
# Create a Dia Space engine and generate speech
dia_space_engine = DiaSpaceTTSEngine(language)
return dia_space_engine.generate_speech(text, voice, speed, response_format)
except Exception as e:
logger.error(f"Error in legacy Dia Space generate_speech: {str(e)}", exc_info=True)
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine()
return dummy_engine.generate_speech(text)
def _create_output_dir() -> str:
"""Create output directory for audio files
Returns:
str: Path to the output directory
"""
output_dir = "temp/outputs"
os.makedirs(output_dir, exist_ok=True)
return output_dir
def _generate_output_path(prefix: str = "output", extension: str = "wav") -> str:
"""Generate a unique output path for audio files
Args:
prefix (str): Prefix for the output filename
extension (str): File extension for the output file
Returns:
str: Path to the output file
"""
output_dir = _create_output_dir()
timestamp = int(time.time())
return f"{output_dir}/{prefix}_{timestamp}.{extension}"
def _call_dia_api(text: str, voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> bytes:
"""Call the Dia Space API to generate speech
Args:
text (str): Input text to synthesize
voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
response_format (str): Audio format ('wav', 'mp3', 'opus')
speed (float): Speech speed multiplier
Returns:
bytes: Audio data
"""
client = _get_client()
# Prepare the request payload
payload = {
"model": DEFAULT_MODEL,
"input": text,
"voice": voice,
"response_format": response_format,
"speed": speed
}
# Make the API request
logger.info(f"Calling Dia Space API with voice: {voice}, format: {response_format}, speed: {speed}")
try:
response = client.post(
f"{DEFAULT_API_URL}/v1/audio/speech",
json=payload,
headers={"Content-Type": "application/json"}
)
# Check for successful response
if response.status_code == 200:
logger.info("Dia Space API call successful")
return response.content
else:
logger.error(f"Dia Space API returned error: {response.status_code}")
logger.error(f"Response: {response.text}")
raise Exception(f"Dia Space API error: {response.status_code}")
except Exception as e:
logger.error(f"Error calling Dia Space API: {str(e)}", exc_info=True)
raise