Spaces:
Running
Running
File size: 44,276 Bytes
ac5de5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 |
# server.py
# Main FastAPI server for Dia TTS
import sys
import logging
import time
import os
import io
import uuid
import sys
import shutil # For file copying
import yaml # For loading presets
from datetime import datetime
from contextlib import asynccontextmanager
from typing import Optional, Literal, List, Dict, Any
import webbrowser
import threading
import time
from fastapi import (
FastAPI,
HTTPException,
Request,
Response,
Form,
UploadFile,
File,
BackgroundTasks,
)
from fastapi.responses import (
StreamingResponse,
JSONResponse,
HTMLResponse,
RedirectResponse,
)
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import uvicorn
import numpy as np
# Internal imports
from config import (
config_manager,
get_host,
get_port,
get_output_path,
get_reference_audio_path,
# register_config_routes is now defined locally
get_model_cache_path,
get_model_repo_id,
get_model_config_filename,
get_model_weights_filename,
# Generation default getters
get_gen_default_speed_factor,
get_gen_default_cfg_scale,
get_gen_default_temperature,
get_gen_default_top_p,
get_gen_default_cfg_filter_top_k,
DEFAULT_CONFIG,
)
from models import OpenAITTSRequest, CustomTTSRequest, ErrorResponse
import engine
from engine import (
load_model as load_dia_model,
generate_speech,
EXPECTED_SAMPLE_RATE,
)
from utils import encode_audio, save_audio_to_file, PerformanceMonitor
# Configure logging (Basic setup, can be enhanced)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
# Reduce verbosity of noisy libraries if needed
# logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
# logging.getLogger("watchfiles").setLevel(logging.WARNING)
logger = logging.getLogger(__name__) # Logger for this module
# --- Global Variables & Constants ---
PRESETS_FILE = "ui/presets.yaml"
loaded_presets: List[Dict[str, Any]] = [] # Cache presets in memory
startup_complete_event = threading.Event()
# --- Helper Functions ---
def load_presets():
"""Loads presets from the YAML file."""
global loaded_presets
try:
if os.path.exists(PRESETS_FILE):
with open(PRESETS_FILE, "r", encoding="utf-8") as f:
loaded_presets = yaml.safe_load(f)
if not isinstance(loaded_presets, list):
logger.error(
f"Presets file '{PRESETS_FILE}' should contain a list, but found {type(loaded_presets)}. No presets loaded."
)
loaded_presets = []
else:
logger.info(
f"Successfully loaded {len(loaded_presets)} presets from {PRESETS_FILE}."
)
else:
logger.warning(
f"Presets file not found at '{PRESETS_FILE}'. No presets will be available."
)
loaded_presets = []
except yaml.YAMLError as e:
logger.error(
f"Error parsing presets YAML file '{PRESETS_FILE}': {e}", exc_info=True
)
loaded_presets = []
except Exception as e:
logger.error(f"Error loading presets file '{PRESETS_FILE}': {e}", exc_info=True)
loaded_presets = []
def get_valid_reference_files() -> list[str]:
"""Gets a list of valid audio files (.wav, .mp3) from the reference directory."""
ref_path = get_reference_audio_path()
valid_files = []
allowed_extensions = (".wav", ".mp3")
try:
if os.path.isdir(ref_path):
for filename in os.listdir(ref_path):
if filename.lower().endswith(allowed_extensions):
# Optional: Add check for file size or basic validity if needed
valid_files.append(filename)
else:
logger.warning(f"Reference audio directory not found: {ref_path}")
except Exception as e:
logger.error(
f"Error reading reference audio directory '{ref_path}': {e}", exc_info=True
)
return sorted(valid_files)
def sanitize_filename(filename: str) -> str:
"""Removes potentially unsafe characters and path components from a filename."""
# Remove directory separators
filename = os.path.basename(filename)
# Keep only alphanumeric, underscore, hyphen, dot. Replace others with underscore.
safe_chars = set(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-"
)
sanitized = "".join(c if c in safe_chars else "_" for c in filename)
# Prevent names starting with dot or consisting only of dots/spaces
if not sanitized or sanitized.lstrip("._ ") == "":
return f"uploaded_file_{uuid.uuid4().hex[:8]}" # Generate a safe fallback name
# Limit length
max_len = 100
if len(sanitized) > max_len:
name, ext = os.path.splitext(sanitized)
sanitized = name[: max_len - len(ext)] + ext
return sanitized
# --- Application Lifespan (Startup/Shutdown) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for startup/shutdown."""
model_loaded_successfully = False # Flag to track success
try:
logger.info("Starting Dia TTS server initialization...")
# Ensure base directories exist
os.makedirs(get_output_path(), exist_ok=True)
os.makedirs(get_reference_audio_path(), exist_ok=True)
os.makedirs(get_model_cache_path(), exist_ok=True)
os.makedirs("ui", exist_ok=True)
os.makedirs("static", exist_ok=True)
# Load presets from YAML file
load_presets()
# Load the main TTS model during startup
if not load_dia_model():
# Model loading failed
error_msg = (
"CRITICAL: Failed to load Dia model on startup. Server cannot start."
)
logger.critical(error_msg)
# Option 1: Raise an exception to stop Uvicorn startup cleanly
raise RuntimeError(error_msg)
# Option 2: Force exit (less clean, might bypass some Uvicorn shutdown)
# sys.exit(1)
else:
logger.info("Dia model loaded successfully.")
model_loaded_successfully = True
# Create and start a delayed browser opening thread
# IMPORTANT: Create this thread AFTER model loading completes
host = get_host()
port = get_port()
browser_thread = threading.Thread(
target=lambda: _delayed_browser_open(host, port), daemon=True
)
browser_thread.start()
# --- Signal completion AFTER potentially long operations ---
logger.info("Application startup sequence finished. Signaling readiness.")
startup_complete_event.set()
yield # Application runs here
except Exception as e:
# Catch the RuntimeError we raised or any other startup error
logger.error(f"Fatal error during application startup: {e}", exc_info=True)
# Do NOT set the event here if startup failed
# Re-raise the exception or exit to ensure the server stops
raise e # Re-raising ensures Uvicorn knows startup failed
# Alternatively: sys.exit(1)
finally:
# Cleanup on shutdown
logger.info("Application shutdown initiated...")
# Add any specific cleanup needed
logger.info("Application shutdown complete.")
def _delayed_browser_open(host, port):
"""Opens browser after a short delay to ensure server is ready"""
try:
# Small delay to ensure Uvicorn is fully ready
time.sleep(2)
display_host = "localhost" if host == "0.0.0.0" else host
browser_url = f"http://{display_host}:{port}/"
# Log to file for debugging
with open("browser_thread_debug.log", "a") as f:
f.write(f"[{time.time()}] Opening browser at {browser_url}\n")
# Try to use logger as well (might work at this point)
try:
logger.info(f"Opening browser at {browser_url}")
except:
pass
# Open browser directly without health checks
webbrowser.open(browser_url)
except Exception as e:
with open("browser_thread_debug.log", "a") as f:
f.write(f"[{time.time()}] Browser open error: {str(e)}\n")
# --- FastAPI App Initialization ---
app = FastAPI(
title="Dia TTS Server",
description="Text-to-Speech server using the Dia model, providing API and Web UI.",
version="1.1.0", # Incremented version
lifespan=lifespan,
)
# List of folders to check/create
folders = ["reference_audio", "model_cache", "outputs"]
# Check each folder and create if it doesn't exist
for folder in folders:
if not os.path.exists(folder):
os.makedirs(folder)
print(f"Created directory: {folder}")
# --- Static Files and Templates ---
# Serve generated audio files from the configured output path
app.mount("/outputs", StaticFiles(directory=get_output_path()), name="outputs")
# Serve UI files (CSS, JS) from the 'ui' directory
app.mount("/ui", StaticFiles(directory="ui"), name="ui_static")
# Initialize Jinja2 templates to look in the 'ui' directory
templates = Jinja2Templates(directory="ui")
# --- Configuration Routes Definition ---
# Defined locally now instead of importing from config.py
def register_config_routes(app: FastAPI):
"""Adds configuration management endpoints to the FastAPI app."""
logger.info(
"Registering configuration routes (/get_config, /save_config, /restart_server, /save_generation_defaults)."
)
@app.get(
"/get_config",
tags=["Configuration"],
summary="Get current server configuration",
)
async def get_current_config():
"""Returns the current server configuration values (from .env or defaults)."""
logger.info("Request received for /get_config")
return JSONResponse(content=config_manager.get_all())
@app.post(
"/save_config", tags=["Configuration"], summary="Save server configuration"
)
async def save_new_config(request: Request):
"""
Saves updated server configuration values (Host, Port, Model paths, etc.)
to the .env file. Requires server restart to apply most changes.
"""
logger.info("Request received for /save_config")
try:
new_config_data = await request.json()
if not isinstance(new_config_data, dict):
raise ValueError("Request body must be a JSON object.")
logger.debug(f"Received server config data to save: {new_config_data}")
# Filter data to only include keys present in DEFAULT_CONFIG
filtered_data = {
k: v for k, v in new_config_data.items() if k in DEFAULT_CONFIG
}
unknown_keys = set(new_config_data.keys()) - set(filtered_data.keys())
if unknown_keys:
logger.warning(
f"Ignoring unknown keys in save_config request: {unknown_keys}"
)
config_manager.update(filtered_data) # Update in memory first
if config_manager.save(): # Attempt to save to .env
logger.info("Server configuration saved successfully to .env.")
return JSONResponse(
content={
"message": "Server configuration saved. Restart server to apply changes."
}
)
else:
logger.error("Failed to save server configuration to .env file.")
raise HTTPException(
status_code=500, detail="Failed to save configuration file."
)
except ValueError as ve:
logger.error(f"Invalid data format for /save_config: {ve}")
raise HTTPException(
status_code=400, detail=f"Invalid request data: {str(ve)}"
)
except Exception as e:
logger.error(f"Error processing /save_config request: {e}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Internal server error during save: {str(e)}"
)
@app.post(
"/save_generation_defaults",
tags=["Configuration"],
summary="Save default generation parameters",
)
async def save_generation_defaults(request: Request):
"""
Saves the provided generation parameters (speed, cfg, temp, etc.)
as the new defaults in the .env file. These are loaded by the UI on startup.
"""
logger.info("Request received for /save_generation_defaults")
try:
gen_params = await request.json()
if not isinstance(gen_params, dict):
raise ValueError("Request body must be a JSON object.")
logger.debug(f"Received generation defaults to save: {gen_params}")
# Map received keys (e.g., 'speed_factor') to .env keys (e.g., 'GEN_DEFAULT_SPEED_FACTOR')
defaults_to_save = {}
key_map = {
"speed_factor": "GEN_DEFAULT_SPEED_FACTOR",
"cfg_scale": "GEN_DEFAULT_CFG_SCALE",
"temperature": "GEN_DEFAULT_TEMPERATURE",
"top_p": "GEN_DEFAULT_TOP_P",
"cfg_filter_top_k": "GEN_DEFAULT_CFG_FILTER_TOP_K",
}
valid_keys_found = False
for ui_key, env_key in key_map.items():
if ui_key in gen_params:
# Basic validation could be added here (e.g., check if float/int)
defaults_to_save[env_key] = str(
gen_params[ui_key]
) # Ensure saving as string
valid_keys_found = True
else:
logger.warning(
f"Missing expected key '{ui_key}' in save_generation_defaults request."
)
if not valid_keys_found:
raise ValueError("No valid generation parameters found in the request.")
config_manager.update(defaults_to_save) # Update in memory
if (
config_manager.save()
): # Save all current config (including these) to .env
logger.info("Generation defaults saved successfully to .env.")
return JSONResponse(content={"message": "Generation defaults saved."})
else:
logger.error("Failed to save generation defaults to .env file.")
raise HTTPException(
status_code=500, detail="Failed to save configuration file."
)
except ValueError as ve:
logger.error(f"Invalid data format for /save_generation_defaults: {ve}")
raise HTTPException(
status_code=400, detail=f"Invalid request data: {str(ve)}"
)
except Exception as e:
logger.error(
f"Error processing /save_generation_defaults request: {e}",
exc_info=True,
)
raise HTTPException(
status_code=500, detail=f"Internal server error during save: {str(e)}"
)
@app.post(
"/restart_server",
tags=["Configuration"],
summary="Attempt to restart the server",
)
async def trigger_server_restart(background_tasks: BackgroundTasks):
"""
Attempts to restart the server process.
NOTE: This is highly dependent on how the server is run (e.g., with uvicorn --reload,
or managed by systemd/supervisor). A simple exit might just stop the process.
This implementation attempts a clean exit, relying on the runner to restart it.
"""
logger.warning("Received request to restart server via API.")
def _do_restart():
time.sleep(1) # Short delay to allow response to be sent
logger.warning("Attempting clean exit for restart...")
# Option 1: Clean exit (relies on Uvicorn reload or process manager)
sys.exit(0)
# Option 2: Forceful re-execution (use with caution, might not work as expected)
# try:
# logger.warning("Attempting os.execv for restart...")
# os.execv(sys.executable, ['python'] + sys.argv)
# except Exception as exec_e:
# logger.error(f"os.execv failed: {exec_e}. Server may not restart automatically.")
# # Fallback to sys.exit if execv fails
# sys.exit(1)
background_tasks.add_task(_do_restart)
return JSONResponse(
content={
"message": "Restart signal sent. Server should restart shortly if run with auto-reload."
}
)
# --- Register Configuration Routes ---
register_config_routes(app)
# --- API Endpoints ---
@app.post(
"/v1/audio/speech",
response_class=StreamingResponse,
tags=["TTS Generation"],
summary="Generate speech (OpenAI compatible)",
)
async def openai_tts_endpoint(request: OpenAITTSRequest):
"""
Generates speech audio from text, compatible with the OpenAI TTS API structure.
Maps the 'voice' parameter to Dia's voice modes ('S1', 'S2', 'dialogue', or filename for clone).
"""
monitor = PerformanceMonitor()
monitor.record("Request received")
logger.info(
f"Received OpenAI request: voice='{request.voice}', speed={request.speed}, format='{request.response_format}'"
)
logger.debug(f"Input text (start): '{request.input[:100]}...'")
voice_mode = "single_s1" # Default if mapping fails
clone_ref_file = None
ref_path = get_reference_audio_path()
# --- Map OpenAI 'voice' parameter to Dia's modes ---
voice_param = request.voice.strip()
if voice_param.lower() == "dialogue":
voice_mode = "dialogue"
elif voice_param.lower() == "s1":
voice_mode = "single_s1"
elif voice_param.lower() == "s2":
voice_mode = "single_s2"
# Check if it looks like a filename for cloning (allow .wav or .mp3)
elif voice_param.lower().endswith((".wav", ".mp3")):
potential_path = os.path.join(ref_path, voice_param)
# Check if the file actually exists in the reference directory
if os.path.isfile(potential_path):
voice_mode = "clone"
clone_ref_file = voice_param # Use the provided filename
logger.info(
f"OpenAI request mapped to clone mode with file: {clone_ref_file}"
)
else:
logger.warning(
f"Reference file '{voice_param}' specified in OpenAI request not found in '{ref_path}'. Defaulting voice mode."
)
# Fallback to default 'single_s1' if file not found
else:
logger.warning(
f"Unrecognized OpenAI voice parameter '{voice_param}'. Defaulting voice mode to 'single_s1'."
)
# Fallback for any other value
monitor.record("Parameters processed")
try:
# Call the core engine function using mapped parameters
result = generate_speech(
text=request.input,
voice_mode=voice_mode,
clone_reference_filename=clone_ref_file,
speed_factor=request.speed, # Pass speed factor for post-processing
# Use Dia's configured defaults for other generation params unless mapped
max_tokens=None, # Let Dia use its default unless specified otherwise
cfg_scale=get_gen_default_cfg_scale(), # Use saved defaults
temperature=get_gen_default_temperature(),
top_p=get_gen_default_top_p(),
cfg_filter_top_k=get_gen_default_cfg_filter_top_k(),
)
monitor.record("Generation complete")
if result is None:
logger.error("Speech generation failed (engine returned None).")
raise HTTPException(status_code=500, detail="Speech generation failed.")
audio_array, sample_rate = result
if sample_rate != EXPECTED_SAMPLE_RATE:
logger.warning(
f"Engine returned sample rate {sample_rate}, but expected {EXPECTED_SAMPLE_RATE}. Encoding might assume {EXPECTED_SAMPLE_RATE}."
)
# Use EXPECTED_SAMPLE_RATE for encoding as it's what the model is trained for
sample_rate = EXPECTED_SAMPLE_RATE
# Encode the audio in memory to the requested format
encoded_audio = encode_audio(audio_array, sample_rate, request.response_format)
monitor.record("Audio encoding complete")
if encoded_audio is None:
logger.error(f"Failed to encode audio to format: {request.response_format}")
raise HTTPException(
status_code=500,
detail=f"Failed to encode audio to {request.response_format}",
)
# Determine the correct media type for the response header
media_type = "audio/opus" if request.response_format == "opus" else "audio/wav"
# Note: OpenAI uses audio/opus, not audio/ogg;codecs=opus. Let's match OpenAI.
logger.info(
f"Successfully generated {len(encoded_audio)} bytes in format {request.response_format}"
)
logger.debug(monitor.report())
# Stream the encoded audio back to the client
return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)
except HTTPException as http_exc:
# Re-raise HTTPExceptions directly (e.g., from parameter validation)
logger.error(f"HTTP exception during OpenAI request: {http_exc.detail}")
raise http_exc
except Exception as e:
logger.error(f"Error processing OpenAI TTS request: {e}", exc_info=True)
logger.debug(monitor.report())
# Return generic server error for unexpected issues
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post(
"/tts",
response_class=StreamingResponse,
tags=["TTS Generation"],
summary="Generate speech (Custom parameters)",
)
async def custom_tts_endpoint(request: CustomTTSRequest):
"""
Generates speech audio from text using explicit Dia parameters.
"""
monitor = PerformanceMonitor()
monitor.record("Request received")
logger.info(
f"Received custom TTS request: mode='{request.voice_mode}', format='{request.output_format}'"
)
logger.debug(f"Input text (start): '{request.text[:100]}...'")
logger.debug(
f"Params: max_tokens={request.max_tokens}, cfg={request.cfg_scale}, temp={request.temperature}, top_p={request.top_p}, speed={request.speed_factor}, top_k={request.cfg_filter_top_k}"
)
clone_ref_file = None
if request.voice_mode == "clone":
if not request.clone_reference_filename:
raise HTTPException(
status_code=400, # Bad request
detail="Missing 'clone_reference_filename' which is required for clone mode.",
)
ref_path = get_reference_audio_path()
potential_path = os.path.join(ref_path, request.clone_reference_filename)
if not os.path.isfile(potential_path):
logger.error(
f"Reference audio file not found for clone mode: {potential_path}"
)
raise HTTPException(
status_code=404, # Not found
detail=f"Reference audio file not found: {request.clone_reference_filename}",
)
clone_ref_file = request.clone_reference_filename
logger.info(f"Custom request using clone mode with file: {clone_ref_file}")
monitor.record("Parameters processed")
try:
# Call the core engine function with parameters from the request
result = generate_speech(
text=request.text,
voice_mode=request.voice_mode,
clone_reference_filename=clone_ref_file,
max_tokens=request.max_tokens, # Pass user value or None
cfg_scale=request.cfg_scale,
temperature=request.temperature,
top_p=request.top_p,
speed_factor=request.speed_factor, # For post-processing
cfg_filter_top_k=request.cfg_filter_top_k,
)
monitor.record("Generation complete")
if result is None:
logger.error("Speech generation failed (engine returned None).")
raise HTTPException(status_code=500, detail="Speech generation failed.")
audio_array, sample_rate = result
if sample_rate != EXPECTED_SAMPLE_RATE:
logger.warning(
f"Engine returned sample rate {sample_rate}, expected {EXPECTED_SAMPLE_RATE}. Encoding will use {EXPECTED_SAMPLE_RATE}."
)
sample_rate = EXPECTED_SAMPLE_RATE
# Encode the audio in memory
encoded_audio = encode_audio(audio_array, sample_rate, request.output_format)
monitor.record("Audio encoding complete")
if encoded_audio is None:
logger.error(f"Failed to encode audio to format: {request.output_format}")
raise HTTPException(
status_code=500,
detail=f"Failed to encode audio to {request.output_format}",
)
# Determine media type
media_type = "audio/opus" if request.output_format == "opus" else "audio/wav"
logger.info(
f"Successfully generated {len(encoded_audio)} bytes in format {request.output_format}"
)
logger.debug(monitor.report())
# Stream the response
return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)
except HTTPException as http_exc:
logger.error(f"HTTP exception during custom TTS request: {http_exc.detail}")
raise http_exc
except Exception as e:
logger.error(f"Error processing custom TTS request: {e}", exc_info=True)
logger.debug(monitor.report())
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
# --- Web UI Endpoints ---
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def get_web_ui(request: Request):
"""Serves the main TTS web interface."""
logger.info("Serving TTS Web UI (index.html)")
# Get current list of reference files for the clone dropdown
reference_files = get_valid_reference_files()
# Get current server config and default generation params
current_config = config_manager.get_all()
default_gen_params = {
"speed_factor": get_gen_default_speed_factor(),
"cfg_scale": get_gen_default_cfg_scale(),
"temperature": get_gen_default_temperature(),
"top_p": get_gen_default_top_p(),
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
}
return templates.TemplateResponse(
"index.html", # Use the renamed file
{
"request": request,
"reference_files": reference_files,
"config": current_config, # Pass current server config
"presets": loaded_presets, # Pass loaded presets
"default_gen_params": default_gen_params, # Pass default gen params
# Add other variables needed by the template for initial state
"error": None,
"success": None,
"output_file_url": None,
"generation_time": None,
"submitted_text": "",
"submitted_voice_mode": "dialogue", # Default to combined mode
"submitted_clone_file": None,
# Initial generation params will be set by default_gen_params
},
)
@app.post("/web/generate", response_class=HTMLResponse, include_in_schema=False)
async def handle_web_ui_generate(
request: Request,
text: str = Form(...),
voice_mode: Literal["dialogue", "clone"] = Form(...), # Updated modes
clone_reference_select: Optional[str] = Form(None),
# Generation parameters from form
speed_factor: float = Form(...), # Make required or use Depends with default
cfg_scale: float = Form(...),
temperature: float = Form(...),
top_p: float = Form(...),
cfg_filter_top_k: int = Form(...),
):
"""Handles the generation request from the web UI form."""
logger.info(f"Web UI generation request: mode='{voice_mode}'")
monitor = PerformanceMonitor()
monitor.record("Web request received")
output_file_url = None
generation_time = None
error_message = None
success_message = None
output_filename_base = "dia_output" # Default base name
# --- Pre-generation Validation ---
if not text.strip():
error_message = "Please enter some text to synthesize."
clone_ref_file = None
if voice_mode == "clone":
if not clone_reference_select or clone_reference_select == "none":
error_message = "Please select a reference audio file for clone mode."
else:
# Verify selected file still exists (important if files can be deleted)
ref_path = get_reference_audio_path()
potential_path = os.path.join(ref_path, clone_reference_select)
if not os.path.isfile(potential_path):
error_message = f"Selected reference file '{clone_reference_select}' no longer exists. Please refresh or upload."
# Invalidate selection
clone_ref_file = None
clone_reference_select = None # Clear submitted value for re-rendering
else:
clone_ref_file = clone_reference_select
logger.info(f"Using selected reference file: {clone_ref_file}")
# If validation failed, re-render the page with error and submitted values
if error_message:
logger.warning(f"Web UI validation error: {error_message}")
reference_files = get_valid_reference_files()
current_config = config_manager.get_all()
default_gen_params = { # Pass defaults again for consistency
"speed_factor": get_gen_default_speed_factor(),
"cfg_scale": get_gen_default_cfg_scale(),
"temperature": get_gen_default_temperature(),
"top_p": get_gen_default_top_p(),
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
}
# Pass back the values the user submitted
submitted_gen_params = {
"speed_factor": speed_factor,
"cfg_scale": cfg_scale,
"temperature": temperature,
"top_p": top_p,
"cfg_filter_top_k": cfg_filter_top_k,
}
return templates.TemplateResponse(
"index.html",
{
"request": request,
"error": error_message,
"reference_files": reference_files,
"config": current_config,
"presets": loaded_presets,
"default_gen_params": default_gen_params, # Base defaults
# Submitted values to repopulate form
"submitted_text": text,
"submitted_voice_mode": voice_mode,
"submitted_clone_file": clone_reference_select, # Use potentially invalidated value
"submitted_gen_params": submitted_gen_params, # Pass submitted params back
# Ensure other necessary template variables are passed
"success": None,
"output_file_url": None,
"generation_time": None,
},
)
# --- Generation ---
try:
monitor.record("Parameters processed")
# Call the core engine function
result = generate_speech(
text=text,
voice_mode=voice_mode,
clone_reference_filename=clone_ref_file,
speed_factor=speed_factor,
cfg_scale=cfg_scale,
temperature=temperature,
top_p=top_p,
cfg_filter_top_k=cfg_filter_top_k,
max_tokens=None, # Use model default for UI simplicity
)
monitor.record("Generation complete")
if result:
audio_array, sample_rate = result
output_path_base = get_output_path()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create a more descriptive filename
mode_tag = voice_mode
if voice_mode == "clone" and clone_ref_file:
safe_ref_name = sanitize_filename(os.path.splitext(clone_ref_file)[0])
mode_tag = f"clone_{safe_ref_name[:20]}" # Limit length
output_filename = (
f"{mode_tag}_{timestamp}.wav" # Always save as WAV for simplicity
)
output_filepath = os.path.join(output_path_base, output_filename)
# Save the audio to a WAV file
saved = save_audio_to_file(audio_array, sample_rate, output_filepath)
monitor.record("Audio saved")
if saved:
output_file_url = (
f"/outputs/{output_filename}" # URL path for browser access
)
generation_time = (
monitor.events[-1][1] - monitor.start_time
) # Time until save complete
success_message = f"Audio generated successfully!"
logger.info(f"Web UI generated audio saved to: {output_filepath}")
else:
error_message = "Failed to save generated audio file."
logger.error("Failed to save audio file from web UI request.")
else:
error_message = "Speech generation failed (engine returned None)."
logger.error("Speech generation failed for web UI request.")
except Exception as e:
logger.error(f"Error processing web UI TTS request: {e}", exc_info=True)
error_message = f"An unexpected error occurred: {str(e)}"
logger.debug(monitor.report())
# --- Re-render Template with Results ---
reference_files = get_valid_reference_files()
current_config = config_manager.get_all()
default_gen_params = {
"speed_factor": get_gen_default_speed_factor(),
"cfg_scale": get_gen_default_cfg_scale(),
"temperature": get_gen_default_temperature(),
"top_p": get_gen_default_top_p(),
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
}
# Pass back submitted values to repopulate form correctly
submitted_gen_params = {
"speed_factor": speed_factor,
"cfg_scale": cfg_scale,
"temperature": temperature,
"top_p": top_p,
"cfg_filter_top_k": cfg_filter_top_k,
}
return templates.TemplateResponse(
"index.html",
{
"request": request,
"error": error_message,
"success": success_message,
"output_file_url": output_file_url,
"generation_time": f"{generation_time:.2f}" if generation_time else None,
"reference_files": reference_files,
"config": current_config,
"presets": loaded_presets,
"default_gen_params": default_gen_params, # Base defaults
# Pass back submitted values
"submitted_text": text,
"submitted_voice_mode": voice_mode,
"submitted_clone_file": clone_ref_file, # Pass the validated filename back
"submitted_gen_params": submitted_gen_params, # Pass submitted params back
},
)
# --- Reference Audio Upload Endpoint ---
@app.post(
"/upload_reference", tags=["Web UI Helpers"], summary="Upload reference audio files"
)
async def upload_reference_audio(files: List[UploadFile] = File(...)):
"""Handles uploading of reference audio files (.wav, .mp3) for voice cloning."""
logger.info(f"Received request to upload {len(files)} reference audio file(s).")
ref_path = get_reference_audio_path()
uploaded_filenames = []
errors = []
allowed_mime_types = [
"audio/wav",
"audio/mpeg",
"audio/x-wav",
] # Common WAV/MP3 types
allowed_extensions = [".wav", ".mp3"]
for file in files:
try:
# Basic validation
if not file.filename:
errors.append("Received file with no filename.")
continue
# Sanitize filename
safe_filename = sanitize_filename(file.filename)
_, ext = os.path.splitext(safe_filename)
if ext.lower() not in allowed_extensions:
errors.append(
f"File '{file.filename}' has unsupported extension '{ext}'. Allowed: {allowed_extensions}"
)
continue
# Check MIME type (more reliable than extension)
if file.content_type not in allowed_mime_types:
errors.append(
f"File '{file.filename}' has unsupported content type '{file.content_type}'. Allowed: {allowed_mime_types}"
)
continue
# Construct full save path
destination_path = os.path.join(ref_path, safe_filename)
# Prevent overwriting existing files (optional, could add counter)
if os.path.exists(destination_path):
# Simple approach: skip if exists
logger.warning(
f"Reference file '{safe_filename}' already exists. Skipping upload."
)
# Add to list so UI knows it's available, even if not newly uploaded this time
if safe_filename not in uploaded_filenames:
uploaded_filenames.append(safe_filename)
continue
# Alternative: add counter like file_1.wav, file_2.wav
# Save the file using shutil.copyfileobj for efficiency with large files
try:
with open(destination_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"Successfully saved reference file: {destination_path}")
uploaded_filenames.append(safe_filename)
except Exception as save_exc:
errors.append(f"Failed to save file '{safe_filename}': {save_exc}")
logger.error(
f"Failed to save uploaded file '{safe_filename}' to '{destination_path}': {save_exc}",
exc_info=True,
)
finally:
# Ensure the UploadFile resource is closed
await file.close()
except Exception as e:
errors.append(
f"Error processing file '{getattr(file, 'filename', 'unknown')}': {e}"
)
logger.error(
f"Unexpected error processing uploaded file: {e}", exc_info=True
)
# Ensure file is closed even if other errors occur
if file:
await file.close()
# Get the updated list of all valid files in the directory
updated_file_list = get_valid_reference_files()
response_data = {
"message": f"Processed {len(files)} file(s).",
"uploaded_files": uploaded_filenames, # List of successfully saved *new* files this request
"all_reference_files": updated_file_list, # Complete current list
"errors": errors,
}
status_code = (
200 if not errors or len(errors) < len(files) else 400
) # OK if at least one succeeded, else Bad Request
if errors:
logger.warning(f"Upload completed with errors: {errors}")
return JSONResponse(content=response_data, status_code=status_code)
# --- Health Check Endpoint ---
@app.get("/health", tags=["Server Status"], summary="Check server health")
async def health_check():
"""Basic health check, indicates if the server is running and if the model is loaded."""
# Access the MODEL_LOADED variable *directly* from the engine module
# each time the endpoint is called to get the current status.
current_model_status = getattr(engine, "MODEL_LOADED", False) # Safely get status
logger.debug(
f"Health check returning model_loaded status: {current_model_status}"
) # Add debug log
return {"status": "healthy", "model_loaded": current_model_status}
# --- Main Execution ---
if __name__ == "__main__":
host = get_host()
port = get_port()
logger.info(f"Starting Dia TTS server on {host}:{port}")
logger.info(f"Model Repository: {get_model_repo_id()}")
logger.info(f"Model Config File: {get_model_config_filename()}")
logger.info(f"Model Weights File: {get_model_weights_filename()}")
logger.info(f"Model Cache Path: {get_model_cache_path()}")
logger.info(f"Reference Audio Path: {get_reference_audio_path()}")
logger.info(f"Output Path: {get_output_path()}")
# Determine the host to display in logs and use for browser opening
display_host = "localhost" if host == "0.0.0.0" else host
logger.info(f"Web UI will be available at http://{display_host}:{port}/")
logger.info(f"API Docs available at http://{display_host}:{port}/docs")
# Ensure UI directory and index.html exist for UI
ui_dir = "ui"
index_file = os.path.join(ui_dir, "index.html")
if not os.path.isdir(ui_dir) or not os.path.isfile(index_file):
logger.warning(
f"'{ui_dir}' directory or '{index_file}' not found. Web UI may not work."
)
# Optionally create dummy files/dirs if needed for startup
os.makedirs(ui_dir, exist_ok=True)
if not os.path.isfile(index_file):
try:
with open(index_file, "w") as f:
f.write(
"<html><body>Web UI template missing. See project source for index.html.</body></html>"
)
logger.info(f"Created dummy {index_file}.")
except Exception as e:
logger.error(f"Failed to create dummy {index_file}: {e}")
# --- Create synchronization event ---
# This event will be set by the lifespan manager once startup (incl. model loading) is complete.
startup_complete_event = threading.Event()
# Run Uvicorn server
# The lifespan context manager ('lifespan="on"') will run during startup.
# The 'lifespan' function is responsible for loading models and setting the 'startup_complete_event'.
uvicorn.run(
"server:app", # Use the format 'module:app_instance'
host=host,
port=port,
reload=False, # Set reload as needed for development/production
# reload_dirs=[".", "ui"], # Only use reload=True with reload_dirs/includes for development
# reload_includes=[
# "*.py",
# "*.html",
# "*.css",
# "*.js",
# ".env",
# "*.yaml",
# ],
lifespan="on", # Use the lifespan context manager defined in this file
# workers=1 # Keep workers=1 when using reload=True or complex global state/models
)
|