Spaces:
Running
Running
# server.py | |
# Main FastAPI server for Dia TTS | |
import sys | |
import logging | |
import time | |
import os | |
import io | |
import uuid | |
import sys | |
import shutil # For file copying | |
import yaml # For loading presets | |
from datetime import datetime | |
from contextlib import asynccontextmanager | |
from typing import Optional, Literal, List, Dict, Any | |
import webbrowser | |
import threading | |
import time | |
from fastapi import ( | |
FastAPI, | |
HTTPException, | |
Request, | |
Response, | |
Form, | |
UploadFile, | |
File, | |
BackgroundTasks, | |
) | |
from fastapi.responses import ( | |
StreamingResponse, | |
JSONResponse, | |
HTMLResponse, | |
RedirectResponse, | |
) | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
import uvicorn | |
import numpy as np | |
# Internal imports | |
from config import ( | |
config_manager, | |
get_host, | |
get_port, | |
get_output_path, | |
get_reference_audio_path, | |
# register_config_routes is now defined locally | |
get_model_cache_path, | |
get_model_repo_id, | |
get_model_config_filename, | |
get_model_weights_filename, | |
# Generation default getters | |
get_gen_default_speed_factor, | |
get_gen_default_cfg_scale, | |
get_gen_default_temperature, | |
get_gen_default_top_p, | |
get_gen_default_cfg_filter_top_k, | |
DEFAULT_CONFIG, | |
) | |
from models import OpenAITTSRequest, CustomTTSRequest, ErrorResponse | |
import engine | |
from engine import ( | |
load_model as load_dia_model, | |
generate_speech, | |
EXPECTED_SAMPLE_RATE, | |
) | |
from utils import encode_audio, save_audio_to_file, PerformanceMonitor | |
# Configure logging (Basic setup, can be enhanced) | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" | |
) | |
# Reduce verbosity of noisy libraries if needed | |
# logging.getLogger("uvicorn.access").setLevel(logging.WARNING) | |
# logging.getLogger("watchfiles").setLevel(logging.WARNING) | |
logger = logging.getLogger(__name__) # Logger for this module | |
# --- Global Variables & Constants --- | |
PRESETS_FILE = "ui/presets.yaml" | |
loaded_presets: List[Dict[str, Any]] = [] # Cache presets in memory | |
startup_complete_event = threading.Event() | |
# --- Helper Functions --- | |
def load_presets(): | |
"""Loads presets from the YAML file.""" | |
global loaded_presets | |
try: | |
if os.path.exists(PRESETS_FILE): | |
with open(PRESETS_FILE, "r", encoding="utf-8") as f: | |
loaded_presets = yaml.safe_load(f) | |
if not isinstance(loaded_presets, list): | |
logger.error( | |
f"Presets file '{PRESETS_FILE}' should contain a list, but found {type(loaded_presets)}. No presets loaded." | |
) | |
loaded_presets = [] | |
else: | |
logger.info( | |
f"Successfully loaded {len(loaded_presets)} presets from {PRESETS_FILE}." | |
) | |
else: | |
logger.warning( | |
f"Presets file not found at '{PRESETS_FILE}'. No presets will be available." | |
) | |
loaded_presets = [] | |
except yaml.YAMLError as e: | |
logger.error( | |
f"Error parsing presets YAML file '{PRESETS_FILE}': {e}", exc_info=True | |
) | |
loaded_presets = [] | |
except Exception as e: | |
logger.error(f"Error loading presets file '{PRESETS_FILE}': {e}", exc_info=True) | |
loaded_presets = [] | |
def get_valid_reference_files() -> list[str]: | |
"""Gets a list of valid audio files (.wav, .mp3) from the reference directory.""" | |
ref_path = get_reference_audio_path() | |
valid_files = [] | |
allowed_extensions = (".wav", ".mp3") | |
try: | |
if os.path.isdir(ref_path): | |
for filename in os.listdir(ref_path): | |
if filename.lower().endswith(allowed_extensions): | |
# Optional: Add check for file size or basic validity if needed | |
valid_files.append(filename) | |
else: | |
logger.warning(f"Reference audio directory not found: {ref_path}") | |
except Exception as e: | |
logger.error( | |
f"Error reading reference audio directory '{ref_path}': {e}", exc_info=True | |
) | |
return sorted(valid_files) | |
def sanitize_filename(filename: str) -> str: | |
"""Removes potentially unsafe characters and path components from a filename.""" | |
# Remove directory separators | |
filename = os.path.basename(filename) | |
# Keep only alphanumeric, underscore, hyphen, dot. Replace others with underscore. | |
safe_chars = set( | |
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-" | |
) | |
sanitized = "".join(c if c in safe_chars else "_" for c in filename) | |
# Prevent names starting with dot or consisting only of dots/spaces | |
if not sanitized or sanitized.lstrip("._ ") == "": | |
return f"uploaded_file_{uuid.uuid4().hex[:8]}" # Generate a safe fallback name | |
# Limit length | |
max_len = 100 | |
if len(sanitized) > max_len: | |
name, ext = os.path.splitext(sanitized) | |
sanitized = name[: max_len - len(ext)] + ext | |
return sanitized | |
# --- Application Lifespan (Startup/Shutdown) --- | |
async def lifespan(app: FastAPI): | |
"""Application lifespan manager for startup/shutdown.""" | |
model_loaded_successfully = False # Flag to track success | |
try: | |
logger.info("Starting Dia TTS server initialization...") | |
# Ensure base directories exist | |
os.makedirs(get_output_path(), exist_ok=True) | |
os.makedirs(get_reference_audio_path(), exist_ok=True) | |
os.makedirs(get_model_cache_path(), exist_ok=True) | |
os.makedirs("ui", exist_ok=True) | |
os.makedirs("static", exist_ok=True) | |
# Load presets from YAML file | |
load_presets() | |
# Load the main TTS model during startup | |
if not load_dia_model(): | |
# Model loading failed | |
error_msg = ( | |
"CRITICAL: Failed to load Dia model on startup. Server cannot start." | |
) | |
logger.critical(error_msg) | |
# Option 1: Raise an exception to stop Uvicorn startup cleanly | |
raise RuntimeError(error_msg) | |
# Option 2: Force exit (less clean, might bypass some Uvicorn shutdown) | |
# sys.exit(1) | |
else: | |
logger.info("Dia model loaded successfully.") | |
model_loaded_successfully = True | |
# Create and start a delayed browser opening thread | |
# IMPORTANT: Create this thread AFTER model loading completes | |
host = get_host() | |
port = get_port() | |
browser_thread = threading.Thread( | |
target=lambda: _delayed_browser_open(host, port), daemon=True | |
) | |
browser_thread.start() | |
# --- Signal completion AFTER potentially long operations --- | |
logger.info("Application startup sequence finished. Signaling readiness.") | |
startup_complete_event.set() | |
yield # Application runs here | |
except Exception as e: | |
# Catch the RuntimeError we raised or any other startup error | |
logger.error(f"Fatal error during application startup: {e}", exc_info=True) | |
# Do NOT set the event here if startup failed | |
# Re-raise the exception or exit to ensure the server stops | |
raise e # Re-raising ensures Uvicorn knows startup failed | |
# Alternatively: sys.exit(1) | |
finally: | |
# Cleanup on shutdown | |
logger.info("Application shutdown initiated...") | |
# Add any specific cleanup needed | |
logger.info("Application shutdown complete.") | |
def _delayed_browser_open(host, port): | |
"""Opens browser after a short delay to ensure server is ready""" | |
try: | |
# Small delay to ensure Uvicorn is fully ready | |
time.sleep(2) | |
display_host = "localhost" if host == "0.0.0.0" else host | |
browser_url = f"http://{display_host}:{port}/" | |
# Log to file for debugging | |
with open("browser_thread_debug.log", "a") as f: | |
f.write(f"[{time.time()}] Opening browser at {browser_url}\n") | |
# Try to use logger as well (might work at this point) | |
try: | |
logger.info(f"Opening browser at {browser_url}") | |
except: | |
pass | |
# Open browser directly without health checks | |
webbrowser.open(browser_url) | |
except Exception as e: | |
with open("browser_thread_debug.log", "a") as f: | |
f.write(f"[{time.time()}] Browser open error: {str(e)}\n") | |
# --- FastAPI App Initialization --- | |
app = FastAPI( | |
title="Dia TTS Server", | |
description="Text-to-Speech server using the Dia model, providing API and Web UI.", | |
version="1.1.0", # Incremented version | |
lifespan=lifespan, | |
) | |
# List of folders to check/create | |
folders = ["reference_audio", "model_cache", "outputs"] | |
# Check each folder and create if it doesn't exist | |
for folder in folders: | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
print(f"Created directory: {folder}") | |
# --- Static Files and Templates --- | |
# Serve generated audio files from the configured output path | |
app.mount("/outputs", StaticFiles(directory=get_output_path()), name="outputs") | |
# Serve UI files (CSS, JS) from the 'ui' directory | |
app.mount("/ui", StaticFiles(directory="ui"), name="ui_static") | |
# Initialize Jinja2 templates to look in the 'ui' directory | |
templates = Jinja2Templates(directory="ui") | |
# --- Configuration Routes Definition --- | |
# Defined locally now instead of importing from config.py | |
def register_config_routes(app: FastAPI): | |
"""Adds configuration management endpoints to the FastAPI app.""" | |
logger.info( | |
"Registering configuration routes (/get_config, /save_config, /restart_server, /save_generation_defaults)." | |
) | |
async def get_current_config(): | |
"""Returns the current server configuration values (from .env or defaults).""" | |
logger.info("Request received for /get_config") | |
return JSONResponse(content=config_manager.get_all()) | |
async def save_new_config(request: Request): | |
""" | |
Saves updated server configuration values (Host, Port, Model paths, etc.) | |
to the .env file. Requires server restart to apply most changes. | |
""" | |
logger.info("Request received for /save_config") | |
try: | |
new_config_data = await request.json() | |
if not isinstance(new_config_data, dict): | |
raise ValueError("Request body must be a JSON object.") | |
logger.debug(f"Received server config data to save: {new_config_data}") | |
# Filter data to only include keys present in DEFAULT_CONFIG | |
filtered_data = { | |
k: v for k, v in new_config_data.items() if k in DEFAULT_CONFIG | |
} | |
unknown_keys = set(new_config_data.keys()) - set(filtered_data.keys()) | |
if unknown_keys: | |
logger.warning( | |
f"Ignoring unknown keys in save_config request: {unknown_keys}" | |
) | |
config_manager.update(filtered_data) # Update in memory first | |
if config_manager.save(): # Attempt to save to .env | |
logger.info("Server configuration saved successfully to .env.") | |
return JSONResponse( | |
content={ | |
"message": "Server configuration saved. Restart server to apply changes." | |
} | |
) | |
else: | |
logger.error("Failed to save server configuration to .env file.") | |
raise HTTPException( | |
status_code=500, detail="Failed to save configuration file." | |
) | |
except ValueError as ve: | |
logger.error(f"Invalid data format for /save_config: {ve}") | |
raise HTTPException( | |
status_code=400, detail=f"Invalid request data: {str(ve)}" | |
) | |
except Exception as e: | |
logger.error(f"Error processing /save_config request: {e}", exc_info=True) | |
raise HTTPException( | |
status_code=500, detail=f"Internal server error during save: {str(e)}" | |
) | |
async def save_generation_defaults(request: Request): | |
""" | |
Saves the provided generation parameters (speed, cfg, temp, etc.) | |
as the new defaults in the .env file. These are loaded by the UI on startup. | |
""" | |
logger.info("Request received for /save_generation_defaults") | |
try: | |
gen_params = await request.json() | |
if not isinstance(gen_params, dict): | |
raise ValueError("Request body must be a JSON object.") | |
logger.debug(f"Received generation defaults to save: {gen_params}") | |
# Map received keys (e.g., 'speed_factor') to .env keys (e.g., 'GEN_DEFAULT_SPEED_FACTOR') | |
defaults_to_save = {} | |
key_map = { | |
"speed_factor": "GEN_DEFAULT_SPEED_FACTOR", | |
"cfg_scale": "GEN_DEFAULT_CFG_SCALE", | |
"temperature": "GEN_DEFAULT_TEMPERATURE", | |
"top_p": "GEN_DEFAULT_TOP_P", | |
"cfg_filter_top_k": "GEN_DEFAULT_CFG_FILTER_TOP_K", | |
} | |
valid_keys_found = False | |
for ui_key, env_key in key_map.items(): | |
if ui_key in gen_params: | |
# Basic validation could be added here (e.g., check if float/int) | |
defaults_to_save[env_key] = str( | |
gen_params[ui_key] | |
) # Ensure saving as string | |
valid_keys_found = True | |
else: | |
logger.warning( | |
f"Missing expected key '{ui_key}' in save_generation_defaults request." | |
) | |
if not valid_keys_found: | |
raise ValueError("No valid generation parameters found in the request.") | |
config_manager.update(defaults_to_save) # Update in memory | |
if ( | |
config_manager.save() | |
): # Save all current config (including these) to .env | |
logger.info("Generation defaults saved successfully to .env.") | |
return JSONResponse(content={"message": "Generation defaults saved."}) | |
else: | |
logger.error("Failed to save generation defaults to .env file.") | |
raise HTTPException( | |
status_code=500, detail="Failed to save configuration file." | |
) | |
except ValueError as ve: | |
logger.error(f"Invalid data format for /save_generation_defaults: {ve}") | |
raise HTTPException( | |
status_code=400, detail=f"Invalid request data: {str(ve)}" | |
) | |
except Exception as e: | |
logger.error( | |
f"Error processing /save_generation_defaults request: {e}", | |
exc_info=True, | |
) | |
raise HTTPException( | |
status_code=500, detail=f"Internal server error during save: {str(e)}" | |
) | |
async def trigger_server_restart(background_tasks: BackgroundTasks): | |
""" | |
Attempts to restart the server process. | |
NOTE: This is highly dependent on how the server is run (e.g., with uvicorn --reload, | |
or managed by systemd/supervisor). A simple exit might just stop the process. | |
This implementation attempts a clean exit, relying on the runner to restart it. | |
""" | |
logger.warning("Received request to restart server via API.") | |
def _do_restart(): | |
time.sleep(1) # Short delay to allow response to be sent | |
logger.warning("Attempting clean exit for restart...") | |
# Option 1: Clean exit (relies on Uvicorn reload or process manager) | |
sys.exit(0) | |
# Option 2: Forceful re-execution (use with caution, might not work as expected) | |
# try: | |
# logger.warning("Attempting os.execv for restart...") | |
# os.execv(sys.executable, ['python'] + sys.argv) | |
# except Exception as exec_e: | |
# logger.error(f"os.execv failed: {exec_e}. Server may not restart automatically.") | |
# # Fallback to sys.exit if execv fails | |
# sys.exit(1) | |
background_tasks.add_task(_do_restart) | |
return JSONResponse( | |
content={ | |
"message": "Restart signal sent. Server should restart shortly if run with auto-reload." | |
} | |
) | |
# --- Register Configuration Routes --- | |
register_config_routes(app) | |
# --- API Endpoints --- | |
async def openai_tts_endpoint(request: OpenAITTSRequest): | |
""" | |
Generates speech audio from text, compatible with the OpenAI TTS API structure. | |
Maps the 'voice' parameter to Dia's voice modes ('S1', 'S2', 'dialogue', or filename for clone). | |
""" | |
monitor = PerformanceMonitor() | |
monitor.record("Request received") | |
logger.info( | |
f"Received OpenAI request: voice='{request.voice}', speed={request.speed}, format='{request.response_format}'" | |
) | |
logger.debug(f"Input text (start): '{request.input[:100]}...'") | |
voice_mode = "single_s1" # Default if mapping fails | |
clone_ref_file = None | |
ref_path = get_reference_audio_path() | |
# --- Map OpenAI 'voice' parameter to Dia's modes --- | |
voice_param = request.voice.strip() | |
if voice_param.lower() == "dialogue": | |
voice_mode = "dialogue" | |
elif voice_param.lower() == "s1": | |
voice_mode = "single_s1" | |
elif voice_param.lower() == "s2": | |
voice_mode = "single_s2" | |
# Check if it looks like a filename for cloning (allow .wav or .mp3) | |
elif voice_param.lower().endswith((".wav", ".mp3")): | |
potential_path = os.path.join(ref_path, voice_param) | |
# Check if the file actually exists in the reference directory | |
if os.path.isfile(potential_path): | |
voice_mode = "clone" | |
clone_ref_file = voice_param # Use the provided filename | |
logger.info( | |
f"OpenAI request mapped to clone mode with file: {clone_ref_file}" | |
) | |
else: | |
logger.warning( | |
f"Reference file '{voice_param}' specified in OpenAI request not found in '{ref_path}'. Defaulting voice mode." | |
) | |
# Fallback to default 'single_s1' if file not found | |
else: | |
logger.warning( | |
f"Unrecognized OpenAI voice parameter '{voice_param}'. Defaulting voice mode to 'single_s1'." | |
) | |
# Fallback for any other value | |
monitor.record("Parameters processed") | |
try: | |
# Call the core engine function using mapped parameters | |
result = generate_speech( | |
text=request.input, | |
voice_mode=voice_mode, | |
clone_reference_filename=clone_ref_file, | |
speed_factor=request.speed, # Pass speed factor for post-processing | |
# Use Dia's configured defaults for other generation params unless mapped | |
max_tokens=None, # Let Dia use its default unless specified otherwise | |
cfg_scale=get_gen_default_cfg_scale(), # Use saved defaults | |
temperature=get_gen_default_temperature(), | |
top_p=get_gen_default_top_p(), | |
cfg_filter_top_k=get_gen_default_cfg_filter_top_k(), | |
) | |
monitor.record("Generation complete") | |
if result is None: | |
logger.error("Speech generation failed (engine returned None).") | |
raise HTTPException(status_code=500, detail="Speech generation failed.") | |
audio_array, sample_rate = result | |
if sample_rate != EXPECTED_SAMPLE_RATE: | |
logger.warning( | |
f"Engine returned sample rate {sample_rate}, but expected {EXPECTED_SAMPLE_RATE}. Encoding might assume {EXPECTED_SAMPLE_RATE}." | |
) | |
# Use EXPECTED_SAMPLE_RATE for encoding as it's what the model is trained for | |
sample_rate = EXPECTED_SAMPLE_RATE | |
# Encode the audio in memory to the requested format | |
encoded_audio = encode_audio(audio_array, sample_rate, request.response_format) | |
monitor.record("Audio encoding complete") | |
if encoded_audio is None: | |
logger.error(f"Failed to encode audio to format: {request.response_format}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Failed to encode audio to {request.response_format}", | |
) | |
# Determine the correct media type for the response header | |
media_type = "audio/opus" if request.response_format == "opus" else "audio/wav" | |
# Note: OpenAI uses audio/opus, not audio/ogg;codecs=opus. Let's match OpenAI. | |
logger.info( | |
f"Successfully generated {len(encoded_audio)} bytes in format {request.response_format}" | |
) | |
logger.debug(monitor.report()) | |
# Stream the encoded audio back to the client | |
return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type) | |
except HTTPException as http_exc: | |
# Re-raise HTTPExceptions directly (e.g., from parameter validation) | |
logger.error(f"HTTP exception during OpenAI request: {http_exc.detail}") | |
raise http_exc | |
except Exception as e: | |
logger.error(f"Error processing OpenAI TTS request: {e}", exc_info=True) | |
logger.debug(monitor.report()) | |
# Return generic server error for unexpected issues | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
async def custom_tts_endpoint(request: CustomTTSRequest): | |
""" | |
Generates speech audio from text using explicit Dia parameters. | |
""" | |
monitor = PerformanceMonitor() | |
monitor.record("Request received") | |
logger.info( | |
f"Received custom TTS request: mode='{request.voice_mode}', format='{request.output_format}'" | |
) | |
logger.debug(f"Input text (start): '{request.text[:100]}...'") | |
logger.debug( | |
f"Params: max_tokens={request.max_tokens}, cfg={request.cfg_scale}, temp={request.temperature}, top_p={request.top_p}, speed={request.speed_factor}, top_k={request.cfg_filter_top_k}" | |
) | |
clone_ref_file = None | |
if request.voice_mode == "clone": | |
if not request.clone_reference_filename: | |
raise HTTPException( | |
status_code=400, # Bad request | |
detail="Missing 'clone_reference_filename' which is required for clone mode.", | |
) | |
ref_path = get_reference_audio_path() | |
potential_path = os.path.join(ref_path, request.clone_reference_filename) | |
if not os.path.isfile(potential_path): | |
logger.error( | |
f"Reference audio file not found for clone mode: {potential_path}" | |
) | |
raise HTTPException( | |
status_code=404, # Not found | |
detail=f"Reference audio file not found: {request.clone_reference_filename}", | |
) | |
clone_ref_file = request.clone_reference_filename | |
logger.info(f"Custom request using clone mode with file: {clone_ref_file}") | |
monitor.record("Parameters processed") | |
try: | |
# Call the core engine function with parameters from the request | |
result = generate_speech( | |
text=request.text, | |
voice_mode=request.voice_mode, | |
clone_reference_filename=clone_ref_file, | |
max_tokens=request.max_tokens, # Pass user value or None | |
cfg_scale=request.cfg_scale, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
speed_factor=request.speed_factor, # For post-processing | |
cfg_filter_top_k=request.cfg_filter_top_k, | |
) | |
monitor.record("Generation complete") | |
if result is None: | |
logger.error("Speech generation failed (engine returned None).") | |
raise HTTPException(status_code=500, detail="Speech generation failed.") | |
audio_array, sample_rate = result | |
if sample_rate != EXPECTED_SAMPLE_RATE: | |
logger.warning( | |
f"Engine returned sample rate {sample_rate}, expected {EXPECTED_SAMPLE_RATE}. Encoding will use {EXPECTED_SAMPLE_RATE}." | |
) | |
sample_rate = EXPECTED_SAMPLE_RATE | |
# Encode the audio in memory | |
encoded_audio = encode_audio(audio_array, sample_rate, request.output_format) | |
monitor.record("Audio encoding complete") | |
if encoded_audio is None: | |
logger.error(f"Failed to encode audio to format: {request.output_format}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Failed to encode audio to {request.output_format}", | |
) | |
# Determine media type | |
media_type = "audio/opus" if request.output_format == "opus" else "audio/wav" | |
logger.info( | |
f"Successfully generated {len(encoded_audio)} bytes in format {request.output_format}" | |
) | |
logger.debug(monitor.report()) | |
# Stream the response | |
return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type) | |
except HTTPException as http_exc: | |
logger.error(f"HTTP exception during custom TTS request: {http_exc.detail}") | |
raise http_exc | |
except Exception as e: | |
logger.error(f"Error processing custom TTS request: {e}", exc_info=True) | |
logger.debug(monitor.report()) | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
# --- Web UI Endpoints --- | |
async def get_web_ui(request: Request): | |
"""Serves the main TTS web interface.""" | |
logger.info("Serving TTS Web UI (index.html)") | |
# Get current list of reference files for the clone dropdown | |
reference_files = get_valid_reference_files() | |
# Get current server config and default generation params | |
current_config = config_manager.get_all() | |
default_gen_params = { | |
"speed_factor": get_gen_default_speed_factor(), | |
"cfg_scale": get_gen_default_cfg_scale(), | |
"temperature": get_gen_default_temperature(), | |
"top_p": get_gen_default_top_p(), | |
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), | |
} | |
return templates.TemplateResponse( | |
"index.html", # Use the renamed file | |
{ | |
"request": request, | |
"reference_files": reference_files, | |
"config": current_config, # Pass current server config | |
"presets": loaded_presets, # Pass loaded presets | |
"default_gen_params": default_gen_params, # Pass default gen params | |
# Add other variables needed by the template for initial state | |
"error": None, | |
"success": None, | |
"output_file_url": None, | |
"generation_time": None, | |
"submitted_text": "", | |
"submitted_voice_mode": "dialogue", # Default to combined mode | |
"submitted_clone_file": None, | |
# Initial generation params will be set by default_gen_params | |
}, | |
) | |
async def handle_web_ui_generate( | |
request: Request, | |
text: str = Form(...), | |
voice_mode: Literal["dialogue", "clone"] = Form(...), # Updated modes | |
clone_reference_select: Optional[str] = Form(None), | |
# Generation parameters from form | |
speed_factor: float = Form(...), # Make required or use Depends with default | |
cfg_scale: float = Form(...), | |
temperature: float = Form(...), | |
top_p: float = Form(...), | |
cfg_filter_top_k: int = Form(...), | |
): | |
"""Handles the generation request from the web UI form.""" | |
logger.info(f"Web UI generation request: mode='{voice_mode}'") | |
monitor = PerformanceMonitor() | |
monitor.record("Web request received") | |
output_file_url = None | |
generation_time = None | |
error_message = None | |
success_message = None | |
output_filename_base = "dia_output" # Default base name | |
# --- Pre-generation Validation --- | |
if not text.strip(): | |
error_message = "Please enter some text to synthesize." | |
clone_ref_file = None | |
if voice_mode == "clone": | |
if not clone_reference_select or clone_reference_select == "none": | |
error_message = "Please select a reference audio file for clone mode." | |
else: | |
# Verify selected file still exists (important if files can be deleted) | |
ref_path = get_reference_audio_path() | |
potential_path = os.path.join(ref_path, clone_reference_select) | |
if not os.path.isfile(potential_path): | |
error_message = f"Selected reference file '{clone_reference_select}' no longer exists. Please refresh or upload." | |
# Invalidate selection | |
clone_ref_file = None | |
clone_reference_select = None # Clear submitted value for re-rendering | |
else: | |
clone_ref_file = clone_reference_select | |
logger.info(f"Using selected reference file: {clone_ref_file}") | |
# If validation failed, re-render the page with error and submitted values | |
if error_message: | |
logger.warning(f"Web UI validation error: {error_message}") | |
reference_files = get_valid_reference_files() | |
current_config = config_manager.get_all() | |
default_gen_params = { # Pass defaults again for consistency | |
"speed_factor": get_gen_default_speed_factor(), | |
"cfg_scale": get_gen_default_cfg_scale(), | |
"temperature": get_gen_default_temperature(), | |
"top_p": get_gen_default_top_p(), | |
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), | |
} | |
# Pass back the values the user submitted | |
submitted_gen_params = { | |
"speed_factor": speed_factor, | |
"cfg_scale": cfg_scale, | |
"temperature": temperature, | |
"top_p": top_p, | |
"cfg_filter_top_k": cfg_filter_top_k, | |
} | |
return templates.TemplateResponse( | |
"index.html", | |
{ | |
"request": request, | |
"error": error_message, | |
"reference_files": reference_files, | |
"config": current_config, | |
"presets": loaded_presets, | |
"default_gen_params": default_gen_params, # Base defaults | |
# Submitted values to repopulate form | |
"submitted_text": text, | |
"submitted_voice_mode": voice_mode, | |
"submitted_clone_file": clone_reference_select, # Use potentially invalidated value | |
"submitted_gen_params": submitted_gen_params, # Pass submitted params back | |
# Ensure other necessary template variables are passed | |
"success": None, | |
"output_file_url": None, | |
"generation_time": None, | |
}, | |
) | |
# --- Generation --- | |
try: | |
monitor.record("Parameters processed") | |
# Call the core engine function | |
result = generate_speech( | |
text=text, | |
voice_mode=voice_mode, | |
clone_reference_filename=clone_ref_file, | |
speed_factor=speed_factor, | |
cfg_scale=cfg_scale, | |
temperature=temperature, | |
top_p=top_p, | |
cfg_filter_top_k=cfg_filter_top_k, | |
max_tokens=None, # Use model default for UI simplicity | |
) | |
monitor.record("Generation complete") | |
if result: | |
audio_array, sample_rate = result | |
output_path_base = get_output_path() | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
# Create a more descriptive filename | |
mode_tag = voice_mode | |
if voice_mode == "clone" and clone_ref_file: | |
safe_ref_name = sanitize_filename(os.path.splitext(clone_ref_file)[0]) | |
mode_tag = f"clone_{safe_ref_name[:20]}" # Limit length | |
output_filename = ( | |
f"{mode_tag}_{timestamp}.wav" # Always save as WAV for simplicity | |
) | |
output_filepath = os.path.join(output_path_base, output_filename) | |
# Save the audio to a WAV file | |
saved = save_audio_to_file(audio_array, sample_rate, output_filepath) | |
monitor.record("Audio saved") | |
if saved: | |
output_file_url = ( | |
f"/outputs/{output_filename}" # URL path for browser access | |
) | |
generation_time = ( | |
monitor.events[-1][1] - monitor.start_time | |
) # Time until save complete | |
success_message = f"Audio generated successfully!" | |
logger.info(f"Web UI generated audio saved to: {output_filepath}") | |
else: | |
error_message = "Failed to save generated audio file." | |
logger.error("Failed to save audio file from web UI request.") | |
else: | |
error_message = "Speech generation failed (engine returned None)." | |
logger.error("Speech generation failed for web UI request.") | |
except Exception as e: | |
logger.error(f"Error processing web UI TTS request: {e}", exc_info=True) | |
error_message = f"An unexpected error occurred: {str(e)}" | |
logger.debug(monitor.report()) | |
# --- Re-render Template with Results --- | |
reference_files = get_valid_reference_files() | |
current_config = config_manager.get_all() | |
default_gen_params = { | |
"speed_factor": get_gen_default_speed_factor(), | |
"cfg_scale": get_gen_default_cfg_scale(), | |
"temperature": get_gen_default_temperature(), | |
"top_p": get_gen_default_top_p(), | |
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), | |
} | |
# Pass back submitted values to repopulate form correctly | |
submitted_gen_params = { | |
"speed_factor": speed_factor, | |
"cfg_scale": cfg_scale, | |
"temperature": temperature, | |
"top_p": top_p, | |
"cfg_filter_top_k": cfg_filter_top_k, | |
} | |
return templates.TemplateResponse( | |
"index.html", | |
{ | |
"request": request, | |
"error": error_message, | |
"success": success_message, | |
"output_file_url": output_file_url, | |
"generation_time": f"{generation_time:.2f}" if generation_time else None, | |
"reference_files": reference_files, | |
"config": current_config, | |
"presets": loaded_presets, | |
"default_gen_params": default_gen_params, # Base defaults | |
# Pass back submitted values | |
"submitted_text": text, | |
"submitted_voice_mode": voice_mode, | |
"submitted_clone_file": clone_ref_file, # Pass the validated filename back | |
"submitted_gen_params": submitted_gen_params, # Pass submitted params back | |
}, | |
) | |
# --- Reference Audio Upload Endpoint --- | |
async def upload_reference_audio(files: List[UploadFile] = File(...)): | |
"""Handles uploading of reference audio files (.wav, .mp3) for voice cloning.""" | |
logger.info(f"Received request to upload {len(files)} reference audio file(s).") | |
ref_path = get_reference_audio_path() | |
uploaded_filenames = [] | |
errors = [] | |
allowed_mime_types = [ | |
"audio/wav", | |
"audio/mpeg", | |
"audio/x-wav", | |
] # Common WAV/MP3 types | |
allowed_extensions = [".wav", ".mp3"] | |
for file in files: | |
try: | |
# Basic validation | |
if not file.filename: | |
errors.append("Received file with no filename.") | |
continue | |
# Sanitize filename | |
safe_filename = sanitize_filename(file.filename) | |
_, ext = os.path.splitext(safe_filename) | |
if ext.lower() not in allowed_extensions: | |
errors.append( | |
f"File '{file.filename}' has unsupported extension '{ext}'. Allowed: {allowed_extensions}" | |
) | |
continue | |
# Check MIME type (more reliable than extension) | |
if file.content_type not in allowed_mime_types: | |
errors.append( | |
f"File '{file.filename}' has unsupported content type '{file.content_type}'. Allowed: {allowed_mime_types}" | |
) | |
continue | |
# Construct full save path | |
destination_path = os.path.join(ref_path, safe_filename) | |
# Prevent overwriting existing files (optional, could add counter) | |
if os.path.exists(destination_path): | |
# Simple approach: skip if exists | |
logger.warning( | |
f"Reference file '{safe_filename}' already exists. Skipping upload." | |
) | |
# Add to list so UI knows it's available, even if not newly uploaded this time | |
if safe_filename not in uploaded_filenames: | |
uploaded_filenames.append(safe_filename) | |
continue | |
# Alternative: add counter like file_1.wav, file_2.wav | |
# Save the file using shutil.copyfileobj for efficiency with large files | |
try: | |
with open(destination_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
logger.info(f"Successfully saved reference file: {destination_path}") | |
uploaded_filenames.append(safe_filename) | |
except Exception as save_exc: | |
errors.append(f"Failed to save file '{safe_filename}': {save_exc}") | |
logger.error( | |
f"Failed to save uploaded file '{safe_filename}' to '{destination_path}': {save_exc}", | |
exc_info=True, | |
) | |
finally: | |
# Ensure the UploadFile resource is closed | |
await file.close() | |
except Exception as e: | |
errors.append( | |
f"Error processing file '{getattr(file, 'filename', 'unknown')}': {e}" | |
) | |
logger.error( | |
f"Unexpected error processing uploaded file: {e}", exc_info=True | |
) | |
# Ensure file is closed even if other errors occur | |
if file: | |
await file.close() | |
# Get the updated list of all valid files in the directory | |
updated_file_list = get_valid_reference_files() | |
response_data = { | |
"message": f"Processed {len(files)} file(s).", | |
"uploaded_files": uploaded_filenames, # List of successfully saved *new* files this request | |
"all_reference_files": updated_file_list, # Complete current list | |
"errors": errors, | |
} | |
status_code = ( | |
200 if not errors or len(errors) < len(files) else 400 | |
) # OK if at least one succeeded, else Bad Request | |
if errors: | |
logger.warning(f"Upload completed with errors: {errors}") | |
return JSONResponse(content=response_data, status_code=status_code) | |
# --- Health Check Endpoint --- | |
async def health_check(): | |
"""Basic health check, indicates if the server is running and if the model is loaded.""" | |
# Access the MODEL_LOADED variable *directly* from the engine module | |
# each time the endpoint is called to get the current status. | |
current_model_status = getattr(engine, "MODEL_LOADED", False) # Safely get status | |
logger.debug( | |
f"Health check returning model_loaded status: {current_model_status}" | |
) # Add debug log | |
return {"status": "healthy", "model_loaded": current_model_status} | |
# --- Main Execution --- | |
if __name__ == "__main__": | |
host = get_host() | |
port = get_port() | |
logger.info(f"Starting Dia TTS server on {host}:{port}") | |
logger.info(f"Model Repository: {get_model_repo_id()}") | |
logger.info(f"Model Config File: {get_model_config_filename()}") | |
logger.info(f"Model Weights File: {get_model_weights_filename()}") | |
logger.info(f"Model Cache Path: {get_model_cache_path()}") | |
logger.info(f"Reference Audio Path: {get_reference_audio_path()}") | |
logger.info(f"Output Path: {get_output_path()}") | |
# Determine the host to display in logs and use for browser opening | |
display_host = "localhost" if host == "0.0.0.0" else host | |
logger.info(f"Web UI will be available at http://{display_host}:{port}/") | |
logger.info(f"API Docs available at http://{display_host}:{port}/docs") | |
# Ensure UI directory and index.html exist for UI | |
ui_dir = "ui" | |
index_file = os.path.join(ui_dir, "index.html") | |
if not os.path.isdir(ui_dir) or not os.path.isfile(index_file): | |
logger.warning( | |
f"'{ui_dir}' directory or '{index_file}' not found. Web UI may not work." | |
) | |
# Optionally create dummy files/dirs if needed for startup | |
os.makedirs(ui_dir, exist_ok=True) | |
if not os.path.isfile(index_file): | |
try: | |
with open(index_file, "w") as f: | |
f.write( | |
"<html><body>Web UI template missing. See project source for index.html.</body></html>" | |
) | |
logger.info(f"Created dummy {index_file}.") | |
except Exception as e: | |
logger.error(f"Failed to create dummy {index_file}: {e}") | |
# --- Create synchronization event --- | |
# This event will be set by the lifespan manager once startup (incl. model loading) is complete. | |
startup_complete_event = threading.Event() | |
# Run Uvicorn server | |
# The lifespan context manager ('lifespan="on"') will run during startup. | |
# The 'lifespan' function is responsible for loading models and setting the 'startup_complete_event'. | |
uvicorn.run( | |
"server:app", # Use the format 'module:app_instance' | |
host=host, | |
port=port, | |
reload=False, # Set reload as needed for development/production | |
# reload_dirs=[".", "ui"], # Only use reload=True with reload_dirs/includes for development | |
# reload_includes=[ | |
# "*.py", | |
# "*.html", | |
# "*.css", | |
# "*.js", | |
# ".env", | |
# "*.yaml", | |
# ], | |
lifespan="on", # Use the lifespan context manager defined in this file | |
# workers=1 # Keep workers=1 when using reload=True or complex global state/models | |
) | |