File size: 44,276 Bytes
ac5de5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
# server.py
# Main FastAPI server for Dia TTS

import sys
import logging
import time
import os
import io
import uuid
import sys
import shutil  # For file copying
import yaml  # For loading presets
from datetime import datetime
from contextlib import asynccontextmanager
from typing import Optional, Literal, List, Dict, Any
import webbrowser
import threading
import time

from fastapi import (
    FastAPI,
    HTTPException,
    Request,
    Response,
    Form,
    UploadFile,
    File,
    BackgroundTasks,
)
from fastapi.responses import (
    StreamingResponse,
    JSONResponse,
    HTMLResponse,
    RedirectResponse,
)
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import uvicorn
import numpy as np

# Internal imports
from config import (
    config_manager,
    get_host,
    get_port,
    get_output_path,
    get_reference_audio_path,
    # register_config_routes is now defined locally
    get_model_cache_path,
    get_model_repo_id,
    get_model_config_filename,
    get_model_weights_filename,
    # Generation default getters
    get_gen_default_speed_factor,
    get_gen_default_cfg_scale,
    get_gen_default_temperature,
    get_gen_default_top_p,
    get_gen_default_cfg_filter_top_k,
    DEFAULT_CONFIG,
)
from models import OpenAITTSRequest, CustomTTSRequest, ErrorResponse
import engine
from engine import (
    load_model as load_dia_model,
    generate_speech,
    EXPECTED_SAMPLE_RATE,
)
from utils import encode_audio, save_audio_to_file, PerformanceMonitor

# Configure logging (Basic setup, can be enhanced)
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
# Reduce verbosity of noisy libraries if needed
# logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
# logging.getLogger("watchfiles").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)  # Logger for this module

# --- Global Variables & Constants ---
PRESETS_FILE = "ui/presets.yaml"
loaded_presets: List[Dict[str, Any]] = []  # Cache presets in memory
startup_complete_event = threading.Event()

# --- Helper Functions ---


def load_presets():
    """Loads presets from the YAML file."""
    global loaded_presets
    try:
        if os.path.exists(PRESETS_FILE):
            with open(PRESETS_FILE, "r", encoding="utf-8") as f:
                loaded_presets = yaml.safe_load(f)
                if not isinstance(loaded_presets, list):
                    logger.error(
                        f"Presets file '{PRESETS_FILE}' should contain a list, but found {type(loaded_presets)}. No presets loaded."
                    )
                    loaded_presets = []
                else:
                    logger.info(
                        f"Successfully loaded {len(loaded_presets)} presets from {PRESETS_FILE}."
                    )
        else:
            logger.warning(
                f"Presets file not found at '{PRESETS_FILE}'. No presets will be available."
            )
            loaded_presets = []
    except yaml.YAMLError as e:
        logger.error(
            f"Error parsing presets YAML file '{PRESETS_FILE}': {e}", exc_info=True
        )
        loaded_presets = []
    except Exception as e:
        logger.error(f"Error loading presets file '{PRESETS_FILE}': {e}", exc_info=True)
        loaded_presets = []


def get_valid_reference_files() -> list[str]:
    """Gets a list of valid audio files (.wav, .mp3) from the reference directory."""
    ref_path = get_reference_audio_path()
    valid_files = []
    allowed_extensions = (".wav", ".mp3")
    try:
        if os.path.isdir(ref_path):
            for filename in os.listdir(ref_path):
                if filename.lower().endswith(allowed_extensions):
                    # Optional: Add check for file size or basic validity if needed
                    valid_files.append(filename)
        else:
            logger.warning(f"Reference audio directory not found: {ref_path}")
    except Exception as e:
        logger.error(
            f"Error reading reference audio directory '{ref_path}': {e}", exc_info=True
        )
    return sorted(valid_files)


def sanitize_filename(filename: str) -> str:
    """Removes potentially unsafe characters and path components from a filename."""
    # Remove directory separators
    filename = os.path.basename(filename)
    # Keep only alphanumeric, underscore, hyphen, dot. Replace others with underscore.
    safe_chars = set(
        "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-"
    )
    sanitized = "".join(c if c in safe_chars else "_" for c in filename)
    # Prevent names starting with dot or consisting only of dots/spaces
    if not sanitized or sanitized.lstrip("._ ") == "":
        return f"uploaded_file_{uuid.uuid4().hex[:8]}"  # Generate a safe fallback name
    # Limit length
    max_len = 100
    if len(sanitized) > max_len:
        name, ext = os.path.splitext(sanitized)
        sanitized = name[: max_len - len(ext)] + ext
    return sanitized


# --- Application Lifespan (Startup/Shutdown) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Application lifespan manager for startup/shutdown."""
    model_loaded_successfully = False  # Flag to track success
    try:
        logger.info("Starting Dia TTS server initialization...")
        # Ensure base directories exist
        os.makedirs(get_output_path(), exist_ok=True)
        os.makedirs(get_reference_audio_path(), exist_ok=True)
        os.makedirs(get_model_cache_path(), exist_ok=True)
        os.makedirs("ui", exist_ok=True)
        os.makedirs("static", exist_ok=True)

        # Load presets from YAML file
        load_presets()

        # Load the main TTS model during startup
        if not load_dia_model():
            # Model loading failed
            error_msg = (
                "CRITICAL: Failed to load Dia model on startup. Server cannot start."
            )
            logger.critical(error_msg)
            # Option 1: Raise an exception to stop Uvicorn startup cleanly
            raise RuntimeError(error_msg)
            # Option 2: Force exit (less clean, might bypass some Uvicorn shutdown)
            # sys.exit(1)
        else:
            logger.info("Dia model loaded successfully.")
            model_loaded_successfully = True

            # Create and start a delayed browser opening thread
            # IMPORTANT: Create this thread AFTER model loading completes
            host = get_host()
            port = get_port()
            browser_thread = threading.Thread(
                target=lambda: _delayed_browser_open(host, port), daemon=True
            )
            browser_thread.start()

        # --- Signal completion AFTER potentially long operations ---
        logger.info("Application startup sequence finished. Signaling readiness.")
        startup_complete_event.set()

        yield  # Application runs here

    except Exception as e:
        # Catch the RuntimeError we raised or any other startup error
        logger.error(f"Fatal error during application startup: {e}", exc_info=True)
        # Do NOT set the event here if startup failed
        # Re-raise the exception or exit to ensure the server stops
        raise e  # Re-raising ensures Uvicorn knows startup failed
        # Alternatively: sys.exit(1)
    finally:
        # Cleanup on shutdown
        logger.info("Application shutdown initiated...")
        # Add any specific cleanup needed
        logger.info("Application shutdown complete.")


def _delayed_browser_open(host, port):
    """Opens browser after a short delay to ensure server is ready"""
    try:
        # Small delay to ensure Uvicorn is fully ready
        time.sleep(2)

        display_host = "localhost" if host == "0.0.0.0" else host
        browser_url = f"http://{display_host}:{port}/"

        # Log to file for debugging
        with open("browser_thread_debug.log", "a") as f:
            f.write(f"[{time.time()}] Opening browser at {browser_url}\n")

        # Try to use logger as well (might work at this point)
        try:
            logger.info(f"Opening browser at {browser_url}")
        except:
            pass

        # Open browser directly without health checks
        webbrowser.open(browser_url)

    except Exception as e:
        with open("browser_thread_debug.log", "a") as f:
            f.write(f"[{time.time()}] Browser open error: {str(e)}\n")


# --- FastAPI App Initialization ---
app = FastAPI(
    title="Dia TTS Server",
    description="Text-to-Speech server using the Dia model, providing API and Web UI.",
    version="1.1.0",  # Incremented version
    lifespan=lifespan,
)

# List of folders to check/create
folders = ["reference_audio", "model_cache", "outputs"]

# Check each folder and create if it doesn't exist
for folder in folders:
    if not os.path.exists(folder):
        os.makedirs(folder)
        print(f"Created directory: {folder}")

# --- Static Files and Templates ---
# Serve generated audio files from the configured output path
app.mount("/outputs", StaticFiles(directory=get_output_path()), name="outputs")
# Serve UI files (CSS, JS) from the 'ui' directory
app.mount("/ui", StaticFiles(directory="ui"), name="ui_static")
# Initialize Jinja2 templates to look in the 'ui' directory
templates = Jinja2Templates(directory="ui")


# --- Configuration Routes Definition ---
# Defined locally now instead of importing from config.py
def register_config_routes(app: FastAPI):
    """Adds configuration management endpoints to the FastAPI app."""
    logger.info(
        "Registering configuration routes (/get_config, /save_config, /restart_server, /save_generation_defaults)."
    )

    @app.get(

        "/get_config",

        tags=["Configuration"],

        summary="Get current server configuration",

    )
    async def get_current_config():
        """Returns the current server configuration values (from .env or defaults)."""
        logger.info("Request received for /get_config")
        return JSONResponse(content=config_manager.get_all())

    @app.post(

        "/save_config", tags=["Configuration"], summary="Save server configuration"

    )
    async def save_new_config(request: Request):
        """

        Saves updated server configuration values (Host, Port, Model paths, etc.)

        to the .env file. Requires server restart to apply most changes.

        """
        logger.info("Request received for /save_config")
        try:
            new_config_data = await request.json()
            if not isinstance(new_config_data, dict):
                raise ValueError("Request body must be a JSON object.")
            logger.debug(f"Received server config data to save: {new_config_data}")

            # Filter data to only include keys present in DEFAULT_CONFIG
            filtered_data = {
                k: v for k, v in new_config_data.items() if k in DEFAULT_CONFIG
            }
            unknown_keys = set(new_config_data.keys()) - set(filtered_data.keys())
            if unknown_keys:
                logger.warning(
                    f"Ignoring unknown keys in save_config request: {unknown_keys}"
                )

            config_manager.update(filtered_data)  # Update in memory first
            if config_manager.save():  # Attempt to save to .env
                logger.info("Server configuration saved successfully to .env.")
                return JSONResponse(
                    content={
                        "message": "Server configuration saved. Restart server to apply changes."
                    }
                )
            else:
                logger.error("Failed to save server configuration to .env file.")
                raise HTTPException(
                    status_code=500, detail="Failed to save configuration file."
                )
        except ValueError as ve:
            logger.error(f"Invalid data format for /save_config: {ve}")
            raise HTTPException(
                status_code=400, detail=f"Invalid request data: {str(ve)}"
            )
        except Exception as e:
            logger.error(f"Error processing /save_config request: {e}", exc_info=True)
            raise HTTPException(
                status_code=500, detail=f"Internal server error during save: {str(e)}"
            )

    @app.post(

        "/save_generation_defaults",

        tags=["Configuration"],

        summary="Save default generation parameters",

    )
    async def save_generation_defaults(request: Request):
        """

        Saves the provided generation parameters (speed, cfg, temp, etc.)

        as the new defaults in the .env file. These are loaded by the UI on startup.

        """
        logger.info("Request received for /save_generation_defaults")
        try:
            gen_params = await request.json()
            if not isinstance(gen_params, dict):
                raise ValueError("Request body must be a JSON object.")
            logger.debug(f"Received generation defaults to save: {gen_params}")

            # Map received keys (e.g., 'speed_factor') to .env keys (e.g., 'GEN_DEFAULT_SPEED_FACTOR')
            defaults_to_save = {}
            key_map = {
                "speed_factor": "GEN_DEFAULT_SPEED_FACTOR",
                "cfg_scale": "GEN_DEFAULT_CFG_SCALE",
                "temperature": "GEN_DEFAULT_TEMPERATURE",
                "top_p": "GEN_DEFAULT_TOP_P",
                "cfg_filter_top_k": "GEN_DEFAULT_CFG_FILTER_TOP_K",
            }
            valid_keys_found = False
            for ui_key, env_key in key_map.items():
                if ui_key in gen_params:
                    # Basic validation could be added here (e.g., check if float/int)
                    defaults_to_save[env_key] = str(
                        gen_params[ui_key]
                    )  # Ensure saving as string
                    valid_keys_found = True
                else:
                    logger.warning(
                        f"Missing expected key '{ui_key}' in save_generation_defaults request."
                    )

            if not valid_keys_found:
                raise ValueError("No valid generation parameters found in the request.")

            config_manager.update(defaults_to_save)  # Update in memory
            if (
                config_manager.save()
            ):  # Save all current config (including these) to .env
                logger.info("Generation defaults saved successfully to .env.")
                return JSONResponse(content={"message": "Generation defaults saved."})
            else:
                logger.error("Failed to save generation defaults to .env file.")
                raise HTTPException(
                    status_code=500, detail="Failed to save configuration file."
                )
        except ValueError as ve:
            logger.error(f"Invalid data format for /save_generation_defaults: {ve}")
            raise HTTPException(
                status_code=400, detail=f"Invalid request data: {str(ve)}"
            )
        except Exception as e:
            logger.error(
                f"Error processing /save_generation_defaults request: {e}",
                exc_info=True,
            )
            raise HTTPException(
                status_code=500, detail=f"Internal server error during save: {str(e)}"
            )

    @app.post(

        "/restart_server",

        tags=["Configuration"],

        summary="Attempt to restart the server",

    )
    async def trigger_server_restart(background_tasks: BackgroundTasks):
        """

        Attempts to restart the server process.

        NOTE: This is highly dependent on how the server is run (e.g., with uvicorn --reload,

        or managed by systemd/supervisor). A simple exit might just stop the process.

        This implementation attempts a clean exit, relying on the runner to restart it.

        """
        logger.warning("Received request to restart server via API.")

        def _do_restart():
            time.sleep(1)  # Short delay to allow response to be sent
            logger.warning("Attempting clean exit for restart...")
            # Option 1: Clean exit (relies on Uvicorn reload or process manager)
            sys.exit(0)
            # Option 2: Forceful re-execution (use with caution, might not work as expected)
            # try:
            #     logger.warning("Attempting os.execv for restart...")
            #     os.execv(sys.executable, ['python'] + sys.argv)
            # except Exception as exec_e:
            #      logger.error(f"os.execv failed: {exec_e}. Server may not restart automatically.")
            #      # Fallback to sys.exit if execv fails
            #      sys.exit(1)

        background_tasks.add_task(_do_restart)
        return JSONResponse(
            content={
                "message": "Restart signal sent. Server should restart shortly if run with auto-reload."
            }
        )


# --- Register Configuration Routes ---
register_config_routes(app)


# --- API Endpoints ---


@app.post(

    "/v1/audio/speech",

    response_class=StreamingResponse,

    tags=["TTS Generation"],

    summary="Generate speech (OpenAI compatible)",

)
async def openai_tts_endpoint(request: OpenAITTSRequest):
    """

    Generates speech audio from text, compatible with the OpenAI TTS API structure.

    Maps the 'voice' parameter to Dia's voice modes ('S1', 'S2', 'dialogue', or filename for clone).

    """
    monitor = PerformanceMonitor()
    monitor.record("Request received")
    logger.info(
        f"Received OpenAI request: voice='{request.voice}', speed={request.speed}, format='{request.response_format}'"
    )
    logger.debug(f"Input text (start): '{request.input[:100]}...'")

    voice_mode = "single_s1"  # Default if mapping fails
    clone_ref_file = None
    ref_path = get_reference_audio_path()

    # --- Map OpenAI 'voice' parameter to Dia's modes ---
    voice_param = request.voice.strip()
    if voice_param.lower() == "dialogue":
        voice_mode = "dialogue"
    elif voice_param.lower() == "s1":
        voice_mode = "single_s1"
    elif voice_param.lower() == "s2":
        voice_mode = "single_s2"
    # Check if it looks like a filename for cloning (allow .wav or .mp3)
    elif voice_param.lower().endswith((".wav", ".mp3")):
        potential_path = os.path.join(ref_path, voice_param)
        # Check if the file actually exists in the reference directory
        if os.path.isfile(potential_path):
            voice_mode = "clone"
            clone_ref_file = voice_param  # Use the provided filename
            logger.info(
                f"OpenAI request mapped to clone mode with file: {clone_ref_file}"
            )
        else:
            logger.warning(
                f"Reference file '{voice_param}' specified in OpenAI request not found in '{ref_path}'. Defaulting voice mode."
            )
            # Fallback to default 'single_s1' if file not found
    else:
        logger.warning(
            f"Unrecognized OpenAI voice parameter '{voice_param}'. Defaulting voice mode to 'single_s1'."
        )
        # Fallback for any other value

    monitor.record("Parameters processed")

    try:
        # Call the core engine function using mapped parameters
        result = generate_speech(
            text=request.input,
            voice_mode=voice_mode,
            clone_reference_filename=clone_ref_file,
            speed_factor=request.speed,  # Pass speed factor for post-processing
            # Use Dia's configured defaults for other generation params unless mapped
            max_tokens=None,  # Let Dia use its default unless specified otherwise
            cfg_scale=get_gen_default_cfg_scale(),  # Use saved defaults
            temperature=get_gen_default_temperature(),
            top_p=get_gen_default_top_p(),
            cfg_filter_top_k=get_gen_default_cfg_filter_top_k(),
        )
        monitor.record("Generation complete")

        if result is None:
            logger.error("Speech generation failed (engine returned None).")
            raise HTTPException(status_code=500, detail="Speech generation failed.")

        audio_array, sample_rate = result

        if sample_rate != EXPECTED_SAMPLE_RATE:
            logger.warning(
                f"Engine returned sample rate {sample_rate}, but expected {EXPECTED_SAMPLE_RATE}. Encoding might assume {EXPECTED_SAMPLE_RATE}."
            )
            # Use EXPECTED_SAMPLE_RATE for encoding as it's what the model is trained for
            sample_rate = EXPECTED_SAMPLE_RATE

        # Encode the audio in memory to the requested format
        encoded_audio = encode_audio(audio_array, sample_rate, request.response_format)
        monitor.record("Audio encoding complete")

        if encoded_audio is None:
            logger.error(f"Failed to encode audio to format: {request.response_format}")
            raise HTTPException(
                status_code=500,
                detail=f"Failed to encode audio to {request.response_format}",
            )

        # Determine the correct media type for the response header
        media_type = "audio/opus" if request.response_format == "opus" else "audio/wav"
        # Note: OpenAI uses audio/opus, not audio/ogg;codecs=opus. Let's match OpenAI.

        logger.info(
            f"Successfully generated {len(encoded_audio)} bytes in format {request.response_format}"
        )
        logger.debug(monitor.report())

        # Stream the encoded audio back to the client
        return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)

    except HTTPException as http_exc:
        # Re-raise HTTPExceptions directly (e.g., from parameter validation)
        logger.error(f"HTTP exception during OpenAI request: {http_exc.detail}")
        raise http_exc
    except Exception as e:
        logger.error(f"Error processing OpenAI TTS request: {e}", exc_info=True)
        logger.debug(monitor.report())
        # Return generic server error for unexpected issues
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")


@app.post(

    "/tts",

    response_class=StreamingResponse,

    tags=["TTS Generation"],

    summary="Generate speech (Custom parameters)",

)
async def custom_tts_endpoint(request: CustomTTSRequest):
    """

    Generates speech audio from text using explicit Dia parameters.

    """
    monitor = PerformanceMonitor()
    monitor.record("Request received")
    logger.info(
        f"Received custom TTS request: mode='{request.voice_mode}', format='{request.output_format}'"
    )
    logger.debug(f"Input text (start): '{request.text[:100]}...'")
    logger.debug(
        f"Params: max_tokens={request.max_tokens}, cfg={request.cfg_scale}, temp={request.temperature}, top_p={request.top_p}, speed={request.speed_factor}, top_k={request.cfg_filter_top_k}"
    )

    clone_ref_file = None
    if request.voice_mode == "clone":
        if not request.clone_reference_filename:
            raise HTTPException(
                status_code=400,  # Bad request
                detail="Missing 'clone_reference_filename' which is required for clone mode.",
            )
        ref_path = get_reference_audio_path()
        potential_path = os.path.join(ref_path, request.clone_reference_filename)
        if not os.path.isfile(potential_path):
            logger.error(
                f"Reference audio file not found for clone mode: {potential_path}"
            )
            raise HTTPException(
                status_code=404,  # Not found
                detail=f"Reference audio file not found: {request.clone_reference_filename}",
            )
        clone_ref_file = request.clone_reference_filename
        logger.info(f"Custom request using clone mode with file: {clone_ref_file}")

    monitor.record("Parameters processed")

    try:
        # Call the core engine function with parameters from the request
        result = generate_speech(
            text=request.text,
            voice_mode=request.voice_mode,
            clone_reference_filename=clone_ref_file,
            max_tokens=request.max_tokens,  # Pass user value or None
            cfg_scale=request.cfg_scale,
            temperature=request.temperature,
            top_p=request.top_p,
            speed_factor=request.speed_factor,  # For post-processing
            cfg_filter_top_k=request.cfg_filter_top_k,
        )
        monitor.record("Generation complete")

        if result is None:
            logger.error("Speech generation failed (engine returned None).")
            raise HTTPException(status_code=500, detail="Speech generation failed.")

        audio_array, sample_rate = result

        if sample_rate != EXPECTED_SAMPLE_RATE:
            logger.warning(
                f"Engine returned sample rate {sample_rate}, expected {EXPECTED_SAMPLE_RATE}. Encoding will use {EXPECTED_SAMPLE_RATE}."
            )
            sample_rate = EXPECTED_SAMPLE_RATE

        # Encode the audio in memory
        encoded_audio = encode_audio(audio_array, sample_rate, request.output_format)
        monitor.record("Audio encoding complete")

        if encoded_audio is None:
            logger.error(f"Failed to encode audio to format: {request.output_format}")
            raise HTTPException(
                status_code=500,
                detail=f"Failed to encode audio to {request.output_format}",
            )

        # Determine media type
        media_type = "audio/opus" if request.output_format == "opus" else "audio/wav"

        logger.info(
            f"Successfully generated {len(encoded_audio)} bytes in format {request.output_format}"
        )
        logger.debug(monitor.report())

        # Stream the response
        return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)

    except HTTPException as http_exc:
        logger.error(f"HTTP exception during custom TTS request: {http_exc.detail}")
        raise http_exc
    except Exception as e:
        logger.error(f"Error processing custom TTS request: {e}", exc_info=True)
        logger.debug(monitor.report())
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")


# --- Web UI Endpoints ---


@app.get("/", response_class=HTMLResponse, include_in_schema=False)
async def get_web_ui(request: Request):
    """Serves the main TTS web interface."""
    logger.info("Serving TTS Web UI (index.html)")
    # Get current list of reference files for the clone dropdown
    reference_files = get_valid_reference_files()
    # Get current server config and default generation params
    current_config = config_manager.get_all()
    default_gen_params = {
        "speed_factor": get_gen_default_speed_factor(),
        "cfg_scale": get_gen_default_cfg_scale(),
        "temperature": get_gen_default_temperature(),
        "top_p": get_gen_default_top_p(),
        "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
    }

    return templates.TemplateResponse(
        "index.html",  # Use the renamed file
        {
            "request": request,
            "reference_files": reference_files,
            "config": current_config,  # Pass current server config
            "presets": loaded_presets,  # Pass loaded presets
            "default_gen_params": default_gen_params,  # Pass default gen params
            # Add other variables needed by the template for initial state
            "error": None,
            "success": None,
            "output_file_url": None,
            "generation_time": None,
            "submitted_text": "",
            "submitted_voice_mode": "dialogue",  # Default to combined mode
            "submitted_clone_file": None,
            # Initial generation params will be set by default_gen_params
        },
    )


@app.post("/web/generate", response_class=HTMLResponse, include_in_schema=False)
async def handle_web_ui_generate(

    request: Request,

    text: str = Form(...),

    voice_mode: Literal["dialogue", "clone"] = Form(...),  # Updated modes

    clone_reference_select: Optional[str] = Form(None),

    # Generation parameters from form

    speed_factor: float = Form(...),  # Make required or use Depends with default

    cfg_scale: float = Form(...),

    temperature: float = Form(...),

    top_p: float = Form(...),

    cfg_filter_top_k: int = Form(...),

):
    """Handles the generation request from the web UI form."""
    logger.info(f"Web UI generation request: mode='{voice_mode}'")
    monitor = PerformanceMonitor()
    monitor.record("Web request received")

    output_file_url = None
    generation_time = None
    error_message = None
    success_message = None
    output_filename_base = "dia_output"  # Default base name

    # --- Pre-generation Validation ---
    if not text.strip():
        error_message = "Please enter some text to synthesize."

    clone_ref_file = None
    if voice_mode == "clone":
        if not clone_reference_select or clone_reference_select == "none":
            error_message = "Please select a reference audio file for clone mode."
        else:
            # Verify selected file still exists (important if files can be deleted)
            ref_path = get_reference_audio_path()
            potential_path = os.path.join(ref_path, clone_reference_select)
            if not os.path.isfile(potential_path):
                error_message = f"Selected reference file '{clone_reference_select}' no longer exists. Please refresh or upload."
                # Invalidate selection
                clone_ref_file = None
                clone_reference_select = None  # Clear submitted value for re-rendering
            else:
                clone_ref_file = clone_reference_select
                logger.info(f"Using selected reference file: {clone_ref_file}")

    # If validation failed, re-render the page with error and submitted values
    if error_message:
        logger.warning(f"Web UI validation error: {error_message}")
        reference_files = get_valid_reference_files()
        current_config = config_manager.get_all()
        default_gen_params = {  # Pass defaults again for consistency
            "speed_factor": get_gen_default_speed_factor(),
            "cfg_scale": get_gen_default_cfg_scale(),
            "temperature": get_gen_default_temperature(),
            "top_p": get_gen_default_top_p(),
            "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
        }
        # Pass back the values the user submitted
        submitted_gen_params = {
            "speed_factor": speed_factor,
            "cfg_scale": cfg_scale,
            "temperature": temperature,
            "top_p": top_p,
            "cfg_filter_top_k": cfg_filter_top_k,
        }

        return templates.TemplateResponse(
            "index.html",
            {
                "request": request,
                "error": error_message,
                "reference_files": reference_files,
                "config": current_config,
                "presets": loaded_presets,
                "default_gen_params": default_gen_params,  # Base defaults
                # Submitted values to repopulate form
                "submitted_text": text,
                "submitted_voice_mode": voice_mode,
                "submitted_clone_file": clone_reference_select,  # Use potentially invalidated value
                "submitted_gen_params": submitted_gen_params,  # Pass submitted params back
                # Ensure other necessary template variables are passed
                "success": None,
                "output_file_url": None,
                "generation_time": None,
            },
        )

    # --- Generation ---
    try:
        monitor.record("Parameters processed")
        # Call the core engine function
        result = generate_speech(
            text=text,
            voice_mode=voice_mode,
            clone_reference_filename=clone_ref_file,
            speed_factor=speed_factor,
            cfg_scale=cfg_scale,
            temperature=temperature,
            top_p=top_p,
            cfg_filter_top_k=cfg_filter_top_k,
            max_tokens=None,  # Use model default for UI simplicity
        )
        monitor.record("Generation complete")

        if result:
            audio_array, sample_rate = result
            output_path_base = get_output_path()
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            # Create a more descriptive filename
            mode_tag = voice_mode
            if voice_mode == "clone" and clone_ref_file:
                safe_ref_name = sanitize_filename(os.path.splitext(clone_ref_file)[0])
                mode_tag = f"clone_{safe_ref_name[:20]}"  # Limit length
            output_filename = (
                f"{mode_tag}_{timestamp}.wav"  # Always save as WAV for simplicity
            )
            output_filepath = os.path.join(output_path_base, output_filename)

            # Save the audio to a WAV file
            saved = save_audio_to_file(audio_array, sample_rate, output_filepath)
            monitor.record("Audio saved")

            if saved:
                output_file_url = (
                    f"/outputs/{output_filename}"  # URL path for browser access
                )
                generation_time = (
                    monitor.events[-1][1] - monitor.start_time
                )  # Time until save complete
                success_message = f"Audio generated successfully!"
                logger.info(f"Web UI generated audio saved to: {output_filepath}")
            else:
                error_message = "Failed to save generated audio file."
                logger.error("Failed to save audio file from web UI request.")
        else:
            error_message = "Speech generation failed (engine returned None)."
            logger.error("Speech generation failed for web UI request.")

    except Exception as e:
        logger.error(f"Error processing web UI TTS request: {e}", exc_info=True)
        error_message = f"An unexpected error occurred: {str(e)}"

    logger.debug(monitor.report())

    # --- Re-render Template with Results ---
    reference_files = get_valid_reference_files()
    current_config = config_manager.get_all()
    default_gen_params = {
        "speed_factor": get_gen_default_speed_factor(),
        "cfg_scale": get_gen_default_cfg_scale(),
        "temperature": get_gen_default_temperature(),
        "top_p": get_gen_default_top_p(),
        "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
    }
    # Pass back submitted values to repopulate form correctly
    submitted_gen_params = {
        "speed_factor": speed_factor,
        "cfg_scale": cfg_scale,
        "temperature": temperature,
        "top_p": top_p,
        "cfg_filter_top_k": cfg_filter_top_k,
    }

    return templates.TemplateResponse(
        "index.html",
        {
            "request": request,
            "error": error_message,
            "success": success_message,
            "output_file_url": output_file_url,
            "generation_time": f"{generation_time:.2f}" if generation_time else None,
            "reference_files": reference_files,
            "config": current_config,
            "presets": loaded_presets,
            "default_gen_params": default_gen_params,  # Base defaults
            # Pass back submitted values
            "submitted_text": text,
            "submitted_voice_mode": voice_mode,
            "submitted_clone_file": clone_ref_file,  # Pass the validated filename back
            "submitted_gen_params": submitted_gen_params,  # Pass submitted params back
        },
    )


# --- Reference Audio Upload Endpoint ---
@app.post(

    "/upload_reference", tags=["Web UI Helpers"], summary="Upload reference audio files"

)
async def upload_reference_audio(files: List[UploadFile] = File(...)):
    """Handles uploading of reference audio files (.wav, .mp3) for voice cloning."""
    logger.info(f"Received request to upload {len(files)} reference audio file(s).")
    ref_path = get_reference_audio_path()
    uploaded_filenames = []
    errors = []
    allowed_mime_types = [
        "audio/wav",
        "audio/mpeg",
        "audio/x-wav",
    ]  # Common WAV/MP3 types
    allowed_extensions = [".wav", ".mp3"]

    for file in files:
        try:
            # Basic validation
            if not file.filename:
                errors.append("Received file with no filename.")
                continue

            # Sanitize filename
            safe_filename = sanitize_filename(file.filename)
            _, ext = os.path.splitext(safe_filename)
            if ext.lower() not in allowed_extensions:
                errors.append(
                    f"File '{file.filename}' has unsupported extension '{ext}'. Allowed: {allowed_extensions}"
                )
                continue

            # Check MIME type (more reliable than extension)
            if file.content_type not in allowed_mime_types:
                errors.append(
                    f"File '{file.filename}' has unsupported content type '{file.content_type}'. Allowed: {allowed_mime_types}"
                )
                continue

            # Construct full save path
            destination_path = os.path.join(ref_path, safe_filename)

            # Prevent overwriting existing files (optional, could add counter)
            if os.path.exists(destination_path):
                # Simple approach: skip if exists
                logger.warning(
                    f"Reference file '{safe_filename}' already exists. Skipping upload."
                )
                # Add to list so UI knows it's available, even if not newly uploaded this time
                if safe_filename not in uploaded_filenames:
                    uploaded_filenames.append(safe_filename)
                continue
                # Alternative: add counter like file_1.wav, file_2.wav

            # Save the file using shutil.copyfileobj for efficiency with large files
            try:
                with open(destination_path, "wb") as buffer:
                    shutil.copyfileobj(file.file, buffer)
                logger.info(f"Successfully saved reference file: {destination_path}")
                uploaded_filenames.append(safe_filename)
            except Exception as save_exc:
                errors.append(f"Failed to save file '{safe_filename}': {save_exc}")
                logger.error(
                    f"Failed to save uploaded file '{safe_filename}' to '{destination_path}': {save_exc}",
                    exc_info=True,
                )
            finally:
                # Ensure the UploadFile resource is closed
                await file.close()

        except Exception as e:
            errors.append(
                f"Error processing file '{getattr(file, 'filename', 'unknown')}': {e}"
            )
            logger.error(
                f"Unexpected error processing uploaded file: {e}", exc_info=True
            )
            # Ensure file is closed even if other errors occur
            if file:
                await file.close()

    # Get the updated list of all valid files in the directory
    updated_file_list = get_valid_reference_files()

    response_data = {
        "message": f"Processed {len(files)} file(s).",
        "uploaded_files": uploaded_filenames,  # List of successfully saved *new* files this request
        "all_reference_files": updated_file_list,  # Complete current list
        "errors": errors,
    }

    status_code = (
        200 if not errors or len(errors) < len(files) else 400
    )  # OK if at least one succeeded, else Bad Request
    if errors:
        logger.warning(f"Upload completed with errors: {errors}")

    return JSONResponse(content=response_data, status_code=status_code)


# --- Health Check Endpoint ---
@app.get("/health", tags=["Server Status"], summary="Check server health")
async def health_check():
    """Basic health check, indicates if the server is running and if the model is loaded."""
    # Access the MODEL_LOADED variable *directly* from the engine module
    # each time the endpoint is called to get the current status.
    current_model_status = getattr(engine, "MODEL_LOADED", False)  # Safely get status
    logger.debug(
        f"Health check returning model_loaded status: {current_model_status}"
    )  # Add debug log
    return {"status": "healthy", "model_loaded": current_model_status}


# --- Main Execution ---
if __name__ == "__main__":
    host = get_host()
    port = get_port()
    logger.info(f"Starting Dia TTS server on {host}:{port}")
    logger.info(f"Model Repository: {get_model_repo_id()}")
    logger.info(f"Model Config File: {get_model_config_filename()}")
    logger.info(f"Model Weights File: {get_model_weights_filename()}")
    logger.info(f"Model Cache Path: {get_model_cache_path()}")
    logger.info(f"Reference Audio Path: {get_reference_audio_path()}")
    logger.info(f"Output Path: {get_output_path()}")
    # Determine the host to display in logs and use for browser opening
    display_host = "localhost" if host == "0.0.0.0" else host
    logger.info(f"Web UI will be available at http://{display_host}:{port}/")
    logger.info(f"API Docs available at http://{display_host}:{port}/docs")

    # Ensure UI directory and index.html exist for UI
    ui_dir = "ui"
    index_file = os.path.join(ui_dir, "index.html")
    if not os.path.isdir(ui_dir) or not os.path.isfile(index_file):
        logger.warning(
            f"'{ui_dir}' directory or '{index_file}' not found. Web UI may not work."
        )
        # Optionally create dummy files/dirs if needed for startup
        os.makedirs(ui_dir, exist_ok=True)
        if not os.path.isfile(index_file):
            try:
                with open(index_file, "w") as f:
                    f.write(
                        "<html><body>Web UI template missing. See project source for index.html.</body></html>"
                    )
                logger.info(f"Created dummy {index_file}.")
            except Exception as e:
                logger.error(f"Failed to create dummy {index_file}: {e}")

    # --- Create synchronization event ---
    # This event will be set by the lifespan manager once startup (incl. model loading) is complete.
    startup_complete_event = threading.Event()

    # Run Uvicorn server
    # The lifespan context manager ('lifespan="on"') will run during startup.
    # The 'lifespan' function is responsible for loading models and setting the 'startup_complete_event'.
    uvicorn.run(
        "server:app",  # Use the format 'module:app_instance'
        host=host,
        port=port,
        reload=False,  # Set reload as needed for development/production
        # reload_dirs=[".", "ui"], # Only use reload=True with reload_dirs/includes for development
        # reload_includes=[
        #     "*.py",
        #     "*.html",
        #     "*.css",
        #     "*.js",
        #     ".env",
        #     "*.yaml",
        # ],
        lifespan="on",  # Use the lifespan context manager defined in this file
        # workers=1 # Keep workers=1 when using reload=True or complex global state/models
    )