# models.py # Pydantic models for API requests and potentially responses from pydantic import BaseModel, Field from typing import Optional, Literal # --- Request Models --- class OpenAITTSRequest(BaseModel): """Request model compatible with the OpenAI TTS API.""" model: str = Field( default="dia-1.6b", description="Model identifier (ignored by this server, always uses Dia). Included for compatibility.", ) input: str = Field(..., description="The text to synthesize.") voice: str = Field( default="S1", description="Voice mode or reference audio filename. Examples: 'S1', 'S2', 'dialogue', 'my_reference.wav'.", ) response_format: Literal["opus", "wav"] = Field( default="opus", description="The desired audio output format." ) speed: float = Field( default=1.0, ge=0.8, le=1.2, # Dia speed factor range seems narrower description="Adjusts the speed of the generated audio (0.8 to 1.2).", ) class CustomTTSRequest(BaseModel): """Request model for the custom /tts endpoint.""" text: str = Field( ..., description="The text to synthesize. For 'dialogue' mode, include [S1]/[S2] tags.", ) voice_mode: Literal["dialogue", "single_s1", "single_s2", "clone"] = Field( default="single_s1", description="Specifies the generation mode." ) clone_reference_filename: Optional[str] = Field( default=None, description="Filename of the reference audio within the configured reference path (required if voice_mode is 'clone').", ) output_format: Literal["opus", "wav"] = Field( default="opus", description="The desired audio output format." ) # Dia-specific generation parameters max_tokens: Optional[int] = Field( default=None, gt=0, description="Maximum number of audio tokens to generate (defaults to model's internal config value).", ) cfg_scale: float = Field( default=3.0, ge=1.0, le=5.0, description="Classifier-Free Guidance scale (1.0-5.0).", ) temperature: float = Field( default=1.3, ge=1.0, le=1.5, description="Sampling temperature (1.0-1.5)." ) top_p: float = Field( default=0.95, ge=0.8, le=1.0, description="Nucleus sampling probability (0.8-1.0).", ) speed_factor: float = Field( default=0.94, ge=0.8, le=1.0, # Dia's default range seems to be <= 1.0 description="Adjusts the speed of the generated audio (0.8 to 1.0).", ) cfg_filter_top_k: int = Field( default=35, ge=15, le=50, description="Top k filter for CFG guidance (15-50)." ) # --- Response Models (Optional, can be simple dicts too) --- class TTSResponse(BaseModel): """Basic response model for successful generation (if returning JSON).""" request_id: str status: str = "completed" generation_time_sec: float output_url: Optional[str] = None # If saving file and returning URL class ErrorResponse(BaseModel): """Error response model.""" detail: str