Spaces:
Running
Running
# models.py | |
# Pydantic models for API requests and potentially responses | |
from pydantic import BaseModel, Field | |
from typing import Optional, Literal | |
# --- Request Models --- | |
class OpenAITTSRequest(BaseModel): | |
"""Request model compatible with the OpenAI TTS API.""" | |
model: str = Field( | |
default="dia-1.6b", | |
description="Model identifier (ignored by this server, always uses Dia). Included for compatibility.", | |
) | |
input: str = Field(..., description="The text to synthesize.") | |
voice: str = Field( | |
default="S1", | |
description="Voice mode or reference audio filename. Examples: 'S1', 'S2', 'dialogue', 'my_reference.wav'.", | |
) | |
response_format: Literal["opus", "wav"] = Field( | |
default="opus", description="The desired audio output format." | |
) | |
speed: float = Field( | |
default=1.0, | |
ge=0.8, | |
le=1.2, # Dia speed factor range seems narrower | |
description="Adjusts the speed of the generated audio (0.8 to 1.2).", | |
) | |
class CustomTTSRequest(BaseModel): | |
"""Request model for the custom /tts endpoint.""" | |
text: str = Field( | |
..., | |
description="The text to synthesize. For 'dialogue' mode, include [S1]/[S2] tags.", | |
) | |
voice_mode: Literal["dialogue", "single_s1", "single_s2", "clone"] = Field( | |
default="single_s1", description="Specifies the generation mode." | |
) | |
clone_reference_filename: Optional[str] = Field( | |
default=None, | |
description="Filename of the reference audio within the configured reference path (required if voice_mode is 'clone').", | |
) | |
output_format: Literal["opus", "wav"] = Field( | |
default="opus", description="The desired audio output format." | |
) | |
# Dia-specific generation parameters | |
max_tokens: Optional[int] = Field( | |
default=None, | |
gt=0, | |
description="Maximum number of audio tokens to generate (defaults to model's internal config value).", | |
) | |
cfg_scale: float = Field( | |
default=3.0, | |
ge=1.0, | |
le=5.0, | |
description="Classifier-Free Guidance scale (1.0-5.0).", | |
) | |
temperature: float = Field( | |
default=1.3, ge=1.0, le=1.5, description="Sampling temperature (1.0-1.5)." | |
) | |
top_p: float = Field( | |
default=0.95, | |
ge=0.8, | |
le=1.0, | |
description="Nucleus sampling probability (0.8-1.0).", | |
) | |
speed_factor: float = Field( | |
default=0.94, | |
ge=0.8, | |
le=1.0, # Dia's default range seems to be <= 1.0 | |
description="Adjusts the speed of the generated audio (0.8 to 1.0).", | |
) | |
cfg_filter_top_k: int = Field( | |
default=35, ge=15, le=50, description="Top k filter for CFG guidance (15-50)." | |
) | |
# --- Response Models (Optional, can be simple dicts too) --- | |
class TTSResponse(BaseModel): | |
"""Basic response model for successful generation (if returning JSON).""" | |
request_id: str | |
status: str = "completed" | |
generation_time_sec: float | |
output_url: Optional[str] = None # If saving file and returning URL | |
class ErrorResponse(BaseModel): | |
"""Error response model.""" | |
detail: str | |