# server.py # Main FastAPI server for Dia TTS import sys import logging import time import os import io import uuid import sys import shutil # For file copying import yaml # For loading presets from datetime import datetime from contextlib import asynccontextmanager from typing import Optional, Literal, List, Dict, Any import webbrowser import threading import time from fastapi import ( FastAPI, HTTPException, Request, Response, Form, UploadFile, File, BackgroundTasks, ) from fastapi.responses import ( StreamingResponse, JSONResponse, HTMLResponse, RedirectResponse, ) from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import uvicorn import numpy as np # Internal imports from config import ( config_manager, get_host, get_port, get_output_path, get_reference_audio_path, # register_config_routes is now defined locally get_model_cache_path, get_model_repo_id, get_model_config_filename, get_model_weights_filename, # Generation default getters get_gen_default_speed_factor, get_gen_default_cfg_scale, get_gen_default_temperature, get_gen_default_top_p, get_gen_default_cfg_filter_top_k, DEFAULT_CONFIG, ) from models import OpenAITTSRequest, CustomTTSRequest, ErrorResponse import engine from engine import ( load_model as load_dia_model, generate_speech, EXPECTED_SAMPLE_RATE, ) from utils import encode_audio, save_audio_to_file, PerformanceMonitor # Configure logging (Basic setup, can be enhanced) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) # Reduce verbosity of noisy libraries if needed # logging.getLogger("uvicorn.access").setLevel(logging.WARNING) # logging.getLogger("watchfiles").setLevel(logging.WARNING) logger = logging.getLogger(__name__) # Logger for this module # --- Global Variables & Constants --- PRESETS_FILE = "ui/presets.yaml" loaded_presets: List[Dict[str, Any]] = [] # Cache presets in memory startup_complete_event = threading.Event() # --- Helper Functions --- def load_presets(): """Loads presets from the YAML file.""" global loaded_presets try: if os.path.exists(PRESETS_FILE): with open(PRESETS_FILE, "r", encoding="utf-8") as f: loaded_presets = yaml.safe_load(f) if not isinstance(loaded_presets, list): logger.error( f"Presets file '{PRESETS_FILE}' should contain a list, but found {type(loaded_presets)}. No presets loaded." ) loaded_presets = [] else: logger.info( f"Successfully loaded {len(loaded_presets)} presets from {PRESETS_FILE}." ) else: logger.warning( f"Presets file not found at '{PRESETS_FILE}'. No presets will be available." ) loaded_presets = [] except yaml.YAMLError as e: logger.error( f"Error parsing presets YAML file '{PRESETS_FILE}': {e}", exc_info=True ) loaded_presets = [] except Exception as e: logger.error(f"Error loading presets file '{PRESETS_FILE}': {e}", exc_info=True) loaded_presets = [] def get_valid_reference_files() -> list[str]: """Gets a list of valid audio files (.wav, .mp3) from the reference directory.""" ref_path = get_reference_audio_path() valid_files = [] allowed_extensions = (".wav", ".mp3") try: if os.path.isdir(ref_path): for filename in os.listdir(ref_path): if filename.lower().endswith(allowed_extensions): # Optional: Add check for file size or basic validity if needed valid_files.append(filename) else: logger.warning(f"Reference audio directory not found: {ref_path}") except Exception as e: logger.error( f"Error reading reference audio directory '{ref_path}': {e}", exc_info=True ) return sorted(valid_files) def sanitize_filename(filename: str) -> str: """Removes potentially unsafe characters and path components from a filename.""" # Remove directory separators filename = os.path.basename(filename) # Keep only alphanumeric, underscore, hyphen, dot. Replace others with underscore. safe_chars = set( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-" ) sanitized = "".join(c if c in safe_chars else "_" for c in filename) # Prevent names starting with dot or consisting only of dots/spaces if not sanitized or sanitized.lstrip("._ ") == "": return f"uploaded_file_{uuid.uuid4().hex[:8]}" # Generate a safe fallback name # Limit length max_len = 100 if len(sanitized) > max_len: name, ext = os.path.splitext(sanitized) sanitized = name[: max_len - len(ext)] + ext return sanitized # --- Application Lifespan (Startup/Shutdown) --- @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager for startup/shutdown.""" model_loaded_successfully = False # Flag to track success try: logger.info("Starting Dia TTS server initialization...") # Ensure base directories exist os.makedirs(get_output_path(), exist_ok=True) os.makedirs(get_reference_audio_path(), exist_ok=True) os.makedirs(get_model_cache_path(), exist_ok=True) os.makedirs("ui", exist_ok=True) os.makedirs("static", exist_ok=True) # Load presets from YAML file load_presets() # Load the main TTS model during startup if not load_dia_model(): # Model loading failed error_msg = ( "CRITICAL: Failed to load Dia model on startup. Server cannot start." ) logger.critical(error_msg) # Option 1: Raise an exception to stop Uvicorn startup cleanly raise RuntimeError(error_msg) # Option 2: Force exit (less clean, might bypass some Uvicorn shutdown) # sys.exit(1) else: logger.info("Dia model loaded successfully.") model_loaded_successfully = True # Create and start a delayed browser opening thread # IMPORTANT: Create this thread AFTER model loading completes host = get_host() port = get_port() browser_thread = threading.Thread( target=lambda: _delayed_browser_open(host, port), daemon=True ) browser_thread.start() # --- Signal completion AFTER potentially long operations --- logger.info("Application startup sequence finished. Signaling readiness.") startup_complete_event.set() yield # Application runs here except Exception as e: # Catch the RuntimeError we raised or any other startup error logger.error(f"Fatal error during application startup: {e}", exc_info=True) # Do NOT set the event here if startup failed # Re-raise the exception or exit to ensure the server stops raise e # Re-raising ensures Uvicorn knows startup failed # Alternatively: sys.exit(1) finally: # Cleanup on shutdown logger.info("Application shutdown initiated...") # Add any specific cleanup needed logger.info("Application shutdown complete.") def _delayed_browser_open(host, port): """Opens browser after a short delay to ensure server is ready""" try: # Small delay to ensure Uvicorn is fully ready time.sleep(2) display_host = "localhost" if host == "0.0.0.0" else host browser_url = f"http://{display_host}:{port}/" # Log to file for debugging with open("browser_thread_debug.log", "a") as f: f.write(f"[{time.time()}] Opening browser at {browser_url}\n") # Try to use logger as well (might work at this point) try: logger.info(f"Opening browser at {browser_url}") except: pass # Open browser directly without health checks webbrowser.open(browser_url) except Exception as e: with open("browser_thread_debug.log", "a") as f: f.write(f"[{time.time()}] Browser open error: {str(e)}\n") # --- FastAPI App Initialization --- app = FastAPI( title="Dia TTS Server", description="Text-to-Speech server using the Dia model, providing API and Web UI.", version="1.1.0", # Incremented version lifespan=lifespan, ) # List of folders to check/create folders = ["reference_audio", "model_cache", "outputs"] # Check each folder and create if it doesn't exist for folder in folders: if not os.path.exists(folder): os.makedirs(folder) print(f"Created directory: {folder}") # --- Static Files and Templates --- # Serve generated audio files from the configured output path app.mount("/outputs", StaticFiles(directory=get_output_path()), name="outputs") # Serve UI files (CSS, JS) from the 'ui' directory app.mount("/ui", StaticFiles(directory="ui"), name="ui_static") # Initialize Jinja2 templates to look in the 'ui' directory templates = Jinja2Templates(directory="ui") # --- Configuration Routes Definition --- # Defined locally now instead of importing from config.py def register_config_routes(app: FastAPI): """Adds configuration management endpoints to the FastAPI app.""" logger.info( "Registering configuration routes (/get_config, /save_config, /restart_server, /save_generation_defaults)." ) @app.get( "/get_config", tags=["Configuration"], summary="Get current server configuration", ) async def get_current_config(): """Returns the current server configuration values (from .env or defaults).""" logger.info("Request received for /get_config") return JSONResponse(content=config_manager.get_all()) @app.post( "/save_config", tags=["Configuration"], summary="Save server configuration" ) async def save_new_config(request: Request): """ Saves updated server configuration values (Host, Port, Model paths, etc.) to the .env file. Requires server restart to apply most changes. """ logger.info("Request received for /save_config") try: new_config_data = await request.json() if not isinstance(new_config_data, dict): raise ValueError("Request body must be a JSON object.") logger.debug(f"Received server config data to save: {new_config_data}") # Filter data to only include keys present in DEFAULT_CONFIG filtered_data = { k: v for k, v in new_config_data.items() if k in DEFAULT_CONFIG } unknown_keys = set(new_config_data.keys()) - set(filtered_data.keys()) if unknown_keys: logger.warning( f"Ignoring unknown keys in save_config request: {unknown_keys}" ) config_manager.update(filtered_data) # Update in memory first if config_manager.save(): # Attempt to save to .env logger.info("Server configuration saved successfully to .env.") return JSONResponse( content={ "message": "Server configuration saved. Restart server to apply changes." } ) else: logger.error("Failed to save server configuration to .env file.") raise HTTPException( status_code=500, detail="Failed to save configuration file." ) except ValueError as ve: logger.error(f"Invalid data format for /save_config: {ve}") raise HTTPException( status_code=400, detail=f"Invalid request data: {str(ve)}" ) except Exception as e: logger.error(f"Error processing /save_config request: {e}", exc_info=True) raise HTTPException( status_code=500, detail=f"Internal server error during save: {str(e)}" ) @app.post( "/save_generation_defaults", tags=["Configuration"], summary="Save default generation parameters", ) async def save_generation_defaults(request: Request): """ Saves the provided generation parameters (speed, cfg, temp, etc.) as the new defaults in the .env file. These are loaded by the UI on startup. """ logger.info("Request received for /save_generation_defaults") try: gen_params = await request.json() if not isinstance(gen_params, dict): raise ValueError("Request body must be a JSON object.") logger.debug(f"Received generation defaults to save: {gen_params}") # Map received keys (e.g., 'speed_factor') to .env keys (e.g., 'GEN_DEFAULT_SPEED_FACTOR') defaults_to_save = {} key_map = { "speed_factor": "GEN_DEFAULT_SPEED_FACTOR", "cfg_scale": "GEN_DEFAULT_CFG_SCALE", "temperature": "GEN_DEFAULT_TEMPERATURE", "top_p": "GEN_DEFAULT_TOP_P", "cfg_filter_top_k": "GEN_DEFAULT_CFG_FILTER_TOP_K", } valid_keys_found = False for ui_key, env_key in key_map.items(): if ui_key in gen_params: # Basic validation could be added here (e.g., check if float/int) defaults_to_save[env_key] = str( gen_params[ui_key] ) # Ensure saving as string valid_keys_found = True else: logger.warning( f"Missing expected key '{ui_key}' in save_generation_defaults request." ) if not valid_keys_found: raise ValueError("No valid generation parameters found in the request.") config_manager.update(defaults_to_save) # Update in memory if ( config_manager.save() ): # Save all current config (including these) to .env logger.info("Generation defaults saved successfully to .env.") return JSONResponse(content={"message": "Generation defaults saved."}) else: logger.error("Failed to save generation defaults to .env file.") raise HTTPException( status_code=500, detail="Failed to save configuration file." ) except ValueError as ve: logger.error(f"Invalid data format for /save_generation_defaults: {ve}") raise HTTPException( status_code=400, detail=f"Invalid request data: {str(ve)}" ) except Exception as e: logger.error( f"Error processing /save_generation_defaults request: {e}", exc_info=True, ) raise HTTPException( status_code=500, detail=f"Internal server error during save: {str(e)}" ) @app.post( "/restart_server", tags=["Configuration"], summary="Attempt to restart the server", ) async def trigger_server_restart(background_tasks: BackgroundTasks): """ Attempts to restart the server process. NOTE: This is highly dependent on how the server is run (e.g., with uvicorn --reload, or managed by systemd/supervisor). A simple exit might just stop the process. This implementation attempts a clean exit, relying on the runner to restart it. """ logger.warning("Received request to restart server via API.") def _do_restart(): time.sleep(1) # Short delay to allow response to be sent logger.warning("Attempting clean exit for restart...") # Option 1: Clean exit (relies on Uvicorn reload or process manager) sys.exit(0) # Option 2: Forceful re-execution (use with caution, might not work as expected) # try: # logger.warning("Attempting os.execv for restart...") # os.execv(sys.executable, ['python'] + sys.argv) # except Exception as exec_e: # logger.error(f"os.execv failed: {exec_e}. Server may not restart automatically.") # # Fallback to sys.exit if execv fails # sys.exit(1) background_tasks.add_task(_do_restart) return JSONResponse( content={ "message": "Restart signal sent. Server should restart shortly if run with auto-reload." } ) # --- Register Configuration Routes --- register_config_routes(app) # --- API Endpoints --- @app.post( "/v1/audio/speech", response_class=StreamingResponse, tags=["TTS Generation"], summary="Generate speech (OpenAI compatible)", ) async def openai_tts_endpoint(request: OpenAITTSRequest): """ Generates speech audio from text, compatible with the OpenAI TTS API structure. Maps the 'voice' parameter to Dia's voice modes ('S1', 'S2', 'dialogue', or filename for clone). """ monitor = PerformanceMonitor() monitor.record("Request received") logger.info( f"Received OpenAI request: voice='{request.voice}', speed={request.speed}, format='{request.response_format}'" ) logger.debug(f"Input text (start): '{request.input[:100]}...'") voice_mode = "single_s1" # Default if mapping fails clone_ref_file = None ref_path = get_reference_audio_path() # --- Map OpenAI 'voice' parameter to Dia's modes --- voice_param = request.voice.strip() if voice_param.lower() == "dialogue": voice_mode = "dialogue" elif voice_param.lower() == "s1": voice_mode = "single_s1" elif voice_param.lower() == "s2": voice_mode = "single_s2" # Check if it looks like a filename for cloning (allow .wav or .mp3) elif voice_param.lower().endswith((".wav", ".mp3")): potential_path = os.path.join(ref_path, voice_param) # Check if the file actually exists in the reference directory if os.path.isfile(potential_path): voice_mode = "clone" clone_ref_file = voice_param # Use the provided filename logger.info( f"OpenAI request mapped to clone mode with file: {clone_ref_file}" ) else: logger.warning( f"Reference file '{voice_param}' specified in OpenAI request not found in '{ref_path}'. Defaulting voice mode." ) # Fallback to default 'single_s1' if file not found else: logger.warning( f"Unrecognized OpenAI voice parameter '{voice_param}'. Defaulting voice mode to 'single_s1'." ) # Fallback for any other value monitor.record("Parameters processed") try: # Call the core engine function using mapped parameters result = generate_speech( text=request.input, voice_mode=voice_mode, clone_reference_filename=clone_ref_file, speed_factor=request.speed, # Pass speed factor for post-processing # Use Dia's configured defaults for other generation params unless mapped max_tokens=None, # Let Dia use its default unless specified otherwise cfg_scale=get_gen_default_cfg_scale(), # Use saved defaults temperature=get_gen_default_temperature(), top_p=get_gen_default_top_p(), cfg_filter_top_k=get_gen_default_cfg_filter_top_k(), ) monitor.record("Generation complete") if result is None: logger.error("Speech generation failed (engine returned None).") raise HTTPException(status_code=500, detail="Speech generation failed.") audio_array, sample_rate = result if sample_rate != EXPECTED_SAMPLE_RATE: logger.warning( f"Engine returned sample rate {sample_rate}, but expected {EXPECTED_SAMPLE_RATE}. Encoding might assume {EXPECTED_SAMPLE_RATE}." ) # Use EXPECTED_SAMPLE_RATE for encoding as it's what the model is trained for sample_rate = EXPECTED_SAMPLE_RATE # Encode the audio in memory to the requested format encoded_audio = encode_audio(audio_array, sample_rate, request.response_format) monitor.record("Audio encoding complete") if encoded_audio is None: logger.error(f"Failed to encode audio to format: {request.response_format}") raise HTTPException( status_code=500, detail=f"Failed to encode audio to {request.response_format}", ) # Determine the correct media type for the response header media_type = "audio/opus" if request.response_format == "opus" else "audio/wav" # Note: OpenAI uses audio/opus, not audio/ogg;codecs=opus. Let's match OpenAI. logger.info( f"Successfully generated {len(encoded_audio)} bytes in format {request.response_format}" ) logger.debug(monitor.report()) # Stream the encoded audio back to the client return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type) except HTTPException as http_exc: # Re-raise HTTPExceptions directly (e.g., from parameter validation) logger.error(f"HTTP exception during OpenAI request: {http_exc.detail}") raise http_exc except Exception as e: logger.error(f"Error processing OpenAI TTS request: {e}", exc_info=True) logger.debug(monitor.report()) # Return generic server error for unexpected issues raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post( "/tts", response_class=StreamingResponse, tags=["TTS Generation"], summary="Generate speech (Custom parameters)", ) async def custom_tts_endpoint(request: CustomTTSRequest): """ Generates speech audio from text using explicit Dia parameters. """ monitor = PerformanceMonitor() monitor.record("Request received") logger.info( f"Received custom TTS request: mode='{request.voice_mode}', format='{request.output_format}'" ) logger.debug(f"Input text (start): '{request.text[:100]}...'") logger.debug( f"Params: max_tokens={request.max_tokens}, cfg={request.cfg_scale}, temp={request.temperature}, top_p={request.top_p}, speed={request.speed_factor}, top_k={request.cfg_filter_top_k}" ) clone_ref_file = None if request.voice_mode == "clone": if not request.clone_reference_filename: raise HTTPException( status_code=400, # Bad request detail="Missing 'clone_reference_filename' which is required for clone mode.", ) ref_path = get_reference_audio_path() potential_path = os.path.join(ref_path, request.clone_reference_filename) if not os.path.isfile(potential_path): logger.error( f"Reference audio file not found for clone mode: {potential_path}" ) raise HTTPException( status_code=404, # Not found detail=f"Reference audio file not found: {request.clone_reference_filename}", ) clone_ref_file = request.clone_reference_filename logger.info(f"Custom request using clone mode with file: {clone_ref_file}") monitor.record("Parameters processed") try: # Call the core engine function with parameters from the request result = generate_speech( text=request.text, voice_mode=request.voice_mode, clone_reference_filename=clone_ref_file, max_tokens=request.max_tokens, # Pass user value or None cfg_scale=request.cfg_scale, temperature=request.temperature, top_p=request.top_p, speed_factor=request.speed_factor, # For post-processing cfg_filter_top_k=request.cfg_filter_top_k, ) monitor.record("Generation complete") if result is None: logger.error("Speech generation failed (engine returned None).") raise HTTPException(status_code=500, detail="Speech generation failed.") audio_array, sample_rate = result if sample_rate != EXPECTED_SAMPLE_RATE: logger.warning( f"Engine returned sample rate {sample_rate}, expected {EXPECTED_SAMPLE_RATE}. Encoding will use {EXPECTED_SAMPLE_RATE}." ) sample_rate = EXPECTED_SAMPLE_RATE # Encode the audio in memory encoded_audio = encode_audio(audio_array, sample_rate, request.output_format) monitor.record("Audio encoding complete") if encoded_audio is None: logger.error(f"Failed to encode audio to format: {request.output_format}") raise HTTPException( status_code=500, detail=f"Failed to encode audio to {request.output_format}", ) # Determine media type media_type = "audio/opus" if request.output_format == "opus" else "audio/wav" logger.info( f"Successfully generated {len(encoded_audio)} bytes in format {request.output_format}" ) logger.debug(monitor.report()) # Stream the response return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type) except HTTPException as http_exc: logger.error(f"HTTP exception during custom TTS request: {http_exc.detail}") raise http_exc except Exception as e: logger.error(f"Error processing custom TTS request: {e}", exc_info=True) logger.debug(monitor.report()) raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") # --- Web UI Endpoints --- @app.get("/", response_class=HTMLResponse, include_in_schema=False) async def get_web_ui(request: Request): """Serves the main TTS web interface.""" logger.info("Serving TTS Web UI (index.html)") # Get current list of reference files for the clone dropdown reference_files = get_valid_reference_files() # Get current server config and default generation params current_config = config_manager.get_all() default_gen_params = { "speed_factor": get_gen_default_speed_factor(), "cfg_scale": get_gen_default_cfg_scale(), "temperature": get_gen_default_temperature(), "top_p": get_gen_default_top_p(), "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), } return templates.TemplateResponse( "index.html", # Use the renamed file { "request": request, "reference_files": reference_files, "config": current_config, # Pass current server config "presets": loaded_presets, # Pass loaded presets "default_gen_params": default_gen_params, # Pass default gen params # Add other variables needed by the template for initial state "error": None, "success": None, "output_file_url": None, "generation_time": None, "submitted_text": "", "submitted_voice_mode": "dialogue", # Default to combined mode "submitted_clone_file": None, # Initial generation params will be set by default_gen_params }, ) @app.post("/web/generate", response_class=HTMLResponse, include_in_schema=False) async def handle_web_ui_generate( request: Request, text: str = Form(...), voice_mode: Literal["dialogue", "clone"] = Form(...), # Updated modes clone_reference_select: Optional[str] = Form(None), # Generation parameters from form speed_factor: float = Form(...), # Make required or use Depends with default cfg_scale: float = Form(...), temperature: float = Form(...), top_p: float = Form(...), cfg_filter_top_k: int = Form(...), ): """Handles the generation request from the web UI form.""" logger.info(f"Web UI generation request: mode='{voice_mode}'") monitor = PerformanceMonitor() monitor.record("Web request received") output_file_url = None generation_time = None error_message = None success_message = None output_filename_base = "dia_output" # Default base name # --- Pre-generation Validation --- if not text.strip(): error_message = "Please enter some text to synthesize." clone_ref_file = None if voice_mode == "clone": if not clone_reference_select or clone_reference_select == "none": error_message = "Please select a reference audio file for clone mode." else: # Verify selected file still exists (important if files can be deleted) ref_path = get_reference_audio_path() potential_path = os.path.join(ref_path, clone_reference_select) if not os.path.isfile(potential_path): error_message = f"Selected reference file '{clone_reference_select}' no longer exists. Please refresh or upload." # Invalidate selection clone_ref_file = None clone_reference_select = None # Clear submitted value for re-rendering else: clone_ref_file = clone_reference_select logger.info(f"Using selected reference file: {clone_ref_file}") # If validation failed, re-render the page with error and submitted values if error_message: logger.warning(f"Web UI validation error: {error_message}") reference_files = get_valid_reference_files() current_config = config_manager.get_all() default_gen_params = { # Pass defaults again for consistency "speed_factor": get_gen_default_speed_factor(), "cfg_scale": get_gen_default_cfg_scale(), "temperature": get_gen_default_temperature(), "top_p": get_gen_default_top_p(), "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), } # Pass back the values the user submitted submitted_gen_params = { "speed_factor": speed_factor, "cfg_scale": cfg_scale, "temperature": temperature, "top_p": top_p, "cfg_filter_top_k": cfg_filter_top_k, } return templates.TemplateResponse( "index.html", { "request": request, "error": error_message, "reference_files": reference_files, "config": current_config, "presets": loaded_presets, "default_gen_params": default_gen_params, # Base defaults # Submitted values to repopulate form "submitted_text": text, "submitted_voice_mode": voice_mode, "submitted_clone_file": clone_reference_select, # Use potentially invalidated value "submitted_gen_params": submitted_gen_params, # Pass submitted params back # Ensure other necessary template variables are passed "success": None, "output_file_url": None, "generation_time": None, }, ) # --- Generation --- try: monitor.record("Parameters processed") # Call the core engine function result = generate_speech( text=text, voice_mode=voice_mode, clone_reference_filename=clone_ref_file, speed_factor=speed_factor, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, cfg_filter_top_k=cfg_filter_top_k, max_tokens=None, # Use model default for UI simplicity ) monitor.record("Generation complete") if result: audio_array, sample_rate = result output_path_base = get_output_path() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Create a more descriptive filename mode_tag = voice_mode if voice_mode == "clone" and clone_ref_file: safe_ref_name = sanitize_filename(os.path.splitext(clone_ref_file)[0]) mode_tag = f"clone_{safe_ref_name[:20]}" # Limit length output_filename = ( f"{mode_tag}_{timestamp}.wav" # Always save as WAV for simplicity ) output_filepath = os.path.join(output_path_base, output_filename) # Save the audio to a WAV file saved = save_audio_to_file(audio_array, sample_rate, output_filepath) monitor.record("Audio saved") if saved: output_file_url = ( f"/outputs/{output_filename}" # URL path for browser access ) generation_time = ( monitor.events[-1][1] - monitor.start_time ) # Time until save complete success_message = f"Audio generated successfully!" logger.info(f"Web UI generated audio saved to: {output_filepath}") else: error_message = "Failed to save generated audio file." logger.error("Failed to save audio file from web UI request.") else: error_message = "Speech generation failed (engine returned None)." logger.error("Speech generation failed for web UI request.") except Exception as e: logger.error(f"Error processing web UI TTS request: {e}", exc_info=True) error_message = f"An unexpected error occurred: {str(e)}" logger.debug(monitor.report()) # --- Re-render Template with Results --- reference_files = get_valid_reference_files() current_config = config_manager.get_all() default_gen_params = { "speed_factor": get_gen_default_speed_factor(), "cfg_scale": get_gen_default_cfg_scale(), "temperature": get_gen_default_temperature(), "top_p": get_gen_default_top_p(), "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), } # Pass back submitted values to repopulate form correctly submitted_gen_params = { "speed_factor": speed_factor, "cfg_scale": cfg_scale, "temperature": temperature, "top_p": top_p, "cfg_filter_top_k": cfg_filter_top_k, } return templates.TemplateResponse( "index.html", { "request": request, "error": error_message, "success": success_message, "output_file_url": output_file_url, "generation_time": f"{generation_time:.2f}" if generation_time else None, "reference_files": reference_files, "config": current_config, "presets": loaded_presets, "default_gen_params": default_gen_params, # Base defaults # Pass back submitted values "submitted_text": text, "submitted_voice_mode": voice_mode, "submitted_clone_file": clone_ref_file, # Pass the validated filename back "submitted_gen_params": submitted_gen_params, # Pass submitted params back }, ) # --- Reference Audio Upload Endpoint --- @app.post( "/upload_reference", tags=["Web UI Helpers"], summary="Upload reference audio files" ) async def upload_reference_audio(files: List[UploadFile] = File(...)): """Handles uploading of reference audio files (.wav, .mp3) for voice cloning.""" logger.info(f"Received request to upload {len(files)} reference audio file(s).") ref_path = get_reference_audio_path() uploaded_filenames = [] errors = [] allowed_mime_types = [ "audio/wav", "audio/mpeg", "audio/x-wav", ] # Common WAV/MP3 types allowed_extensions = [".wav", ".mp3"] for file in files: try: # Basic validation if not file.filename: errors.append("Received file with no filename.") continue # Sanitize filename safe_filename = sanitize_filename(file.filename) _, ext = os.path.splitext(safe_filename) if ext.lower() not in allowed_extensions: errors.append( f"File '{file.filename}' has unsupported extension '{ext}'. Allowed: {allowed_extensions}" ) continue # Check MIME type (more reliable than extension) if file.content_type not in allowed_mime_types: errors.append( f"File '{file.filename}' has unsupported content type '{file.content_type}'. Allowed: {allowed_mime_types}" ) continue # Construct full save path destination_path = os.path.join(ref_path, safe_filename) # Prevent overwriting existing files (optional, could add counter) if os.path.exists(destination_path): # Simple approach: skip if exists logger.warning( f"Reference file '{safe_filename}' already exists. Skipping upload." ) # Add to list so UI knows it's available, even if not newly uploaded this time if safe_filename not in uploaded_filenames: uploaded_filenames.append(safe_filename) continue # Alternative: add counter like file_1.wav, file_2.wav # Save the file using shutil.copyfileobj for efficiency with large files try: with open(destination_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) logger.info(f"Successfully saved reference file: {destination_path}") uploaded_filenames.append(safe_filename) except Exception as save_exc: errors.append(f"Failed to save file '{safe_filename}': {save_exc}") logger.error( f"Failed to save uploaded file '{safe_filename}' to '{destination_path}': {save_exc}", exc_info=True, ) finally: # Ensure the UploadFile resource is closed await file.close() except Exception as e: errors.append( f"Error processing file '{getattr(file, 'filename', 'unknown')}': {e}" ) logger.error( f"Unexpected error processing uploaded file: {e}", exc_info=True ) # Ensure file is closed even if other errors occur if file: await file.close() # Get the updated list of all valid files in the directory updated_file_list = get_valid_reference_files() response_data = { "message": f"Processed {len(files)} file(s).", "uploaded_files": uploaded_filenames, # List of successfully saved *new* files this request "all_reference_files": updated_file_list, # Complete current list "errors": errors, } status_code = ( 200 if not errors or len(errors) < len(files) else 400 ) # OK if at least one succeeded, else Bad Request if errors: logger.warning(f"Upload completed with errors: {errors}") return JSONResponse(content=response_data, status_code=status_code) # --- Health Check Endpoint --- @app.get("/health", tags=["Server Status"], summary="Check server health") async def health_check(): """Basic health check, indicates if the server is running and if the model is loaded.""" # Access the MODEL_LOADED variable *directly* from the engine module # each time the endpoint is called to get the current status. current_model_status = getattr(engine, "MODEL_LOADED", False) # Safely get status logger.debug( f"Health check returning model_loaded status: {current_model_status}" ) # Add debug log return {"status": "healthy", "model_loaded": current_model_status} # --- Main Execution --- if __name__ == "__main__": host = get_host() port = get_port() logger.info(f"Starting Dia TTS server on {host}:{port}") logger.info(f"Model Repository: {get_model_repo_id()}") logger.info(f"Model Config File: {get_model_config_filename()}") logger.info(f"Model Weights File: {get_model_weights_filename()}") logger.info(f"Model Cache Path: {get_model_cache_path()}") logger.info(f"Reference Audio Path: {get_reference_audio_path()}") logger.info(f"Output Path: {get_output_path()}") # Determine the host to display in logs and use for browser opening display_host = "localhost" if host == "0.0.0.0" else host logger.info(f"Web UI will be available at http://{display_host}:{port}/") logger.info(f"API Docs available at http://{display_host}:{port}/docs") # Ensure UI directory and index.html exist for UI ui_dir = "ui" index_file = os.path.join(ui_dir, "index.html") if not os.path.isdir(ui_dir) or not os.path.isfile(index_file): logger.warning( f"'{ui_dir}' directory or '{index_file}' not found. Web UI may not work." ) # Optionally create dummy files/dirs if needed for startup os.makedirs(ui_dir, exist_ok=True) if not os.path.isfile(index_file): try: with open(index_file, "w") as f: f.write( "Web UI template missing. See project source for index.html." ) logger.info(f"Created dummy {index_file}.") except Exception as e: logger.error(f"Failed to create dummy {index_file}: {e}") # --- Create synchronization event --- # This event will be set by the lifespan manager once startup (incl. model loading) is complete. startup_complete_event = threading.Event() # Run Uvicorn server # The lifespan context manager ('lifespan="on"') will run during startup. # The 'lifespan' function is responsible for loading models and setting the 'startup_complete_event'. uvicorn.run( "server:app", # Use the format 'module:app_instance' host=host, port=port, reload=False, # Set reload as needed for development/production # reload_dirs=[".", "ui"], # Only use reload=True with reload_dirs/includes for development # reload_includes=[ # "*.py", # "*.html", # "*.css", # "*.js", # ".env", # "*.yaml", # ], lifespan="on", # Use the lifespan context manager defined in this file # workers=1 # Keep workers=1 when using reload=True or complex global state/models )