Michael Hu commited on
Commit
05b45a5
·
1 Parent(s): e55a2a8

initial check in

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +66 -0
  2. api/__init__.py +1 -0
  3. api/src/builds/v1_0/config.json +172 -0
  4. api/src/core/__init__.py +3 -0
  5. api/src/core/config.py +85 -0
  6. api/src/core/don_quixote.txt +9 -0
  7. api/src/core/model_config.py +50 -0
  8. api/src/core/openai_mappings.json +18 -0
  9. api/src/core/paths.py +413 -0
  10. api/src/inference/__init__.py +12 -0
  11. api/src/inference/base.py +127 -0
  12. api/src/inference/kokoro_v1.py +370 -0
  13. api/src/inference/model_manager.py +171 -0
  14. api/src/inference/voice_manager.py +115 -0
  15. api/src/main.py +152 -0
  16. api/src/models/v1_0/config.json +150 -0
  17. api/src/routers/__init__.py +1 -0
  18. api/src/routers/debug.py +209 -0
  19. api/src/routers/development.py +408 -0
  20. api/src/routers/openai_compatible.py +662 -0
  21. api/src/routers/web_player.py +49 -0
  22. api/src/services/__init__.py +3 -0
  23. api/src/services/audio.py +248 -0
  24. api/src/services/streaming_audio_writer.py +100 -0
  25. api/src/services/temp_manager.py +170 -0
  26. api/src/services/text_processing/__init__.py +21 -0
  27. api/src/services/text_processing/normalizer.py +415 -0
  28. api/src/services/text_processing/phonemizer.py +102 -0
  29. api/src/services/text_processing/text_processor.py +276 -0
  30. api/src/services/text_processing/vocabulary.py +40 -0
  31. api/src/services/tts_service.py +459 -0
  32. api/src/structures/__init__.py +17 -0
  33. api/src/structures/custom_responses.py +50 -0
  34. api/src/structures/model_schemas.py +16 -0
  35. api/src/structures/schemas.py +158 -0
  36. api/src/structures/text_schemas.py +41 -0
  37. api/tests/__init__.py +1 -0
  38. api/tests/conftest.py +71 -0
  39. api/tests/test_audio_service.py +256 -0
  40. api/tests/test_data/generate_test_data.py +23 -0
  41. api/tests/test_data/test_audio.npy +0 -0
  42. api/tests/test_development.py +34 -0
  43. api/tests/test_kokoro_v1.py +165 -0
  44. api/tests/test_normalizer.py +179 -0
  45. api/tests/test_openai_endpoints.py +499 -0
  46. api/tests/test_paths.py +138 -0
  47. api/tests/test_text_processor.py +105 -0
  48. api/tests/test_tts_service.py +126 -0
  49. charts/kokoro-fastapi/.helmignore +23 -0
  50. charts/kokoro-fastapi/Chart.yaml +12 -0
Dockerfile ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Install dependencies and check espeak location
4
+ RUN apt-get update && apt-get install -y \
5
+ espeak-ng \
6
+ espeak-ng-data \
7
+ git \
8
+ libsndfile1 \
9
+ curl \
10
+ ffmpeg \
11
+ g++ \
12
+ && apt-get clean \
13
+ && rm -rf /var/lib/apt/lists/* \
14
+ && mkdir -p /usr/share/espeak-ng-data \
15
+ && ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
16
+
17
+ # Install UV using the installer script
18
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
19
+ mv /root/.local/bin/uv /usr/local/bin/ && \
20
+ mv /root/.local/bin/uvx /usr/local/bin/
21
+
22
+ # Create non-root user and set up directories and permissions
23
+ RUN useradd -m -u 1000 appuser && \
24
+ mkdir -p /app/api/src/models/v1_0 && \
25
+ chown -R appuser:appuser /app
26
+
27
+ USER appuser
28
+ WORKDIR /app
29
+
30
+ # Copy dependency files
31
+ COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
32
+
33
+ # Install Rust (required to build sudachipy and pyopenjtalk-plus)
34
+ RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
35
+ ENV PATH="/home/appuser/.cargo/bin:$PATH"
36
+
37
+ # Install dependencies
38
+ RUN --mount=type=cache,target=/root/.cache/uv \
39
+ uv venv --python 3.10 && \
40
+ uv sync --extra cpu
41
+
42
+ # Copy project files including models
43
+ COPY --chown=appuser:appuser api ./api
44
+ COPY --chown=appuser:appuser web ./web
45
+ COPY --chown=appuser:appuser docker/scripts/ ./
46
+ RUN chmod +x ./entrypoint.sh
47
+
48
+ # Set environment variables
49
+ ENV PYTHONUNBUFFERED=1 \
50
+ PYTHONPATH=/app:/app/api \
51
+ PATH="/app/.venv/bin:$PATH" \
52
+ UV_LINK_MODE=copy \
53
+ USE_GPU=false \
54
+ PHONEMIZER_ESPEAK_PATH=/usr/bin \
55
+ PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
56
+ ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
57
+
58
+ ENV DOWNLOAD_MODEL=true
59
+ # Download model if enabled
60
+ RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \
61
+ python download_model.py --output api/src/models/v1_0; \
62
+ fi
63
+
64
+ ENV DEVICE="cpu"
65
+ # Run FastAPI server through entrypoint.sh
66
+ CMD ["./entrypoint.sh"]
api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Make api directory a Python package
api/src/builds/v1_0/config.json ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "istftnet": {
3
+ "upsample_kernel_sizes": [
4
+ 20,
5
+ 12
6
+ ],
7
+ "upsample_rates": [
8
+ 10,
9
+ 6
10
+ ],
11
+ "gen_istft_hop_size": 5,
12
+ "gen_istft_n_fft": 20,
13
+ "resblock_dilation_sizes": [
14
+ [
15
+ 1,
16
+ 3,
17
+ 5
18
+ ],
19
+ [
20
+ 1,
21
+ 3,
22
+ 5
23
+ ],
24
+ [
25
+ 1,
26
+ 3,
27
+ 5
28
+ ]
29
+ ],
30
+ "resblock_kernel_sizes": [
31
+ 3,
32
+ 7,
33
+ 11
34
+ ],
35
+ "upsample_initial_channel": 512
36
+ },
37
+ "dim_in": 64,
38
+ "dropout": 0.2,
39
+ "hidden_dim": 512,
40
+ "max_conv_dim": 512,
41
+ "max_dur": 50,
42
+ "multispeaker": true,
43
+ "n_layer": 3,
44
+ "n_mels": 80,
45
+ "n_token": 178,
46
+ "style_dim": 128,
47
+ "text_encoder_kernel_size": 5,
48
+ "plbert": {
49
+ "hidden_size": 768,
50
+ "num_attention_heads": 12,
51
+ "intermediate_size": 2048,
52
+ "max_position_embeddings": 512,
53
+ "num_hidden_layers": 12,
54
+ "dropout": 0.1
55
+ },
56
+ "vocab": {
57
+ ";": 1,
58
+ ":": 2,
59
+ ",": 3,
60
+ ".": 4,
61
+ "!": 5,
62
+ "?": 6,
63
+ "—": 9,
64
+ "…": 10,
65
+ "\"": 11,
66
+ "(": 12,
67
+ ")": 13,
68
+ "“": 14,
69
+ "”": 15,
70
+ " ": 16,
71
+ "̃": 17,
72
+ "ʣ": 18,
73
+ "ʥ": 19,
74
+ "ʦ": 20,
75
+ "ʨ": 21,
76
+ "ᵝ": 22,
77
+ "ꭧ": 23,
78
+ "A": 24,
79
+ "I": 25,
80
+ "O": 31,
81
+ "Q": 33,
82
+ "S": 35,
83
+ "T": 36,
84
+ "W": 39,
85
+ "Y": 41,
86
+ "ᵊ": 42,
87
+ "a": 43,
88
+ "b": 44,
89
+ "c": 45,
90
+ "d": 46,
91
+ "e": 47,
92
+ "f": 48,
93
+ "h": 50,
94
+ "i": 51,
95
+ "j": 52,
96
+ "k": 53,
97
+ "l": 54,
98
+ "m": 55,
99
+ "n": 56,
100
+ "o": 57,
101
+ "p": 58,
102
+ "q": 59,
103
+ "r": 60,
104
+ "s": 61,
105
+ "t": 62,
106
+ "u": 63,
107
+ "v": 64,
108
+ "w": 65,
109
+ "x": 66,
110
+ "y": 67,
111
+ "z": 68,
112
+ "ɑ": 69,
113
+ "ɐ": 70,
114
+ "ɒ": 71,
115
+ "æ": 72,
116
+ "β": 75,
117
+ "ɔ": 76,
118
+ "ɕ": 77,
119
+ "ç": 78,
120
+ "ɖ": 80,
121
+ "ð": 81,
122
+ "ʤ": 82,
123
+ "ə": 83,
124
+ "ɚ": 85,
125
+ "ɛ": 86,
126
+ "ɜ": 87,
127
+ "ɟ": 90,
128
+ "ɡ": 92,
129
+ "ɥ": 99,
130
+ "ɨ": 101,
131
+ "ɪ": 102,
132
+ "ʝ": 103,
133
+ "ɯ": 110,
134
+ "ɰ": 111,
135
+ "ŋ": 112,
136
+ "ɳ": 113,
137
+ "ɲ": 114,
138
+ "ɴ": 115,
139
+ "ø": 116,
140
+ "ɸ": 118,
141
+ "θ": 119,
142
+ "œ": 120,
143
+ "ɹ": 123,
144
+ "ɾ": 125,
145
+ "ɻ": 126,
146
+ "ʁ": 128,
147
+ "ɽ": 129,
148
+ "ʂ": 130,
149
+ "ʃ": 131,
150
+ "ʈ": 132,
151
+ "ʧ": 133,
152
+ "ʊ": 135,
153
+ "ʋ": 136,
154
+ "ʌ": 138,
155
+ "ɣ": 139,
156
+ "ɤ": 140,
157
+ "χ": 142,
158
+ "ʎ": 143,
159
+ "ʒ": 147,
160
+ "ʔ": 148,
161
+ "ˈ": 156,
162
+ "ˌ": 157,
163
+ "ː": 158,
164
+ "ʰ": 162,
165
+ "ʲ": 164,
166
+ "↓": 169,
167
+ "→": 171,
168
+ "↗": 172,
169
+ "↘": 173,
170
+ "ᵻ": 177
171
+ }
172
+ }
api/src/core/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config import settings
2
+
3
+ __all__ = ["settings"]
api/src/core/config.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pydantic_settings import BaseSettings
3
+
4
+
5
+ class Settings(BaseSettings):
6
+ # API Settings
7
+ api_title: str = "Kokoro TTS API"
8
+ api_description: str = "API for text-to-speech generation using Kokoro"
9
+ api_version: str = "1.0.0"
10
+ host: str = "0.0.0.0"
11
+ port: int = 8880
12
+
13
+ # Application Settings
14
+ output_dir: str = "output"
15
+ output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
16
+ default_voice: str = "af_heart"
17
+ default_voice_code: str | None = (
18
+ None # If set, overrides the first letter of voice name, though api call param still takes precedence
19
+ )
20
+ use_gpu: bool = True # Whether to use GPU acceleration if available
21
+ device_type: str | None = (
22
+ None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
23
+ )
24
+ allow_local_voice_saving: bool = (
25
+ False # Whether to allow saving combined voices locally
26
+ )
27
+
28
+ # Container absolute paths
29
+ model_dir: str = "/app/api/src/models" # Absolute path in container
30
+ voices_dir: str = "/app/api/src/voices/v1_0" # Absolute path in container
31
+
32
+ # Audio Settings
33
+ sample_rate: int = 24000
34
+ # Text Processing Settings
35
+ target_min_tokens: int = 175 # Target minimum tokens per chunk
36
+ target_max_tokens: int = 250 # Target maximum tokens per chunk
37
+ absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
38
+ advanced_text_normalization: bool = True # Preproesses the text before misiki
39
+ voice_weight_normalization: bool = (
40
+ True # Normalize the voice weights so they add up to 1
41
+ )
42
+
43
+ gap_trim_ms: int = (
44
+ 1 # Base amount to trim from streaming chunk ends in milliseconds
45
+ )
46
+ dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
47
+ dynamic_gap_trim_padding_char_multiplier: dict[str, float] = {
48
+ ".": 1,
49
+ "!": 0.9,
50
+ "?": 1,
51
+ ",": 0.8,
52
+ }
53
+
54
+ # Web Player Settings
55
+ enable_web_player: bool = True # Whether to serve the web player UI
56
+ web_player_path: str = "web" # Path to web player static files
57
+ cors_origins: list[str] = ["*"] # CORS origins for web player
58
+ cors_enabled: bool = True # Whether to enable CORS
59
+
60
+ # Temp File Settings for WEB Ui
61
+ temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
62
+ max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
63
+ max_temp_dir_age_hours: int = 1 # Remove temp files older than 1 hour
64
+ max_temp_dir_count: int = 3 # Maximum number of temp files to keep
65
+
66
+ class Config:
67
+ env_file = ".env"
68
+
69
+ def get_device(self) -> str:
70
+ """Get the appropriate device based on settings and availability"""
71
+ if not self.use_gpu:
72
+ return "cpu"
73
+
74
+ if self.device_type:
75
+ return self.device_type
76
+
77
+ # Auto-detect device
78
+ if torch.backends.mps.is_available():
79
+ return "mps"
80
+ elif torch.cuda.is_available():
81
+ return "cuda"
82
+ return "cpu"
83
+
84
+
85
+ settings = Settings()
api/src/core/don_quixote.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ In a village of La Mancha, the name of which I have no desire to call
2
+ to mind, there lived not long since one of those gentlemen that keep a
3
+ lance in the lance-rack, an old buckler, a lean hack, and a greyhound
4
+ for coursing. An olla of rather more beef than mutton, a salad on most
5
+ nights, scraps on Saturdays, lentils on Fridays, and a pigeon or so
6
+ extra on Sundays, made away with three-quarters of his income. The rest
7
+ of it went in a doublet of fine cloth and velvet breeches and shoes to
8
+ match for holidays, while on week-days he made a brave figure in his
9
+ best homespun.
api/src/core/model_config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model configuration for Kokoro V1.
2
+
3
+ This module provides model-specific configuration settings that complement the application-level
4
+ settings in config.py. While config.py handles general application settings (API, paths, etc.),
5
+ this module focuses on memory management and model file paths.
6
+ """
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ class KokoroV1Config(BaseModel):
12
+ """Kokoro V1 configuration."""
13
+
14
+ languages: list[str] = ["en"]
15
+
16
+ class Config:
17
+ frozen = True
18
+
19
+
20
+ class PyTorchConfig(BaseModel):
21
+ """PyTorch backend configuration."""
22
+
23
+ memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
24
+ retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
25
+
26
+ class Config:
27
+ frozen = True
28
+
29
+
30
+ class ModelConfig(BaseModel):
31
+ """Kokoro V1 model configuration."""
32
+
33
+ # General settings
34
+ cache_voices: bool = Field(True, description="Whether to cache voice tensors")
35
+ voice_cache_size: int = Field(2, description="Maximum number of cached voices")
36
+
37
+ # Model filename
38
+ pytorch_kokoro_v1_file: str = Field(
39
+ "v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename"
40
+ )
41
+
42
+ # Backend config
43
+ pytorch_gpu: PyTorchConfig = Field(default_factory=PyTorchConfig)
44
+
45
+ class Config:
46
+ frozen = True
47
+
48
+
49
+ # Global instance
50
+ model_config = ModelConfig()
api/src/core/openai_mappings.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "tts-1": "kokoro-v1_0",
4
+ "tts-1-hd": "kokoro-v1_0",
5
+ "kokoro": "kokoro-v1_0"
6
+ },
7
+ "voices": {
8
+ "alloy": "am_v0adam",
9
+ "ash": "af_v0nicole",
10
+ "coral": "bf_v0emma",
11
+ "echo": "af_v0bella",
12
+ "fable": "af_sarah",
13
+ "onyx": "bm_george",
14
+ "nova": "bf_isabella",
15
+ "sage": "am_michael",
16
+ "shimmer": "af_sky"
17
+ }
18
+ }
api/src/core/paths.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Async file and path operations."""
2
+
3
+ import io
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set
8
+
9
+ import aiofiles
10
+ import aiofiles.os
11
+ import torch
12
+ from loguru import logger
13
+
14
+ from .config import settings
15
+
16
+
17
+ async def _find_file(
18
+ filename: str,
19
+ search_paths: List[str],
20
+ filter_fn: Optional[Callable[[str], bool]] = None,
21
+ ) -> str:
22
+ """Find file in search paths.
23
+
24
+ Args:
25
+ filename: Name of file to find
26
+ search_paths: List of paths to search in
27
+ filter_fn: Optional function to filter files
28
+
29
+ Returns:
30
+ Absolute path to file
31
+
32
+ Raises:
33
+ RuntimeError: If file not found
34
+ """
35
+ if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
36
+ return filename
37
+
38
+ for path in search_paths:
39
+ full_path = os.path.join(path, filename)
40
+ if await aiofiles.os.path.exists(full_path):
41
+ if filter_fn is None or filter_fn(full_path):
42
+ return full_path
43
+
44
+ raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}")
45
+
46
+
47
+ async def _scan_directories(
48
+ search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None
49
+ ) -> Set[str]:
50
+ """Scan directories for files.
51
+
52
+ Args:
53
+ search_paths: List of paths to scan
54
+ filter_fn: Optional function to filter files
55
+
56
+ Returns:
57
+ Set of matching filenames
58
+ """
59
+ results = set()
60
+
61
+ for path in search_paths:
62
+ if not await aiofiles.os.path.exists(path):
63
+ continue
64
+
65
+ try:
66
+ # Get directory entries first
67
+ entries = await aiofiles.os.scandir(path)
68
+ # Then process entries after await completes
69
+ for entry in entries:
70
+ if filter_fn is None or filter_fn(entry.name):
71
+ results.add(entry.name)
72
+ except Exception as e:
73
+ logger.warning(f"Error scanning {path}: {e}")
74
+
75
+ return results
76
+
77
+
78
+ async def get_model_path(model_name: str) -> str:
79
+ """Get path to model file.
80
+
81
+ Args:
82
+ model_name: Name of model file
83
+
84
+ Returns:
85
+ Absolute path to model file
86
+
87
+ Raises:
88
+ RuntimeError: If model not found
89
+ """
90
+ # Get api directory path (two levels up from core)
91
+ api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
92
+
93
+ # Construct model directory path relative to api directory
94
+ model_dir = os.path.join(api_dir, settings.model_dir)
95
+
96
+ # Ensure model directory exists
97
+ os.makedirs(model_dir, exist_ok=True)
98
+
99
+ # Search in model directory
100
+ search_paths = [model_dir]
101
+ logger.debug(f"Searching for model in path: {model_dir}")
102
+
103
+ return await _find_file(model_name, search_paths)
104
+
105
+
106
+ async def get_voice_path(voice_name: str) -> str:
107
+ """Get path to voice file.
108
+
109
+ Args:
110
+ voice_name: Name of voice file (without .pt extension)
111
+
112
+ Returns:
113
+ Absolute path to voice file
114
+
115
+ Raises:
116
+ RuntimeError: If voice not found
117
+ """
118
+ # Get api directory path
119
+ api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
120
+
121
+ # Construct voice directory path relative to api directory
122
+ voice_dir = os.path.join(api_dir, settings.voices_dir)
123
+
124
+ # Ensure voice directory exists
125
+ os.makedirs(voice_dir, exist_ok=True)
126
+
127
+ voice_file = f"{voice_name}.pt"
128
+
129
+ # Search in voice directory/o
130
+ search_paths = [voice_dir]
131
+ logger.debug(f"Searching for voice in path: {voice_dir}")
132
+
133
+ return await _find_file(voice_file, search_paths)
134
+
135
+
136
+ async def list_voices() -> List[str]:
137
+ """List available voice files.
138
+
139
+ Returns:
140
+ List of voice names (without .pt extension)
141
+ """
142
+ # Get api directory path
143
+ api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
144
+
145
+ # Construct voice directory path relative to api directory
146
+ voice_dir = os.path.join(api_dir, settings.voices_dir)
147
+
148
+ # Ensure voice directory exists
149
+ os.makedirs(voice_dir, exist_ok=True)
150
+
151
+ # Search in voice directory
152
+ search_paths = [voice_dir]
153
+ logger.debug(f"Scanning for voices in path: {voice_dir}")
154
+
155
+ def filter_voice_files(name: str) -> bool:
156
+ return name.endswith(".pt")
157
+
158
+ voices = await _scan_directories(search_paths, filter_voice_files)
159
+ return sorted([name[:-3] for name in voices]) # Remove .pt extension
160
+
161
+
162
+ async def load_voice_tensor(
163
+ voice_path: str, device: str = "cpu", weights_only=False
164
+ ) -> torch.Tensor:
165
+ """Load voice tensor from file.
166
+
167
+ Args:
168
+ voice_path: Path to voice file
169
+ device: Device to load tensor to
170
+
171
+ Returns:
172
+ Voice tensor
173
+
174
+ Raises:
175
+ RuntimeError: If file cannot be read
176
+ """
177
+ try:
178
+ async with aiofiles.open(voice_path, "rb") as f:
179
+ data = await f.read()
180
+ return torch.load(
181
+ io.BytesIO(data), map_location=device, weights_only=weights_only
182
+ )
183
+ except Exception as e:
184
+ raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
185
+
186
+
187
+ async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
188
+ """Save voice tensor to file.
189
+
190
+ Args:
191
+ tensor: Voice tensor to save
192
+ voice_path: Path to save voice file
193
+
194
+ Raises:
195
+ RuntimeError: If file cannot be written
196
+ """
197
+ try:
198
+ buffer = io.BytesIO()
199
+ torch.save(tensor, buffer)
200
+ async with aiofiles.open(voice_path, "wb") as f:
201
+ await f.write(buffer.getvalue())
202
+ except Exception as e:
203
+ raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
204
+
205
+
206
+ async def load_json(path: str) -> dict:
207
+ """Load JSON file asynchronously.
208
+
209
+ Args:
210
+ path: Path to JSON file
211
+
212
+ Returns:
213
+ Parsed JSON data
214
+
215
+ Raises:
216
+ RuntimeError: If file cannot be read or parsed
217
+ """
218
+ try:
219
+ async with aiofiles.open(path, "r", encoding="utf-8") as f:
220
+ content = await f.read()
221
+ return json.loads(content)
222
+ except Exception as e:
223
+ raise RuntimeError(f"Failed to load JSON file {path}: {e}")
224
+
225
+
226
+ async def load_model_weights(path: str, device: str = "cpu") -> dict:
227
+ """Load model weights asynchronously.
228
+
229
+ Args:
230
+ path: Path to model file (.pth or .onnx)
231
+ device: Device to load model to
232
+
233
+ Returns:
234
+ Model weights
235
+
236
+ Raises:
237
+ RuntimeError: If file cannot be read
238
+ """
239
+ try:
240
+ async with aiofiles.open(path, "rb") as f:
241
+ data = await f.read()
242
+ return torch.load(io.BytesIO(data), map_location=device, weights_only=True)
243
+ except Exception as e:
244
+ raise RuntimeError(f"Failed to load model weights from {path}: {e}")
245
+
246
+
247
+ async def read_file(path: str) -> str:
248
+ """Read text file asynchronously.
249
+
250
+ Args:
251
+ path: Path to file
252
+
253
+ Returns:
254
+ File contents as string
255
+
256
+ Raises:
257
+ RuntimeError: If file cannot be read
258
+ """
259
+ try:
260
+ async with aiofiles.open(path, "r", encoding="utf-8") as f:
261
+ return await f.read()
262
+ except Exception as e:
263
+ raise RuntimeError(f"Failed to read file {path}: {e}")
264
+
265
+
266
+ async def read_bytes(path: str) -> bytes:
267
+ """Read file as bytes asynchronously.
268
+
269
+ Args:
270
+ path: Path to file
271
+
272
+ Returns:
273
+ File contents as bytes
274
+
275
+ Raises:
276
+ RuntimeError: If file cannot be read
277
+ """
278
+ try:
279
+ async with aiofiles.open(path, "rb") as f:
280
+ return await f.read()
281
+ except Exception as e:
282
+ raise RuntimeError(f"Failed to read file {path}: {e}")
283
+
284
+
285
+ async def get_web_file_path(filename: str) -> str:
286
+ """Get path to web static file.
287
+
288
+ Args:
289
+ filename: Name of file in web directory
290
+
291
+ Returns:
292
+ Absolute path to file
293
+
294
+ Raises:
295
+ RuntimeError: If file not found
296
+ """
297
+ # Get project root directory (four levels up from core to get to project root)
298
+ root_dir = os.path.dirname(
299
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
300
+ )
301
+
302
+ # Construct web directory path relative to project root
303
+ web_dir = os.path.join("/app", settings.web_player_path)
304
+
305
+ # Search in web directory
306
+ search_paths = [web_dir]
307
+ logger.debug(f"Searching for web file in path: {web_dir}")
308
+
309
+ return await _find_file(filename, search_paths)
310
+
311
+
312
+ async def get_content_type(path: str) -> str:
313
+ """Get content type for file.
314
+
315
+ Args:
316
+ path: Path to file
317
+
318
+ Returns:
319
+ Content type string
320
+ """
321
+ ext = os.path.splitext(path)[1].lower()
322
+ return {
323
+ ".html": "text/html",
324
+ ".js": "application/javascript",
325
+ ".css": "text/css",
326
+ ".png": "image/png",
327
+ ".jpg": "image/jpeg",
328
+ ".jpeg": "image/jpeg",
329
+ ".gif": "image/gif",
330
+ ".svg": "image/svg+xml",
331
+ ".ico": "image/x-icon",
332
+ }.get(ext, "application/octet-stream")
333
+
334
+
335
+ async def verify_model_path(model_path: str) -> bool:
336
+ """Verify model file exists at path."""
337
+ return await aiofiles.os.path.exists(model_path)
338
+
339
+
340
+ async def cleanup_temp_files() -> None:
341
+ """Clean up old temp files on startup"""
342
+ try:
343
+ if not await aiofiles.os.path.exists(settings.temp_file_dir):
344
+ await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
345
+ return
346
+
347
+ entries = await aiofiles.os.scandir(settings.temp_file_dir)
348
+ for entry in entries:
349
+ if entry.is_file():
350
+ stat = await aiofiles.os.stat(entry.path)
351
+ max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
352
+ if max_age < stat.st_mtime:
353
+ try:
354
+ await aiofiles.os.remove(entry.path)
355
+ logger.info(f"Cleaned up old temp file: {entry.name}")
356
+ except Exception as e:
357
+ logger.warning(
358
+ f"Failed to delete old temp file {entry.name}: {e}"
359
+ )
360
+ except Exception as e:
361
+ logger.warning(f"Error cleaning temp files: {e}")
362
+
363
+
364
+ async def get_temp_file_path(filename: str) -> str:
365
+ """Get path to temporary audio file.
366
+
367
+ Args:
368
+ filename: Name of temp file
369
+
370
+ Returns:
371
+ Absolute path to temp file
372
+
373
+ Raises:
374
+ RuntimeError: If temp directory does not exist
375
+ """
376
+ temp_path = os.path.join(settings.temp_file_dir, filename)
377
+
378
+ # Ensure temp directory exists
379
+ if not await aiofiles.os.path.exists(settings.temp_file_dir):
380
+ await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
381
+
382
+ return temp_path
383
+
384
+
385
+ async def list_temp_files() -> List[str]:
386
+ """List temporary audio files.
387
+
388
+ Returns:
389
+ List of temp file names
390
+ """
391
+ if not await aiofiles.os.path.exists(settings.temp_file_dir):
392
+ return []
393
+
394
+ entries = await aiofiles.os.scandir(settings.temp_file_dir)
395
+ return [entry.name for entry in entries if entry.is_file()]
396
+
397
+
398
+ async def get_temp_dir_size() -> int:
399
+ """Get total size of temp directory in bytes.
400
+
401
+ Returns:
402
+ Size in bytes
403
+ """
404
+ if not await aiofiles.os.path.exists(settings.temp_file_dir):
405
+ return 0
406
+
407
+ total = 0
408
+ entries = await aiofiles.os.scandir(settings.temp_file_dir)
409
+ for entry in entries:
410
+ if entry.is_file():
411
+ stat = await aiofiles.os.stat(entry.path)
412
+ total += stat.st_size
413
+ return total
api/src/inference/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model inference package."""
2
+
3
+ from .base import BaseModelBackend
4
+ from .kokoro_v1 import KokoroV1
5
+ from .model_manager import ModelManager, get_manager
6
+
7
+ __all__ = [
8
+ "BaseModelBackend",
9
+ "ModelManager",
10
+ "get_manager",
11
+ "KokoroV1",
12
+ ]
api/src/inference/base.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base interface for Kokoro inference."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncGenerator, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ class AudioChunk:
11
+ """Class for audio chunks returned by model backends"""
12
+
13
+ def __init__(
14
+ self,
15
+ audio: np.ndarray,
16
+ word_timestamps: Optional[List] = [],
17
+ output: Optional[Union[bytes, np.ndarray]] = b"",
18
+ ):
19
+ self.audio = audio
20
+ self.word_timestamps = word_timestamps
21
+ self.output = output
22
+
23
+ @staticmethod
24
+ def combine(audio_chunk_list: List):
25
+ output = AudioChunk(
26
+ audio_chunk_list[0].audio, audio_chunk_list[0].word_timestamps
27
+ )
28
+
29
+ for audio_chunk in audio_chunk_list[1:]:
30
+ output.audio = np.concatenate(
31
+ (output.audio, audio_chunk.audio), dtype=np.int16
32
+ )
33
+ if output.word_timestamps is not None:
34
+ output.word_timestamps += audio_chunk.word_timestamps
35
+
36
+ return output
37
+
38
+
39
+ class ModelBackend(ABC):
40
+ """Abstract base class for model inference backend."""
41
+
42
+ @abstractmethod
43
+ async def load_model(self, path: str) -> None:
44
+ """Load model from path.
45
+
46
+ Args:
47
+ path: Path to model file
48
+
49
+ Raises:
50
+ RuntimeError: If model loading fails
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ async def generate(
56
+ self,
57
+ text: str,
58
+ voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
59
+ speed: float = 1.0,
60
+ ) -> AsyncGenerator[AudioChunk, None]:
61
+ """Generate audio from text.
62
+
63
+ Args:
64
+ text: Input text to synthesize
65
+ voice: Either a voice path or tuple of (name, tensor/path)
66
+ speed: Speed multiplier
67
+
68
+ Yields:
69
+ Generated audio chunks
70
+
71
+ Raises:
72
+ RuntimeError: If generation fails
73
+ """
74
+ pass
75
+
76
+ @abstractmethod
77
+ def unload(self) -> None:
78
+ """Unload model and free resources."""
79
+ pass
80
+
81
+ @property
82
+ @abstractmethod
83
+ def is_loaded(self) -> bool:
84
+ """Check if model is loaded.
85
+
86
+ Returns:
87
+ True if model is loaded, False otherwise
88
+ """
89
+ pass
90
+
91
+ @property
92
+ @abstractmethod
93
+ def device(self) -> str:
94
+ """Get device model is running on.
95
+
96
+ Returns:
97
+ Device string ('cpu' or 'cuda')
98
+ """
99
+ pass
100
+
101
+
102
+ class BaseModelBackend(ModelBackend):
103
+ """Base implementation of model backend."""
104
+
105
+ def __init__(self):
106
+ """Initialize base backend."""
107
+ self._model: Optional[torch.nn.Module] = None
108
+ self._device: str = "cpu"
109
+
110
+ @property
111
+ def is_loaded(self) -> bool:
112
+ """Check if model is loaded."""
113
+ return self._model is not None
114
+
115
+ @property
116
+ def device(self) -> str:
117
+ """Get device model is running on."""
118
+ return self._device
119
+
120
+ def unload(self) -> None:
121
+ """Unload model and free resources."""
122
+ if self._model is not None:
123
+ del self._model
124
+ self._model = None
125
+ if torch.cuda.is_available():
126
+ torch.cuda.empty_cache()
127
+ torch.cuda.synchronize()
api/src/inference/kokoro_v1.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Clean Kokoro implementation with controlled resource management."""
2
+
3
+ import os
4
+ from typing import AsyncGenerator, Dict, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from kokoro import KModel, KPipeline
9
+ from loguru import logger
10
+
11
+ from ..core import paths
12
+ from ..core.config import settings
13
+ from ..core.model_config import model_config
14
+ from ..structures.schemas import WordTimestamp
15
+ from .base import AudioChunk, BaseModelBackend
16
+
17
+
18
+ class KokoroV1(BaseModelBackend):
19
+ """Kokoro backend with controlled resource management."""
20
+
21
+ def __init__(self):
22
+ """Initialize backend with environment-based configuration."""
23
+ super().__init__()
24
+ # Strictly respect settings.use_gpu
25
+ self._device = settings.get_device()
26
+ self._model: Optional[KModel] = None
27
+ self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
28
+
29
+ async def load_model(self, path: str) -> None:
30
+ """Load pre-baked model.
31
+
32
+ Args:
33
+ path: Path to model file
34
+
35
+ Raises:
36
+ RuntimeError: If model loading fails
37
+ """
38
+ try:
39
+ # Get verified model path
40
+ model_path = await paths.get_model_path(path)
41
+ config_path = os.path.join(os.path.dirname(model_path), "config.json")
42
+
43
+ if not os.path.exists(config_path):
44
+ raise RuntimeError(f"Config file not found: {config_path}")
45
+
46
+ logger.info(f"Loading Kokoro model on {self._device}")
47
+ logger.info(f"Config path: {config_path}")
48
+ logger.info(f"Model path: {model_path}")
49
+
50
+ # Load model and let KModel handle device mapping
51
+ self._model = KModel(config=config_path, model=model_path).eval()
52
+ # For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
53
+ if self._device == "mps":
54
+ logger.info(
55
+ "Moving model to MPS device with CPU fallback for unsupported operations"
56
+ )
57
+ self._model = self._model.to(torch.device("mps"))
58
+ elif self._device == "cuda":
59
+ self._model = self._model.cuda()
60
+ else:
61
+ self._model = self._model.cpu()
62
+
63
+ except FileNotFoundError as e:
64
+ raise e
65
+ except Exception as e:
66
+ raise RuntimeError(f"Failed to load Kokoro model: {e}")
67
+
68
+ def _get_pipeline(self, lang_code: str) -> KPipeline:
69
+ """Get or create pipeline for language code.
70
+
71
+ Args:
72
+ lang_code: Language code to use
73
+
74
+ Returns:
75
+ KPipeline instance for the language
76
+ """
77
+ if not self._model:
78
+ raise RuntimeError("Model not loaded")
79
+
80
+ if lang_code not in self._pipelines:
81
+ logger.info(f"Creating new pipeline for language code: {lang_code}")
82
+ self._pipelines[lang_code] = KPipeline(
83
+ lang_code=lang_code, model=self._model, device=self._device
84
+ )
85
+ return self._pipelines[lang_code]
86
+
87
+ async def generate_from_tokens(
88
+ self,
89
+ tokens: str,
90
+ voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
91
+ speed: float = 1.0,
92
+ lang_code: Optional[str] = None,
93
+ ) -> AsyncGenerator[np.ndarray, None]:
94
+ """Generate audio from phoneme tokens.
95
+
96
+ Args:
97
+ tokens: Input phoneme tokens to synthesize
98
+ voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
99
+ speed: Speed multiplier
100
+ lang_code: Optional language code override
101
+
102
+ Yields:
103
+ Generated audio chunks
104
+
105
+ Raises:
106
+ RuntimeError: If generation fails
107
+ """
108
+ if not self.is_loaded:
109
+ raise RuntimeError("Model not loaded")
110
+
111
+ try:
112
+ # Memory management for GPU
113
+ if self._device == "cuda":
114
+ if self._check_memory():
115
+ self._clear_memory()
116
+
117
+ # Handle voice input
118
+ voice_path: str
119
+ voice_name: str
120
+ if isinstance(voice, tuple):
121
+ voice_name, voice_data = voice
122
+ if isinstance(voice_data, str):
123
+ voice_path = voice_data
124
+ else:
125
+ # Save tensor to temporary file
126
+ import tempfile
127
+
128
+ temp_dir = tempfile.gettempdir()
129
+ voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
130
+ # Save tensor with CPU mapping for portability
131
+ torch.save(voice_data.cpu(), voice_path)
132
+ else:
133
+ voice_path = voice
134
+ voice_name = os.path.splitext(os.path.basename(voice_path))[0]
135
+
136
+ # Load voice tensor with proper device mapping
137
+ voice_tensor = await paths.load_voice_tensor(
138
+ voice_path, device=self._device
139
+ )
140
+ # Save back to a temporary file with proper device mapping
141
+ import tempfile
142
+
143
+ temp_dir = tempfile.gettempdir()
144
+ temp_path = os.path.join(
145
+ temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
146
+ )
147
+ await paths.save_voice_tensor(voice_tensor, temp_path)
148
+ voice_path = temp_path
149
+
150
+ # Use provided lang_code, settings voice code override, or first letter of voice name
151
+ if lang_code: # api is given priority
152
+ pipeline_lang_code = lang_code
153
+ elif settings.default_voice_code: # settings is next priority
154
+ pipeline_lang_code = settings.default_voice_code
155
+ else: # voice name is default/fallback
156
+ pipeline_lang_code = voice_name[0].lower()
157
+
158
+ pipeline = self._get_pipeline(pipeline_lang_code)
159
+
160
+ logger.debug(
161
+ f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}{'...' if len(tokens) > 100 else ''}'"
162
+ )
163
+ for result in pipeline.generate_from_tokens(
164
+ tokens=tokens, voice=voice_path, speed=speed, model=self._model
165
+ ):
166
+ if result.audio is not None:
167
+ logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
168
+ yield result.audio.numpy()
169
+ else:
170
+ logger.warning("No audio in chunk")
171
+
172
+ except Exception as e:
173
+ logger.error(f"Generation failed: {e}")
174
+ if (
175
+ self._device == "cuda"
176
+ and model_config.pytorch_gpu.retry_on_oom
177
+ and "out of memory" in str(e).lower()
178
+ ):
179
+ self._clear_memory()
180
+ async for chunk in self.generate_from_tokens(
181
+ tokens, voice, speed, lang_code
182
+ ):
183
+ yield chunk
184
+ raise
185
+
186
+ async def generate(
187
+ self,
188
+ text: str,
189
+ voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
190
+ speed: float = 1.0,
191
+ lang_code: Optional[str] = None,
192
+ return_timestamps: Optional[bool] = False,
193
+ ) -> AsyncGenerator[AudioChunk, None]:
194
+ """Generate audio using model.
195
+
196
+ Args:
197
+ text: Input text to synthesize
198
+ voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
199
+ speed: Speed multiplier
200
+ lang_code: Optional language code override
201
+
202
+ Yields:
203
+ Generated audio chunks
204
+
205
+ Raises:
206
+ RuntimeError: If generation fails
207
+ """
208
+ if not self.is_loaded:
209
+ raise RuntimeError("Model not loaded")
210
+ try:
211
+ # Memory management for GPU
212
+ if self._device == "cuda":
213
+ if self._check_memory():
214
+ self._clear_memory()
215
+
216
+ # Handle voice input
217
+ voice_path: str
218
+ voice_name: str
219
+ if isinstance(voice, tuple):
220
+ voice_name, voice_data = voice
221
+ if isinstance(voice_data, str):
222
+ voice_path = voice_data
223
+ else:
224
+ # Save tensor to temporary file
225
+ import tempfile
226
+
227
+ temp_dir = tempfile.gettempdir()
228
+ voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
229
+ # Save tensor with CPU mapping for portability
230
+ torch.save(voice_data.cpu(), voice_path)
231
+ else:
232
+ voice_path = voice
233
+ voice_name = os.path.splitext(os.path.basename(voice_path))[0]
234
+
235
+ # Load voice tensor with proper device mapping
236
+ voice_tensor = await paths.load_voice_tensor(
237
+ voice_path, device=self._device
238
+ )
239
+ # Save back to a temporary file with proper device mapping
240
+ import tempfile
241
+
242
+ temp_dir = tempfile.gettempdir()
243
+ temp_path = os.path.join(
244
+ temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
245
+ )
246
+ await paths.save_voice_tensor(voice_tensor, temp_path)
247
+ voice_path = temp_path
248
+
249
+ # Use provided lang_code, settings voice code override, or first letter of voice name
250
+ pipeline_lang_code = (
251
+ lang_code
252
+ if lang_code
253
+ else (
254
+ settings.default_voice_code
255
+ if settings.default_voice_code
256
+ else voice_name[0].lower()
257
+ )
258
+ )
259
+ pipeline = self._get_pipeline(pipeline_lang_code)
260
+
261
+ logger.debug(
262
+ f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}{'...' if len(text) > 100 else ''}'"
263
+ )
264
+ for result in pipeline(
265
+ text, voice=voice_path, speed=speed, model=self._model
266
+ ):
267
+ if result.audio is not None:
268
+ logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
269
+ word_timestamps = None
270
+ if (
271
+ return_timestamps
272
+ and hasattr(result, "tokens")
273
+ and result.tokens
274
+ ):
275
+ word_timestamps = []
276
+ current_offset = 0.0
277
+ logger.debug(
278
+ f"Processing chunk timestamps with {len(result.tokens)} tokens"
279
+ )
280
+ if result.pred_dur is not None:
281
+ try:
282
+ # Add timestamps with offset
283
+ for token in result.tokens:
284
+ if not all(
285
+ hasattr(token, attr)
286
+ for attr in [
287
+ "text",
288
+ "start_ts",
289
+ "end_ts",
290
+ ]
291
+ ):
292
+ continue
293
+ if not token.text or not token.text.strip():
294
+ continue
295
+
296
+ start_time = float(token.start_ts) + current_offset
297
+ end_time = float(token.end_ts) + current_offset
298
+ word_timestamps.append(
299
+ WordTimestamp(
300
+ word=str(token.text).strip(),
301
+ start_time=start_time,
302
+ end_time=end_time,
303
+ )
304
+ )
305
+ logger.debug(
306
+ f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
307
+ )
308
+
309
+ except Exception as e:
310
+ logger.error(
311
+ f"Failed to process timestamps for chunk: {e}"
312
+ )
313
+
314
+ yield AudioChunk(
315
+ result.audio.numpy(), word_timestamps=word_timestamps
316
+ )
317
+ else:
318
+ logger.warning("No audio in chunk")
319
+
320
+ except Exception as e:
321
+ logger.error(f"Generation failed: {e}")
322
+ if (
323
+ self._device == "cuda"
324
+ and model_config.pytorch_gpu.retry_on_oom
325
+ and "out of memory" in str(e).lower()
326
+ ):
327
+ self._clear_memory()
328
+ async for chunk in self.generate(text, voice, speed, lang_code):
329
+ yield chunk
330
+ raise
331
+
332
+ def _check_memory(self) -> bool:
333
+ """Check if memory usage is above threshold."""
334
+ if self._device == "cuda":
335
+ memory_gb = torch.cuda.memory_allocated() / 1e9
336
+ return memory_gb > model_config.pytorch_gpu.memory_threshold
337
+ # MPS doesn't provide memory management APIs
338
+ return False
339
+
340
+ def _clear_memory(self) -> None:
341
+ """Clear device memory."""
342
+ if self._device == "cuda":
343
+ torch.cuda.empty_cache()
344
+ torch.cuda.synchronize()
345
+ elif self._device == "mps":
346
+ # Empty cache if available (future-proofing)
347
+ if hasattr(torch.mps, "empty_cache"):
348
+ torch.mps.empty_cache()
349
+
350
+ def unload(self) -> None:
351
+ """Unload model and free resources."""
352
+ if self._model is not None:
353
+ del self._model
354
+ self._model = None
355
+ for pipeline in self._pipelines.values():
356
+ del pipeline
357
+ self._pipelines.clear()
358
+ if torch.cuda.is_available():
359
+ torch.cuda.empty_cache()
360
+ torch.cuda.synchronize()
361
+
362
+ @property
363
+ def is_loaded(self) -> bool:
364
+ """Check if model is loaded."""
365
+ return self._model is not None
366
+
367
+ @property
368
+ def device(self) -> str:
369
+ """Get device model is running on."""
370
+ return self._device
api/src/inference/model_manager.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kokoro V1 model management."""
2
+
3
+ from typing import Optional
4
+
5
+ from loguru import logger
6
+
7
+ from ..core import paths
8
+ from ..core.config import settings
9
+ from ..core.model_config import ModelConfig, model_config
10
+ from .base import BaseModelBackend
11
+ from .kokoro_v1 import KokoroV1
12
+
13
+
14
+ class ModelManager:
15
+ """Manages Kokoro V1 model loading and inference."""
16
+
17
+ # Singleton instance
18
+ _instance = None
19
+
20
+ def __init__(self, config: Optional[ModelConfig] = None):
21
+ """Initialize manager.
22
+
23
+ Args:
24
+ config: Optional model configuration override
25
+ """
26
+ self._config = config or model_config
27
+ self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
28
+ self._device: Optional[str] = None
29
+
30
+ def _determine_device(self) -> str:
31
+ """Determine device based on settings."""
32
+ return "cuda" if settings.use_gpu else "cpu"
33
+
34
+ async def initialize(self) -> None:
35
+ """Initialize Kokoro V1 backend."""
36
+ try:
37
+ self._device = self._determine_device()
38
+ logger.info(f"Initializing Kokoro V1 on {self._device}")
39
+ self._backend = KokoroV1()
40
+
41
+ except Exception as e:
42
+ raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
43
+
44
+ async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
45
+ """Initialize and warm up model.
46
+
47
+ Args:
48
+ voice_manager: Voice manager instance for warmup
49
+
50
+ Returns:
51
+ Tuple of (device, backend type, voice count)
52
+
53
+ Raises:
54
+ RuntimeError: If initialization fails
55
+ """
56
+ import time
57
+
58
+ start = time.perf_counter()
59
+
60
+ try:
61
+ # Initialize backend
62
+ await self.initialize()
63
+
64
+ # Load model
65
+ model_path = self._config.pytorch_kokoro_v1_file
66
+ await self.load_model(model_path)
67
+
68
+ # Use paths module to get voice path
69
+ try:
70
+ voices = await paths.list_voices()
71
+ voice_path = await paths.get_voice_path(settings.default_voice)
72
+
73
+ # Warm up with short text
74
+ warmup_text = "Warmup text for initialization."
75
+ # Use default voice name for warmup
76
+ voice_name = settings.default_voice
77
+ logger.debug(f"Using default voice '{voice_name}' for warmup")
78
+ async for _ in self.generate(warmup_text, (voice_name, voice_path)):
79
+ pass
80
+ except Exception as e:
81
+ raise RuntimeError(f"Failed to get default voice: {e}")
82
+
83
+ ms = int((time.perf_counter() - start) * 1000)
84
+ logger.info(f"Warmup completed in {ms}ms")
85
+
86
+ return self._device, "kokoro_v1", len(voices)
87
+ except FileNotFoundError as e:
88
+ logger.error("""
89
+ Model files not found! You need to download the Kokoro V1 model:
90
+
91
+ 1. Download model using the script:
92
+ python docker/scripts/download_model.py --output api/src/models/v1_0
93
+
94
+ 2. Or set environment variable in docker-compose:
95
+ DOWNLOAD_MODEL=true
96
+ """)
97
+ exit(0)
98
+ except Exception as e:
99
+ raise RuntimeError(f"Warmup failed: {e}")
100
+
101
+ def get_backend(self) -> BaseModelBackend:
102
+ """Get initialized backend.
103
+
104
+ Returns:
105
+ Initialized backend instance
106
+
107
+ Raises:
108
+ RuntimeError: If backend not initialized
109
+ """
110
+ if not self._backend:
111
+ raise RuntimeError("Backend not initialized")
112
+ return self._backend
113
+
114
+ async def load_model(self, path: str) -> None:
115
+ """Load model using initialized backend.
116
+
117
+ Args:
118
+ path: Path to model file
119
+
120
+ Raises:
121
+ RuntimeError: If loading fails
122
+ """
123
+ if not self._backend:
124
+ raise RuntimeError("Backend not initialized")
125
+
126
+ try:
127
+ await self._backend.load_model(path)
128
+ except FileNotFoundError as e:
129
+ raise e
130
+ except Exception as e:
131
+ raise RuntimeError(f"Failed to load model: {e}")
132
+
133
+ async def generate(self, *args, **kwargs):
134
+ """Generate audio using initialized backend.
135
+
136
+ Raises:
137
+ RuntimeError: If generation fails
138
+ """
139
+ if not self._backend:
140
+ raise RuntimeError("Backend not initialized")
141
+
142
+ try:
143
+ async for chunk in self._backend.generate(*args, **kwargs):
144
+ yield chunk
145
+ except Exception as e:
146
+ raise RuntimeError(f"Generation failed: {e}")
147
+
148
+ def unload_all(self) -> None:
149
+ """Unload model and free resources."""
150
+ if self._backend:
151
+ self._backend.unload()
152
+ self._backend = None
153
+
154
+ @property
155
+ def current_backend(self) -> str:
156
+ """Get current backend type."""
157
+ return "kokoro_v1"
158
+
159
+
160
+ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
161
+ """Get model manager instance.
162
+
163
+ Args:
164
+ config: Optional configuration override
165
+
166
+ Returns:
167
+ ModelManager instance
168
+ """
169
+ if ModelManager._instance is None:
170
+ ModelManager._instance = ModelManager(config)
171
+ return ModelManager._instance
api/src/inference/voice_manager.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Voice management with controlled resource handling."""
2
+
3
+ from typing import Dict, List, Optional
4
+
5
+ import aiofiles
6
+ import torch
7
+ from loguru import logger
8
+
9
+ from ..core import paths
10
+ from ..core.config import settings
11
+
12
+
13
+ class VoiceManager:
14
+ """Manages voice loading and caching with controlled resource usage."""
15
+
16
+ # Singleton instance
17
+ _instance = None
18
+
19
+ def __init__(self):
20
+ """Initialize voice manager."""
21
+ # Strictly respect settings.use_gpu
22
+ self._device = settings.get_device()
23
+ self._voices: Dict[str, torch.Tensor] = {}
24
+
25
+ async def get_voice_path(self, voice_name: str) -> str:
26
+ """Get path to voice file.
27
+
28
+ Args:
29
+ voice_name: Name of voice
30
+
31
+ Returns:
32
+ Path to voice file
33
+
34
+ Raises:
35
+ RuntimeError: If voice not found
36
+ """
37
+ return await paths.get_voice_path(voice_name)
38
+
39
+ async def load_voice(
40
+ self, voice_name: str, device: Optional[str] = None
41
+ ) -> torch.Tensor:
42
+ """Load voice tensor.
43
+
44
+ Args:
45
+ voice_name: Name of voice to load
46
+ device: Optional override for target device
47
+
48
+ Returns:
49
+ Voice tensor
50
+
51
+ Raises:
52
+ RuntimeError: If voice not found
53
+ """
54
+ try:
55
+ voice_path = await self.get_voice_path(voice_name)
56
+ target_device = device or self._device
57
+ voice = await paths.load_voice_tensor(voice_path, target_device)
58
+ self._voices[voice_name] = voice
59
+ return voice
60
+ except Exception as e:
61
+ raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
62
+
63
+ async def combine_voices(
64
+ self, voices: List[str], device: Optional[str] = None
65
+ ) -> torch.Tensor:
66
+ """Combine multiple voices.
67
+
68
+ Args:
69
+ voices: List of voice names to combine
70
+ device: Optional override for target device
71
+
72
+ Returns:
73
+ Combined voice tensor
74
+
75
+ Raises:
76
+ RuntimeError: If any voice not found
77
+ """
78
+ if len(voices) < 2:
79
+ raise ValueError("Need at least 2 voices to combine")
80
+
81
+ target_device = device or self._device
82
+ voice_tensors = []
83
+ for name in voices:
84
+ voice = await self.load_voice(name, target_device)
85
+ voice_tensors.append(voice)
86
+
87
+ combined = torch.mean(torch.stack(voice_tensors), dim=0)
88
+ return combined
89
+
90
+ async def list_voices(self) -> List[str]:
91
+ """List available voice names.
92
+
93
+ Returns:
94
+ List of voice names
95
+ """
96
+ return await paths.list_voices()
97
+
98
+ def cache_info(self) -> Dict[str, int]:
99
+ """Get cache statistics.
100
+
101
+ Returns:
102
+ Dict with cache statistics
103
+ """
104
+ return {"loaded_voices": len(self._voices), "device": self._device}
105
+
106
+
107
+ async def get_manager() -> VoiceManager:
108
+ """Get voice manager instance.
109
+
110
+ Returns:
111
+ VoiceManager instance
112
+ """
113
+ if VoiceManager._instance is None:
114
+ VoiceManager._instance = VoiceManager()
115
+ return VoiceManager._instance
api/src/main.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI OpenAI Compatible API
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ from contextlib import asynccontextmanager
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import uvicorn
12
+ from fastapi import FastAPI
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from loguru import logger
15
+
16
+ from .core.config import settings
17
+ from .routers.debug import router as debug_router
18
+ from .routers.development import router as dev_router
19
+ from .routers.openai_compatible import router as openai_router
20
+ from .routers.web_player import router as web_router
21
+
22
+
23
+ def setup_logger():
24
+ """Configure loguru logger with custom formatting"""
25
+ config = {
26
+ "handlers": [
27
+ {
28
+ "sink": sys.stdout,
29
+ "format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
30
+ "{level: <8} | "
31
+ "<fg #4169E1>{module}:{line}</fg #4169E1> | "
32
+ "{message}",
33
+ "colorize": True,
34
+ "level": "DEBUG",
35
+ },
36
+ ],
37
+ }
38
+ logger.remove()
39
+ logger.configure(**config)
40
+ logger.level("ERROR", color="<red>")
41
+
42
+
43
+ # Configure logger
44
+ setup_logger()
45
+
46
+
47
+ @asynccontextmanager
48
+ async def lifespan(app: FastAPI):
49
+ """Lifespan context manager for model initialization"""
50
+ from .inference.model_manager import get_manager
51
+ from .inference.voice_manager import get_manager as get_voice_manager
52
+ from .services.temp_manager import cleanup_temp_files
53
+
54
+ # Clean old temp files on startup
55
+ await cleanup_temp_files()
56
+
57
+ logger.info("Loading TTS model and voice packs...")
58
+
59
+ try:
60
+ # Initialize managers
61
+ model_manager = await get_manager()
62
+ voice_manager = await get_voice_manager()
63
+
64
+ # Initialize model with warmup and get status
65
+ device, model, voicepack_count = await model_manager.initialize_with_warmup(
66
+ voice_manager
67
+ )
68
+
69
+ except Exception as e:
70
+ logger.error(f"Failed to initialize model: {e}")
71
+ raise
72
+
73
+ boundary = "░" * 2 * 12
74
+ startup_msg = f"""
75
+
76
+ {boundary}
77
+
78
+ ╔═╗┌─┐┌─┐┌┬┐
79
+ ╠╣ ├─┤└─┐ │
80
+ ╚ ┴ ┴└─┘ ┴
81
+ ╦╔═┌─┐┬┌─┌─┐
82
+ ╠╩╗│ │├┴┐│ │
83
+ ╩ ╩└─┘┴ ┴└─┘
84
+
85
+ {boundary}
86
+ """
87
+ startup_msg += f"\nModel warmed up on {device}: {model}"
88
+ if device == "mps":
89
+ startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
90
+ elif device == "cuda":
91
+ startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
92
+ else:
93
+ startup_msg += "\nRunning on CPU"
94
+ startup_msg += f"\n{voicepack_count} voice packs loaded"
95
+
96
+ # Add web player info if enabled
97
+ if settings.enable_web_player:
98
+ startup_msg += (
99
+ f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
100
+ )
101
+ startup_msg += f"\nor http://localhost:{settings.port}/web/"
102
+ else:
103
+ startup_msg += "\n\nWeb Player: disabled"
104
+
105
+ startup_msg += f"\n{boundary}\n"
106
+ logger.info(startup_msg)
107
+
108
+ yield
109
+
110
+
111
+ # Initialize FastAPI app
112
+ app = FastAPI(
113
+ title=settings.api_title,
114
+ description=settings.api_description,
115
+ version=settings.api_version,
116
+ lifespan=lifespan,
117
+ openapi_url="/openapi.json", # Explicitly enable OpenAPI schema
118
+ )
119
+
120
+ # Add CORS middleware if enabled
121
+ if settings.cors_enabled:
122
+ app.add_middleware(
123
+ CORSMiddleware,
124
+ allow_origins=settings.cors_origins,
125
+ allow_credentials=True,
126
+ allow_methods=["*"],
127
+ allow_headers=["*"],
128
+ )
129
+
130
+ # Include routers
131
+ app.include_router(openai_router, prefix="/v1")
132
+ app.include_router(dev_router) # Development endpoints
133
+ app.include_router(debug_router) # Debug endpoints
134
+ if settings.enable_web_player:
135
+ app.include_router(web_router, prefix="/web") # Web player static files
136
+
137
+
138
+ # Health check endpoint
139
+ @app.get("/health")
140
+ async def health_check():
141
+ """Health check endpoint"""
142
+ return {"status": "healthy"}
143
+
144
+
145
+ @app.get("/v1/test")
146
+ async def test_endpoint():
147
+ """Test endpoint to verify routing"""
148
+ return {"status": "ok"}
149
+
150
+
151
+ if __name__ == "__main__":
152
+ uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True)
api/src/models/v1_0/config.json ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "istftnet": {
3
+ "upsample_kernel_sizes": [20, 12],
4
+ "upsample_rates": [10, 6],
5
+ "gen_istft_hop_size": 5,
6
+ "gen_istft_n_fft": 20,
7
+ "resblock_dilation_sizes": [
8
+ [1, 3, 5],
9
+ [1, 3, 5],
10
+ [1, 3, 5]
11
+ ],
12
+ "resblock_kernel_sizes": [3, 7, 11],
13
+ "upsample_initial_channel": 512
14
+ },
15
+ "dim_in": 64,
16
+ "dropout": 0.2,
17
+ "hidden_dim": 512,
18
+ "max_conv_dim": 512,
19
+ "max_dur": 50,
20
+ "multispeaker": true,
21
+ "n_layer": 3,
22
+ "n_mels": 80,
23
+ "n_token": 178,
24
+ "style_dim": 128,
25
+ "text_encoder_kernel_size": 5,
26
+ "plbert": {
27
+ "hidden_size": 768,
28
+ "num_attention_heads": 12,
29
+ "intermediate_size": 2048,
30
+ "max_position_embeddings": 512,
31
+ "num_hidden_layers": 12,
32
+ "dropout": 0.1
33
+ },
34
+ "vocab": {
35
+ ";": 1,
36
+ ":": 2,
37
+ ",": 3,
38
+ ".": 4,
39
+ "!": 5,
40
+ "?": 6,
41
+ "—": 9,
42
+ "…": 10,
43
+ "\"": 11,
44
+ "(": 12,
45
+ ")": 13,
46
+ "“": 14,
47
+ "”": 15,
48
+ " ": 16,
49
+ "\u0303": 17,
50
+ "ʣ": 18,
51
+ "ʥ": 19,
52
+ "ʦ": 20,
53
+ "ʨ": 21,
54
+ "ᵝ": 22,
55
+ "\uAB67": 23,
56
+ "A": 24,
57
+ "I": 25,
58
+ "O": 31,
59
+ "Q": 33,
60
+ "S": 35,
61
+ "T": 36,
62
+ "W": 39,
63
+ "Y": 41,
64
+ "ᵊ": 42,
65
+ "a": 43,
66
+ "b": 44,
67
+ "c": 45,
68
+ "d": 46,
69
+ "e": 47,
70
+ "f": 48,
71
+ "h": 50,
72
+ "i": 51,
73
+ "j": 52,
74
+ "k": 53,
75
+ "l": 54,
76
+ "m": 55,
77
+ "n": 56,
78
+ "o": 57,
79
+ "p": 58,
80
+ "q": 59,
81
+ "r": 60,
82
+ "s": 61,
83
+ "t": 62,
84
+ "u": 63,
85
+ "v": 64,
86
+ "w": 65,
87
+ "x": 66,
88
+ "y": 67,
89
+ "z": 68,
90
+ "ɑ": 69,
91
+ "ɐ": 70,
92
+ "ɒ": 71,
93
+ "æ": 72,
94
+ "β": 75,
95
+ "ɔ": 76,
96
+ "ɕ": 77,
97
+ "ç": 78,
98
+ "ɖ": 80,
99
+ "ð": 81,
100
+ "ʤ": 82,
101
+ "ə": 83,
102
+ "ɚ": 85,
103
+ "ɛ": 86,
104
+ "ɜ": 87,
105
+ "ɟ": 90,
106
+ "ɡ": 92,
107
+ "ɥ": 99,
108
+ "ɨ": 101,
109
+ "ɪ": 102,
110
+ "ʝ": 103,
111
+ "ɯ": 110,
112
+ "ɰ": 111,
113
+ "ŋ": 112,
114
+ "ɳ": 113,
115
+ "ɲ": 114,
116
+ "ɴ": 115,
117
+ "ø": 116,
118
+ "ɸ": 118,
119
+ "θ": 119,
120
+ "œ": 120,
121
+ "ɹ": 123,
122
+ "ɾ": 125,
123
+ "ɻ": 126,
124
+ "ʁ": 128,
125
+ "ɽ": 129,
126
+ "ʂ": 130,
127
+ "ʃ": 131,
128
+ "ʈ": 132,
129
+ "ʧ": 133,
130
+ "ʊ": 135,
131
+ "ʋ": 136,
132
+ "ʌ": 138,
133
+ "ɣ": 139,
134
+ "ɤ": 140,
135
+ "χ": 142,
136
+ "ʎ": 143,
137
+ "ʒ": 147,
138
+ "ʔ": 148,
139
+ "ˈ": 156,
140
+ "ˌ": 157,
141
+ "ː": 158,
142
+ "ʰ": 162,
143
+ "ʲ": 164,
144
+ "↓": 169,
145
+ "→": 171,
146
+ "↗": 172,
147
+ "↘": 173,
148
+ "ᵻ": 177
149
+ }
150
+ }
api/src/routers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
api/src/routers/debug.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from datetime import datetime
4
+
5
+ import psutil
6
+ import torch
7
+ from fastapi import APIRouter
8
+
9
+ try:
10
+ import GPUtil
11
+
12
+ GPU_AVAILABLE = True
13
+ except ImportError:
14
+ GPU_AVAILABLE = False
15
+
16
+ router = APIRouter(tags=["debug"])
17
+
18
+
19
+ @router.get("/debug/threads")
20
+ async def get_thread_info():
21
+ process = psutil.Process()
22
+ current_threads = threading.enumerate()
23
+
24
+ # Get per-thread CPU times
25
+ thread_details = []
26
+ for thread in current_threads:
27
+ thread_info = {
28
+ "name": thread.name,
29
+ "id": thread.ident,
30
+ "alive": thread.is_alive(),
31
+ "daemon": thread.daemon,
32
+ }
33
+ thread_details.append(thread_info)
34
+
35
+ return {
36
+ "total_threads": process.num_threads(),
37
+ "active_threads": len(current_threads),
38
+ "thread_names": [t.name for t in current_threads],
39
+ "thread_details": thread_details,
40
+ "memory_mb": process.memory_info().rss / 1024 / 1024,
41
+ }
42
+
43
+
44
+ @router.get("/debug/storage")
45
+ async def get_storage_info():
46
+ # Get disk partitions
47
+ partitions = psutil.disk_partitions()
48
+ storage_info = []
49
+
50
+ for partition in partitions:
51
+ try:
52
+ usage = psutil.disk_usage(partition.mountpoint)
53
+ storage_info.append(
54
+ {
55
+ "device": partition.device,
56
+ "mountpoint": partition.mountpoint,
57
+ "fstype": partition.fstype,
58
+ "total_gb": usage.total / (1024**3),
59
+ "used_gb": usage.used / (1024**3),
60
+ "free_gb": usage.free / (1024**3),
61
+ "percent_used": usage.percent,
62
+ }
63
+ )
64
+ except PermissionError:
65
+ continue
66
+
67
+ return {"storage_info": storage_info}
68
+
69
+
70
+ @router.get("/debug/system")
71
+ async def get_system_info():
72
+ process = psutil.Process()
73
+
74
+ # CPU Info
75
+ cpu_info = {
76
+ "cpu_count": psutil.cpu_count(),
77
+ "cpu_percent": psutil.cpu_percent(interval=1),
78
+ "per_cpu_percent": psutil.cpu_percent(interval=1, percpu=True),
79
+ "load_avg": psutil.getloadavg(),
80
+ }
81
+
82
+ # Memory Info
83
+ virtual_memory = psutil.virtual_memory()
84
+ swap_memory = psutil.swap_memory()
85
+ memory_info = {
86
+ "virtual": {
87
+ "total_gb": virtual_memory.total / (1024**3),
88
+ "available_gb": virtual_memory.available / (1024**3),
89
+ "used_gb": virtual_memory.used / (1024**3),
90
+ "percent": virtual_memory.percent,
91
+ },
92
+ "swap": {
93
+ "total_gb": swap_memory.total / (1024**3),
94
+ "used_gb": swap_memory.used / (1024**3),
95
+ "free_gb": swap_memory.free / (1024**3),
96
+ "percent": swap_memory.percent,
97
+ },
98
+ }
99
+
100
+ # Process Info
101
+ process_info = {
102
+ "pid": process.pid,
103
+ "status": process.status(),
104
+ "create_time": datetime.fromtimestamp(process.create_time()).isoformat(),
105
+ "cpu_percent": process.cpu_percent(),
106
+ "memory_percent": process.memory_percent(),
107
+ }
108
+
109
+ # Network Info
110
+ network_info = {
111
+ "connections": len(process.net_connections()),
112
+ "network_io": psutil.net_io_counters()._asdict(),
113
+ }
114
+
115
+ # GPU Info if available
116
+ gpu_info = None
117
+ if torch.backends.mps.is_available():
118
+ gpu_info = {
119
+ "type": "MPS",
120
+ "available": True,
121
+ "device": "Apple Silicon",
122
+ "backend": "Metal",
123
+ }
124
+ elif GPU_AVAILABLE:
125
+ try:
126
+ gpus = GPUtil.getGPUs()
127
+ gpu_info = [
128
+ {
129
+ "id": gpu.id,
130
+ "name": gpu.name,
131
+ "load": gpu.load,
132
+ "memory": {
133
+ "total": gpu.memoryTotal,
134
+ "used": gpu.memoryUsed,
135
+ "free": gpu.memoryFree,
136
+ "percent": (gpu.memoryUsed / gpu.memoryTotal) * 100,
137
+ },
138
+ "temperature": gpu.temperature,
139
+ }
140
+ for gpu in gpus
141
+ ]
142
+ except Exception:
143
+ gpu_info = "GPU information unavailable"
144
+
145
+ return {
146
+ "cpu": cpu_info,
147
+ "memory": memory_info,
148
+ "process": process_info,
149
+ "network": network_info,
150
+ "gpu": gpu_info,
151
+ }
152
+
153
+
154
+ @router.get("/debug/session_pools")
155
+ async def get_session_pool_info():
156
+ """Get information about ONNX session pools."""
157
+ from ..inference.model_manager import get_manager
158
+
159
+ manager = await get_manager()
160
+ pools = manager._session_pools
161
+ current_time = time.time()
162
+
163
+ pool_info = {}
164
+
165
+ # Get CPU pool info
166
+ if "onnx_cpu" in pools:
167
+ cpu_pool = pools["onnx_cpu"]
168
+ pool_info["cpu"] = {
169
+ "active_sessions": len(cpu_pool._sessions),
170
+ "max_sessions": cpu_pool._max_size,
171
+ "sessions": [
172
+ {"model": path, "age_seconds": current_time - info.last_used}
173
+ for path, info in cpu_pool._sessions.items()
174
+ ],
175
+ }
176
+
177
+ # Get GPU pool info
178
+ if "onnx_gpu" in pools:
179
+ gpu_pool = pools["onnx_gpu"]
180
+ pool_info["gpu"] = {
181
+ "active_sessions": len(gpu_pool._sessions),
182
+ "max_streams": gpu_pool._max_size,
183
+ "available_streams": len(gpu_pool._available_streams),
184
+ "sessions": [
185
+ {
186
+ "model": path,
187
+ "age_seconds": current_time - info.last_used,
188
+ "stream_id": info.stream_id,
189
+ }
190
+ for path, info in gpu_pool._sessions.items()
191
+ ],
192
+ }
193
+
194
+ # Add GPU memory info if available
195
+ if GPU_AVAILABLE:
196
+ try:
197
+ gpus = GPUtil.getGPUs()
198
+ if gpus:
199
+ gpu = gpus[0] # Assume first GPU
200
+ pool_info["gpu"]["memory"] = {
201
+ "total_mb": gpu.memoryTotal,
202
+ "used_mb": gpu.memoryUsed,
203
+ "free_mb": gpu.memoryFree,
204
+ "percent_used": (gpu.memoryUsed / gpu.memoryTotal) * 100,
205
+ }
206
+ except Exception:
207
+ pass
208
+
209
+ return pool_info
api/src/routers/development.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ import re
5
+ from pathlib import Path
6
+ from typing import AsyncGenerator, List, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
11
+ from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
12
+ from kokoro import KPipeline
13
+ from loguru import logger
14
+
15
+ from ..core.config import settings
16
+ from ..inference.base import AudioChunk
17
+ from ..services.audio import AudioNormalizer, AudioService
18
+ from ..services.streaming_audio_writer import StreamingAudioWriter
19
+ from ..services.temp_manager import TempFileWriter
20
+ from ..services.text_processing import smart_split
21
+ from ..services.tts_service import TTSService
22
+ from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
23
+ from ..structures.custom_responses import JSONStreamingResponse
24
+ from ..structures.text_schemas import (
25
+ GenerateFromPhonemesRequest,
26
+ PhonemeRequest,
27
+ PhonemeResponse,
28
+ )
29
+ from .openai_compatible import process_and_validate_voices, stream_audio_chunks
30
+
31
+ router = APIRouter(tags=["text processing"])
32
+
33
+
34
+ async def get_tts_service() -> TTSService:
35
+ """Dependency to get TTSService instance"""
36
+ return (
37
+ await TTSService.create()
38
+ ) # Create service with properly initialized managers
39
+
40
+
41
+ @router.post("/dev/phonemize", response_model=PhonemeResponse)
42
+ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
43
+ """Convert text to phonemes using Kokoro's quiet mode.
44
+
45
+ Args:
46
+ request: Request containing text and language
47
+
48
+ Returns:
49
+ Phonemes and token IDs
50
+ """
51
+ try:
52
+ if not request.text:
53
+ raise ValueError("Text cannot be empty")
54
+
55
+ # Initialize Kokoro pipeline in quiet mode (no model)
56
+ pipeline = KPipeline(lang_code=request.language, model=False)
57
+
58
+ # Get first result from pipeline (we only need one since we're not chunking)
59
+ for result in pipeline(request.text):
60
+ # result.graphemes = original text
61
+ # result.phonemes = phonemized text
62
+ # result.tokens = token objects (if available)
63
+ return PhonemeResponse(phonemes=result.phonemes, tokens=[])
64
+
65
+ raise ValueError("Failed to generate phonemes")
66
+ except ValueError as e:
67
+ logger.error(f"Error in phoneme generation: {str(e)}")
68
+ raise HTTPException(
69
+ status_code=500, detail={"error": "Server error", "message": str(e)}
70
+ )
71
+ except Exception as e:
72
+ logger.error(f"Error in phoneme generation: {str(e)}")
73
+ raise HTTPException(
74
+ status_code=500, detail={"error": "Server error", "message": str(e)}
75
+ )
76
+
77
+
78
+ @router.post("/dev/generate_from_phonemes")
79
+ async def generate_from_phonemes(
80
+ request: GenerateFromPhonemesRequest,
81
+ client_request: Request,
82
+ tts_service: TTSService = Depends(get_tts_service),
83
+ ) -> StreamingResponse:
84
+ """Generate audio directly from phonemes using Kokoro's phoneme format"""
85
+ try:
86
+ # Basic validation
87
+ if not isinstance(request.phonemes, str):
88
+ raise ValueError("Phonemes must be a string")
89
+ if not request.phonemes:
90
+ raise ValueError("Phonemes cannot be empty")
91
+
92
+ # Create streaming audio writer and normalizer
93
+ writer = StreamingAudioWriter(format="wav", sample_rate=24000, channels=1)
94
+ normalizer = AudioNormalizer()
95
+
96
+ async def generate_chunks():
97
+ try:
98
+ # Generate audio from phonemes
99
+ chunk_audio, _ = await tts_service.generate_from_phonemes(
100
+ phonemes=request.phonemes, # Pass complete phoneme string
101
+ voice=request.voice,
102
+ speed=1.0,
103
+ )
104
+
105
+ if chunk_audio is not None:
106
+ # Normalize audio before writing
107
+ normalized_audio = await normalizer.normalize(chunk_audio)
108
+ # Write chunk and yield bytes
109
+ chunk_bytes = writer.write_chunk(normalized_audio)
110
+ if chunk_bytes:
111
+ yield chunk_bytes
112
+
113
+ # Finalize and yield remaining bytes
114
+ final_bytes = writer.write_chunk(finalize=True)
115
+ if final_bytes:
116
+ yield final_bytes
117
+ else:
118
+ raise ValueError("Failed to generate audio data")
119
+
120
+ except Exception as e:
121
+ logger.error(f"Error in audio generation: {str(e)}")
122
+ # Clean up writer on error
123
+ writer.close()
124
+ # Re-raise the original exception
125
+ raise
126
+
127
+ return StreamingResponse(
128
+ generate_chunks(),
129
+ media_type="audio/wav",
130
+ headers={
131
+ "Content-Disposition": "attachment; filename=speech.wav",
132
+ "X-Accel-Buffering": "no",
133
+ "Cache-Control": "no-cache",
134
+ "Transfer-Encoding": "chunked",
135
+ },
136
+ )
137
+
138
+ except ValueError as e:
139
+ logger.error(f"Error generating audio: {str(e)}")
140
+ raise HTTPException(
141
+ status_code=400,
142
+ detail={
143
+ "error": "validation_error",
144
+ "message": str(e),
145
+ "type": "invalid_request_error",
146
+ },
147
+ )
148
+ except Exception as e:
149
+ logger.error(f"Error generating audio: {str(e)}")
150
+ raise HTTPException(
151
+ status_code=500,
152
+ detail={
153
+ "error": "processing_error",
154
+ "message": str(e),
155
+ "type": "server_error",
156
+ },
157
+ )
158
+
159
+
160
+ @router.post("/dev/captioned_speech")
161
+ async def create_captioned_speech(
162
+ request: CaptionedSpeechRequest,
163
+ client_request: Request,
164
+ x_raw_response: str = Header(None, alias="x-raw-response"),
165
+ tts_service: TTSService = Depends(get_tts_service),
166
+ ):
167
+ """Generate audio with word-level timestamps using streaming approach"""
168
+
169
+ try:
170
+ # model_name = get_model_name(request.model)
171
+ tts_service = await get_tts_service()
172
+ voice_name = await process_and_validate_voices(request.voice, tts_service)
173
+
174
+ # Set content type based on format
175
+ content_type = {
176
+ "mp3": "audio/mpeg",
177
+ "opus": "audio/opus",
178
+ "m4a": "audio/mp4",
179
+ "flac": "audio/flac",
180
+ "wav": "audio/wav",
181
+ "pcm": "audio/pcm",
182
+ }.get(request.response_format, f"audio/{request.response_format}")
183
+
184
+ writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
185
+ # Check if streaming is requested (default for OpenAI client)
186
+ if request.stream:
187
+ # Create generator but don't start it yet
188
+ generator = stream_audio_chunks(
189
+ tts_service, request, client_request, writer
190
+ )
191
+
192
+ # If download link requested, wrap generator with temp file writer
193
+ if request.return_download_link:
194
+ from ..services.temp_manager import TempFileWriter
195
+
196
+ temp_writer = TempFileWriter(request.response_format)
197
+ await temp_writer.__aenter__() # Initialize temp file
198
+
199
+ # Get download path immediately after temp file creation
200
+ download_path = temp_writer.download_path
201
+
202
+ # Create response headers with download path
203
+ headers = {
204
+ "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
205
+ "X-Accel-Buffering": "no",
206
+ "Cache-Control": "no-cache",
207
+ "Transfer-Encoding": "chunked",
208
+ "X-Download-Path": download_path,
209
+ }
210
+
211
+ # Create async generator for streaming
212
+ async def dual_output():
213
+ try:
214
+ # Write chunks to temp file and stream
215
+ async for chunk_data in generator:
216
+ # The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
217
+ timestamp_acumulator = []
218
+
219
+ if chunk_data.output: # Skip empty chunks
220
+ await temp_writer.write(chunk_data.output)
221
+ base64_chunk = base64.b64encode(
222
+ chunk_data.output
223
+ ).decode("utf-8")
224
+
225
+ # Add any chunks that may be in the acumulator into the return word_timestamps
226
+ chunk_data.word_timestamps = (
227
+ timestamp_acumulator + chunk_data.word_timestamps
228
+ )
229
+ timestamp_acumulator = []
230
+
231
+ yield CaptionedSpeechResponse(
232
+ audio=base64_chunk,
233
+ audio_format=content_type,
234
+ timestamps=chunk_data.word_timestamps,
235
+ )
236
+ else:
237
+ if (
238
+ chunk_data.word_timestamps is not None
239
+ and len(chunk_data.word_timestamps) > 0
240
+ ):
241
+ timestamp_acumulator += chunk_data.word_timestamps
242
+
243
+ # Finalize the temp file
244
+ await temp_writer.finalize()
245
+ except Exception as e:
246
+ logger.error(f"Error in dual output streaming: {e}")
247
+ await temp_writer.__aexit__(type(e), e, e.__traceback__)
248
+ raise
249
+ finally:
250
+ # Ensure temp writer is closed
251
+ if not temp_writer._finalized:
252
+ await temp_writer.__aexit__(None, None, None)
253
+ writer.close()
254
+
255
+ # Stream with temp file writing
256
+ return JSONStreamingResponse(
257
+ dual_output(), media_type="application/json", headers=headers
258
+ )
259
+
260
+ async def single_output():
261
+ try:
262
+ # The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
263
+ timestamp_acumulator = []
264
+
265
+ # Stream chunks
266
+ async for chunk_data in generator:
267
+ if chunk_data.output: # Skip empty chunks
268
+ # Encode the chunk bytes into base 64
269
+ base64_chunk = base64.b64encode(chunk_data.output).decode(
270
+ "utf-8"
271
+ )
272
+
273
+ # Add any chunks that may be in the acumulator into the return word_timestamps
274
+ if chunk_data.word_timestamps != None:
275
+ chunk_data.word_timestamps = (
276
+ timestamp_acumulator + chunk_data.word_timestamps
277
+ )
278
+ else:
279
+ chunk_data.word_timestamps = []
280
+ timestamp_acumulator = []
281
+
282
+ yield CaptionedSpeechResponse(
283
+ audio=base64_chunk,
284
+ audio_format=content_type,
285
+ timestamps=chunk_data.word_timestamps,
286
+ )
287
+ else:
288
+ if (
289
+ chunk_data.word_timestamps is not None
290
+ and len(chunk_data.word_timestamps) > 0
291
+ ):
292
+ timestamp_acumulator += chunk_data.word_timestamps
293
+
294
+ except Exception as e:
295
+ logger.error(f"Error in single output streaming: {e}")
296
+ writer.close()
297
+ raise
298
+
299
+ # Standard streaming without download link
300
+ return JSONStreamingResponse(
301
+ single_output(),
302
+ media_type="application/json",
303
+ headers={
304
+ "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
305
+ "X-Accel-Buffering": "no",
306
+ "Cache-Control": "no-cache",
307
+ "Transfer-Encoding": "chunked",
308
+ },
309
+ )
310
+ else:
311
+ # Generate complete audio using public interface
312
+ audio_data = await tts_service.generate_audio(
313
+ text=request.input,
314
+ voice=voice_name,
315
+ writer=writer,
316
+ speed=request.speed,
317
+ return_timestamps=request.return_timestamps,
318
+ normalization_options=request.normalization_options,
319
+ lang_code=request.lang_code,
320
+ )
321
+
322
+ audio_data = await AudioService.convert_audio(
323
+ audio_data,
324
+ request.response_format,
325
+ writer,
326
+ is_last_chunk=False,
327
+ trim_audio=False,
328
+ )
329
+
330
+ # Convert to requested format with proper finalization
331
+ final = await AudioService.convert_audio(
332
+ AudioChunk(np.array([], dtype=np.int16)),
333
+ request.response_format,
334
+ writer,
335
+ is_last_chunk=True,
336
+ )
337
+ output = audio_data.output + final.output
338
+
339
+ base64_output = base64.b64encode(output).decode("utf-8")
340
+
341
+ content = CaptionedSpeechResponse(
342
+ audio=base64_output,
343
+ audio_format=content_type,
344
+ timestamps=audio_data.word_timestamps,
345
+ ).model_dump()
346
+
347
+ writer.close()
348
+
349
+ return JSONResponse(
350
+ content=content,
351
+ media_type="application/json",
352
+ headers={
353
+ "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
354
+ "Cache-Control": "no-cache", # Prevent caching
355
+ },
356
+ )
357
+
358
+ except ValueError as e:
359
+ # Handle validation errors
360
+ logger.warning(f"Invalid request: {str(e)}")
361
+
362
+ try:
363
+ writer.close()
364
+ except:
365
+ pass
366
+
367
+ raise HTTPException(
368
+ status_code=400,
369
+ detail={
370
+ "error": "validation_error",
371
+ "message": str(e),
372
+ "type": "invalid_request_error",
373
+ },
374
+ )
375
+ except RuntimeError as e:
376
+ # Handle runtime/processing errors
377
+ logger.error(f"Processing error: {str(e)}")
378
+
379
+ try:
380
+ writer.close()
381
+ except:
382
+ pass
383
+
384
+ raise HTTPException(
385
+ status_code=500,
386
+ detail={
387
+ "error": "processing_error",
388
+ "message": str(e),
389
+ "type": "server_error",
390
+ },
391
+ )
392
+ except Exception as e:
393
+ # Handle unexpected errors
394
+ logger.error(f"Unexpected error in captioned speech generation: {str(e)}")
395
+
396
+ try:
397
+ writer.close()
398
+ except:
399
+ pass
400
+
401
+ raise HTTPException(
402
+ status_code=500,
403
+ detail={
404
+ "error": "processing_error",
405
+ "message": str(e),
406
+ "type": "server_error",
407
+ },
408
+ )
api/src/routers/openai_compatible.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI-compatible router for text-to-speech"""
2
+
3
+ import io
4
+ import json
5
+ import os
6
+ import re
7
+ import tempfile
8
+ from typing import AsyncGenerator, Dict, List, Tuple, Union
9
+ from urllib import response
10
+
11
+ import aiofiles
12
+ import numpy as np
13
+ import torch
14
+ from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
15
+ from fastapi.responses import FileResponse, StreamingResponse
16
+ from loguru import logger
17
+
18
+ from ..core.config import settings
19
+ from ..inference.base import AudioChunk
20
+ from ..services.audio import AudioService
21
+ from ..services.streaming_audio_writer import StreamingAudioWriter
22
+ from ..services.tts_service import TTSService
23
+ from ..structures import OpenAISpeechRequest
24
+ from ..structures.schemas import CaptionedSpeechRequest
25
+
26
+
27
+ # Load OpenAI mappings
28
+ def load_openai_mappings() -> Dict:
29
+ """Load OpenAI voice and model mappings from JSON"""
30
+ api_dir = os.path.dirname(os.path.dirname(__file__))
31
+ mapping_path = os.path.join(api_dir, "core", "openai_mappings.json")
32
+ try:
33
+ with open(mapping_path, "r") as f:
34
+ return json.load(f)
35
+ except Exception as e:
36
+ logger.error(f"Failed to load OpenAI mappings: {e}")
37
+ return {"models": {}, "voices": {}}
38
+
39
+
40
+ # Global mappings
41
+ _openai_mappings = load_openai_mappings()
42
+
43
+
44
+ router = APIRouter(
45
+ tags=["OpenAI Compatible TTS"],
46
+ responses={404: {"description": "Not found"}},
47
+ )
48
+
49
+ # Global TTSService instance with lock
50
+ _tts_service = None
51
+ _init_lock = None
52
+
53
+
54
+ async def get_tts_service() -> TTSService:
55
+ """Get global TTSService instance"""
56
+ global _tts_service, _init_lock
57
+
58
+ # Create lock if needed
59
+ if _init_lock is None:
60
+ import asyncio
61
+
62
+ _init_lock = asyncio.Lock()
63
+
64
+ # Initialize service if needed
65
+ if _tts_service is None:
66
+ async with _init_lock:
67
+ # Double check pattern
68
+ if _tts_service is None:
69
+ _tts_service = await TTSService.create()
70
+ logger.info("Created global TTSService instance")
71
+
72
+ return _tts_service
73
+
74
+
75
+ def get_model_name(model: str) -> str:
76
+ """Get internal model name from OpenAI model name"""
77
+ base_name = _openai_mappings["models"].get(model)
78
+ if not base_name:
79
+ raise ValueError(f"Unsupported model: {model}")
80
+ return base_name + ".pth"
81
+
82
+
83
+ async def process_and_validate_voices(
84
+ voice_input: Union[str, List[str]], tts_service: TTSService
85
+ ) -> str:
86
+ """Process voice input, handling both string and list formats
87
+
88
+ Returns:
89
+ Voice name to use (with weights if specified)
90
+ """
91
+ voices = []
92
+ # Convert input to list of voices
93
+ if isinstance(voice_input, str):
94
+ voice_input = voice_input.replace(" ", "").strip()
95
+
96
+ if voice_input[-1] in "+-" or voice_input[0] in "+-":
97
+ raise ValueError(f"Voice combination contains empty combine items")
98
+
99
+ if re.search(r"[+-]{2,}", voice_input) is not None:
100
+ raise ValueError(f"Voice combination contains empty combine items")
101
+ voices = re.split(r"([-+])", voice_input)
102
+ else:
103
+ voices = [[item, "+"] for item in voice_input][:-1]
104
+
105
+ available_voices = await tts_service.list_voices()
106
+
107
+ for voice_index in range(0, len(voices), 2):
108
+ mapped_voice = voices[voice_index].split("(")
109
+ mapped_voice = list(map(str.strip, mapped_voice))
110
+
111
+ if len(mapped_voice) > 2:
112
+ raise ValueError(
113
+ f"Voice '{voices[voice_index]}' contains too many weight items"
114
+ )
115
+
116
+ if mapped_voice.count(")") > 1:
117
+ raise ValueError(
118
+ f"Voice '{voices[voice_index]}' contains too many weight items"
119
+ )
120
+
121
+ mapped_voice[0] = _openai_mappings["voices"].get(
122
+ mapped_voice[0], mapped_voice[0]
123
+ )
124
+
125
+ if mapped_voice[0] not in available_voices:
126
+ raise ValueError(
127
+ f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
128
+ )
129
+
130
+ voices[voice_index] = "(".join(mapped_voice)
131
+
132
+ return "".join(voices)
133
+
134
+
135
+ async def stream_audio_chunks(
136
+ tts_service: TTSService,
137
+ request: Union[OpenAISpeechRequest, CaptionedSpeechRequest],
138
+ client_request: Request,
139
+ writer: StreamingAudioWriter,
140
+ ) -> AsyncGenerator[AudioChunk, None]:
141
+ """Stream audio chunks as they're generated with client disconnect handling"""
142
+ voice_name = await process_and_validate_voices(request.voice, tts_service)
143
+ unique_properties = {"return_timestamps": False}
144
+ if hasattr(request, "return_timestamps"):
145
+ unique_properties["return_timestamps"] = request.return_timestamps
146
+
147
+ try:
148
+ async for chunk_data in tts_service.generate_audio_stream(
149
+ text=request.input,
150
+ voice=voice_name,
151
+ writer=writer,
152
+ speed=request.speed,
153
+ output_format=request.response_format,
154
+ lang_code=request.lang_code,
155
+ normalization_options=request.normalization_options,
156
+ return_timestamps=unique_properties["return_timestamps"],
157
+ ):
158
+ # Check if client is still connected
159
+ is_disconnected = client_request.is_disconnected
160
+ if callable(is_disconnected):
161
+ is_disconnected = await is_disconnected()
162
+ if is_disconnected:
163
+ logger.info("Client disconnected, stopping audio generation")
164
+ break
165
+
166
+ yield chunk_data
167
+ except Exception as e:
168
+ logger.error(f"Error in audio streaming: {str(e)}")
169
+ # Let the exception propagate to trigger cleanup
170
+ raise
171
+
172
+
173
+ @router.post("/audio/speech")
174
+ async def create_speech(
175
+ request: OpenAISpeechRequest,
176
+ client_request: Request,
177
+ x_raw_response: str = Header(None, alias="x-raw-response"),
178
+ ):
179
+ """OpenAI-compatible endpoint for text-to-speech"""
180
+ # Validate model before processing request
181
+ if request.model not in _openai_mappings["models"]:
182
+ raise HTTPException(
183
+ status_code=400,
184
+ detail={
185
+ "error": "invalid_model",
186
+ "message": f"Unsupported model: {request.model}",
187
+ "type": "invalid_request_error",
188
+ },
189
+ )
190
+
191
+ try:
192
+ # model_name = get_model_name(request.model)
193
+ tts_service = await get_tts_service()
194
+ voice_name = await process_and_validate_voices(request.voice, tts_service)
195
+
196
+ # Set content type based on format
197
+ content_type = {
198
+ "mp3": "audio/mpeg",
199
+ "opus": "audio/opus",
200
+ "aac": "audio/aac",
201
+ "flac": "audio/flac",
202
+ "wav": "audio/wav",
203
+ "pcm": "audio/pcm",
204
+ }.get(request.response_format, f"audio/{request.response_format}")
205
+
206
+ writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
207
+
208
+ # Check if streaming is requested (default for OpenAI client)
209
+ if request.stream:
210
+ # Create generator but don't start it yet
211
+ generator = stream_audio_chunks(
212
+ tts_service, request, client_request, writer
213
+ )
214
+
215
+ # If download link requested, wrap generator with temp file writer
216
+ if request.return_download_link:
217
+ from ..services.temp_manager import TempFileWriter
218
+
219
+ # Use download_format if specified, otherwise use response_format
220
+ output_format = request.download_format or request.response_format
221
+ temp_writer = TempFileWriter(output_format)
222
+ await temp_writer.__aenter__() # Initialize temp file
223
+
224
+ # Get download path immediately after temp file creation
225
+ download_path = temp_writer.download_path
226
+
227
+ # Create response headers with download path
228
+ headers = {
229
+ "Content-Disposition": f"attachment; filename=speech.{output_format}",
230
+ "X-Accel-Buffering": "no",
231
+ "Cache-Control": "no-cache",
232
+ "Transfer-Encoding": "chunked",
233
+ "X-Download-Path": download_path,
234
+ }
235
+
236
+ # Add header to indicate if temp file writing is available
237
+ if temp_writer._write_error:
238
+ headers["X-Download-Status"] = "unavailable"
239
+
240
+ # Create async generator for streaming
241
+ async def dual_output():
242
+ try:
243
+ # Write chunks to temp file and stream
244
+ async for chunk_data in generator:
245
+ if chunk_data.output: # Skip empty chunks
246
+ await temp_writer.write(chunk_data.output)
247
+ # if return_json:
248
+ # yield chunk, chunk_data
249
+ # else:
250
+ yield chunk_data.output
251
+
252
+ # Finalize the temp file
253
+ await temp_writer.finalize()
254
+ except Exception as e:
255
+ logger.error(f"Error in dual output streaming: {e}")
256
+ await temp_writer.__aexit__(type(e), e, e.__traceback__)
257
+ raise
258
+ finally:
259
+ # Ensure temp writer is closed
260
+ if not temp_writer._finalized:
261
+ await temp_writer.__aexit__(None, None, None)
262
+ writer.close()
263
+
264
+ # Stream with temp file writing
265
+ return StreamingResponse(
266
+ dual_output(), media_type=content_type, headers=headers
267
+ )
268
+
269
+ async def single_output():
270
+ try:
271
+ # Stream chunks
272
+ async for chunk_data in generator:
273
+ if chunk_data.output: # Skip empty chunks
274
+ yield chunk_data.output
275
+ except Exception as e:
276
+ logger.error(f"Error in single output streaming: {e}")
277
+ writer.close()
278
+ raise
279
+
280
+ # Standard streaming without download link
281
+ return StreamingResponse(
282
+ single_output(),
283
+ media_type=content_type,
284
+ headers={
285
+ "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
286
+ "X-Accel-Buffering": "no",
287
+ "Cache-Control": "no-cache",
288
+ "Transfer-Encoding": "chunked",
289
+ },
290
+ )
291
+ else:
292
+ headers = {
293
+ "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
294
+ "Cache-Control": "no-cache", # Prevent caching
295
+ }
296
+
297
+ # Generate complete audio using public interface
298
+ audio_data = await tts_service.generate_audio(
299
+ text=request.input,
300
+ voice=voice_name,
301
+ writer=writer,
302
+ speed=request.speed,
303
+ normalization_options=request.normalization_options,
304
+ lang_code=request.lang_code,
305
+ )
306
+
307
+ audio_data = await AudioService.convert_audio(
308
+ audio_data,
309
+ request.response_format,
310
+ writer,
311
+ is_last_chunk=False,
312
+ trim_audio=False,
313
+ )
314
+
315
+ # Convert to requested format with proper finalization
316
+ final = await AudioService.convert_audio(
317
+ AudioChunk(np.array([], dtype=np.int16)),
318
+ request.response_format,
319
+ writer,
320
+ is_last_chunk=True,
321
+ )
322
+ output = audio_data.output + final.output
323
+
324
+ if request.return_download_link:
325
+ from ..services.temp_manager import TempFileWriter
326
+
327
+ # Use download_format if specified, otherwise use response_format
328
+ output_format = request.download_format or request.response_format
329
+ temp_writer = TempFileWriter(output_format)
330
+ await temp_writer.__aenter__() # Initialize temp file
331
+
332
+ # Get download path immediately after temp file creation
333
+ download_path = temp_writer.download_path
334
+ headers["X-Download-Path"] = download_path
335
+
336
+ try:
337
+ # Write chunks to temp file
338
+ logger.info("Writing chunks to tempory file for download")
339
+ await temp_writer.write(output)
340
+ # Finalize the temp file
341
+ await temp_writer.finalize()
342
+
343
+ except Exception as e:
344
+ logger.error(f"Error in dual output: {e}")
345
+ await temp_writer.__aexit__(type(e), e, e.__traceback__)
346
+ raise
347
+ finally:
348
+ # Ensure temp writer is closed
349
+ if not temp_writer._finalized:
350
+ await temp_writer.__aexit__(None, None, None)
351
+ writer.close()
352
+
353
+ return Response(
354
+ content=output,
355
+ media_type=content_type,
356
+ headers=headers,
357
+ )
358
+
359
+ except ValueError as e:
360
+ # Handle validation errors
361
+ logger.warning(f"Invalid request: {str(e)}")
362
+
363
+ try:
364
+ writer.close()
365
+ except:
366
+ pass
367
+
368
+ raise HTTPException(
369
+ status_code=400,
370
+ detail={
371
+ "error": "validation_error",
372
+ "message": str(e),
373
+ "type": "invalid_request_error",
374
+ },
375
+ )
376
+ except RuntimeError as e:
377
+ # Handle runtime/processing errors
378
+ logger.error(f"Processing error: {str(e)}")
379
+
380
+ try:
381
+ writer.close()
382
+ except:
383
+ pass
384
+
385
+ raise HTTPException(
386
+ status_code=500,
387
+ detail={
388
+ "error": "processing_error",
389
+ "message": str(e),
390
+ "type": "server_error",
391
+ },
392
+ )
393
+ except Exception as e:
394
+ # Handle unexpected errors
395
+ logger.error(f"Unexpected error in speech generation: {str(e)}")
396
+
397
+ try:
398
+ writer.close()
399
+ except:
400
+ pass
401
+
402
+ raise HTTPException(
403
+ status_code=500,
404
+ detail={
405
+ "error": "processing_error",
406
+ "message": str(e),
407
+ "type": "server_error",
408
+ },
409
+ )
410
+
411
+
412
+ @router.get("/download/{filename}")
413
+ async def download_audio_file(filename: str):
414
+ """Download a generated audio file from temp storage"""
415
+ try:
416
+ from ..core.paths import _find_file, get_content_type
417
+
418
+ # Search for file in temp directory
419
+ file_path = await _find_file(
420
+ filename=filename, search_paths=[settings.temp_file_dir]
421
+ )
422
+
423
+ # Get content type from path helper
424
+ content_type = await get_content_type(file_path)
425
+
426
+ return FileResponse(
427
+ file_path,
428
+ media_type=content_type,
429
+ filename=filename,
430
+ headers={
431
+ "Cache-Control": "no-cache",
432
+ "Content-Disposition": f"attachment; filename={filename}",
433
+ },
434
+ )
435
+
436
+ except Exception as e:
437
+ logger.error(f"Error serving download file {filename}: {e}")
438
+ raise HTTPException(
439
+ status_code=500,
440
+ detail={
441
+ "error": "server_error",
442
+ "message": "Failed to serve audio file",
443
+ "type": "server_error",
444
+ },
445
+ )
446
+
447
+
448
+ @router.get("/models")
449
+ async def list_models():
450
+ """List all available models"""
451
+ try:
452
+ # Create standard model list
453
+ models = [
454
+ {
455
+ "id": "tts-1",
456
+ "object": "model",
457
+ "created": 1686935002,
458
+ "owned_by": "kokoro",
459
+ },
460
+ {
461
+ "id": "tts-1-hd",
462
+ "object": "model",
463
+ "created": 1686935002,
464
+ "owned_by": "kokoro",
465
+ },
466
+ {
467
+ "id": "kokoro",
468
+ "object": "model",
469
+ "created": 1686935002,
470
+ "owned_by": "kokoro",
471
+ },
472
+ ]
473
+
474
+ return {"object": "list", "data": models}
475
+ except Exception as e:
476
+ logger.error(f"Error listing models: {str(e)}")
477
+ raise HTTPException(
478
+ status_code=500,
479
+ detail={
480
+ "error": "server_error",
481
+ "message": "Failed to retrieve model list",
482
+ "type": "server_error",
483
+ },
484
+ )
485
+
486
+
487
+ @router.get("/models/{model}")
488
+ async def retrieve_model(model: str):
489
+ """Retrieve a specific model"""
490
+ try:
491
+ # Define available models
492
+ models = {
493
+ "tts-1": {
494
+ "id": "tts-1",
495
+ "object": "model",
496
+ "created": 1686935002,
497
+ "owned_by": "kokoro",
498
+ },
499
+ "tts-1-hd": {
500
+ "id": "tts-1-hd",
501
+ "object": "model",
502
+ "created": 1686935002,
503
+ "owned_by": "kokoro",
504
+ },
505
+ "kokoro": {
506
+ "id": "kokoro",
507
+ "object": "model",
508
+ "created": 1686935002,
509
+ "owned_by": "kokoro",
510
+ },
511
+ }
512
+
513
+ # Check if requested model exists
514
+ if model not in models:
515
+ raise HTTPException(
516
+ status_code=404,
517
+ detail={
518
+ "error": "model_not_found",
519
+ "message": f"Model '{model}' not found",
520
+ "type": "invalid_request_error",
521
+ },
522
+ )
523
+
524
+ # Return the specific model
525
+ return models[model]
526
+ except HTTPException:
527
+ raise
528
+ except Exception as e:
529
+ logger.error(f"Error retrieving model {model}: {str(e)}")
530
+ raise HTTPException(
531
+ status_code=500,
532
+ detail={
533
+ "error": "server_error",
534
+ "message": "Failed to retrieve model information",
535
+ "type": "server_error",
536
+ },
537
+ )
538
+
539
+
540
+ @router.get("/audio/voices")
541
+ async def list_voices():
542
+ """List all available voices for text-to-speech"""
543
+ try:
544
+ tts_service = await get_tts_service()
545
+ voices = await tts_service.list_voices()
546
+ return {"voices": voices}
547
+ except Exception as e:
548
+ logger.error(f"Error listing voices: {str(e)}")
549
+ raise HTTPException(
550
+ status_code=500,
551
+ detail={
552
+ "error": "server_error",
553
+ "message": "Failed to retrieve voice list",
554
+ "type": "server_error",
555
+ },
556
+ )
557
+
558
+
559
+ @router.post("/audio/voices/combine")
560
+ async def combine_voices(request: Union[str, List[str]]):
561
+ """Combine multiple voices into a new voice and return the .pt file.
562
+
563
+ Args:
564
+ request: Either a string with voices separated by + (e.g. "voice1+voice2")
565
+ or a list of voice names to combine
566
+
567
+ Returns:
568
+ FileResponse with the combined voice .pt file
569
+
570
+ Raises:
571
+ HTTPException:
572
+ - 400: Invalid request (wrong number of voices, voice not found)
573
+ - 500: Server error (file system issues, combination failed)
574
+ """
575
+ # Check if local voice saving is allowed
576
+ if not settings.allow_local_voice_saving:
577
+ raise HTTPException(
578
+ status_code=403,
579
+ detail={
580
+ "error": "permission_denied",
581
+ "message": "Local voice saving is disabled",
582
+ "type": "permission_error",
583
+ },
584
+ )
585
+
586
+ try:
587
+ # Convert input to list of voices
588
+ if isinstance(request, str):
589
+ # Check if it's an OpenAI voice name
590
+ mapped_voice = _openai_mappings["voices"].get(request)
591
+ if mapped_voice:
592
+ request = mapped_voice
593
+ voices = [v.strip() for v in request.split("+") if v.strip()]
594
+ else:
595
+ # For list input, map each voice if it's an OpenAI voice name
596
+ voices = [_openai_mappings["voices"].get(v, v) for v in request]
597
+ voices = [v.strip() for v in voices if v.strip()]
598
+
599
+ if not voices:
600
+ raise ValueError("No voices provided")
601
+
602
+ # For multiple voices, validate base voices exist
603
+ tts_service = await get_tts_service()
604
+ available_voices = await tts_service.list_voices()
605
+ for voice in voices:
606
+ if voice not in available_voices:
607
+ raise ValueError(
608
+ f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
609
+ )
610
+
611
+ # Combine voices
612
+ combined_tensor = await tts_service.combine_voices(voices=voices)
613
+ combined_name = "+".join(voices)
614
+
615
+ # Save to temp file
616
+ temp_dir = tempfile.gettempdir()
617
+ voice_path = os.path.join(temp_dir, f"{combined_name}.pt")
618
+ buffer = io.BytesIO()
619
+ torch.save(combined_tensor, buffer)
620
+ async with aiofiles.open(voice_path, "wb") as f:
621
+ await f.write(buffer.getvalue())
622
+
623
+ return FileResponse(
624
+ voice_path,
625
+ media_type="application/octet-stream",
626
+ filename=f"{combined_name}.pt",
627
+ headers={
628
+ "Content-Disposition": f"attachment; filename={combined_name}.pt",
629
+ "Cache-Control": "no-cache",
630
+ },
631
+ )
632
+
633
+ except ValueError as e:
634
+ logger.warning(f"Invalid voice combination request: {str(e)}")
635
+ raise HTTPException(
636
+ status_code=400,
637
+ detail={
638
+ "error": "validation_error",
639
+ "message": str(e),
640
+ "type": "invalid_request_error",
641
+ },
642
+ )
643
+ except RuntimeError as e:
644
+ logger.error(f"Voice combination processing error: {str(e)}")
645
+ raise HTTPException(
646
+ status_code=500,
647
+ detail={
648
+ "error": "processing_error",
649
+ "message": "Failed to process voice combination request",
650
+ "type": "server_error",
651
+ },
652
+ )
653
+ except Exception as e:
654
+ logger.error(f"Unexpected error in voice combination: {str(e)}")
655
+ raise HTTPException(
656
+ status_code=500,
657
+ detail={
658
+ "error": "server_error",
659
+ "message": "An unexpected error occurred",
660
+ "type": "server_error",
661
+ },
662
+ )
api/src/routers/web_player.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web player router with async file serving."""
2
+
3
+ from fastapi import APIRouter, HTTPException
4
+ from fastapi.responses import Response
5
+ from loguru import logger
6
+
7
+ from ..core.config import settings
8
+ from ..core.paths import get_content_type, get_web_file_path, read_bytes
9
+
10
+ router = APIRouter(
11
+ tags=["Web Player"],
12
+ responses={404: {"description": "Not found"}},
13
+ )
14
+
15
+
16
+ @router.get("/{filename:path}")
17
+ async def serve_web_file(filename: str):
18
+ """Serve web player static files asynchronously."""
19
+ if not settings.enable_web_player:
20
+ raise HTTPException(status_code=404, detail="Web player is disabled")
21
+
22
+ try:
23
+ # Default to index.html for root path
24
+ if filename == "" or filename == "/":
25
+ filename = "index.html"
26
+
27
+ # Get file path
28
+ file_path = await get_web_file_path(filename)
29
+
30
+ # Read file content
31
+ content = await read_bytes(file_path)
32
+
33
+ # Get content type
34
+ content_type = await get_content_type(file_path)
35
+
36
+ return Response(
37
+ content=content,
38
+ media_type=content_type,
39
+ headers={
40
+ "Cache-Control": "no-cache", # Prevent caching during development
41
+ },
42
+ )
43
+
44
+ except RuntimeError as e:
45
+ logger.warning(f"Web file not found: {filename}")
46
+ raise HTTPException(status_code=404, detail=str(e))
47
+ except Exception as e:
48
+ logger.error(f"Error serving web file {filename}: {e}")
49
+ raise HTTPException(status_code=500, detail="Internal server error")
api/src/services/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .tts_service import TTSService
2
+
3
+ __all__ = ["TTSService"]
api/src/services/audio.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio conversion service"""
2
+
3
+ import math
4
+ import struct
5
+ import time
6
+ from io import BytesIO
7
+ from typing import Tuple
8
+
9
+ import numpy as np
10
+ import scipy.io.wavfile as wavfile
11
+ import soundfile as sf
12
+ from loguru import logger
13
+ from pydub import AudioSegment
14
+ from torch import norm
15
+
16
+ from ..core.config import settings
17
+ from ..inference.base import AudioChunk
18
+ from .streaming_audio_writer import StreamingAudioWriter
19
+
20
+
21
+ class AudioNormalizer:
22
+ """Handles audio normalization state for a single stream"""
23
+
24
+ def __init__(self):
25
+ self.chunk_trim_ms = settings.gap_trim_ms
26
+ self.sample_rate = 24000 # Sample rate of the audio
27
+ self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
28
+ self.samples_to_pad_start = int(50 * self.sample_rate / 1000)
29
+
30
+ def find_first_last_non_silent(
31
+ self,
32
+ audio_data: np.ndarray,
33
+ chunk_text: str,
34
+ speed: float,
35
+ silence_threshold_db: int = -45,
36
+ is_last_chunk: bool = False,
37
+ ) -> tuple[int, int]:
38
+ """Finds the indices of the first and last non-silent samples in audio data.
39
+
40
+ Args:
41
+ audio_data: Input audio data as numpy array
42
+ chunk_text: The text sent to the model to generate the resulting speech
43
+ speed: The speaking speed of the voice
44
+ silence_threshold_db: How quiet audio has to be to be conssidered silent
45
+ is_last_chunk: Whether this is the last chunk
46
+
47
+ Returns:
48
+ A tuple with the start of the non silent portion and with the end of the non silent portion
49
+ """
50
+
51
+ pad_multiplier = 1
52
+ split_character = chunk_text.strip()
53
+ if len(split_character) > 0:
54
+ split_character = split_character[-1]
55
+ if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
56
+ pad_multiplier = settings.dynamic_gap_trim_padding_char_multiplier[
57
+ split_character
58
+ ]
59
+
60
+ if not is_last_chunk:
61
+ samples_to_pad_end = max(
62
+ int(
63
+ (
64
+ settings.dynamic_gap_trim_padding_ms
65
+ * self.sample_rate
66
+ * pad_multiplier
67
+ )
68
+ / 1000
69
+ )
70
+ - self.samples_to_pad_start,
71
+ 0,
72
+ )
73
+ else:
74
+ samples_to_pad_end = self.samples_to_pad_start
75
+ # Convert dBFS threshold to amplitude
76
+ amplitude_threshold = np.iinfo(audio_data.dtype).max * (
77
+ 10 ** (silence_threshold_db / 20)
78
+ )
79
+ # Find the first samples above the silence threshold at the start and end of the audio
80
+ non_silent_index_start, non_silent_index_end = None, None
81
+
82
+ for X in range(0, len(audio_data)):
83
+ if audio_data[X] > amplitude_threshold:
84
+ non_silent_index_start = X
85
+ break
86
+
87
+ for X in range(len(audio_data) - 1, -1, -1):
88
+ if audio_data[X] > amplitude_threshold:
89
+ non_silent_index_end = X
90
+ break
91
+
92
+ # Handle the case where the entire audio is silent
93
+ if non_silent_index_start == None or non_silent_index_end == None:
94
+ return 0, len(audio_data)
95
+
96
+ return max(non_silent_index_start - self.samples_to_pad_start, 0), min(
97
+ non_silent_index_end + math.ceil(samples_to_pad_end / speed),
98
+ len(audio_data),
99
+ )
100
+
101
+ def normalize(self, audio_data: np.ndarray) -> np.ndarray:
102
+ """Convert audio data to int16 range
103
+
104
+ Args:
105
+ audio_data: Input audio data as numpy array
106
+ Returns:
107
+ Normalized audio data
108
+ """
109
+ if audio_data.dtype != np.int16:
110
+ # Scale directly to int16 range with clipping
111
+ return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
112
+ return audio_data
113
+
114
+
115
+ class AudioService:
116
+ """Service for audio format conversions with streaming support"""
117
+
118
+ # Supported formats
119
+ SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm"}
120
+
121
+ # Default audio format settings balanced for speed and compression
122
+ DEFAULT_SETTINGS = {
123
+ "mp3": {
124
+ "bitrate_mode": "CONSTANT", # Faster than variable bitrate
125
+ "compression_level": 0.0, # Balanced compression
126
+ },
127
+ "opus": {
128
+ "compression_level": 0.0, # Good balance for speech
129
+ },
130
+ "flac": {
131
+ "compression_level": 0.0, # Light compression, still fast
132
+ },
133
+ "aac": {
134
+ "bitrate": "192k", # Default AAC bitrate
135
+ },
136
+ }
137
+
138
+ @staticmethod
139
+ async def convert_audio(
140
+ audio_chunk: AudioChunk,
141
+ output_format: str,
142
+ writer: StreamingAudioWriter,
143
+ speed: float = 1,
144
+ chunk_text: str = "",
145
+ is_last_chunk: bool = False,
146
+ trim_audio: bool = True,
147
+ normalizer: AudioNormalizer = None,
148
+ ) -> AudioChunk:
149
+ """Convert audio data to specified format with streaming support
150
+
151
+ Args:
152
+ audio_data: Numpy array of audio samples
153
+ output_format: Target format (wav, mp3, ogg, pcm)
154
+ writer: The StreamingAudioWriter to use
155
+ speed: The speaking speed of the voice
156
+ chunk_text: The text sent to the model to generate the resulting speech
157
+ is_last_chunk: Whether this is the last chunk
158
+ trim_audio: Whether audio should be trimmed
159
+ normalizer: Optional AudioNormalizer instance for consistent normalization
160
+
161
+ Returns:
162
+ Bytes of the converted audio chunk
163
+ """
164
+
165
+ try:
166
+ # Validate format
167
+ if output_format not in AudioService.SUPPORTED_FORMATS:
168
+ raise ValueError(f"Format {output_format} not supported")
169
+
170
+ # Always normalize audio to ensure proper amplitude scaling
171
+ if normalizer is None:
172
+ normalizer = AudioNormalizer()
173
+
174
+ audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
175
+
176
+ if trim_audio == True:
177
+ audio_chunk = AudioService.trim_audio(
178
+ audio_chunk, chunk_text, speed, is_last_chunk, normalizer
179
+ )
180
+
181
+ # Write audio data first
182
+ if len(audio_chunk.audio) > 0:
183
+ chunk_data = writer.write_chunk(audio_chunk.audio)
184
+
185
+ # Then finalize if this is the last chunk
186
+ if is_last_chunk:
187
+ final_data = writer.write_chunk(finalize=True)
188
+
189
+ if final_data:
190
+ audio_chunk.output = final_data
191
+ return audio_chunk
192
+
193
+ if chunk_data:
194
+ audio_chunk.output = chunk_data
195
+ return audio_chunk
196
+
197
+ except Exception as e:
198
+ logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
199
+ raise ValueError(
200
+ f"Failed to convert audio stream to {output_format}: {str(e)}"
201
+ )
202
+
203
+ @staticmethod
204
+ def trim_audio(
205
+ audio_chunk: AudioChunk,
206
+ chunk_text: str = "",
207
+ speed: float = 1,
208
+ is_last_chunk: bool = False,
209
+ normalizer: AudioNormalizer = None,
210
+ ) -> AudioChunk:
211
+ """Trim silence from start and end
212
+
213
+ Args:
214
+ audio_data: Input audio data as numpy array
215
+ chunk_text: The text sent to the model to generate the resulting speech
216
+ speed: The speaking speed of the voice
217
+ is_last_chunk: Whether this is the last chunk
218
+ normalizer: Optional AudioNormalizer instance for consistent normalization
219
+
220
+ Returns:
221
+ Trimmed audio data
222
+ """
223
+ if normalizer is None:
224
+ normalizer = AudioNormalizer()
225
+
226
+ audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
227
+
228
+ trimed_samples = 0
229
+ # Trim start and end if enough samples
230
+ if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
231
+ audio_chunk.audio = audio_chunk.audio[
232
+ normalizer.samples_to_trim : -normalizer.samples_to_trim
233
+ ]
234
+ trimed_samples += normalizer.samples_to_trim
235
+
236
+ # Find non silent portion and trim
237
+ start_index, end_index = normalizer.find_first_last_non_silent(
238
+ audio_chunk.audio, chunk_text, speed, is_last_chunk=is_last_chunk
239
+ )
240
+
241
+ audio_chunk.audio = audio_chunk.audio[start_index:end_index]
242
+ trimed_samples += start_index
243
+
244
+ if audio_chunk.word_timestamps is not None:
245
+ for timestamp in audio_chunk.word_timestamps:
246
+ timestamp.start_time -= trimed_samples / 24000
247
+ timestamp.end_time -= trimed_samples / 24000
248
+ return audio_chunk
api/src/services/streaming_audio_writer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio conversion service with proper streaming support"""
2
+
3
+ import struct
4
+ from io import BytesIO
5
+ from typing import Optional
6
+
7
+ import av
8
+ import numpy as np
9
+ import soundfile as sf
10
+ from loguru import logger
11
+ from pydub import AudioSegment
12
+
13
+
14
+ class StreamingAudioWriter:
15
+ """Handles streaming audio format conversions"""
16
+
17
+ def __init__(self, format: str, sample_rate: int, channels: int = 1):
18
+ self.format = format.lower()
19
+ self.sample_rate = sample_rate
20
+ self.channels = channels
21
+ self.bytes_written = 0
22
+ self.pts = 0
23
+
24
+ codec_map = {
25
+ "wav": "pcm_s16le",
26
+ "mp3": "mp3",
27
+ "opus": "libopus",
28
+ "flac": "flac",
29
+ "aac": "aac",
30
+ }
31
+ # Format-specific setup
32
+ if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
33
+ if self.format != "pcm":
34
+ self.output_buffer = BytesIO()
35
+ self.container = av.open(
36
+ self.output_buffer,
37
+ mode="w",
38
+ format=self.format if self.format != "aac" else "adts",
39
+ )
40
+ self.stream = self.container.add_stream(
41
+ codec_map[self.format],
42
+ sample_rate=self.sample_rate,
43
+ layout="mono" if self.channels == 1 else "stereo",
44
+ )
45
+ self.stream.bit_rate = 128000
46
+ else:
47
+ raise ValueError(f"Unsupported format: {format}")
48
+
49
+ def close(self):
50
+ if hasattr(self, "container"):
51
+ self.container.close()
52
+
53
+ if hasattr(self, "output_buffer"):
54
+ self.output_buffer.close()
55
+
56
+ def write_chunk(
57
+ self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
58
+ ) -> bytes:
59
+ """Write a chunk of audio data and return bytes in the target format.
60
+
61
+ Args:
62
+ audio_data: Audio data to write, or None if finalizing
63
+ finalize: Whether this is the final write to close the stream
64
+ """
65
+
66
+ if finalize:
67
+ if self.format != "pcm":
68
+ packets = self.stream.encode(None)
69
+ for packet in packets:
70
+ self.container.mux(packet)
71
+
72
+ data = self.output_buffer.getvalue()
73
+ self.close()
74
+ return data
75
+
76
+ if audio_data is None or len(audio_data) == 0:
77
+ return b""
78
+
79
+ if self.format == "pcm":
80
+ # Write raw bytes
81
+ return audio_data.tobytes()
82
+ else:
83
+ frame = av.AudioFrame.from_ndarray(
84
+ audio_data.reshape(1, -1),
85
+ format="s16",
86
+ layout="mono" if self.channels == 1 else "stereo",
87
+ )
88
+ frame.sample_rate = self.sample_rate
89
+
90
+ frame.pts = self.pts
91
+ self.pts += frame.samples
92
+
93
+ packets = self.stream.encode(frame)
94
+ for packet in packets:
95
+ self.container.mux(packet)
96
+
97
+ data = self.output_buffer.getvalue()
98
+ self.output_buffer.seek(0)
99
+ self.output_buffer.truncate(0)
100
+ return data
api/src/services/temp_manager.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Temporary file writer for audio downloads"""
2
+
3
+ import os
4
+ import tempfile
5
+ from typing import List, Optional
6
+
7
+ import aiofiles
8
+ from fastapi import HTTPException
9
+ from loguru import logger
10
+
11
+ from ..core.config import settings
12
+
13
+
14
+ async def cleanup_temp_files() -> None:
15
+ """Clean up old temp files"""
16
+ try:
17
+ if not await aiofiles.os.path.exists(settings.temp_file_dir):
18
+ await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
19
+ return
20
+
21
+ # Get all temp files with stats
22
+ files = []
23
+ total_size = 0
24
+
25
+ # Use os.scandir for sync iteration, but aiofiles.os.stat for async stats
26
+ for entry in os.scandir(settings.temp_file_dir):
27
+ if entry.is_file():
28
+ stat = await aiofiles.os.stat(entry.path)
29
+ files.append((entry.path, stat.st_mtime, stat.st_size))
30
+ total_size += stat.st_size
31
+
32
+ # Sort by modification time (oldest first)
33
+ files.sort(key=lambda x: x[1])
34
+
35
+ # Remove files if:
36
+ # 1. They're too old
37
+ # 2. We have too many files
38
+ # 3. Directory is too large
39
+ current_time = (await aiofiles.os.stat(settings.temp_file_dir)).st_mtime
40
+ max_age = settings.max_temp_dir_age_hours * 3600
41
+
42
+ for path, mtime, size in files:
43
+ should_delete = False
44
+
45
+ # Check age
46
+ if current_time - mtime > max_age:
47
+ should_delete = True
48
+ logger.info(f"Deleting old temp file: {path}")
49
+
50
+ # Check count limit
51
+ elif len(files) > settings.max_temp_dir_count:
52
+ should_delete = True
53
+ logger.info(f"Deleting excess temp file: {path}")
54
+
55
+ # Check size limit
56
+ elif total_size > settings.max_temp_dir_size_mb * 1024 * 1024:
57
+ should_delete = True
58
+ logger.info(f"Deleting to reduce directory size: {path}")
59
+
60
+ if should_delete:
61
+ try:
62
+ await aiofiles.os.remove(path)
63
+ total_size -= size
64
+ logger.info(f"Deleted temp file: {path}")
65
+ except Exception as e:
66
+ logger.warning(f"Failed to delete temp file {path}: {e}")
67
+
68
+ except Exception as e:
69
+ logger.warning(f"Error during temp file cleanup: {e}")
70
+
71
+
72
+ class TempFileWriter:
73
+ """Handles writing audio chunks to a temp file"""
74
+
75
+ def __init__(self, format: str):
76
+ """Initialize temp file writer
77
+
78
+ Args:
79
+ format: Audio format extension (mp3, wav, etc)
80
+ """
81
+ self.format = format
82
+ self.temp_file = None
83
+ self._finalized = False
84
+ self._write_error = False # Flag to track if we've had a write error
85
+
86
+ async def __aenter__(self):
87
+ """Async context manager entry"""
88
+ try:
89
+ # Clean up old files first
90
+ await cleanup_temp_files()
91
+
92
+ # Create temp file with proper extension
93
+ await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
94
+ temp = tempfile.NamedTemporaryFile(
95
+ dir=settings.temp_file_dir,
96
+ delete=False,
97
+ suffix=f".{self.format}",
98
+ mode="wb",
99
+ )
100
+ self.temp_file = await aiofiles.open(temp.name, mode="wb")
101
+ self.temp_path = temp.name
102
+ temp.close() # Close sync file, we'll use async version
103
+
104
+ # Generate download path immediately
105
+ self.download_path = f"/download/{os.path.basename(self.temp_path)}"
106
+ except Exception as e:
107
+ # Handle permission issues or other errors gracefully
108
+ logger.error(f"Failed to create temp file: {e}")
109
+ self._write_error = True
110
+ # Set a placeholder path so the API can still function
111
+ self.temp_path = f"unavailable_{self.format}"
112
+ self.download_path = f"/download/{self.temp_path}"
113
+
114
+ return self
115
+
116
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
117
+ """Async context manager exit"""
118
+ try:
119
+ if self.temp_file and not self._finalized:
120
+ await self.temp_file.close()
121
+ self._finalized = True
122
+ except Exception as e:
123
+ logger.error(f"Error closing temp file: {e}")
124
+ self._write_error = True
125
+
126
+ async def write(self, chunk: bytes) -> None:
127
+ """Write a chunk of audio data
128
+
129
+ Args:
130
+ chunk: Audio data bytes to write
131
+ """
132
+ if self._finalized:
133
+ raise RuntimeError("Cannot write to finalized temp file")
134
+
135
+ # Skip writing if we've already encountered an error
136
+ if self._write_error or not self.temp_file:
137
+ return
138
+
139
+ try:
140
+ await self.temp_file.write(chunk)
141
+ await self.temp_file.flush()
142
+ except Exception as e:
143
+ # Handle permission issues or other errors gracefully
144
+ logger.error(f"Failed to write to temp file: {e}")
145
+ self._write_error = True
146
+
147
+ async def finalize(self) -> str:
148
+ """Close temp file and return download path
149
+
150
+ Returns:
151
+ Path to use for downloading the temp file
152
+ """
153
+ if self._finalized:
154
+ raise RuntimeError("Temp file already finalized")
155
+
156
+ # Skip finalizing if we've already encountered an error
157
+ if self._write_error or not self.temp_file:
158
+ self._finalized = True
159
+ return self.download_path
160
+
161
+ try:
162
+ await self.temp_file.close()
163
+ self._finalized = True
164
+ except Exception as e:
165
+ # Handle permission issues or other errors gracefully
166
+ logger.error(f"Failed to finalize temp file: {e}")
167
+ self._write_error = True
168
+ self._finalized = True
169
+
170
+ return self.download_path
api/src/services/text_processing/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text processing pipeline."""
2
+
3
+ from .normalizer import normalize_text
4
+ from .phonemizer import phonemize
5
+ from .text_processor import process_text_chunk, smart_split
6
+ from .vocabulary import tokenize
7
+
8
+
9
+ def process_text(text: str) -> list[int]:
10
+ """Process text into token IDs (for backward compatibility)."""
11
+ return process_text_chunk(text)
12
+
13
+
14
+ __all__ = [
15
+ "normalize_text",
16
+ "phonemize",
17
+ "tokenize",
18
+ "process_text",
19
+ "process_text_chunk",
20
+ "smart_split",
21
+ ]
api/src/services/text_processing/normalizer.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text normalization module for TTS processing.
3
+ Handles various text formats including URLs, emails, numbers, money, and special characters.
4
+ Converts them into a format suitable for text-to-speech processing.
5
+ """
6
+
7
+ import re
8
+ from functools import lru_cache
9
+
10
+ import inflect
11
+ from numpy import number
12
+ from text_to_num import text2num
13
+ from torch import mul
14
+
15
+ from ...structures.schemas import NormalizationOptions
16
+
17
+ # Constants
18
+ VALID_TLDS = [
19
+ "com",
20
+ "org",
21
+ "net",
22
+ "edu",
23
+ "gov",
24
+ "mil",
25
+ "int",
26
+ "biz",
27
+ "info",
28
+ "name",
29
+ "pro",
30
+ "coop",
31
+ "museum",
32
+ "travel",
33
+ "jobs",
34
+ "mobi",
35
+ "tel",
36
+ "asia",
37
+ "cat",
38
+ "xxx",
39
+ "aero",
40
+ "arpa",
41
+ "bg",
42
+ "br",
43
+ "ca",
44
+ "cn",
45
+ "de",
46
+ "es",
47
+ "eu",
48
+ "fr",
49
+ "in",
50
+ "it",
51
+ "jp",
52
+ "mx",
53
+ "nl",
54
+ "ru",
55
+ "uk",
56
+ "us",
57
+ "io",
58
+ "co",
59
+ ]
60
+
61
+ VALID_UNITS = {
62
+ "m": "meter",
63
+ "cm": "centimeter",
64
+ "mm": "millimeter",
65
+ "km": "kilometer",
66
+ "in": "inch",
67
+ "ft": "foot",
68
+ "yd": "yard",
69
+ "mi": "mile", # Length
70
+ "g": "gram",
71
+ "kg": "kilogram",
72
+ "mg": "milligram", # Mass
73
+ "s": "second",
74
+ "ms": "millisecond",
75
+ "min": "minutes",
76
+ "h": "hour", # Time
77
+ "l": "liter",
78
+ "ml": "mililiter",
79
+ "cl": "centiliter",
80
+ "dl": "deciliter", # Volume
81
+ "kph": "kilometer per hour",
82
+ "mph": "mile per hour",
83
+ "mi/h": "mile per hour",
84
+ "m/s": "meter per second",
85
+ "km/h": "kilometer per hour",
86
+ "mm/s": "milimeter per second",
87
+ "cm/s": "centimeter per second",
88
+ "ft/s": "feet per second",
89
+ "cm/h": "centimeter per day", # Speed
90
+ "°c": "degree celsius",
91
+ "c": "degree celsius",
92
+ "°f": "degree fahrenheit",
93
+ "f": "degree fahrenheit",
94
+ "k": "kelvin", # Temperature
95
+ "pa": "pascal",
96
+ "kpa": "kilopascal",
97
+ "mpa": "megapascal",
98
+ "atm": "atmosphere", # Pressure
99
+ "hz": "hertz",
100
+ "khz": "kilohertz",
101
+ "mhz": "megahertz",
102
+ "ghz": "gigahertz", # Frequency
103
+ "v": "volt",
104
+ "kv": "kilovolt",
105
+ "mv": "mergavolt", # Voltage
106
+ "a": "amp",
107
+ "ma": "megaamp",
108
+ "ka": "kiloamp", # Current
109
+ "w": "watt",
110
+ "kw": "kilowatt",
111
+ "mw": "megawatt", # Power
112
+ "j": "joule",
113
+ "kj": "kilojoule",
114
+ "mj": "megajoule", # Energy
115
+ "Ω": "ohm",
116
+ "kΩ": "kiloohm",
117
+ "mΩ": "megaohm", # Resistance (Ohm)
118
+ "f": "farad",
119
+ "µf": "microfarad",
120
+ "nf": "nanofarad",
121
+ "pf": "picofarad", # Capacitance
122
+ "b": "bit",
123
+ "kb": "kilobit",
124
+ "mb": "megabit",
125
+ "gb": "gigabit",
126
+ "tb": "terabit",
127
+ "pb": "petabit", # Data size
128
+ "kbps": "kilobit per second",
129
+ "mbps": "megabit per second",
130
+ "gbps": "gigabit per second",
131
+ "tbps": "terabit per second",
132
+ "px": "pixel", # CSS units
133
+ }
134
+
135
+
136
+ # Pre-compiled regex patterns for performance
137
+ EMAIL_PATTERN = re.compile(
138
+ r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE
139
+ )
140
+ URL_PATTERN = re.compile(
141
+ r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:"
142
+ + "|".join(VALID_TLDS)
143
+ + "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?",
144
+ re.IGNORECASE,
145
+ )
146
+
147
+ UNIT_PATTERN = re.compile(
148
+ r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*("
149
+ + "|".join(sorted(list(VALID_UNITS.keys()), reverse=True))
150
+ + r"""){1}(?=[^\w\d]{1}|\b)""",
151
+ re.IGNORECASE,
152
+ )
153
+
154
+ TIME_PATTERN = re.compile(
155
+ r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
156
+ )
157
+
158
+ INFLECT_ENGINE = inflect.engine()
159
+
160
+
161
+ def split_num(num: re.Match[str]) -> str:
162
+ """Handle number splitting for various formats"""
163
+ num = num.group()
164
+ if "." in num:
165
+ return num
166
+ elif ":" in num:
167
+ h, m = [int(n) for n in num.split(":")]
168
+ if m == 0:
169
+ return f"{h} o'clock"
170
+ elif m < 10:
171
+ return f"{h} oh {m}"
172
+ return f"{h} {m}"
173
+ year = int(num[:4])
174
+ if year < 1100 or year % 1000 < 10:
175
+ return num
176
+ left, right = num[:2], int(num[2:4])
177
+ s = "s" if num.endswith("s") else ""
178
+ if 100 <= year % 1000 <= 999:
179
+ if right == 0:
180
+ return f"{left} hundred{s}"
181
+ elif right < 10:
182
+ return f"{left} oh {right}{s}"
183
+ return f"{left} {right}{s}"
184
+
185
+
186
+ def handle_units(u: re.Match[str]) -> str:
187
+ """Converts units to their full form"""
188
+ unit_string = u.group(6).strip()
189
+ unit = unit_string
190
+
191
+ if unit_string.lower() in VALID_UNITS:
192
+ unit = VALID_UNITS[unit_string.lower()].split(" ")
193
+
194
+ # Handles the B vs b case
195
+ if unit[0].endswith("bit"):
196
+ b_case = unit_string[min(1, len(unit_string) - 1)]
197
+ if b_case == "B":
198
+ unit[0] = unit[0][:-3] + "byte"
199
+
200
+ number = u.group(1).strip()
201
+ unit[0] = INFLECT_ENGINE.no(unit[0], number)
202
+ return " ".join(unit)
203
+
204
+
205
+ def conditional_int(number: float, threshold: float = 0.00001):
206
+ if abs(round(number) - number) < threshold:
207
+ return int(round(number))
208
+ return number
209
+
210
+
211
+ def handle_money(m: re.Match[str]) -> str:
212
+ """Convert money expressions to spoken form"""
213
+
214
+ bill = "dollar" if m.group(2) == "$" else "pound"
215
+ coin = "cent" if m.group(2) == "$" else "pence"
216
+ number = m.group(3)
217
+
218
+ multiplier = m.group(4)
219
+ try:
220
+ number = float(number)
221
+ except:
222
+ return m.group()
223
+
224
+ if m.group(1) == "-":
225
+ number *= -1
226
+
227
+ if number % 1 == 0 or multiplier != "":
228
+ text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
229
+ else:
230
+ sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
231
+
232
+ text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
233
+
234
+ return text_number
235
+
236
+
237
+ def handle_decimal(num: re.Match[str]) -> str:
238
+ """Convert decimal numbers to spoken form"""
239
+ a, b = num.group().split(".")
240
+ return " point ".join([a, " ".join(b)])
241
+
242
+
243
+ def handle_email(m: re.Match[str]) -> str:
244
+ """Convert email addresses into speakable format"""
245
+ email = m.group(0)
246
+ parts = email.split("@")
247
+ if len(parts) == 2:
248
+ user, domain = parts
249
+ domain = domain.replace(".", " dot ")
250
+ return f"{user} at {domain}"
251
+ return email
252
+
253
+
254
+ def handle_url(u: re.Match[str]) -> str:
255
+ """Make URLs speakable by converting special characters to spoken words"""
256
+ if not u:
257
+ return ""
258
+
259
+ url = u.group(0).strip()
260
+
261
+ # Handle protocol first
262
+ url = re.sub(
263
+ r"^https?://",
264
+ lambda a: "https " if "https" in a.group() else "http ",
265
+ url,
266
+ flags=re.IGNORECASE,
267
+ )
268
+ url = re.sub(r"^www\.", "www ", url, flags=re.IGNORECASE)
269
+
270
+ # Handle port numbers before other replacements
271
+ url = re.sub(r":(\d+)(?=/|$)", lambda m: f" colon {m.group(1)}", url)
272
+
273
+ # Split into domain and path
274
+ parts = url.split("/", 1)
275
+ domain = parts[0]
276
+ path = parts[1] if len(parts) > 1 else ""
277
+
278
+ # Handle dots in domain
279
+ domain = domain.replace(".", " dot ")
280
+
281
+ # Reconstruct URL
282
+ if path:
283
+ url = f"{domain} slash {path}"
284
+ else:
285
+ url = domain
286
+
287
+ # Replace remaining symbols with words
288
+ url = url.replace("-", " dash ")
289
+ url = url.replace("_", " underscore ")
290
+ url = url.replace("?", " question-mark ")
291
+ url = url.replace("=", " equals ")
292
+ url = url.replace("&", " ampersand ")
293
+ url = url.replace("%", " percent ")
294
+ url = url.replace(":", " colon ") # Handle any remaining colons
295
+ url = url.replace("/", " slash ") # Handle any remaining slashes
296
+
297
+ # Clean up extra spaces
298
+ return re.sub(r"\s+", " ", url).strip()
299
+
300
+
301
+ def handle_phone_number(p: re.Match[str]) -> str:
302
+ p = list(p.groups())
303
+
304
+ country_code = ""
305
+ if p[0] is not None:
306
+ p[0] = p[0].replace("+", "")
307
+ country_code += INFLECT_ENGINE.number_to_words(p[0])
308
+
309
+ area_code = INFLECT_ENGINE.number_to_words(
310
+ p[2].replace("(", "").replace(")", ""), group=1, comma=""
311
+ )
312
+
313
+ telephone_prefix = INFLECT_ENGINE.number_to_words(p[3], group=1, comma="")
314
+
315
+ line_number = INFLECT_ENGINE.number_to_words(p[4], group=1, comma="")
316
+
317
+ return ",".join([country_code, area_code, telephone_prefix, line_number])
318
+
319
+
320
+ def handle_time(t: re.Match[str]) -> str:
321
+ t = t.groups()
322
+
323
+ numbers = " ".join(
324
+ [INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")]
325
+ )
326
+
327
+ half = ""
328
+ if t[2] is not None:
329
+ half = t[2].strip()
330
+
331
+ return numbers + half
332
+
333
+
334
+ def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
335
+ """Normalize text for TTS processing"""
336
+ # Handle email addresses first if enabled
337
+ if normalization_options.email_normalization:
338
+ text = EMAIL_PATTERN.sub(handle_email, text)
339
+
340
+ # Handle URLs if enabled
341
+ if normalization_options.url_normalization:
342
+ text = URL_PATTERN.sub(handle_url, text)
343
+
344
+ # Pre-process numbers with units if enabled
345
+ if normalization_options.unit_normalization:
346
+ text = UNIT_PATTERN.sub(handle_units, text)
347
+
348
+ # Replace optional pluralization
349
+ if normalization_options.optional_pluralization_normalization:
350
+ text = re.sub(r"\(s\)", "s", text)
351
+
352
+ # Replace phone numbers:
353
+ if normalization_options.phone_normalization:
354
+ text = re.sub(
355
+ r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",
356
+ handle_phone_number,
357
+ text,
358
+ )
359
+
360
+ # Replace quotes and brackets
361
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
362
+ text = text.replace("«", chr(8220)).replace("»", chr(8221))
363
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
364
+
365
+ # Handle CJK punctuation and some non standard chars
366
+ for a, b in zip("、。!,:;?–", ",.!,:;?-"):
367
+ text = text.replace(a, b + " ")
368
+
369
+ # Handle simple time in the format of HH:MM:SS
370
+ text = TIME_PATTERN.sub(
371
+ handle_time,
372
+ text,
373
+ )
374
+
375
+ # Clean up whitespace
376
+ text = re.sub(r"[^\S \n]", " ", text)
377
+ text = re.sub(r" +", " ", text)
378
+ text = re.sub(r"(?<=\n) +(?=\n)", "", text)
379
+
380
+ # Handle titles and abbreviations
381
+ text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
382
+ text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
383
+ text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
384
+ text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
385
+ text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
386
+
387
+ # Handle common words
388
+ text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
389
+
390
+ # Handle numbers and money
391
+ text = re.sub(r"(?<=\d),(?=\d)", "", text)
392
+
393
+ text = re.sub(
394
+ r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
395
+ handle_money,
396
+ text,
397
+ )
398
+
399
+ text = re.sub(
400
+ r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
401
+ )
402
+
403
+ text = re.sub(r"\d*\.\d+", handle_decimal, text)
404
+
405
+ # Handle various formatting
406
+ text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
407
+ text = re.sub(r"(?<=\d)S", " S", text)
408
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
409
+ text = re.sub(r"(?<=X')S\b", "s", text)
410
+ text = re.sub(
411
+ r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
412
+ )
413
+ text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
414
+
415
+ return text.strip()
api/src/services/text_processing/phonemizer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+
4
+ import phonemizer
5
+
6
+ from .normalizer import normalize_text
7
+
8
+ phonemizers = {}
9
+
10
+
11
+ class PhonemizerBackend(ABC):
12
+ """Abstract base class for phonemization backends"""
13
+
14
+ @abstractmethod
15
+ def phonemize(self, text: str) -> str:
16
+ """Convert text to phonemes
17
+
18
+ Args:
19
+ text: Text to convert to phonemes
20
+
21
+ Returns:
22
+ Phonemized text
23
+ """
24
+ pass
25
+
26
+
27
+ class EspeakBackend(PhonemizerBackend):
28
+ """Espeak-based phonemizer implementation"""
29
+
30
+ def __init__(self, language: str):
31
+ """Initialize espeak backend
32
+
33
+ Args:
34
+ language: Language code ('en-us' or 'en-gb')
35
+ """
36
+ self.backend = phonemizer.backend.EspeakBackend(
37
+ language=language, preserve_punctuation=True, with_stress=True
38
+ )
39
+
40
+ self.language = language
41
+
42
+ def phonemize(self, text: str) -> str:
43
+ """Convert text to phonemes using espeak
44
+
45
+ Args:
46
+ text: Text to convert to phonemes
47
+
48
+ Returns:
49
+ Phonemized text
50
+ """
51
+ # Phonemize text
52
+ ps = self.backend.phonemize([text])
53
+ ps = ps[0] if ps else ""
54
+
55
+ # Handle special cases
56
+ ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
57
+ ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
58
+ ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
59
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»"" ]|$)', "z", ps)
60
+
61
+ # Language-specific rules
62
+ if self.language == "en-us":
63
+ ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
64
+
65
+ return ps.strip()
66
+
67
+
68
+ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
69
+ """Factory function to create phonemizer backend
70
+
71
+ Args:
72
+ language: Language code ('a' for US English, 'b' for British English)
73
+
74
+ Returns:
75
+ Phonemizer backend instance
76
+ """
77
+ # Map language codes to espeak language codes
78
+ lang_map = {"a": "en-us", "b": "en-gb"}
79
+
80
+ if language not in lang_map:
81
+ raise ValueError(f"Unsupported language code: {language}")
82
+
83
+ return EspeakBackend(lang_map[language])
84
+
85
+
86
+ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
87
+ """Convert text to phonemes
88
+
89
+ Args:
90
+ text: Text to convert to phonemes
91
+ language: Language code ('a' for US English, 'b' for British English)
92
+ normalize: Whether to normalize text before phonemization
93
+
94
+ Returns:
95
+ Phonemized text
96
+ """
97
+ global phonemizers
98
+ if normalize:
99
+ text = normalize_text(text)
100
+ if language not in phonemizers:
101
+ phonemizers[language] = create_phonemizer(language)
102
+ return phonemizers[language].phonemize(text)
api/src/services/text_processing/text_processor.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified text processing for TTS with smart chunking."""
2
+
3
+ import re
4
+ import time
5
+ from typing import AsyncGenerator, Dict, List, Tuple
6
+
7
+ from loguru import logger
8
+
9
+ from ...core.config import settings
10
+ from ...structures.schemas import NormalizationOptions
11
+ from .normalizer import normalize_text
12
+ from .phonemizer import phonemize
13
+ from .vocabulary import tokenize
14
+
15
+ # Pre-compiled regex patterns for performance
16
+ CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
17
+
18
+
19
+ def process_text_chunk(
20
+ text: str, language: str = "a", skip_phonemize: bool = False
21
+ ) -> List[int]:
22
+ """Process a chunk of text through normalization, phonemization, and tokenization.
23
+
24
+ Args:
25
+ text: Text chunk to process
26
+ language: Language code for phonemization
27
+ skip_phonemize: If True, treat input as phonemes and skip normalization/phonemization
28
+
29
+ Returns:
30
+ List of token IDs
31
+ """
32
+ start_time = time.time()
33
+
34
+ if skip_phonemize:
35
+ # Input is already phonemes, just tokenize
36
+ t0 = time.time()
37
+ tokens = tokenize(text)
38
+ t1 = time.time()
39
+ else:
40
+ # Normal text processing pipeline
41
+ t0 = time.time()
42
+ t1 = time.time()
43
+
44
+ t0 = time.time()
45
+ phonemes = phonemize(text, language, normalize=False) # Already normalized
46
+ t1 = time.time()
47
+
48
+ t0 = time.time()
49
+ tokens = tokenize(phonemes)
50
+ t1 = time.time()
51
+
52
+ total_time = time.time() - start_time
53
+ logger.debug(
54
+ f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'"
55
+ )
56
+
57
+ return tokens
58
+
59
+
60
+ async def yield_chunk(
61
+ text: str, tokens: List[int], chunk_count: int
62
+ ) -> Tuple[str, List[int]]:
63
+ """Yield a chunk with consistent logging."""
64
+ logger.debug(
65
+ f"Yielding chunk {chunk_count}: '{text[:50]}{'...' if len(text) > 50 else ''}' ({len(tokens)} tokens)"
66
+ )
67
+ return text, tokens
68
+
69
+
70
+ def process_text(text: str, language: str = "a") -> List[int]:
71
+ """Process text into token IDs.
72
+
73
+ Args:
74
+ text: Text to process
75
+ language: Language code for phonemization
76
+
77
+ Returns:
78
+ List of token IDs
79
+ """
80
+ if not isinstance(text, str):
81
+ text = str(text) if text is not None else ""
82
+
83
+ text = text.strip()
84
+ if not text:
85
+ return []
86
+
87
+ return process_text_chunk(text, language)
88
+
89
+
90
+ def get_sentence_info(
91
+ text: str, custom_phenomes_list: Dict[str, str]
92
+ ) -> List[Tuple[str, List[int], int]]:
93
+ """Process all sentences and return info."""
94
+ sentences = re.split(r"([.!?;:])(?=\s|$)", text)
95
+ phoneme_length, min_value = len(custom_phenomes_list), 0
96
+
97
+ results = []
98
+ for i in range(0, len(sentences), 2):
99
+ sentence = sentences[i].strip()
100
+ for replaced in range(min_value, phoneme_length):
101
+ current_id = f"</|custom_phonemes_{replaced}|/>"
102
+ if current_id in sentence:
103
+ sentence = sentence.replace(
104
+ current_id, custom_phenomes_list.pop(current_id)
105
+ )
106
+ min_value += 1
107
+
108
+ punct = sentences[i + 1] if i + 1 < len(sentences) else ""
109
+
110
+ if not sentence:
111
+ continue
112
+
113
+ full = sentence + punct
114
+ tokens = process_text_chunk(full)
115
+ results.append((full, tokens, len(tokens)))
116
+
117
+ return results
118
+
119
+
120
+ def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
121
+ latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
122
+ phenomes_list[latest_id] = s.group(0).strip()
123
+ return latest_id
124
+
125
+
126
+ async def smart_split(
127
+ text: str,
128
+ max_tokens: int = settings.absolute_max_tokens,
129
+ lang_code: str = "a",
130
+ normalization_options: NormalizationOptions = NormalizationOptions(),
131
+ ) -> AsyncGenerator[Tuple[str, List[int]], None]:
132
+ """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
133
+ start_time = time.time()
134
+ chunk_count = 0
135
+ logger.info(f"Starting smart split for {len(text)} chars")
136
+
137
+ custom_phoneme_list = {}
138
+
139
+ # Normalize text
140
+ if settings.advanced_text_normalization and normalization_options.normalize:
141
+ print(lang_code)
142
+ if lang_code in ["a", "b", "en-us", "en-gb"]:
143
+ text = CUSTOM_PHONEMES.sub(
144
+ lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
145
+ )
146
+ text = normalize_text(text, normalization_options)
147
+ else:
148
+ logger.info(
149
+ "Skipping text normalization as it is only supported for english"
150
+ )
151
+
152
+ # Process all sentences
153
+ sentences = get_sentence_info(text, custom_phoneme_list)
154
+
155
+ current_chunk = []
156
+ current_tokens = []
157
+ current_count = 0
158
+
159
+ for sentence, tokens, count in sentences:
160
+ # Handle sentences that exceed max tokens
161
+ if count > max_tokens:
162
+ # Yield current chunk if any
163
+ if current_chunk:
164
+ chunk_text = " ".join(current_chunk)
165
+ chunk_count += 1
166
+ logger.debug(
167
+ f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
168
+ )
169
+ yield chunk_text, current_tokens
170
+ current_chunk = []
171
+ current_tokens = []
172
+ current_count = 0
173
+
174
+ # Split long sentence on commas
175
+ clauses = re.split(r"([,])", sentence)
176
+ clause_chunk = []
177
+ clause_tokens = []
178
+ clause_count = 0
179
+
180
+ for j in range(0, len(clauses), 2):
181
+ clause = clauses[j].strip()
182
+ comma = clauses[j + 1] if j + 1 < len(clauses) else ""
183
+
184
+ if not clause:
185
+ continue
186
+
187
+ full_clause = clause + comma
188
+
189
+ tokens = process_text_chunk(full_clause)
190
+ count = len(tokens)
191
+
192
+ # If adding clause keeps us under max and not optimal yet
193
+ if (
194
+ clause_count + count <= max_tokens
195
+ and clause_count + count <= settings.target_max_tokens
196
+ ):
197
+ clause_chunk.append(full_clause)
198
+ clause_tokens.extend(tokens)
199
+ clause_count += count
200
+ else:
201
+ # Yield clause chunk if we have one
202
+ if clause_chunk:
203
+ chunk_text = " ".join(clause_chunk)
204
+ chunk_count += 1
205
+ logger.debug(
206
+ f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
207
+ )
208
+ yield chunk_text, clause_tokens
209
+ clause_chunk = [full_clause]
210
+ clause_tokens = tokens
211
+ clause_count = count
212
+
213
+ # Don't forget last clause chunk
214
+ if clause_chunk:
215
+ chunk_text = " ".join(clause_chunk)
216
+ chunk_count += 1
217
+ logger.debug(
218
+ f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
219
+ )
220
+ yield chunk_text, clause_tokens
221
+
222
+ # Regular sentence handling
223
+ elif (
224
+ current_count >= settings.target_min_tokens
225
+ and current_count + count > settings.target_max_tokens
226
+ ):
227
+ # If we have a good sized chunk and adding next sentence exceeds target,
228
+ # yield current chunk and start new one
229
+ chunk_text = " ".join(current_chunk)
230
+ chunk_count += 1
231
+ logger.info(
232
+ f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
233
+ )
234
+ yield chunk_text, current_tokens
235
+ current_chunk = [sentence]
236
+ current_tokens = tokens
237
+ current_count = count
238
+ elif current_count + count <= settings.target_max_tokens:
239
+ # Keep building chunk while under target max
240
+ current_chunk.append(sentence)
241
+ current_tokens.extend(tokens)
242
+ current_count += count
243
+ elif (
244
+ current_count + count <= max_tokens
245
+ and current_count < settings.target_min_tokens
246
+ ):
247
+ # Only exceed target max if we haven't reached minimum size yet
248
+ current_chunk.append(sentence)
249
+ current_tokens.extend(tokens)
250
+ current_count += count
251
+ else:
252
+ # Yield current chunk and start new one
253
+ if current_chunk:
254
+ chunk_text = " ".join(current_chunk)
255
+ chunk_count += 1
256
+ logger.info(
257
+ f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
258
+ )
259
+ yield chunk_text, current_tokens
260
+ current_chunk = [sentence]
261
+ current_tokens = tokens
262
+ current_count = count
263
+
264
+ # Don't forget the last chunk
265
+ if current_chunk:
266
+ chunk_text = " ".join(current_chunk)
267
+ chunk_count += 1
268
+ logger.info(
269
+ f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
270
+ )
271
+ yield chunk_text, current_tokens
272
+
273
+ total_time = time.time() - start_time
274
+ logger.info(
275
+ f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
276
+ )
api/src/services/text_processing/vocabulary.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_vocab():
2
+ """Get the vocabulary dictionary mapping characters to token IDs"""
3
+ _pad = "$"
4
+ _punctuation = ';:,.!?¡¿—…"«»"" '
5
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
6
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
7
+
8
+ # Create vocabulary dictionary
9
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
10
+ return {symbol: i for i, symbol in enumerate(symbols)}
11
+
12
+
13
+ # Initialize vocabulary
14
+ VOCAB = get_vocab()
15
+
16
+
17
+ def tokenize(phonemes: str) -> list[int]:
18
+ """Convert phonemes string to token IDs
19
+
20
+ Args:
21
+ phonemes: String of phonemes to tokenize
22
+
23
+ Returns:
24
+ List of token IDs
25
+ """
26
+ return [i for i in map(VOCAB.get, phonemes) if i is not None]
27
+
28
+
29
+ def decode_tokens(tokens: list[int]) -> str:
30
+ """Convert token IDs back to phonemes string
31
+
32
+ Args:
33
+ tokens: List of token IDs
34
+
35
+ Returns:
36
+ String of phonemes
37
+ """
38
+ # Create reverse mapping
39
+ id_to_symbol = {i: s for s, i in VOCAB.items()}
40
+ return "".join(id_to_symbol[t] for t in tokens)
api/src/services/tts_service.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TTS service using model and voice managers."""
2
+
3
+ import asyncio
4
+ import os
5
+ import re
6
+ import tempfile
7
+ import time
8
+ from typing import AsyncGenerator, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from kokoro import KPipeline
13
+ from loguru import logger
14
+
15
+ from ..core.config import settings
16
+ from ..inference.base import AudioChunk
17
+ from ..inference.kokoro_v1 import KokoroV1
18
+ from ..inference.model_manager import get_manager as get_model_manager
19
+ from ..inference.voice_manager import get_manager as get_voice_manager
20
+ from ..structures.schemas import NormalizationOptions
21
+ from .audio import AudioNormalizer, AudioService
22
+ from .streaming_audio_writer import StreamingAudioWriter
23
+ from .text_processing import tokenize
24
+ from .text_processing.text_processor import process_text_chunk, smart_split
25
+
26
+
27
+ class TTSService:
28
+ """Text-to-speech service."""
29
+
30
+ # Limit concurrent chunk processing
31
+ _chunk_semaphore = asyncio.Semaphore(4)
32
+
33
+ def __init__(self, output_dir: str = None):
34
+ """Initialize service."""
35
+ self.output_dir = output_dir
36
+ self.model_manager = None
37
+ self._voice_manager = None
38
+
39
+ @classmethod
40
+ async def create(cls, output_dir: str = None) -> "TTSService":
41
+ """Create and initialize TTSService instance."""
42
+ service = cls(output_dir)
43
+ service.model_manager = await get_model_manager()
44
+ service._voice_manager = await get_voice_manager()
45
+ return service
46
+
47
+ async def _process_chunk(
48
+ self,
49
+ chunk_text: str,
50
+ tokens: List[int],
51
+ voice_name: str,
52
+ voice_path: str,
53
+ speed: float,
54
+ writer: StreamingAudioWriter,
55
+ output_format: Optional[str] = None,
56
+ is_first: bool = False,
57
+ is_last: bool = False,
58
+ normalizer: Optional[AudioNormalizer] = None,
59
+ lang_code: Optional[str] = None,
60
+ return_timestamps: Optional[bool] = False,
61
+ ) -> AsyncGenerator[AudioChunk, None]:
62
+ """Process tokens into audio."""
63
+ async with self._chunk_semaphore:
64
+ try:
65
+ # Handle stream finalization
66
+ if is_last:
67
+ # Skip format conversion for raw audio mode
68
+ if not output_format:
69
+ yield AudioChunk(np.array([], dtype=np.int16), output=b"")
70
+ return
71
+ chunk_data = await AudioService.convert_audio(
72
+ AudioChunk(
73
+ np.array([], dtype=np.float32)
74
+ ), # Dummy data for type checking
75
+ output_format,
76
+ writer,
77
+ speed,
78
+ "",
79
+ normalizer=normalizer,
80
+ is_last_chunk=True,
81
+ )
82
+ yield chunk_data
83
+ return
84
+
85
+ # Skip empty chunks
86
+ if not tokens and not chunk_text:
87
+ return
88
+
89
+ # Get backend
90
+ backend = self.model_manager.get_backend()
91
+
92
+ # Generate audio using pre-warmed model
93
+ if isinstance(backend, KokoroV1):
94
+ chunk_index = 0
95
+ # For Kokoro V1, pass text and voice info with lang_code
96
+ async for chunk_data in self.model_manager.generate(
97
+ chunk_text,
98
+ (voice_name, voice_path),
99
+ speed=speed,
100
+ lang_code=lang_code,
101
+ return_timestamps=return_timestamps,
102
+ ):
103
+ # For streaming, convert to bytes
104
+ if output_format:
105
+ try:
106
+ chunk_data = await AudioService.convert_audio(
107
+ chunk_data,
108
+ output_format,
109
+ writer,
110
+ speed,
111
+ chunk_text,
112
+ is_last_chunk=is_last,
113
+ normalizer=normalizer,
114
+ )
115
+ yield chunk_data
116
+ except Exception as e:
117
+ logger.error(f"Failed to convert audio: {str(e)}")
118
+ else:
119
+ chunk_data = AudioService.trim_audio(
120
+ chunk_data, chunk_text, speed, is_last, normalizer
121
+ )
122
+ yield chunk_data
123
+ chunk_index += 1
124
+ else:
125
+ # For legacy backends, load voice tensor
126
+ voice_tensor = await self._voice_manager.load_voice(
127
+ voice_name, device=backend.device
128
+ )
129
+ chunk_data = await self.model_manager.generate(
130
+ tokens,
131
+ voice_tensor,
132
+ speed=speed,
133
+ return_timestamps=return_timestamps,
134
+ )
135
+
136
+ if chunk_data.audio is None:
137
+ logger.error("Model generated None for audio chunk")
138
+ return
139
+
140
+ if len(chunk_data.audio) == 0:
141
+ logger.error("Model generated empty audio chunk")
142
+ return
143
+
144
+ # For streaming, convert to bytes
145
+ if output_format:
146
+ try:
147
+ chunk_data = await AudioService.convert_audio(
148
+ chunk_data,
149
+ output_format,
150
+ writer,
151
+ speed,
152
+ chunk_text,
153
+ normalizer=normalizer,
154
+ is_last_chunk=is_last,
155
+ )
156
+ yield chunk_data
157
+ except Exception as e:
158
+ logger.error(f"Failed to convert audio: {str(e)}")
159
+ else:
160
+ trimmed = AudioService.trim_audio(
161
+ chunk_data, chunk_text, speed, is_last, normalizer
162
+ )
163
+ yield trimmed
164
+ except Exception as e:
165
+ logger.error(f"Failed to process tokens: {str(e)}")
166
+
167
+ async def _load_voice_from_path(self, path: str, weight: float):
168
+ # Check if the path is None and raise a ValueError if it is not
169
+ if not path:
170
+ raise ValueError(f"Voice not found at path: {path}")
171
+
172
+ logger.debug(f"Loading voice tensor from path: {path}")
173
+ return torch.load(path, map_location="cpu") * weight
174
+
175
+ async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
176
+ """Get voice path, handling combined voices.
177
+
178
+ Args:
179
+ voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
180
+
181
+ Returns:
182
+ Tuple of (voice name to use, voice path to use)
183
+
184
+ Raises:
185
+ RuntimeError: If voice not found
186
+ """
187
+ try:
188
+ # Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
189
+ split_voice = re.split(r"([-+])", voice)
190
+
191
+ # If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
192
+ if len(split_voice) == 1:
193
+ # Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
194
+ if (
195
+ "(" not in voice and ")" not in voice
196
+ ) or settings.voice_weight_normalization == True:
197
+ path = await self._voice_manager.get_voice_path(voice)
198
+ if not path:
199
+ raise RuntimeError(f"Voice not found: {voice}")
200
+ logger.debug(f"Using single voice path: {path}")
201
+ return voice, path
202
+
203
+ total_weight = 0
204
+
205
+ for voice_index in range(0, len(split_voice), 2):
206
+ voice_object = split_voice[voice_index]
207
+
208
+ if "(" in voice_object and ")" in voice_object:
209
+ voice_name = voice_object.split("(")[0].strip()
210
+ voice_weight = float(voice_object.split("(")[1].split(")")[0])
211
+ else:
212
+ voice_name = voice_object
213
+ voice_weight = 1
214
+
215
+ total_weight += voice_weight
216
+ split_voice[voice_index] = (voice_name, voice_weight)
217
+
218
+ # If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
219
+ if settings.voice_weight_normalization == False:
220
+ total_weight = 1
221
+
222
+ # Load the first voice as the starting point for voices to be combined onto
223
+ path = await self._voice_manager.get_voice_path(split_voice[0][0])
224
+ combined_tensor = await self._load_voice_from_path(
225
+ path, split_voice[0][1] / total_weight
226
+ )
227
+
228
+ # Loop through each + or - in split_voice so they can be applied to combined voice
229
+ for operation_index in range(1, len(split_voice) - 1, 2):
230
+ # Get the voice path of the voice 1 index ahead of the operator
231
+ path = await self._voice_manager.get_voice_path(
232
+ split_voice[operation_index + 1][0]
233
+ )
234
+ voice_tensor = await self._load_voice_from_path(
235
+ path, split_voice[operation_index + 1][1] / total_weight
236
+ )
237
+
238
+ # Either add or subtract the voice from the current combined voice
239
+ if split_voice[operation_index] == "+":
240
+ combined_tensor += voice_tensor
241
+ else:
242
+ combined_tensor -= voice_tensor
243
+
244
+ # Save the new combined voice so it can be loaded latter
245
+ temp_dir = tempfile.gettempdir()
246
+ combined_path = os.path.join(temp_dir, f"{voice}.pt")
247
+ logger.debug(f"Saving combined voice to: {combined_path}")
248
+ torch.save(combined_tensor, combined_path)
249
+ return voice, combined_path
250
+ except Exception as e:
251
+ logger.error(f"Failed to get voice path: {e}")
252
+ raise
253
+
254
+ async def generate_audio_stream(
255
+ self,
256
+ text: str,
257
+ voice: str,
258
+ writer: StreamingAudioWriter,
259
+ speed: float = 1.0,
260
+ output_format: str = "wav",
261
+ lang_code: Optional[str] = None,
262
+ normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
263
+ return_timestamps: Optional[bool] = False,
264
+ ) -> AsyncGenerator[AudioChunk, None]:
265
+ """Generate and stream audio chunks."""
266
+ stream_normalizer = AudioNormalizer()
267
+ chunk_index = 0
268
+ current_offset = 0.0
269
+ try:
270
+ # Get backend
271
+ backend = self.model_manager.get_backend()
272
+
273
+ # Get voice path, handling combined voices
274
+ voice_name, voice_path = await self._get_voices_path(voice)
275
+ logger.debug(f"Using voice path: {voice_path}")
276
+
277
+ # Use provided lang_code or determine from voice name
278
+ pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
279
+ logger.info(
280
+ f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
281
+ )
282
+
283
+ # Process text in chunks with smart splitting
284
+ async for chunk_text, tokens in smart_split(
285
+ text,
286
+ lang_code=pipeline_lang_code,
287
+ normalization_options=normalization_options,
288
+ ):
289
+ try:
290
+ # Process audio for chunk
291
+ async for chunk_data in self._process_chunk(
292
+ chunk_text, # Pass text for Kokoro V1
293
+ tokens, # Pass tokens for legacy backends
294
+ voice_name, # Pass voice name
295
+ voice_path, # Pass voice path
296
+ speed,
297
+ writer,
298
+ output_format,
299
+ is_first=(chunk_index == 0),
300
+ is_last=False, # We'll update the last chunk later
301
+ normalizer=stream_normalizer,
302
+ lang_code=pipeline_lang_code, # Pass lang_code
303
+ return_timestamps=return_timestamps,
304
+ ):
305
+ if chunk_data.word_timestamps is not None:
306
+ for timestamp in chunk_data.word_timestamps:
307
+ timestamp.start_time += current_offset
308
+ timestamp.end_time += current_offset
309
+
310
+ current_offset += len(chunk_data.audio) / 24000
311
+
312
+ if chunk_data.output is not None:
313
+ yield chunk_data
314
+
315
+ else:
316
+ logger.warning(
317
+ f"No audio generated for chunk: '{chunk_text[:100]}...'"
318
+ )
319
+ chunk_index += 1
320
+ except Exception as e:
321
+ logger.error(
322
+ f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
323
+ )
324
+ continue
325
+
326
+ # Only finalize if we successfully processed at least one chunk
327
+ if chunk_index > 0:
328
+ try:
329
+ # Empty tokens list to finalize audio
330
+ async for chunk_data in self._process_chunk(
331
+ "", # Empty text
332
+ [], # Empty tokens
333
+ voice_name,
334
+ voice_path,
335
+ speed,
336
+ writer,
337
+ output_format,
338
+ is_first=False,
339
+ is_last=True, # Signal this is the last chunk
340
+ normalizer=stream_normalizer,
341
+ lang_code=pipeline_lang_code, # Pass lang_code
342
+ ):
343
+ if chunk_data.output is not None:
344
+ yield chunk_data
345
+ except Exception as e:
346
+ logger.error(f"Failed to finalize audio stream: {str(e)}")
347
+
348
+ except Exception as e:
349
+ logger.error(f"Error in phoneme audio generation: {str(e)}")
350
+ raise e
351
+
352
+ async def generate_audio(
353
+ self,
354
+ text: str,
355
+ voice: str,
356
+ writer: StreamingAudioWriter,
357
+ speed: float = 1.0,
358
+ return_timestamps: bool = False,
359
+ normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
360
+ lang_code: Optional[str] = None,
361
+ ) -> AudioChunk:
362
+ """Generate complete audio for text using streaming internally."""
363
+ audio_data_chunks = []
364
+
365
+ try:
366
+ async for audio_stream_data in self.generate_audio_stream(
367
+ text,
368
+ voice,
369
+ writer,
370
+ speed=speed,
371
+ normalization_options=normalization_options,
372
+ return_timestamps=return_timestamps,
373
+ lang_code=lang_code,
374
+ output_format=None,
375
+ ):
376
+ if len(audio_stream_data.audio) > 0:
377
+ audio_data_chunks.append(audio_stream_data)
378
+
379
+ combined_audio_data = AudioChunk.combine(audio_data_chunks)
380
+ return combined_audio_data
381
+ except Exception as e:
382
+ logger.error(f"Error in audio generation: {str(e)}")
383
+ raise
384
+
385
+ async def combine_voices(self, voices: List[str]) -> torch.Tensor:
386
+ """Combine multiple voices.
387
+
388
+ Returns:
389
+ Combined voice tensor
390
+ """
391
+
392
+ return await self._voice_manager.combine_voices(voices)
393
+
394
+ async def list_voices(self) -> List[str]:
395
+ """List available voices."""
396
+ return await self._voice_manager.list_voices()
397
+
398
+ async def generate_from_phonemes(
399
+ self,
400
+ phonemes: str,
401
+ voice: str,
402
+ speed: float = 1.0,
403
+ lang_code: Optional[str] = None,
404
+ ) -> Tuple[np.ndarray, float]:
405
+ """Generate audio directly from phonemes.
406
+
407
+ Args:
408
+ phonemes: Phonemes in Kokoro format
409
+ voice: Voice name
410
+ speed: Speed multiplier
411
+ lang_code: Optional language code override
412
+
413
+ Returns:
414
+ Tuple of (audio array, processing time)
415
+ """
416
+ start_time = time.time()
417
+ try:
418
+ # Get backend and voice path
419
+ backend = self.model_manager.get_backend()
420
+ voice_name, voice_path = await self._get_voices_path(voice)
421
+
422
+ if isinstance(backend, KokoroV1):
423
+ # For Kokoro V1, use generate_from_tokens with raw phonemes
424
+ result = None
425
+ # Use provided lang_code or determine from voice name
426
+ pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
427
+ logger.info(
428
+ f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
429
+ )
430
+
431
+ try:
432
+ # Use backend's pipeline management
433
+ for r in backend._get_pipeline(
434
+ pipeline_lang_code
435
+ ).generate_from_tokens(
436
+ tokens=phonemes, # Pass raw phonemes string
437
+ voice=voice_path,
438
+ speed=speed,
439
+ ):
440
+ if r.audio is not None:
441
+ result = r
442
+ break
443
+ except Exception as e:
444
+ logger.error(f"Failed to generate from phonemes: {e}")
445
+ raise RuntimeError(f"Phoneme generation failed: {e}")
446
+
447
+ if result is None or result.audio is None:
448
+ raise ValueError("No audio generated")
449
+
450
+ processing_time = time.time() - start_time
451
+ return result.audio.numpy(), processing_time
452
+ else:
453
+ raise ValueError(
454
+ "Phoneme generation only supported with Kokoro V1 backend"
455
+ )
456
+
457
+ except Exception as e:
458
+ logger.error(f"Error in phoneme audio generation: {str(e)}")
459
+ raise
api/src/structures/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .schemas import (
2
+ CaptionedSpeechRequest,
3
+ CaptionedSpeechResponse,
4
+ OpenAISpeechRequest,
5
+ TTSStatus,
6
+ VoiceCombineRequest,
7
+ WordTimestamp,
8
+ )
9
+
10
+ __all__ = [
11
+ "OpenAISpeechRequest",
12
+ "CaptionedSpeechRequest",
13
+ "CaptionedSpeechResponse",
14
+ "WordTimestamp",
15
+ "TTSStatus",
16
+ "VoiceCombineRequest",
17
+ ]
api/src/structures/custom_responses.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import typing
3
+ from collections.abc import AsyncIterable, Iterable
4
+
5
+ from pydantic import BaseModel
6
+ from starlette.background import BackgroundTask
7
+ from starlette.concurrency import iterate_in_threadpool
8
+ from starlette.responses import JSONResponse, StreamingResponse
9
+
10
+
11
+ class JSONStreamingResponse(StreamingResponse, JSONResponse):
12
+ """StreamingResponse that also render with JSON."""
13
+
14
+ def __init__(
15
+ self,
16
+ content: Iterable | AsyncIterable,
17
+ status_code: int = 200,
18
+ headers: dict[str, str] | None = None,
19
+ media_type: str | None = None,
20
+ background: BackgroundTask | None = None,
21
+ ) -> None:
22
+ if isinstance(content, AsyncIterable):
23
+ self._content_iterable: AsyncIterable = content
24
+ else:
25
+ self._content_iterable = iterate_in_threadpool(content)
26
+
27
+ async def body_iterator() -> AsyncIterable[bytes]:
28
+ async for content_ in self._content_iterable:
29
+ if isinstance(content_, BaseModel):
30
+ content_ = content_.model_dump()
31
+ yield self.render(content_)
32
+
33
+ self.body_iterator = body_iterator()
34
+ self.status_code = status_code
35
+ if media_type is not None:
36
+ self.media_type = media_type
37
+ self.background = background
38
+ self.init_headers(headers)
39
+
40
+ def render(self, content: typing.Any) -> bytes:
41
+ return (
42
+ json.dumps(
43
+ content,
44
+ ensure_ascii=False,
45
+ allow_nan=False,
46
+ indent=None,
47
+ separators=(",", ":"),
48
+ )
49
+ + "\n"
50
+ ).encode("utf-8")
api/src/structures/model_schemas.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Voice configuration schemas."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class VoiceConfig(BaseModel):
7
+ """Voice configuration."""
8
+
9
+ use_cache: bool = Field(True, description="Whether to cache loaded voices")
10
+ cache_size: int = Field(3, description="Number of voices to cache")
11
+ validate_on_load: bool = Field(
12
+ True, description="Whether to validate voices when loading"
13
+ )
14
+
15
+ class Config:
16
+ frozen = True # Make config immutable
api/src/structures/schemas.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import List, Literal, Optional, Union
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class VoiceCombineRequest(BaseModel):
8
+ """Request schema for voice combination endpoint that accepts either a string with + or a list"""
9
+
10
+ voices: Union[str, List[str]] = Field(
11
+ ...,
12
+ description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine",
13
+ )
14
+
15
+
16
+ class TTSStatus(str, Enum):
17
+ PENDING = "pending"
18
+ PROCESSING = "processing"
19
+ COMPLETED = "completed"
20
+ FAILED = "failed"
21
+ DELETED = "deleted" # For files removed by cleanup
22
+
23
+
24
+ # OpenAI-compatible schemas
25
+ class WordTimestamp(BaseModel):
26
+ """Word-level timestamp information"""
27
+
28
+ word: str = Field(..., description="The word or token")
29
+ start_time: float = Field(..., description="Start time in seconds")
30
+ end_time: float = Field(..., description="End time in seconds")
31
+
32
+
33
+ class CaptionedSpeechResponse(BaseModel):
34
+ """Response schema for captioned speech endpoint"""
35
+
36
+ audio: str = Field(..., description="The generated audio data encoded in base 64")
37
+ audio_format: str = Field(..., description="The format of the output audio")
38
+ timestamps: Optional[List[WordTimestamp]] = Field(
39
+ ..., description="Word-level timestamps"
40
+ )
41
+
42
+
43
+ class NormalizationOptions(BaseModel):
44
+ """Options for the normalization system"""
45
+
46
+ normalize: bool = Field(
47
+ default=True,
48
+ description="Normalizes input text to make it easier for the model to say",
49
+ )
50
+ unit_normalization: bool = Field(
51
+ default=False, description="Transforms units like 10KB to 10 kilobytes"
52
+ )
53
+ url_normalization: bool = Field(
54
+ default=True,
55
+ description="Changes urls so they can be properly pronounced by kokoro",
56
+ )
57
+ email_normalization: bool = Field(
58
+ default=True,
59
+ description="Changes emails so they can be properly pronouced by kokoro",
60
+ )
61
+ optional_pluralization_normalization: bool = Field(
62
+ default=True,
63
+ description="Replaces (s) with s so some words get pronounced correctly",
64
+ )
65
+ phone_normalization: bool = Field(
66
+ default=True,
67
+ description="Changes phone numbers so they can be properly pronouced by kokoro",
68
+ )
69
+
70
+
71
+ class OpenAISpeechRequest(BaseModel):
72
+ """Request schema for OpenAI-compatible speech endpoint"""
73
+
74
+ model: str = Field(
75
+ default="kokoro",
76
+ description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro",
77
+ )
78
+ input: str = Field(..., description="The text to generate audio for")
79
+ voice: str = Field(
80
+ default="af_heart",
81
+ description="The voice to use for generation. Can be a base voice or a combined voice name.",
82
+ )
83
+ response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
84
+ default="mp3",
85
+ description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
86
+ )
87
+ download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = (
88
+ Field(
89
+ default=None,
90
+ description="Optional different format for the final download. If not provided, uses response_format.",
91
+ )
92
+ )
93
+ speed: float = Field(
94
+ default=1.0,
95
+ ge=0.25,
96
+ le=4.0,
97
+ description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
98
+ )
99
+ stream: bool = Field(
100
+ default=True, # Default to streaming for OpenAI compatibility
101
+ description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
102
+ )
103
+ return_download_link: bool = Field(
104
+ default=False,
105
+ description="If true, returns a download link in X-Download-Path header after streaming completes",
106
+ )
107
+ lang_code: Optional[str] = Field(
108
+ default=None,
109
+ description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
110
+ )
111
+ normalization_options: Optional[NormalizationOptions] = Field(
112
+ default=NormalizationOptions(),
113
+ description="Options for the normalization system",
114
+ )
115
+
116
+
117
+ class CaptionedSpeechRequest(BaseModel):
118
+ """Request schema for captioned speech endpoint"""
119
+
120
+ model: str = Field(
121
+ default="kokoro",
122
+ description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro",
123
+ )
124
+ input: str = Field(..., description="The text to generate audio for")
125
+ voice: str = Field(
126
+ default="af_heart",
127
+ description="The voice to use for generation. Can be a base voice or a combined voice name.",
128
+ )
129
+ response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
130
+ default="mp3",
131
+ description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
132
+ )
133
+ speed: float = Field(
134
+ default=1.0,
135
+ ge=0.25,
136
+ le=4.0,
137
+ description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
138
+ )
139
+ stream: bool = Field(
140
+ default=True, # Default to streaming for OpenAI compatibility
141
+ description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
142
+ )
143
+ return_timestamps: bool = Field(
144
+ default=True,
145
+ description="If true (default), returns word-level timestamps in the response",
146
+ )
147
+ return_download_link: bool = Field(
148
+ default=False,
149
+ description="If true, returns a download link in X-Download-Path header after streaming completes",
150
+ )
151
+ lang_code: Optional[str] = Field(
152
+ default=None,
153
+ description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
154
+ )
155
+ normalization_options: Optional[NormalizationOptions] = Field(
156
+ default=NormalizationOptions(),
157
+ description="Options for the normalization system",
158
+ )
api/src/structures/text_schemas.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ from pydantic import BaseModel, Field, field_validator
4
+
5
+
6
+ class PhonemeRequest(BaseModel):
7
+ text: str
8
+ language: str = "a" # Default to American English
9
+
10
+
11
+ class PhonemeResponse(BaseModel):
12
+ phonemes: str
13
+ tokens: list[int]
14
+
15
+
16
+ class StitchOptions(BaseModel):
17
+ """Options for stitching audio chunks together"""
18
+
19
+ gap_method: str = Field(
20
+ default="static_trim",
21
+ description="Method to handle gaps between chunks. Currently only 'static_trim' supported.",
22
+ )
23
+ trim_ms: int = Field(
24
+ default=0,
25
+ ge=0,
26
+ description="Milliseconds to trim from chunk boundaries when using static_trim",
27
+ )
28
+
29
+ @field_validator("gap_method")
30
+ @classmethod
31
+ def validate_gap_method(cls, v: str) -> str:
32
+ if v != "static_trim":
33
+ raise ValueError("Currently only 'static_trim' gap method is supported")
34
+ return v
35
+
36
+
37
+ class GenerateFromPhonemesRequest(BaseModel):
38
+ """Simple request for phoneme-to-speech generation"""
39
+
40
+ phonemes: str = Field(..., description="Phoneme string to synthesize")
41
+ voice: str = Field(..., description="Voice ID to use for generation")
api/tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Make tests directory a Python package
api/tests/conftest.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from unittest.mock import AsyncMock, MagicMock, patch
4
+
5
+ import numpy as np
6
+ import pytest
7
+ import pytest_asyncio
8
+ import torch
9
+
10
+ from api.src.inference.model_manager import ModelManager
11
+ from api.src.inference.voice_manager import VoiceManager
12
+ from api.src.services.tts_service import TTSService
13
+ from api.src.structures.model_schemas import VoiceConfig
14
+
15
+
16
+ @pytest.fixture
17
+ def mock_voice_tensor():
18
+ """Load a real voice tensor for testing."""
19
+ voice_path = os.path.join(
20
+ os.path.dirname(os.path.dirname(__file__)), "src/voices/af_bella.pt"
21
+ )
22
+ return torch.load(voice_path, map_location="cpu", weights_only=False)
23
+
24
+
25
+ @pytest.fixture
26
+ def mock_audio_output():
27
+ """Load pre-generated test audio for consistent testing."""
28
+ test_audio_path = os.path.join(
29
+ os.path.dirname(__file__), "test_data/test_audio.npy"
30
+ )
31
+ return np.load(test_audio_path) # Return as numpy array instead of bytes
32
+
33
+
34
+ @pytest_asyncio.fixture
35
+ async def mock_model_manager(mock_audio_output):
36
+ """Mock model manager for testing."""
37
+ manager = AsyncMock(spec=ModelManager)
38
+ manager.get_backend = MagicMock()
39
+
40
+ async def mock_generate(*args, **kwargs):
41
+ # Simulate successful audio generation
42
+ return np.random.rand(24000).astype(np.float32) # 1 second of random audio data
43
+
44
+ manager.generate = AsyncMock(side_effect=mock_generate)
45
+ return manager
46
+
47
+
48
+ @pytest_asyncio.fixture
49
+ async def mock_voice_manager(mock_voice_tensor):
50
+ """Mock voice manager for testing."""
51
+ manager = AsyncMock(spec=VoiceManager)
52
+ manager.get_voice_path = MagicMock(return_value="/mock/path/voice.pt")
53
+ manager.load_voice = AsyncMock(return_value=mock_voice_tensor)
54
+ manager.list_voices = AsyncMock(return_value=["voice1", "voice2"])
55
+ manager.combine_voices = AsyncMock(return_value="voice1_voice2")
56
+ return manager
57
+
58
+
59
+ @pytest_asyncio.fixture
60
+ async def tts_service(mock_model_manager, mock_voice_manager):
61
+ """Get mocked TTS service instance."""
62
+ service = TTSService()
63
+ service.model_manager = mock_model_manager
64
+ service._voice_manager = mock_voice_manager
65
+ return service
66
+
67
+
68
+ @pytest.fixture
69
+ def test_voice():
70
+ """Return a test voice name."""
71
+ return "voice1"
api/tests/test_audio_service.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for AudioService"""
2
+
3
+ from unittest.mock import patch
4
+
5
+ import numpy as np
6
+ import pytest
7
+
8
+ from api.src.inference.base import AudioChunk
9
+ from api.src.services.audio import AudioNormalizer, AudioService
10
+ from api.src.services.streaming_audio_writer import StreamingAudioWriter
11
+
12
+
13
+ @pytest.fixture(autouse=True)
14
+ def mock_settings():
15
+ """Mock settings for all tests"""
16
+ with patch("api.src.services.audio.settings") as mock_settings:
17
+ mock_settings.gap_trim_ms = 250
18
+ yield mock_settings
19
+
20
+
21
+ @pytest.fixture
22
+ def sample_audio():
23
+ """Generate a simple sine wave for testing"""
24
+ sample_rate = 24000
25
+ duration = 0.1 # 100ms
26
+ t = np.linspace(0, duration, int(sample_rate * duration))
27
+ frequency = 440 # A4 note
28
+ return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
29
+
30
+
31
+ @pytest.mark.asyncio
32
+ async def test_convert_to_wav(sample_audio):
33
+ """Test converting to WAV format"""
34
+ audio_data, sample_rate = sample_audio
35
+
36
+ writer = StreamingAudioWriter("wav", sample_rate=24000)
37
+ # Write and finalize in one step for WAV
38
+ audio_chunk = await AudioService.convert_audio(
39
+ AudioChunk(audio_data), "wav", writer, is_last_chunk=False
40
+ )
41
+
42
+ writer.close()
43
+
44
+ assert isinstance(audio_chunk.output, bytes)
45
+ assert isinstance(audio_chunk, AudioChunk)
46
+ assert len(audio_chunk.output) > 0
47
+ # Check WAV header
48
+ assert audio_chunk.output.startswith(b"RIFF")
49
+ assert b"WAVE" in audio_chunk.output[:12]
50
+
51
+
52
+ @pytest.mark.asyncio
53
+ async def test_convert_to_mp3(sample_audio):
54
+ """Test converting to MP3 format"""
55
+ audio_data, sample_rate = sample_audio
56
+
57
+ writer = StreamingAudioWriter("mp3", sample_rate=24000)
58
+
59
+ audio_chunk = await AudioService.convert_audio(
60
+ AudioChunk(audio_data), "mp3", writer
61
+ )
62
+
63
+ writer.close()
64
+
65
+ assert isinstance(audio_chunk.output, bytes)
66
+ assert isinstance(audio_chunk, AudioChunk)
67
+ assert len(audio_chunk.output) > 0
68
+ # Check MP3 header (ID3 or MPEG frame sync)
69
+ assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(
70
+ b"\xff\xfb"
71
+ )
72
+
73
+
74
+ @pytest.mark.asyncio
75
+ async def test_convert_to_opus(sample_audio):
76
+ """Test converting to Opus format"""
77
+
78
+ audio_data, sample_rate = sample_audio
79
+
80
+ writer = StreamingAudioWriter("opus", sample_rate=24000)
81
+
82
+ audio_chunk = await AudioService.convert_audio(
83
+ AudioChunk(audio_data), "opus", writer
84
+ )
85
+
86
+ writer.close()
87
+
88
+ assert isinstance(audio_chunk.output, bytes)
89
+ assert isinstance(audio_chunk, AudioChunk)
90
+ assert len(audio_chunk.output) > 0
91
+ # Check OGG header
92
+ assert audio_chunk.output.startswith(b"OggS")
93
+
94
+
95
+ @pytest.mark.asyncio
96
+ async def test_convert_to_flac(sample_audio):
97
+ """Test converting to FLAC format"""
98
+ audio_data, sample_rate = sample_audio
99
+
100
+ writer = StreamingAudioWriter("flac", sample_rate=24000)
101
+
102
+ audio_chunk = await AudioService.convert_audio(
103
+ AudioChunk(audio_data), "flac", writer
104
+ )
105
+
106
+ writer.close()
107
+
108
+ assert isinstance(audio_chunk.output, bytes)
109
+ assert isinstance(audio_chunk, AudioChunk)
110
+ assert len(audio_chunk.output) > 0
111
+ # Check FLAC header
112
+ assert audio_chunk.output.startswith(b"fLaC")
113
+
114
+
115
+ @pytest.mark.asyncio
116
+ async def test_convert_to_aac(sample_audio):
117
+ """Test converting to M4A format"""
118
+ audio_data, sample_rate = sample_audio
119
+
120
+ writer = StreamingAudioWriter("aac", sample_rate=24000)
121
+
122
+ audio_chunk = await AudioService.convert_audio(
123
+ AudioChunk(audio_data), "aac", writer
124
+ )
125
+
126
+ writer.close()
127
+
128
+ assert isinstance(audio_chunk.output, bytes)
129
+ assert isinstance(audio_chunk, AudioChunk)
130
+ assert len(audio_chunk.output) > 0
131
+ # Check ADTS header (AAC)
132
+ assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(
133
+ b"\xff\xf1"
134
+ )
135
+
136
+
137
+ @pytest.mark.asyncio
138
+ async def test_convert_to_pcm(sample_audio):
139
+ """Test converting to PCM format"""
140
+ audio_data, sample_rate = sample_audio
141
+
142
+ writer = StreamingAudioWriter("pcm", sample_rate=24000)
143
+
144
+ audio_chunk = await AudioService.convert_audio(
145
+ AudioChunk(audio_data), "pcm", writer
146
+ )
147
+
148
+ writer.close()
149
+
150
+ assert isinstance(audio_chunk.output, bytes)
151
+ assert isinstance(audio_chunk, AudioChunk)
152
+ assert len(audio_chunk.output) > 0
153
+ # PCM is raw bytes, so no header to check
154
+
155
+
156
+ @pytest.mark.asyncio
157
+ async def test_convert_to_invalid_format_raises_error(sample_audio):
158
+ """Test that converting to an invalid format raises an error"""
159
+ # audio_data, sample_rate = sample_audio
160
+ with pytest.raises(ValueError, match="Unsupported format: invalid"):
161
+ writer = StreamingAudioWriter("invalid", sample_rate=24000)
162
+
163
+
164
+ @pytest.mark.asyncio
165
+ async def test_normalization_wav(sample_audio):
166
+ """Test that WAV output is properly normalized to int16 range"""
167
+ audio_data, sample_rate = sample_audio
168
+
169
+ writer = StreamingAudioWriter("wav", sample_rate=24000)
170
+
171
+ # Create audio data outside int16 range
172
+ large_audio = audio_data * 1e5
173
+ # Write and finalize in one step for WAV
174
+ audio_chunk = await AudioService.convert_audio(
175
+ AudioChunk(large_audio), "wav", writer
176
+ )
177
+
178
+ writer.close()
179
+
180
+ assert isinstance(audio_chunk.output, bytes)
181
+ assert isinstance(audio_chunk, AudioChunk)
182
+ assert len(audio_chunk.output) > 0
183
+
184
+
185
+ @pytest.mark.asyncio
186
+ async def test_normalization_pcm(sample_audio):
187
+ """Test that PCM output is properly normalized to int16 range"""
188
+ audio_data, sample_rate = sample_audio
189
+
190
+ writer = StreamingAudioWriter("pcm", sample_rate=24000)
191
+
192
+ # Create audio data outside int16 range
193
+ large_audio = audio_data * 1e5
194
+ audio_chunk = await AudioService.convert_audio(
195
+ AudioChunk(large_audio), "pcm", writer
196
+ )
197
+ assert isinstance(audio_chunk.output, bytes)
198
+ assert isinstance(audio_chunk, AudioChunk)
199
+ assert len(audio_chunk.output) > 0
200
+
201
+
202
+ @pytest.mark.asyncio
203
+ async def test_invalid_audio_data():
204
+ """Test handling of invalid audio data"""
205
+ invalid_audio = np.array([]) # Empty array
206
+ sample_rate = 24000
207
+
208
+ writer = StreamingAudioWriter("wav", sample_rate=24000)
209
+
210
+ with pytest.raises(ValueError):
211
+ await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer)
212
+
213
+
214
+ @pytest.mark.asyncio
215
+ async def test_different_sample_rates(sample_audio):
216
+ """Test converting audio with different sample rates"""
217
+ audio_data, _ = sample_audio
218
+ sample_rates = [8000, 16000, 44100, 48000]
219
+
220
+ for rate in sample_rates:
221
+ writer = StreamingAudioWriter("wav", sample_rate=rate)
222
+
223
+ audio_chunk = await AudioService.convert_audio(
224
+ AudioChunk(audio_data), "wav", writer
225
+ )
226
+
227
+ writer.close()
228
+
229
+ assert isinstance(audio_chunk.output, bytes)
230
+ assert isinstance(audio_chunk, AudioChunk)
231
+ assert len(audio_chunk.output) > 0
232
+
233
+
234
+ @pytest.mark.asyncio
235
+ async def test_buffer_position_after_conversion(sample_audio):
236
+ """Test that buffer position is reset after writing"""
237
+ audio_data, sample_rate = sample_audio
238
+
239
+ writer = StreamingAudioWriter("wav", sample_rate=24000)
240
+
241
+ # Write and finalize in one step for first conversion
242
+ audio_chunk1 = await AudioService.convert_audio(
243
+ AudioChunk(audio_data), "wav", writer, is_last_chunk=True
244
+ )
245
+ assert isinstance(audio_chunk1.output, bytes)
246
+ assert isinstance(audio_chunk1, AudioChunk)
247
+ # Convert again to ensure buffer was properly reset
248
+
249
+ writer = StreamingAudioWriter("wav", sample_rate=24000)
250
+
251
+ audio_chunk2 = await AudioService.convert_audio(
252
+ AudioChunk(audio_data), "wav", writer, is_last_chunk=True
253
+ )
254
+ assert isinstance(audio_chunk2.output, bytes)
255
+ assert isinstance(audio_chunk2, AudioChunk)
256
+ assert len(audio_chunk1.output) == len(audio_chunk2.output)
api/tests/test_data/generate_test_data.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+
6
+ def generate_test_audio():
7
+ """Generate test audio data - 1 second of 440Hz tone"""
8
+ # Create 1 second of silence at 24kHz
9
+ audio = np.zeros(24000, dtype=np.float32)
10
+
11
+ # Add a simple sine wave to make it non-zero
12
+ t = np.linspace(0, 1, 24000)
13
+ audio += 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone at half amplitude
14
+
15
+ # Create test_data directory if it doesn't exist
16
+ os.makedirs("api/tests/test_data", exist_ok=True)
17
+
18
+ # Save the test audio
19
+ np.save("api/tests/test_data/test_audio.npy", audio)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ generate_test_audio()
api/tests/test_data/test_audio.npy ADDED
Binary file (96.1 kB). View file
 
api/tests/test_development.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import pytest
6
+ import requests
7
+
8
+
9
+ def test_generate_captioned_speech():
10
+ """Test the generate_captioned_speech function with mocked responses"""
11
+ # Mock the API responses
12
+ mock_audio_response = MagicMock()
13
+ mock_audio_response.status_code = 200
14
+
15
+ mock_timestamps_response = MagicMock()
16
+ mock_timestamps_response.status_code = 200
17
+ mock_timestamps_response.content = json.dumps(
18
+ {
19
+ "audio": base64.b64encode(b"mock audio data").decode("utf-8"),
20
+ "timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}],
21
+ }
22
+ )
23
+
24
+ # Patch the HTTP requests
25
+ with patch("requests.post", return_value=mock_timestamps_response):
26
+ # Import here to avoid module-level import issues
27
+ from examples.captioned_speech_example import generate_captioned_speech
28
+
29
+ # Test the function
30
+ audio, timestamps = generate_captioned_speech("test text")
31
+
32
+ # Verify we got both audio and timestamps
33
+ assert audio == b"mock audio data"
34
+ assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]
api/tests/test_kokoro_v1.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import ANY, MagicMock, patch
2
+
3
+ import numpy as np
4
+ import pytest
5
+ import torch
6
+
7
+ from api.src.inference.kokoro_v1 import KokoroV1
8
+
9
+
10
+ @pytest.fixture
11
+ def kokoro_backend():
12
+ """Create a KokoroV1 instance for testing."""
13
+ return KokoroV1()
14
+
15
+
16
+ def test_initial_state(kokoro_backend):
17
+ """Test initial state of KokoroV1."""
18
+ assert not kokoro_backend.is_loaded
19
+ assert kokoro_backend._model is None
20
+ assert kokoro_backend._pipelines == {} # Now using dict of pipelines
21
+ # Device should be set based on settings
22
+ assert kokoro_backend.device in ["cuda", "cpu"]
23
+
24
+
25
+ @patch("torch.cuda.is_available", return_value=True)
26
+ @patch("torch.cuda.memory_allocated", return_value=5e9)
27
+ def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
28
+ """Test GPU memory management functions."""
29
+ # Patch backend so it thinks we have cuda
30
+ with patch.object(kokoro_backend, "_device", "cuda"):
31
+ # Test memory check
32
+ with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
33
+ mock_config.pytorch_gpu.memory_threshold = 4
34
+ assert kokoro_backend._check_memory() == True
35
+
36
+ mock_config.pytorch_gpu.memory_threshold = 6
37
+ assert kokoro_backend._check_memory() == False
38
+
39
+
40
+ @patch("torch.cuda.empty_cache")
41
+ @patch("torch.cuda.synchronize")
42
+ def test_clear_memory(mock_sync, mock_clear, kokoro_backend):
43
+ """Test memory clearing."""
44
+ with patch.object(kokoro_backend, "_device", "cuda"):
45
+ kokoro_backend._clear_memory()
46
+ mock_clear.assert_called_once()
47
+ mock_sync.assert_called_once()
48
+
49
+
50
+ @pytest.mark.asyncio
51
+ async def test_load_model_validation(kokoro_backend):
52
+ """Test model loading validation."""
53
+ with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
54
+ await kokoro_backend.load_model("nonexistent_model.pth")
55
+
56
+
57
+ def test_unload_with_pipelines(kokoro_backend):
58
+ """Test model unloading with multiple pipelines."""
59
+ # Mock loaded state with multiple pipelines
60
+ kokoro_backend._model = MagicMock()
61
+ pipeline_a = MagicMock()
62
+ pipeline_e = MagicMock()
63
+ kokoro_backend._pipelines = {"a": pipeline_a, "e": pipeline_e}
64
+ assert kokoro_backend.is_loaded
65
+
66
+ # Test unload
67
+ kokoro_backend.unload()
68
+ assert not kokoro_backend.is_loaded
69
+ assert kokoro_backend._model is None
70
+ assert kokoro_backend._pipelines == {} # All pipelines should be cleared
71
+
72
+
73
+ @pytest.mark.asyncio
74
+ async def test_generate_validation(kokoro_backend):
75
+ """Test generation validation."""
76
+ with pytest.raises(RuntimeError, match="Model not loaded"):
77
+ async for _ in kokoro_backend.generate("test", "voice"):
78
+ pass
79
+
80
+
81
+ @pytest.mark.asyncio
82
+ async def test_generate_from_tokens_validation(kokoro_backend):
83
+ """Test token generation validation."""
84
+ with pytest.raises(RuntimeError, match="Model not loaded"):
85
+ async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
86
+ pass
87
+
88
+
89
+ def test_get_pipeline_creates_new(kokoro_backend):
90
+ """Test that _get_pipeline creates new pipeline for new language code."""
91
+ # Mock loaded state
92
+ kokoro_backend._model = MagicMock()
93
+
94
+ # Mock KPipeline
95
+ mock_pipeline = MagicMock()
96
+ with patch(
97
+ "api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline
98
+ ) as mock_kpipeline:
99
+ # Get pipeline for Spanish
100
+ pipeline_e = kokoro_backend._get_pipeline("e")
101
+
102
+ # Should create new pipeline with correct params
103
+ mock_kpipeline.assert_called_once_with(
104
+ lang_code="e", model=kokoro_backend._model, device=kokoro_backend._device
105
+ )
106
+ assert pipeline_e == mock_pipeline
107
+ assert kokoro_backend._pipelines["e"] == mock_pipeline
108
+
109
+
110
+ def test_get_pipeline_reuses_existing(kokoro_backend):
111
+ """Test that _get_pipeline reuses existing pipeline for same language code."""
112
+ # Mock loaded state
113
+ kokoro_backend._model = MagicMock()
114
+
115
+ # Mock KPipeline
116
+ mock_pipeline = MagicMock()
117
+ with patch(
118
+ "api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline
119
+ ) as mock_kpipeline:
120
+ # Get pipeline twice for same language
121
+ pipeline1 = kokoro_backend._get_pipeline("e")
122
+ pipeline2 = kokoro_backend._get_pipeline("e")
123
+
124
+ # Should only create pipeline once
125
+ mock_kpipeline.assert_called_once()
126
+ assert pipeline1 == pipeline2
127
+ assert kokoro_backend._pipelines["e"] == mock_pipeline
128
+
129
+
130
+ @pytest.mark.asyncio
131
+ async def test_generate_uses_correct_pipeline(kokoro_backend):
132
+ """Test that generate uses correct pipeline for language code."""
133
+ # Mock loaded state
134
+ kokoro_backend._model = MagicMock()
135
+
136
+ # Mock voice path handling
137
+ with (
138
+ patch("api.src.core.paths.load_voice_tensor") as mock_load_voice,
139
+ patch("api.src.core.paths.save_voice_tensor"),
140
+ patch("tempfile.gettempdir") as mock_tempdir,
141
+ ):
142
+ mock_load_voice.return_value = torch.ones(1)
143
+ mock_tempdir.return_value = "/tmp"
144
+
145
+ # Mock KPipeline
146
+ mock_pipeline = MagicMock()
147
+ mock_pipeline.return_value = iter([]) # Empty generator for testing
148
+ with patch("api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline):
149
+ # Generate with Spanish voice and explicit lang_code
150
+ async for _ in kokoro_backend.generate("test", "ef_voice", lang_code="e"):
151
+ pass
152
+
153
+ # Should create pipeline with Spanish lang_code
154
+ assert "e" in kokoro_backend._pipelines
155
+ # Use ANY to match the temp file path since it's dynamic
156
+ mock_pipeline.assert_called_with(
157
+ "test",
158
+ voice=ANY, # Don't check exact path since it's dynamic
159
+ speed=1.0,
160
+ model=kokoro_backend._model,
161
+ )
162
+ # Verify the voice path is a temp file path
163
+ call_args = mock_pipeline.call_args
164
+ assert isinstance(call_args[1]["voice"], str)
165
+ assert call_args[1]["voice"].startswith("/tmp/temp_voice_")
api/tests/test_normalizer.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for text normalization service"""
2
+
3
+ import pytest
4
+
5
+ from api.src.services.text_processing.normalizer import normalize_text
6
+ from api.src.structures.schemas import NormalizationOptions
7
+
8
+
9
+ def test_url_protocols():
10
+ """Test URL protocol handling"""
11
+ assert (
12
+ normalize_text(
13
+ "Check out https://example.com",
14
+ normalization_options=NormalizationOptions(),
15
+ )
16
+ == "Check out https example dot com"
17
+ )
18
+ assert (
19
+ normalize_text(
20
+ "Visit http://site.com", normalization_options=NormalizationOptions()
21
+ )
22
+ == "Visit http site dot com"
23
+ )
24
+ assert (
25
+ normalize_text(
26
+ "Go to https://test.org/path", normalization_options=NormalizationOptions()
27
+ )
28
+ == "Go to https test dot org slash path"
29
+ )
30
+
31
+
32
+ def test_url_www():
33
+ """Test www prefix handling"""
34
+ assert (
35
+ normalize_text(
36
+ "Go to www.example.com", normalization_options=NormalizationOptions()
37
+ )
38
+ == "Go to www example dot com"
39
+ )
40
+ assert (
41
+ normalize_text(
42
+ "Visit www.test.org/docs", normalization_options=NormalizationOptions()
43
+ )
44
+ == "Visit www test dot org slash docs"
45
+ )
46
+ assert (
47
+ normalize_text(
48
+ "Check www.site.com?q=test", normalization_options=NormalizationOptions()
49
+ )
50
+ == "Check www site dot com question-mark q equals test"
51
+ )
52
+
53
+
54
+ def test_url_localhost():
55
+ """Test localhost URL handling"""
56
+ assert (
57
+ normalize_text(
58
+ "Running on localhost:7860", normalization_options=NormalizationOptions()
59
+ )
60
+ == "Running on localhost colon 78 60"
61
+ )
62
+ assert (
63
+ normalize_text(
64
+ "Server at localhost:8080/api", normalization_options=NormalizationOptions()
65
+ )
66
+ == "Server at localhost colon 80 80 slash api"
67
+ )
68
+ assert (
69
+ normalize_text(
70
+ "Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
71
+ )
72
+ == "Test localhost colon 3000 slash test question-mark v equals 1"
73
+ )
74
+
75
+
76
+ def test_url_ip_addresses():
77
+ """Test IP address URL handling"""
78
+ assert (
79
+ normalize_text(
80
+ "Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
81
+ )
82
+ == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
83
+ )
84
+ assert (
85
+ normalize_text(
86
+ "API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
87
+ )
88
+ == "API at 192 dot 168 dot 1 dot 1 colon 8000"
89
+ )
90
+ assert (
91
+ normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
92
+ == "Server 127 dot 0 dot 0 dot 1"
93
+ )
94
+
95
+
96
+ def test_url_raw_domains():
97
+ """Test raw domain handling"""
98
+ assert (
99
+ normalize_text(
100
+ "Visit google.com/search", normalization_options=NormalizationOptions()
101
+ )
102
+ == "Visit google dot com slash search"
103
+ )
104
+ assert (
105
+ normalize_text(
106
+ "Go to example.com/path?q=test",
107
+ normalization_options=NormalizationOptions(),
108
+ )
109
+ == "Go to example dot com slash path question-mark q equals test"
110
+ )
111
+ assert (
112
+ normalize_text(
113
+ "Check docs.test.com", normalization_options=NormalizationOptions()
114
+ )
115
+ == "Check docs dot test dot com"
116
+ )
117
+
118
+
119
+ def test_url_email_addresses():
120
+ """Test email address handling"""
121
+ assert (
122
+ normalize_text(
123
+ "Email me at [email protected]", normalization_options=NormalizationOptions()
124
+ )
125
+ == "Email me at user at example dot com"
126
+ )
127
+ assert (
128
+ normalize_text(
129
+ "Contact [email protected]", normalization_options=NormalizationOptions()
130
+ )
131
+ == "Contact admin at test dot org"
132
+ )
133
+ assert (
134
+ normalize_text(
135
+ "Send to [email protected]", normalization_options=NormalizationOptions()
136
+ )
137
+ == "Send to test dot user at site dot com"
138
+ )
139
+
140
+
141
+ def test_money():
142
+ """Test that money text is normalized correctly"""
143
+ assert (
144
+ normalize_text(
145
+ "He lost $5.3 thousand.", normalization_options=NormalizationOptions()
146
+ )
147
+ == "He lost five point three thousand dollars."
148
+ )
149
+ assert (
150
+ normalize_text(
151
+ "To put it weirdly -$6.9 million",
152
+ normalization_options=NormalizationOptions(),
153
+ )
154
+ == "To put it weirdly minus six point nine million dollars"
155
+ )
156
+ assert (
157
+ normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
158
+ == "It costs fifty dollars and thirty cents."
159
+ )
160
+
161
+
162
+ def test_non_url_text():
163
+ """Test that non-URL text is unaffected"""
164
+ assert (
165
+ normalize_text(
166
+ "This is not.a.url text", normalization_options=NormalizationOptions()
167
+ )
168
+ == "This is not-a-url text"
169
+ )
170
+ assert (
171
+ normalize_text(
172
+ "Hello, how are you today?", normalization_options=NormalizationOptions()
173
+ )
174
+ == "Hello, how are you today?"
175
+ )
176
+ assert (
177
+ normalize_text("It costs $50.", normalization_options=NormalizationOptions())
178
+ == "It costs fifty dollars."
179
+ )
api/tests/test_openai_endpoints.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ from typing import AsyncGenerator, Tuple
5
+ from unittest.mock import AsyncMock, MagicMock, patch
6
+
7
+ import numpy as np
8
+ import pytest
9
+ from fastapi.testclient import TestClient
10
+
11
+ from api.src.core.config import settings
12
+ from api.src.inference.base import AudioChunk
13
+ from api.src.main import app
14
+ from api.src.routers.openai_compatible import (
15
+ get_tts_service,
16
+ load_openai_mappings,
17
+ stream_audio_chunks,
18
+ )
19
+ from api.src.services.streaming_audio_writer import StreamingAudioWriter
20
+ from api.src.services.tts_service import TTSService
21
+ from api.src.structures.schemas import OpenAISpeechRequest
22
+
23
+ client = TestClient(app)
24
+
25
+
26
+ @pytest.fixture
27
+ def test_voice():
28
+ """Fixture providing a test voice name."""
29
+ return "test_voice"
30
+
31
+
32
+ @pytest.fixture
33
+ def mock_openai_mappings():
34
+ """Mock OpenAI mappings for testing."""
35
+ with patch(
36
+ "api.src.routers.openai_compatible._openai_mappings",
37
+ {
38
+ "models": {"tts-1": "kokoro-v1_0", "tts-1-hd": "kokoro-v1_0"},
39
+ "voices": {"alloy": "am_adam", "nova": "bf_isabella"},
40
+ },
41
+ ):
42
+ yield
43
+
44
+
45
+ @pytest.fixture
46
+ def mock_json_file(tmp_path):
47
+ """Create a temporary mock JSON file."""
48
+ content = {
49
+ "models": {"test-model": "test-kokoro"},
50
+ "voices": {"test-voice": "test-internal"},
51
+ }
52
+ json_file = tmp_path / "test_mappings.json"
53
+ json_file.write_text(json.dumps(content))
54
+ return json_file
55
+
56
+
57
+ def test_load_openai_mappings(mock_json_file):
58
+ """Test loading OpenAI mappings from JSON file"""
59
+ with patch("os.path.join", return_value=str(mock_json_file)):
60
+ mappings = load_openai_mappings()
61
+ assert "models" in mappings
62
+ assert "voices" in mappings
63
+ assert mappings["models"]["test-model"] == "test-kokoro"
64
+ assert mappings["voices"]["test-voice"] == "test-internal"
65
+
66
+
67
+ def test_load_openai_mappings_file_not_found():
68
+ """Test handling of missing mappings file"""
69
+ with patch("os.path.join", return_value="/nonexistent/path"):
70
+ mappings = load_openai_mappings()
71
+ assert mappings == {"models": {}, "voices": {}}
72
+
73
+
74
+ def test_list_models(mock_openai_mappings):
75
+ """Test listing available models endpoint"""
76
+ response = client.get("/v1/models")
77
+ assert response.status_code == 200
78
+ data = response.json()
79
+ assert data["object"] == "list"
80
+ assert isinstance(data["data"], list)
81
+ assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
82
+
83
+ # Verify all expected models are present
84
+ model_ids = [model["id"] for model in data["data"]]
85
+ assert "tts-1" in model_ids
86
+ assert "tts-1-hd" in model_ids
87
+ assert "kokoro" in model_ids
88
+
89
+ # Verify model format
90
+ for model in data["data"]:
91
+ assert model["object"] == "model"
92
+ assert "created" in model
93
+ assert model["owned_by"] == "kokoro"
94
+
95
+
96
+ def test_retrieve_model(mock_openai_mappings):
97
+ """Test retrieving a specific model endpoint"""
98
+ # Test successful model retrieval
99
+ response = client.get("/v1/models/tts-1")
100
+ assert response.status_code == 200
101
+ data = response.json()
102
+ assert data["id"] == "tts-1"
103
+ assert data["object"] == "model"
104
+ assert data["owned_by"] == "kokoro"
105
+ assert "created" in data
106
+
107
+ # Test non-existent model
108
+ response = client.get("/v1/models/nonexistent-model")
109
+ assert response.status_code == 404
110
+ error = response.json()
111
+ assert error["detail"]["error"] == "model_not_found"
112
+ assert "not found" in error["detail"]["message"]
113
+ assert error["detail"]["type"] == "invalid_request_error"
114
+
115
+
116
+ @pytest.mark.asyncio
117
+ async def test_get_tts_service_initialization():
118
+ """Test TTSService initialization"""
119
+ with patch("api.src.routers.openai_compatible._tts_service", None):
120
+ with patch("api.src.routers.openai_compatible._init_lock", None):
121
+ with patch("api.src.services.tts_service.TTSService.create") as mock_create:
122
+ mock_service = AsyncMock()
123
+ mock_create.return_value = mock_service
124
+
125
+ # Test concurrent access
126
+ async def get_service():
127
+ return await get_tts_service()
128
+
129
+ # Create multiple concurrent requests
130
+ tasks = [get_service() for _ in range(5)]
131
+ results = await asyncio.gather(*tasks)
132
+
133
+ # Verify service was created only once
134
+ mock_create.assert_called_once()
135
+ assert all(r == mock_service for r in results)
136
+
137
+
138
+ @pytest.mark.asyncio
139
+ async def test_stream_audio_chunks_client_disconnect():
140
+ """Test handling of client disconnect during streaming"""
141
+ mock_request = MagicMock()
142
+ mock_request.is_disconnected = AsyncMock(return_value=True)
143
+
144
+ mock_service = AsyncMock()
145
+
146
+ async def mock_stream(*args, **kwargs):
147
+ for i in range(5):
148
+ yield AudioChunk(np.ndarray([], np.int16), output=b"chunk")
149
+
150
+ mock_service.generate_audio_stream = mock_stream
151
+ mock_service.list_voices.return_value = ["test_voice"]
152
+
153
+ request = OpenAISpeechRequest(
154
+ model="kokoro",
155
+ input="Test text",
156
+ voice="test_voice",
157
+ response_format="mp3",
158
+ stream=True,
159
+ speed=1.0,
160
+ )
161
+
162
+ writer = StreamingAudioWriter("mp3", 24000)
163
+
164
+ chunks = []
165
+ async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
166
+ chunks.append(chunk)
167
+
168
+ writer.close()
169
+
170
+ assert len(chunks) == 0 # Should stop immediately due to disconnect
171
+
172
+
173
+ def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings):
174
+ """Test OpenAI voice name mapping"""
175
+ mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
176
+
177
+ response = client.post(
178
+ "/v1/audio/speech",
179
+ json={
180
+ "model": "tts-1",
181
+ "input": "Hello world",
182
+ "voice": "alloy", # OpenAI voice name
183
+ "response_format": "mp3",
184
+ "stream": False,
185
+ },
186
+ )
187
+ assert response.status_code == 200
188
+ mock_tts_service.generate_audio.assert_called_once()
189
+ assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam"
190
+
191
+
192
+ def test_openai_voice_mapping_streaming(
193
+ mock_tts_service, mock_openai_mappings, mock_audio_bytes
194
+ ):
195
+ """Test OpenAI voice mapping in streaming mode"""
196
+ mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
197
+
198
+ response = client.post(
199
+ "/v1/audio/speech",
200
+ json={
201
+ "model": "tts-1-hd",
202
+ "input": "Hello world",
203
+ "voice": "nova", # OpenAI voice name
204
+ "response_format": "mp3",
205
+ "stream": True,
206
+ },
207
+ )
208
+ assert response.status_code == 200
209
+ content = b""
210
+ for chunk in response.iter_bytes():
211
+ content += chunk
212
+ assert content == mock_audio_bytes
213
+
214
+
215
+ def test_invalid_openai_model(mock_tts_service, mock_openai_mappings):
216
+ """Test error handling for invalid OpenAI model"""
217
+ response = client.post(
218
+ "/v1/audio/speech",
219
+ json={
220
+ "model": "invalid-model",
221
+ "input": "Hello world",
222
+ "voice": "alloy",
223
+ "response_format": "mp3",
224
+ "stream": False,
225
+ },
226
+ )
227
+ assert response.status_code == 400
228
+ error_response = response.json()
229
+ assert error_response["detail"]["error"] == "invalid_model"
230
+ assert "Unsupported model" in error_response["detail"]["message"]
231
+
232
+
233
+ @pytest.fixture
234
+ def mock_audio_bytes():
235
+ """Mock audio bytes for testing."""
236
+ return b"mock audio data"
237
+
238
+
239
+ @pytest.fixture
240
+ def mock_tts_service(mock_audio_bytes):
241
+ """Mock TTS service for testing."""
242
+ with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
243
+ service = AsyncMock(spec=TTSService)
244
+ service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
245
+
246
+ async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
247
+ yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes)
248
+
249
+ service.generate_audio_stream = mock_stream
250
+ service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
251
+ service.combine_voices.return_value = "voice1_voice2"
252
+
253
+ mock_get.return_value = service
254
+ mock_get.side_effect = None
255
+ yield service
256
+
257
+
258
+ @patch("api.src.services.audio.AudioService.convert_audio")
259
+ def test_openai_speech_endpoint(
260
+ mock_convert, mock_tts_service, test_voice, mock_audio_bytes
261
+ ):
262
+ """Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
263
+ # Configure mocks
264
+ mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
265
+ mock_convert.return_value = AudioChunk(
266
+ np.zeros(1000, np.int16), output=mock_audio_bytes
267
+ )
268
+
269
+ response = client.post(
270
+ "/v1/audio/speech",
271
+ json={
272
+ "model": "kokoro",
273
+ "input": "Hello world",
274
+ "voice": test_voice,
275
+ "response_format": "mp3",
276
+ "stream": False,
277
+ },
278
+ )
279
+ assert response.status_code == 200
280
+ assert response.headers["content-type"] == "audio/mpeg"
281
+ assert len(response.content) > 0
282
+ assert response.content == mock_audio_bytes + mock_audio_bytes
283
+
284
+ mock_tts_service.generate_audio.assert_called_once()
285
+ assert mock_convert.call_count == 2
286
+
287
+
288
+ def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
289
+ """Test the OpenAI-compatible speech endpoint with streaming"""
290
+ response = client.post(
291
+ "/v1/audio/speech",
292
+ json={
293
+ "model": "kokoro",
294
+ "input": "Hello world",
295
+ "voice": test_voice,
296
+ "response_format": "mp3",
297
+ "stream": True,
298
+ },
299
+ )
300
+ assert response.status_code == 200
301
+ assert response.headers["content-type"] == "audio/mpeg"
302
+ assert "Transfer-Encoding" in response.headers
303
+ assert response.headers["Transfer-Encoding"] == "chunked"
304
+
305
+ content = b""
306
+ for chunk in response.iter_bytes():
307
+ content += chunk
308
+ assert content == mock_audio_bytes
309
+
310
+
311
+ def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes):
312
+ """Test PCM streaming format"""
313
+ response = client.post(
314
+ "/v1/audio/speech",
315
+ json={
316
+ "model": "kokoro",
317
+ "input": "Hello world",
318
+ "voice": test_voice,
319
+ "response_format": "pcm",
320
+ "stream": True,
321
+ },
322
+ )
323
+ assert response.status_code == 200
324
+ assert response.headers["content-type"] == "audio/pcm"
325
+
326
+ content = b""
327
+ for chunk in response.iter_bytes():
328
+ content += chunk
329
+ assert content == mock_audio_bytes
330
+
331
+
332
+ def test_openai_speech_invalid_voice(mock_tts_service):
333
+ """Test error handling for invalid voice"""
334
+ mock_tts_service.generate_audio.side_effect = ValueError(
335
+ "Voice 'invalid_voice' not found"
336
+ )
337
+
338
+ response = client.post(
339
+ "/v1/audio/speech",
340
+ json={
341
+ "model": "kokoro",
342
+ "input": "Hello world",
343
+ "voice": "invalid_voice",
344
+ "response_format": "mp3",
345
+ "stream": False,
346
+ },
347
+ )
348
+ assert response.status_code == 400
349
+ error_response = response.json()
350
+ assert error_response["detail"]["error"] == "validation_error"
351
+ assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"]
352
+ assert error_response["detail"]["type"] == "invalid_request_error"
353
+
354
+
355
+ def test_openai_speech_empty_text(mock_tts_service, test_voice):
356
+ """Test error handling for empty text"""
357
+
358
+ async def mock_error_stream(*args, **kwargs):
359
+ raise ValueError("Text is empty after preprocessing")
360
+
361
+ mock_tts_service.generate_audio = mock_error_stream
362
+ mock_tts_service.list_voices.return_value = ["test_voice"]
363
+
364
+ response = client.post(
365
+ "/v1/audio/speech",
366
+ json={
367
+ "model": "kokoro",
368
+ "input": "",
369
+ "voice": test_voice,
370
+ "response_format": "mp3",
371
+ "stream": False,
372
+ },
373
+ )
374
+ assert response.status_code == 400
375
+ error_response = response.json()
376
+ assert error_response["detail"]["error"] == "validation_error"
377
+ assert "Text is empty after preprocessing" in error_response["detail"]["message"]
378
+ assert error_response["detail"]["type"] == "invalid_request_error"
379
+
380
+
381
+ def test_openai_speech_invalid_format(mock_tts_service, test_voice):
382
+ """Test error handling for invalid format"""
383
+ response = client.post(
384
+ "/v1/audio/speech",
385
+ json={
386
+ "model": "kokoro",
387
+ "input": "Hello world",
388
+ "voice": test_voice,
389
+ "response_format": "invalid_format",
390
+ "stream": False,
391
+ },
392
+ )
393
+ assert response.status_code == 422 # Validation error from Pydantic
394
+
395
+
396
+ def test_list_voices(mock_tts_service):
397
+ """Test listing available voices"""
398
+ # Override the mock for this specific test
399
+ mock_tts_service.list_voices.return_value = ["voice1", "voice2"]
400
+
401
+ response = client.get("/v1/audio/voices")
402
+ assert response.status_code == 200
403
+ data = response.json()
404
+ assert "voices" in data
405
+ assert len(data["voices"]) == 2
406
+ assert "voice1" in data["voices"]
407
+ assert "voice2" in data["voices"]
408
+
409
+
410
+ @patch("api.src.routers.openai_compatible.settings")
411
+ def test_combine_voices(mock_settings, mock_tts_service):
412
+ """Test combining voices endpoint"""
413
+ # Enable local voice saving for this test
414
+ mock_settings.allow_local_voice_saving = True
415
+
416
+ response = client.post("/v1/audio/voices/combine", json="voice1+voice2")
417
+ assert response.status_code == 200
418
+ assert response.headers["content-type"] == "application/octet-stream"
419
+ assert "voice1+voice2.pt" in response.headers["content-disposition"]
420
+
421
+
422
+ def test_server_error(mock_tts_service, test_voice):
423
+ """Test handling of server errors"""
424
+
425
+ async def mock_error_stream(*args, **kwargs):
426
+ raise RuntimeError("Internal server error")
427
+
428
+ mock_tts_service.generate_audio = mock_error_stream
429
+ mock_tts_service.list_voices.return_value = ["test_voice"]
430
+
431
+ response = client.post(
432
+ "/v1/audio/speech",
433
+ json={
434
+ "model": "kokoro",
435
+ "input": "Hello world",
436
+ "voice": test_voice,
437
+ "response_format": "mp3",
438
+ "stream": False,
439
+ },
440
+ )
441
+ assert response.status_code == 500
442
+ error_response = response.json()
443
+ assert error_response["detail"]["error"] == "processing_error"
444
+ assert error_response["detail"]["type"] == "server_error"
445
+
446
+
447
+ def test_streaming_error(mock_tts_service, test_voice):
448
+ """Test handling streaming errors"""
449
+ # Mock process_voices to raise the error
450
+ mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
451
+
452
+ response = client.post(
453
+ "/v1/audio/speech",
454
+ json={
455
+ "model": "kokoro",
456
+ "input": "Hello world",
457
+ "voice": test_voice,
458
+ "response_format": "mp3",
459
+ "stream": True,
460
+ },
461
+ )
462
+
463
+ assert response.status_code == 500
464
+ error_data = response.json()
465
+ assert error_data["detail"]["error"] == "processing_error"
466
+ assert error_data["detail"]["type"] == "server_error"
467
+ assert "Streaming failed" in error_data["detail"]["message"]
468
+
469
+
470
+ @pytest.mark.asyncio
471
+ async def test_streaming_initialization_error():
472
+ """Test handling of streaming initialization errors"""
473
+ mock_service = AsyncMock()
474
+
475
+ async def mock_error_stream(*args, **kwargs):
476
+ if False: # This makes it a proper generator
477
+ yield b""
478
+ raise RuntimeError("Failed to initialize stream")
479
+
480
+ mock_service.generate_audio_stream = mock_error_stream
481
+ mock_service.list_voices.return_value = ["test_voice"]
482
+
483
+ request = OpenAISpeechRequest(
484
+ model="kokoro",
485
+ input="Test text",
486
+ voice="test_voice",
487
+ response_format="mp3",
488
+ stream=True,
489
+ speed=1.0,
490
+ )
491
+
492
+ writer = StreamingAudioWriter("mp3", 24000)
493
+
494
+ with pytest.raises(RuntimeError) as exc:
495
+ async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
496
+ pass
497
+
498
+ writer.close()
499
+ assert "Failed to initialize stream" in str(exc.value)
api/tests/test_paths.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from unittest.mock import patch
3
+
4
+ import pytest
5
+
6
+ from api.src.core.paths import (
7
+ _find_file,
8
+ _scan_directories,
9
+ get_content_type,
10
+ get_temp_dir_size,
11
+ get_temp_file_path,
12
+ list_temp_files,
13
+ )
14
+
15
+
16
+ @pytest.mark.asyncio
17
+ async def test_find_file_exists():
18
+ """Test finding existing file."""
19
+ with patch("aiofiles.os.path.exists") as mock_exists:
20
+ mock_exists.return_value = True
21
+ path = await _find_file("test.txt", ["/test/path"])
22
+ assert path == "/test/path/test.txt"
23
+
24
+
25
+ @pytest.mark.asyncio
26
+ async def test_find_file_not_exists():
27
+ """Test finding non-existent file."""
28
+ with patch("aiofiles.os.path.exists") as mock_exists:
29
+ mock_exists.return_value = False
30
+ with pytest.raises(FileNotFoundError, match="File not found"):
31
+ await _find_file("test.txt", ["/test/path"])
32
+
33
+
34
+ @pytest.mark.asyncio
35
+ async def test_find_file_with_filter():
36
+ """Test finding file with filter function."""
37
+ with patch("aiofiles.os.path.exists") as mock_exists:
38
+ mock_exists.return_value = True
39
+ filter_fn = lambda p: p.endswith(".txt")
40
+ path = await _find_file("test.txt", ["/test/path"], filter_fn)
41
+ assert path == "/test/path/test.txt"
42
+
43
+
44
+ @pytest.mark.asyncio
45
+ async def test_scan_directories():
46
+ """Test scanning directories."""
47
+ mock_entry = type("MockEntry", (), {"name": "test.txt"})()
48
+
49
+ with (
50
+ patch("aiofiles.os.path.exists") as mock_exists,
51
+ patch("aiofiles.os.scandir") as mock_scandir,
52
+ ):
53
+ mock_exists.return_value = True
54
+ mock_scandir.return_value = [mock_entry]
55
+
56
+ files = await _scan_directories(["/test/path"])
57
+ assert "test.txt" in files
58
+
59
+
60
+ @pytest.mark.asyncio
61
+ async def test_get_content_type():
62
+ """Test content type detection."""
63
+ test_cases = [
64
+ ("test.html", "text/html"),
65
+ ("test.js", "application/javascript"),
66
+ ("test.css", "text/css"),
67
+ ("test.png", "image/png"),
68
+ ("test.unknown", "application/octet-stream"),
69
+ ]
70
+
71
+ for filename, expected in test_cases:
72
+ content_type = await get_content_type(filename)
73
+ assert content_type == expected
74
+
75
+
76
+ @pytest.mark.asyncio
77
+ async def test_get_temp_file_path():
78
+ """Test temp file path generation."""
79
+ with (
80
+ patch("aiofiles.os.path.exists") as mock_exists,
81
+ patch("aiofiles.os.makedirs") as mock_makedirs,
82
+ ):
83
+ mock_exists.return_value = False
84
+
85
+ path = await get_temp_file_path("test.wav")
86
+ assert "test.wav" in path
87
+ mock_makedirs.assert_called_once()
88
+
89
+
90
+ @pytest.mark.asyncio
91
+ async def test_list_temp_files():
92
+ """Test listing temp files."""
93
+
94
+ class MockEntry:
95
+ def __init__(self, name):
96
+ self.name = name
97
+
98
+ def is_file(self):
99
+ return True
100
+
101
+ mock_entry = MockEntry("test.wav")
102
+
103
+ with (
104
+ patch("aiofiles.os.path.exists") as mock_exists,
105
+ patch("aiofiles.os.scandir") as mock_scandir,
106
+ ):
107
+ mock_exists.return_value = True
108
+ mock_scandir.return_value = [mock_entry]
109
+
110
+ files = await list_temp_files()
111
+ assert "test.wav" in files
112
+
113
+
114
+ @pytest.mark.asyncio
115
+ async def test_get_temp_dir_size():
116
+ """Test getting temp directory size."""
117
+
118
+ class MockEntry:
119
+ def __init__(self, path):
120
+ self.path = path
121
+
122
+ def is_file(self):
123
+ return True
124
+
125
+ mock_entry = MockEntry("/tmp/test.wav")
126
+ mock_stat = type("MockStat", (), {"st_size": 1024})()
127
+
128
+ with (
129
+ patch("aiofiles.os.path.exists") as mock_exists,
130
+ patch("aiofiles.os.scandir") as mock_scandir,
131
+ patch("aiofiles.os.stat") as mock_stat_fn,
132
+ ):
133
+ mock_exists.return_value = True
134
+ mock_scandir.return_value = [mock_entry]
135
+ mock_stat_fn.return_value = mock_stat
136
+
137
+ size = await get_temp_dir_size()
138
+ assert size == 1024
api/tests/test_text_processor.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from api.src.services.text_processing.text_processor import (
4
+ get_sentence_info,
5
+ process_text_chunk,
6
+ smart_split,
7
+ )
8
+
9
+
10
+ def test_process_text_chunk_basic():
11
+ """Test basic text chunk processing."""
12
+ text = "Hello world"
13
+ tokens = process_text_chunk(text)
14
+ assert isinstance(tokens, list)
15
+ assert len(tokens) > 0
16
+
17
+
18
+ def test_process_text_chunk_empty():
19
+ """Test processing empty text."""
20
+ text = ""
21
+ tokens = process_text_chunk(text)
22
+ assert isinstance(tokens, list)
23
+ assert len(tokens) == 0
24
+
25
+
26
+ def test_process_text_chunk_phonemes():
27
+ """Test processing with skip_phonemize."""
28
+ phonemes = "h @ l @U" # Example phoneme sequence
29
+ tokens = process_text_chunk(phonemes, skip_phonemize=True)
30
+ assert isinstance(tokens, list)
31
+ assert len(tokens) > 0
32
+
33
+
34
+ def test_get_sentence_info():
35
+ """Test sentence splitting and info extraction."""
36
+ text = "This is sentence one. This is sentence two! What about three?"
37
+ results = get_sentence_info(text, {})
38
+
39
+ assert len(results) == 3
40
+ for sentence, tokens, count in results:
41
+ assert isinstance(sentence, str)
42
+ assert isinstance(tokens, list)
43
+ assert isinstance(count, int)
44
+ assert count == len(tokens)
45
+ assert count > 0
46
+
47
+
48
+ def test_get_sentence_info_phenomoes():
49
+ """Test sentence splitting and info extraction."""
50
+ text = (
51
+ "This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
52
+ )
53
+ results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
54
+
55
+ assert len(results) == 3
56
+ assert "sˈɛntᵊns" in results[1][0]
57
+ for sentence, tokens, count in results:
58
+ assert isinstance(sentence, str)
59
+ assert isinstance(tokens, list)
60
+ assert isinstance(count, int)
61
+ assert count == len(tokens)
62
+ assert count > 0
63
+
64
+
65
+ @pytest.mark.asyncio
66
+ async def test_smart_split_short_text():
67
+ """Test smart splitting with text under max tokens."""
68
+ text = "This is a short test sentence."
69
+ chunks = []
70
+ async for chunk_text, chunk_tokens in smart_split(text):
71
+ chunks.append((chunk_text, chunk_tokens))
72
+
73
+ assert len(chunks) == 1
74
+ assert isinstance(chunks[0][0], str)
75
+ assert isinstance(chunks[0][1], list)
76
+
77
+
78
+ @pytest.mark.asyncio
79
+ async def test_smart_split_long_text():
80
+ """Test smart splitting with longer text."""
81
+ # Create text that should split into multiple chunks
82
+ text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
83
+
84
+ chunks = []
85
+ async for chunk_text, chunk_tokens in smart_split(text):
86
+ chunks.append((chunk_text, chunk_tokens))
87
+
88
+ assert len(chunks) > 1
89
+ for chunk_text, chunk_tokens in chunks:
90
+ assert isinstance(chunk_text, str)
91
+ assert isinstance(chunk_tokens, list)
92
+ assert len(chunk_tokens) > 0
93
+
94
+
95
+ @pytest.mark.asyncio
96
+ async def test_smart_split_with_punctuation():
97
+ """Test smart splitting handles punctuation correctly."""
98
+ text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
99
+
100
+ chunks = []
101
+ async for chunk_text, chunk_tokens in smart_split(text):
102
+ chunks.append(chunk_text)
103
+
104
+ # Verify punctuation is preserved
105
+ assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
api/tests/test_tts_service.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import AsyncMock, MagicMock, patch
2
+
3
+ import numpy as np
4
+ import pytest
5
+ import torch
6
+
7
+ from api.src.services.tts_service import TTSService
8
+
9
+
10
+ @pytest.fixture
11
+ def mock_managers():
12
+ """Mock model and voice managers."""
13
+
14
+ async def _mock_managers():
15
+ model_manager = AsyncMock()
16
+ model_manager.get_backend.return_value = MagicMock()
17
+
18
+ voice_manager = AsyncMock()
19
+ voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
20
+ voice_manager.list_voices.return_value = ["voice1", "voice2"]
21
+
22
+ with (
23
+ patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
24
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
25
+ ):
26
+ mock_get_model.return_value = model_manager
27
+ mock_get_voice.return_value = voice_manager
28
+ return model_manager, voice_manager
29
+
30
+ return _mock_managers()
31
+
32
+
33
+ @pytest.fixture
34
+ def tts_service(mock_managers):
35
+ """Create TTSService instance with mocked dependencies."""
36
+
37
+ async def _create_service():
38
+ return await TTSService.create("test_output")
39
+
40
+ return _create_service()
41
+
42
+
43
+ @pytest.mark.asyncio
44
+ async def test_service_creation():
45
+ """Test service creation and initialization."""
46
+ model_manager = AsyncMock()
47
+ voice_manager = AsyncMock()
48
+
49
+ with (
50
+ patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
51
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
52
+ ):
53
+ mock_get_model.return_value = model_manager
54
+ mock_get_voice.return_value = voice_manager
55
+
56
+ service = await TTSService.create("test_output")
57
+ assert service.output_dir == "test_output"
58
+ assert service.model_manager is model_manager
59
+ assert service._voice_manager is voice_manager
60
+
61
+
62
+ @pytest.mark.asyncio
63
+ async def test_get_voice_path_single():
64
+ """Test getting path for single voice."""
65
+ model_manager = AsyncMock()
66
+ voice_manager = AsyncMock()
67
+ voice_manager.get_voice_path.return_value = "/path/to/voice1.pt"
68
+
69
+ with (
70
+ patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
71
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
72
+ ):
73
+ mock_get_model.return_value = model_manager
74
+ mock_get_voice.return_value = voice_manager
75
+
76
+ service = await TTSService.create("test_output")
77
+ name, path = await service._get_voices_path("voice1")
78
+ assert name == "voice1"
79
+ assert path == "/path/to/voice1.pt"
80
+ voice_manager.get_voice_path.assert_called_once_with("voice1")
81
+
82
+
83
+ @pytest.mark.asyncio
84
+ async def test_get_voice_path_combined():
85
+ """Test getting path for combined voices."""
86
+ model_manager = AsyncMock()
87
+ voice_manager = AsyncMock()
88
+ voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
89
+
90
+ with (
91
+ patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
92
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
93
+ patch("torch.load") as mock_load,
94
+ patch("torch.save") as mock_save,
95
+ patch("tempfile.gettempdir") as mock_temp,
96
+ ):
97
+ mock_get_model.return_value = model_manager
98
+ mock_get_voice.return_value = voice_manager
99
+ mock_temp.return_value = "/tmp"
100
+ mock_load.return_value = torch.ones(10)
101
+
102
+ service = await TTSService.create("test_output")
103
+ name, path = await service._get_voices_path("voice1+voice2")
104
+ assert name == "voice1+voice2"
105
+ assert path.endswith("voice1+voice2.pt")
106
+ mock_save.assert_called_once()
107
+
108
+
109
+ @pytest.mark.asyncio
110
+ async def test_list_voices():
111
+ """Test listing available voices."""
112
+ model_manager = AsyncMock()
113
+ voice_manager = AsyncMock()
114
+ voice_manager.list_voices.return_value = ["voice1", "voice2"]
115
+
116
+ with (
117
+ patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
118
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
119
+ ):
120
+ mock_get_model.return_value = model_manager
121
+ mock_get_voice.return_value = voice_manager
122
+
123
+ service = await TTSService.create("test_output")
124
+ voices = await service.list_voices()
125
+ assert voices == ["voice1", "voice2"]
126
+ voice_manager.list_voices.assert_called_once()
charts/kokoro-fastapi/.helmignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Patterns to ignore when building packages.
2
+ # This supports shell glob matching, relative path matching, and
3
+ # negation (prefixed with !). Only one pattern per line.
4
+ .DS_Store
5
+ # Common VCS dirs
6
+ .git/
7
+ .gitignore
8
+ .bzr/
9
+ .bzrignore
10
+ .hg/
11
+ .hgignore
12
+ .svn/
13
+ # Common backup files
14
+ *.swp
15
+ *.bak
16
+ *.tmp
17
+ *.orig
18
+ *~
19
+ # Various IDEs
20
+ .project
21
+ .idea/
22
+ *.tmproj
23
+ .vscode/
charts/kokoro-fastapi/Chart.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apiVersion: v2
2
+ name: kokoro-fastapi
3
+ description: A Helm chart for deploying the Kokoro FastAPI TTS service to Kubernetes
4
+ type: application
5
+ version: 0.3.0
6
+ appVersion: "0.3.0"
7
+
8
+ keywords:
9
+ - tts
10
+ - fastapi
11
+ - gpu
12
+ - kokoro