Spaces:
Runtime error
Runtime error
Michael Hu
commited on
Commit
·
05b45a5
1
Parent(s):
e55a2a8
initial check in
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +66 -0
- api/__init__.py +1 -0
- api/src/builds/v1_0/config.json +172 -0
- api/src/core/__init__.py +3 -0
- api/src/core/config.py +85 -0
- api/src/core/don_quixote.txt +9 -0
- api/src/core/model_config.py +50 -0
- api/src/core/openai_mappings.json +18 -0
- api/src/core/paths.py +413 -0
- api/src/inference/__init__.py +12 -0
- api/src/inference/base.py +127 -0
- api/src/inference/kokoro_v1.py +370 -0
- api/src/inference/model_manager.py +171 -0
- api/src/inference/voice_manager.py +115 -0
- api/src/main.py +152 -0
- api/src/models/v1_0/config.json +150 -0
- api/src/routers/__init__.py +1 -0
- api/src/routers/debug.py +209 -0
- api/src/routers/development.py +408 -0
- api/src/routers/openai_compatible.py +662 -0
- api/src/routers/web_player.py +49 -0
- api/src/services/__init__.py +3 -0
- api/src/services/audio.py +248 -0
- api/src/services/streaming_audio_writer.py +100 -0
- api/src/services/temp_manager.py +170 -0
- api/src/services/text_processing/__init__.py +21 -0
- api/src/services/text_processing/normalizer.py +415 -0
- api/src/services/text_processing/phonemizer.py +102 -0
- api/src/services/text_processing/text_processor.py +276 -0
- api/src/services/text_processing/vocabulary.py +40 -0
- api/src/services/tts_service.py +459 -0
- api/src/structures/__init__.py +17 -0
- api/src/structures/custom_responses.py +50 -0
- api/src/structures/model_schemas.py +16 -0
- api/src/structures/schemas.py +158 -0
- api/src/structures/text_schemas.py +41 -0
- api/tests/__init__.py +1 -0
- api/tests/conftest.py +71 -0
- api/tests/test_audio_service.py +256 -0
- api/tests/test_data/generate_test_data.py +23 -0
- api/tests/test_data/test_audio.npy +0 -0
- api/tests/test_development.py +34 -0
- api/tests/test_kokoro_v1.py +165 -0
- api/tests/test_normalizer.py +179 -0
- api/tests/test_openai_endpoints.py +499 -0
- api/tests/test_paths.py +138 -0
- api/tests/test_text_processor.py +105 -0
- api/tests/test_tts_service.py +126 -0
- charts/kokoro-fastapi/.helmignore +23 -0
- charts/kokoro-fastapi/Chart.yaml +12 -0
Dockerfile
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
# Install dependencies and check espeak location
|
4 |
+
RUN apt-get update && apt-get install -y \
|
5 |
+
espeak-ng \
|
6 |
+
espeak-ng-data \
|
7 |
+
git \
|
8 |
+
libsndfile1 \
|
9 |
+
curl \
|
10 |
+
ffmpeg \
|
11 |
+
g++ \
|
12 |
+
&& apt-get clean \
|
13 |
+
&& rm -rf /var/lib/apt/lists/* \
|
14 |
+
&& mkdir -p /usr/share/espeak-ng-data \
|
15 |
+
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
|
16 |
+
|
17 |
+
# Install UV using the installer script
|
18 |
+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
19 |
+
mv /root/.local/bin/uv /usr/local/bin/ && \
|
20 |
+
mv /root/.local/bin/uvx /usr/local/bin/
|
21 |
+
|
22 |
+
# Create non-root user and set up directories and permissions
|
23 |
+
RUN useradd -m -u 1000 appuser && \
|
24 |
+
mkdir -p /app/api/src/models/v1_0 && \
|
25 |
+
chown -R appuser:appuser /app
|
26 |
+
|
27 |
+
USER appuser
|
28 |
+
WORKDIR /app
|
29 |
+
|
30 |
+
# Copy dependency files
|
31 |
+
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
32 |
+
|
33 |
+
# Install Rust (required to build sudachipy and pyopenjtalk-plus)
|
34 |
+
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
|
35 |
+
ENV PATH="/home/appuser/.cargo/bin:$PATH"
|
36 |
+
|
37 |
+
# Install dependencies
|
38 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
39 |
+
uv venv --python 3.10 && \
|
40 |
+
uv sync --extra cpu
|
41 |
+
|
42 |
+
# Copy project files including models
|
43 |
+
COPY --chown=appuser:appuser api ./api
|
44 |
+
COPY --chown=appuser:appuser web ./web
|
45 |
+
COPY --chown=appuser:appuser docker/scripts/ ./
|
46 |
+
RUN chmod +x ./entrypoint.sh
|
47 |
+
|
48 |
+
# Set environment variables
|
49 |
+
ENV PYTHONUNBUFFERED=1 \
|
50 |
+
PYTHONPATH=/app:/app/api \
|
51 |
+
PATH="/app/.venv/bin:$PATH" \
|
52 |
+
UV_LINK_MODE=copy \
|
53 |
+
USE_GPU=false \
|
54 |
+
PHONEMIZER_ESPEAK_PATH=/usr/bin \
|
55 |
+
PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
|
56 |
+
ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
|
57 |
+
|
58 |
+
ENV DOWNLOAD_MODEL=true
|
59 |
+
# Download model if enabled
|
60 |
+
RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \
|
61 |
+
python download_model.py --output api/src/models/v1_0; \
|
62 |
+
fi
|
63 |
+
|
64 |
+
ENV DEVICE="cpu"
|
65 |
+
# Run FastAPI server through entrypoint.sh
|
66 |
+
CMD ["./entrypoint.sh"]
|
api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Make api directory a Python package
|
api/src/builds/v1_0/config.json
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"istftnet": {
|
3 |
+
"upsample_kernel_sizes": [
|
4 |
+
20,
|
5 |
+
12
|
6 |
+
],
|
7 |
+
"upsample_rates": [
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"gen_istft_hop_size": 5,
|
12 |
+
"gen_istft_n_fft": 20,
|
13 |
+
"resblock_dilation_sizes": [
|
14 |
+
[
|
15 |
+
1,
|
16 |
+
3,
|
17 |
+
5
|
18 |
+
],
|
19 |
+
[
|
20 |
+
1,
|
21 |
+
3,
|
22 |
+
5
|
23 |
+
],
|
24 |
+
[
|
25 |
+
1,
|
26 |
+
3,
|
27 |
+
5
|
28 |
+
]
|
29 |
+
],
|
30 |
+
"resblock_kernel_sizes": [
|
31 |
+
3,
|
32 |
+
7,
|
33 |
+
11
|
34 |
+
],
|
35 |
+
"upsample_initial_channel": 512
|
36 |
+
},
|
37 |
+
"dim_in": 64,
|
38 |
+
"dropout": 0.2,
|
39 |
+
"hidden_dim": 512,
|
40 |
+
"max_conv_dim": 512,
|
41 |
+
"max_dur": 50,
|
42 |
+
"multispeaker": true,
|
43 |
+
"n_layer": 3,
|
44 |
+
"n_mels": 80,
|
45 |
+
"n_token": 178,
|
46 |
+
"style_dim": 128,
|
47 |
+
"text_encoder_kernel_size": 5,
|
48 |
+
"plbert": {
|
49 |
+
"hidden_size": 768,
|
50 |
+
"num_attention_heads": 12,
|
51 |
+
"intermediate_size": 2048,
|
52 |
+
"max_position_embeddings": 512,
|
53 |
+
"num_hidden_layers": 12,
|
54 |
+
"dropout": 0.1
|
55 |
+
},
|
56 |
+
"vocab": {
|
57 |
+
";": 1,
|
58 |
+
":": 2,
|
59 |
+
",": 3,
|
60 |
+
".": 4,
|
61 |
+
"!": 5,
|
62 |
+
"?": 6,
|
63 |
+
"—": 9,
|
64 |
+
"…": 10,
|
65 |
+
"\"": 11,
|
66 |
+
"(": 12,
|
67 |
+
")": 13,
|
68 |
+
"“": 14,
|
69 |
+
"”": 15,
|
70 |
+
" ": 16,
|
71 |
+
"̃": 17,
|
72 |
+
"ʣ": 18,
|
73 |
+
"ʥ": 19,
|
74 |
+
"ʦ": 20,
|
75 |
+
"ʨ": 21,
|
76 |
+
"ᵝ": 22,
|
77 |
+
"ꭧ": 23,
|
78 |
+
"A": 24,
|
79 |
+
"I": 25,
|
80 |
+
"O": 31,
|
81 |
+
"Q": 33,
|
82 |
+
"S": 35,
|
83 |
+
"T": 36,
|
84 |
+
"W": 39,
|
85 |
+
"Y": 41,
|
86 |
+
"ᵊ": 42,
|
87 |
+
"a": 43,
|
88 |
+
"b": 44,
|
89 |
+
"c": 45,
|
90 |
+
"d": 46,
|
91 |
+
"e": 47,
|
92 |
+
"f": 48,
|
93 |
+
"h": 50,
|
94 |
+
"i": 51,
|
95 |
+
"j": 52,
|
96 |
+
"k": 53,
|
97 |
+
"l": 54,
|
98 |
+
"m": 55,
|
99 |
+
"n": 56,
|
100 |
+
"o": 57,
|
101 |
+
"p": 58,
|
102 |
+
"q": 59,
|
103 |
+
"r": 60,
|
104 |
+
"s": 61,
|
105 |
+
"t": 62,
|
106 |
+
"u": 63,
|
107 |
+
"v": 64,
|
108 |
+
"w": 65,
|
109 |
+
"x": 66,
|
110 |
+
"y": 67,
|
111 |
+
"z": 68,
|
112 |
+
"ɑ": 69,
|
113 |
+
"ɐ": 70,
|
114 |
+
"ɒ": 71,
|
115 |
+
"æ": 72,
|
116 |
+
"β": 75,
|
117 |
+
"ɔ": 76,
|
118 |
+
"ɕ": 77,
|
119 |
+
"ç": 78,
|
120 |
+
"ɖ": 80,
|
121 |
+
"ð": 81,
|
122 |
+
"ʤ": 82,
|
123 |
+
"ə": 83,
|
124 |
+
"ɚ": 85,
|
125 |
+
"ɛ": 86,
|
126 |
+
"ɜ": 87,
|
127 |
+
"ɟ": 90,
|
128 |
+
"ɡ": 92,
|
129 |
+
"ɥ": 99,
|
130 |
+
"ɨ": 101,
|
131 |
+
"ɪ": 102,
|
132 |
+
"ʝ": 103,
|
133 |
+
"ɯ": 110,
|
134 |
+
"ɰ": 111,
|
135 |
+
"ŋ": 112,
|
136 |
+
"ɳ": 113,
|
137 |
+
"ɲ": 114,
|
138 |
+
"ɴ": 115,
|
139 |
+
"ø": 116,
|
140 |
+
"ɸ": 118,
|
141 |
+
"θ": 119,
|
142 |
+
"œ": 120,
|
143 |
+
"ɹ": 123,
|
144 |
+
"ɾ": 125,
|
145 |
+
"ɻ": 126,
|
146 |
+
"ʁ": 128,
|
147 |
+
"ɽ": 129,
|
148 |
+
"ʂ": 130,
|
149 |
+
"ʃ": 131,
|
150 |
+
"ʈ": 132,
|
151 |
+
"ʧ": 133,
|
152 |
+
"ʊ": 135,
|
153 |
+
"ʋ": 136,
|
154 |
+
"ʌ": 138,
|
155 |
+
"ɣ": 139,
|
156 |
+
"ɤ": 140,
|
157 |
+
"χ": 142,
|
158 |
+
"ʎ": 143,
|
159 |
+
"ʒ": 147,
|
160 |
+
"ʔ": 148,
|
161 |
+
"ˈ": 156,
|
162 |
+
"ˌ": 157,
|
163 |
+
"ː": 158,
|
164 |
+
"ʰ": 162,
|
165 |
+
"ʲ": 164,
|
166 |
+
"↓": 169,
|
167 |
+
"→": 171,
|
168 |
+
"↗": 172,
|
169 |
+
"↘": 173,
|
170 |
+
"ᵻ": 177
|
171 |
+
}
|
172 |
+
}
|
api/src/core/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .config import settings
|
2 |
+
|
3 |
+
__all__ = ["settings"]
|
api/src/core/config.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pydantic_settings import BaseSettings
|
3 |
+
|
4 |
+
|
5 |
+
class Settings(BaseSettings):
|
6 |
+
# API Settings
|
7 |
+
api_title: str = "Kokoro TTS API"
|
8 |
+
api_description: str = "API for text-to-speech generation using Kokoro"
|
9 |
+
api_version: str = "1.0.0"
|
10 |
+
host: str = "0.0.0.0"
|
11 |
+
port: int = 8880
|
12 |
+
|
13 |
+
# Application Settings
|
14 |
+
output_dir: str = "output"
|
15 |
+
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
16 |
+
default_voice: str = "af_heart"
|
17 |
+
default_voice_code: str | None = (
|
18 |
+
None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
19 |
+
)
|
20 |
+
use_gpu: bool = True # Whether to use GPU acceleration if available
|
21 |
+
device_type: str | None = (
|
22 |
+
None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
|
23 |
+
)
|
24 |
+
allow_local_voice_saving: bool = (
|
25 |
+
False # Whether to allow saving combined voices locally
|
26 |
+
)
|
27 |
+
|
28 |
+
# Container absolute paths
|
29 |
+
model_dir: str = "/app/api/src/models" # Absolute path in container
|
30 |
+
voices_dir: str = "/app/api/src/voices/v1_0" # Absolute path in container
|
31 |
+
|
32 |
+
# Audio Settings
|
33 |
+
sample_rate: int = 24000
|
34 |
+
# Text Processing Settings
|
35 |
+
target_min_tokens: int = 175 # Target minimum tokens per chunk
|
36 |
+
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
37 |
+
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
38 |
+
advanced_text_normalization: bool = True # Preproesses the text before misiki
|
39 |
+
voice_weight_normalization: bool = (
|
40 |
+
True # Normalize the voice weights so they add up to 1
|
41 |
+
)
|
42 |
+
|
43 |
+
gap_trim_ms: int = (
|
44 |
+
1 # Base amount to trim from streaming chunk ends in milliseconds
|
45 |
+
)
|
46 |
+
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
47 |
+
dynamic_gap_trim_padding_char_multiplier: dict[str, float] = {
|
48 |
+
".": 1,
|
49 |
+
"!": 0.9,
|
50 |
+
"?": 1,
|
51 |
+
",": 0.8,
|
52 |
+
}
|
53 |
+
|
54 |
+
# Web Player Settings
|
55 |
+
enable_web_player: bool = True # Whether to serve the web player UI
|
56 |
+
web_player_path: str = "web" # Path to web player static files
|
57 |
+
cors_origins: list[str] = ["*"] # CORS origins for web player
|
58 |
+
cors_enabled: bool = True # Whether to enable CORS
|
59 |
+
|
60 |
+
# Temp File Settings for WEB Ui
|
61 |
+
temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
|
62 |
+
max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
|
63 |
+
max_temp_dir_age_hours: int = 1 # Remove temp files older than 1 hour
|
64 |
+
max_temp_dir_count: int = 3 # Maximum number of temp files to keep
|
65 |
+
|
66 |
+
class Config:
|
67 |
+
env_file = ".env"
|
68 |
+
|
69 |
+
def get_device(self) -> str:
|
70 |
+
"""Get the appropriate device based on settings and availability"""
|
71 |
+
if not self.use_gpu:
|
72 |
+
return "cpu"
|
73 |
+
|
74 |
+
if self.device_type:
|
75 |
+
return self.device_type
|
76 |
+
|
77 |
+
# Auto-detect device
|
78 |
+
if torch.backends.mps.is_available():
|
79 |
+
return "mps"
|
80 |
+
elif torch.cuda.is_available():
|
81 |
+
return "cuda"
|
82 |
+
return "cpu"
|
83 |
+
|
84 |
+
|
85 |
+
settings = Settings()
|
api/src/core/don_quixote.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
In a village of La Mancha, the name of which I have no desire to call
|
2 |
+
to mind, there lived not long since one of those gentlemen that keep a
|
3 |
+
lance in the lance-rack, an old buckler, a lean hack, and a greyhound
|
4 |
+
for coursing. An olla of rather more beef than mutton, a salad on most
|
5 |
+
nights, scraps on Saturdays, lentils on Fridays, and a pigeon or so
|
6 |
+
extra on Sundays, made away with three-quarters of his income. The rest
|
7 |
+
of it went in a doublet of fine cloth and velvet breeches and shoes to
|
8 |
+
match for holidays, while on week-days he made a brave figure in his
|
9 |
+
best homespun.
|
api/src/core/model_config.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Model configuration for Kokoro V1.
|
2 |
+
|
3 |
+
This module provides model-specific configuration settings that complement the application-level
|
4 |
+
settings in config.py. While config.py handles general application settings (API, paths, etc.),
|
5 |
+
this module focuses on memory management and model file paths.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
|
10 |
+
|
11 |
+
class KokoroV1Config(BaseModel):
|
12 |
+
"""Kokoro V1 configuration."""
|
13 |
+
|
14 |
+
languages: list[str] = ["en"]
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
frozen = True
|
18 |
+
|
19 |
+
|
20 |
+
class PyTorchConfig(BaseModel):
|
21 |
+
"""PyTorch backend configuration."""
|
22 |
+
|
23 |
+
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
24 |
+
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
25 |
+
|
26 |
+
class Config:
|
27 |
+
frozen = True
|
28 |
+
|
29 |
+
|
30 |
+
class ModelConfig(BaseModel):
|
31 |
+
"""Kokoro V1 model configuration."""
|
32 |
+
|
33 |
+
# General settings
|
34 |
+
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
35 |
+
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
36 |
+
|
37 |
+
# Model filename
|
38 |
+
pytorch_kokoro_v1_file: str = Field(
|
39 |
+
"v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename"
|
40 |
+
)
|
41 |
+
|
42 |
+
# Backend config
|
43 |
+
pytorch_gpu: PyTorchConfig = Field(default_factory=PyTorchConfig)
|
44 |
+
|
45 |
+
class Config:
|
46 |
+
frozen = True
|
47 |
+
|
48 |
+
|
49 |
+
# Global instance
|
50 |
+
model_config = ModelConfig()
|
api/src/core/openai_mappings.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"tts-1": "kokoro-v1_0",
|
4 |
+
"tts-1-hd": "kokoro-v1_0",
|
5 |
+
"kokoro": "kokoro-v1_0"
|
6 |
+
},
|
7 |
+
"voices": {
|
8 |
+
"alloy": "am_v0adam",
|
9 |
+
"ash": "af_v0nicole",
|
10 |
+
"coral": "bf_v0emma",
|
11 |
+
"echo": "af_v0bella",
|
12 |
+
"fable": "af_sarah",
|
13 |
+
"onyx": "bm_george",
|
14 |
+
"nova": "bf_isabella",
|
15 |
+
"sage": "am_michael",
|
16 |
+
"shimmer": "af_sky"
|
17 |
+
}
|
18 |
+
}
|
api/src/core/paths.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Async file and path operations."""
|
2 |
+
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set
|
8 |
+
|
9 |
+
import aiofiles
|
10 |
+
import aiofiles.os
|
11 |
+
import torch
|
12 |
+
from loguru import logger
|
13 |
+
|
14 |
+
from .config import settings
|
15 |
+
|
16 |
+
|
17 |
+
async def _find_file(
|
18 |
+
filename: str,
|
19 |
+
search_paths: List[str],
|
20 |
+
filter_fn: Optional[Callable[[str], bool]] = None,
|
21 |
+
) -> str:
|
22 |
+
"""Find file in search paths.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
filename: Name of file to find
|
26 |
+
search_paths: List of paths to search in
|
27 |
+
filter_fn: Optional function to filter files
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Absolute path to file
|
31 |
+
|
32 |
+
Raises:
|
33 |
+
RuntimeError: If file not found
|
34 |
+
"""
|
35 |
+
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
|
36 |
+
return filename
|
37 |
+
|
38 |
+
for path in search_paths:
|
39 |
+
full_path = os.path.join(path, filename)
|
40 |
+
if await aiofiles.os.path.exists(full_path):
|
41 |
+
if filter_fn is None or filter_fn(full_path):
|
42 |
+
return full_path
|
43 |
+
|
44 |
+
raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}")
|
45 |
+
|
46 |
+
|
47 |
+
async def _scan_directories(
|
48 |
+
search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None
|
49 |
+
) -> Set[str]:
|
50 |
+
"""Scan directories for files.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
search_paths: List of paths to scan
|
54 |
+
filter_fn: Optional function to filter files
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Set of matching filenames
|
58 |
+
"""
|
59 |
+
results = set()
|
60 |
+
|
61 |
+
for path in search_paths:
|
62 |
+
if not await aiofiles.os.path.exists(path):
|
63 |
+
continue
|
64 |
+
|
65 |
+
try:
|
66 |
+
# Get directory entries first
|
67 |
+
entries = await aiofiles.os.scandir(path)
|
68 |
+
# Then process entries after await completes
|
69 |
+
for entry in entries:
|
70 |
+
if filter_fn is None or filter_fn(entry.name):
|
71 |
+
results.add(entry.name)
|
72 |
+
except Exception as e:
|
73 |
+
logger.warning(f"Error scanning {path}: {e}")
|
74 |
+
|
75 |
+
return results
|
76 |
+
|
77 |
+
|
78 |
+
async def get_model_path(model_name: str) -> str:
|
79 |
+
"""Get path to model file.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
model_name: Name of model file
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Absolute path to model file
|
86 |
+
|
87 |
+
Raises:
|
88 |
+
RuntimeError: If model not found
|
89 |
+
"""
|
90 |
+
# Get api directory path (two levels up from core)
|
91 |
+
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
92 |
+
|
93 |
+
# Construct model directory path relative to api directory
|
94 |
+
model_dir = os.path.join(api_dir, settings.model_dir)
|
95 |
+
|
96 |
+
# Ensure model directory exists
|
97 |
+
os.makedirs(model_dir, exist_ok=True)
|
98 |
+
|
99 |
+
# Search in model directory
|
100 |
+
search_paths = [model_dir]
|
101 |
+
logger.debug(f"Searching for model in path: {model_dir}")
|
102 |
+
|
103 |
+
return await _find_file(model_name, search_paths)
|
104 |
+
|
105 |
+
|
106 |
+
async def get_voice_path(voice_name: str) -> str:
|
107 |
+
"""Get path to voice file.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
voice_name: Name of voice file (without .pt extension)
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Absolute path to voice file
|
114 |
+
|
115 |
+
Raises:
|
116 |
+
RuntimeError: If voice not found
|
117 |
+
"""
|
118 |
+
# Get api directory path
|
119 |
+
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
120 |
+
|
121 |
+
# Construct voice directory path relative to api directory
|
122 |
+
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
123 |
+
|
124 |
+
# Ensure voice directory exists
|
125 |
+
os.makedirs(voice_dir, exist_ok=True)
|
126 |
+
|
127 |
+
voice_file = f"{voice_name}.pt"
|
128 |
+
|
129 |
+
# Search in voice directory/o
|
130 |
+
search_paths = [voice_dir]
|
131 |
+
logger.debug(f"Searching for voice in path: {voice_dir}")
|
132 |
+
|
133 |
+
return await _find_file(voice_file, search_paths)
|
134 |
+
|
135 |
+
|
136 |
+
async def list_voices() -> List[str]:
|
137 |
+
"""List available voice files.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
List of voice names (without .pt extension)
|
141 |
+
"""
|
142 |
+
# Get api directory path
|
143 |
+
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
144 |
+
|
145 |
+
# Construct voice directory path relative to api directory
|
146 |
+
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
147 |
+
|
148 |
+
# Ensure voice directory exists
|
149 |
+
os.makedirs(voice_dir, exist_ok=True)
|
150 |
+
|
151 |
+
# Search in voice directory
|
152 |
+
search_paths = [voice_dir]
|
153 |
+
logger.debug(f"Scanning for voices in path: {voice_dir}")
|
154 |
+
|
155 |
+
def filter_voice_files(name: str) -> bool:
|
156 |
+
return name.endswith(".pt")
|
157 |
+
|
158 |
+
voices = await _scan_directories(search_paths, filter_voice_files)
|
159 |
+
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
160 |
+
|
161 |
+
|
162 |
+
async def load_voice_tensor(
|
163 |
+
voice_path: str, device: str = "cpu", weights_only=False
|
164 |
+
) -> torch.Tensor:
|
165 |
+
"""Load voice tensor from file.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
voice_path: Path to voice file
|
169 |
+
device: Device to load tensor to
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
Voice tensor
|
173 |
+
|
174 |
+
Raises:
|
175 |
+
RuntimeError: If file cannot be read
|
176 |
+
"""
|
177 |
+
try:
|
178 |
+
async with aiofiles.open(voice_path, "rb") as f:
|
179 |
+
data = await f.read()
|
180 |
+
return torch.load(
|
181 |
+
io.BytesIO(data), map_location=device, weights_only=weights_only
|
182 |
+
)
|
183 |
+
except Exception as e:
|
184 |
+
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
185 |
+
|
186 |
+
|
187 |
+
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
|
188 |
+
"""Save voice tensor to file.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
tensor: Voice tensor to save
|
192 |
+
voice_path: Path to save voice file
|
193 |
+
|
194 |
+
Raises:
|
195 |
+
RuntimeError: If file cannot be written
|
196 |
+
"""
|
197 |
+
try:
|
198 |
+
buffer = io.BytesIO()
|
199 |
+
torch.save(tensor, buffer)
|
200 |
+
async with aiofiles.open(voice_path, "wb") as f:
|
201 |
+
await f.write(buffer.getvalue())
|
202 |
+
except Exception as e:
|
203 |
+
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
|
204 |
+
|
205 |
+
|
206 |
+
async def load_json(path: str) -> dict:
|
207 |
+
"""Load JSON file asynchronously.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
path: Path to JSON file
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
Parsed JSON data
|
214 |
+
|
215 |
+
Raises:
|
216 |
+
RuntimeError: If file cannot be read or parsed
|
217 |
+
"""
|
218 |
+
try:
|
219 |
+
async with aiofiles.open(path, "r", encoding="utf-8") as f:
|
220 |
+
content = await f.read()
|
221 |
+
return json.loads(content)
|
222 |
+
except Exception as e:
|
223 |
+
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
|
224 |
+
|
225 |
+
|
226 |
+
async def load_model_weights(path: str, device: str = "cpu") -> dict:
|
227 |
+
"""Load model weights asynchronously.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
path: Path to model file (.pth or .onnx)
|
231 |
+
device: Device to load model to
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
Model weights
|
235 |
+
|
236 |
+
Raises:
|
237 |
+
RuntimeError: If file cannot be read
|
238 |
+
"""
|
239 |
+
try:
|
240 |
+
async with aiofiles.open(path, "rb") as f:
|
241 |
+
data = await f.read()
|
242 |
+
return torch.load(io.BytesIO(data), map_location=device, weights_only=True)
|
243 |
+
except Exception as e:
|
244 |
+
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
|
245 |
+
|
246 |
+
|
247 |
+
async def read_file(path: str) -> str:
|
248 |
+
"""Read text file asynchronously.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
path: Path to file
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
File contents as string
|
255 |
+
|
256 |
+
Raises:
|
257 |
+
RuntimeError: If file cannot be read
|
258 |
+
"""
|
259 |
+
try:
|
260 |
+
async with aiofiles.open(path, "r", encoding="utf-8") as f:
|
261 |
+
return await f.read()
|
262 |
+
except Exception as e:
|
263 |
+
raise RuntimeError(f"Failed to read file {path}: {e}")
|
264 |
+
|
265 |
+
|
266 |
+
async def read_bytes(path: str) -> bytes:
|
267 |
+
"""Read file as bytes asynchronously.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
path: Path to file
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
File contents as bytes
|
274 |
+
|
275 |
+
Raises:
|
276 |
+
RuntimeError: If file cannot be read
|
277 |
+
"""
|
278 |
+
try:
|
279 |
+
async with aiofiles.open(path, "rb") as f:
|
280 |
+
return await f.read()
|
281 |
+
except Exception as e:
|
282 |
+
raise RuntimeError(f"Failed to read file {path}: {e}")
|
283 |
+
|
284 |
+
|
285 |
+
async def get_web_file_path(filename: str) -> str:
|
286 |
+
"""Get path to web static file.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
filename: Name of file in web directory
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
Absolute path to file
|
293 |
+
|
294 |
+
Raises:
|
295 |
+
RuntimeError: If file not found
|
296 |
+
"""
|
297 |
+
# Get project root directory (four levels up from core to get to project root)
|
298 |
+
root_dir = os.path.dirname(
|
299 |
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
300 |
+
)
|
301 |
+
|
302 |
+
# Construct web directory path relative to project root
|
303 |
+
web_dir = os.path.join("/app", settings.web_player_path)
|
304 |
+
|
305 |
+
# Search in web directory
|
306 |
+
search_paths = [web_dir]
|
307 |
+
logger.debug(f"Searching for web file in path: {web_dir}")
|
308 |
+
|
309 |
+
return await _find_file(filename, search_paths)
|
310 |
+
|
311 |
+
|
312 |
+
async def get_content_type(path: str) -> str:
|
313 |
+
"""Get content type for file.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
path: Path to file
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
Content type string
|
320 |
+
"""
|
321 |
+
ext = os.path.splitext(path)[1].lower()
|
322 |
+
return {
|
323 |
+
".html": "text/html",
|
324 |
+
".js": "application/javascript",
|
325 |
+
".css": "text/css",
|
326 |
+
".png": "image/png",
|
327 |
+
".jpg": "image/jpeg",
|
328 |
+
".jpeg": "image/jpeg",
|
329 |
+
".gif": "image/gif",
|
330 |
+
".svg": "image/svg+xml",
|
331 |
+
".ico": "image/x-icon",
|
332 |
+
}.get(ext, "application/octet-stream")
|
333 |
+
|
334 |
+
|
335 |
+
async def verify_model_path(model_path: str) -> bool:
|
336 |
+
"""Verify model file exists at path."""
|
337 |
+
return await aiofiles.os.path.exists(model_path)
|
338 |
+
|
339 |
+
|
340 |
+
async def cleanup_temp_files() -> None:
|
341 |
+
"""Clean up old temp files on startup"""
|
342 |
+
try:
|
343 |
+
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
344 |
+
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
345 |
+
return
|
346 |
+
|
347 |
+
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
348 |
+
for entry in entries:
|
349 |
+
if entry.is_file():
|
350 |
+
stat = await aiofiles.os.stat(entry.path)
|
351 |
+
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
|
352 |
+
if max_age < stat.st_mtime:
|
353 |
+
try:
|
354 |
+
await aiofiles.os.remove(entry.path)
|
355 |
+
logger.info(f"Cleaned up old temp file: {entry.name}")
|
356 |
+
except Exception as e:
|
357 |
+
logger.warning(
|
358 |
+
f"Failed to delete old temp file {entry.name}: {e}"
|
359 |
+
)
|
360 |
+
except Exception as e:
|
361 |
+
logger.warning(f"Error cleaning temp files: {e}")
|
362 |
+
|
363 |
+
|
364 |
+
async def get_temp_file_path(filename: str) -> str:
|
365 |
+
"""Get path to temporary audio file.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
filename: Name of temp file
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
Absolute path to temp file
|
372 |
+
|
373 |
+
Raises:
|
374 |
+
RuntimeError: If temp directory does not exist
|
375 |
+
"""
|
376 |
+
temp_path = os.path.join(settings.temp_file_dir, filename)
|
377 |
+
|
378 |
+
# Ensure temp directory exists
|
379 |
+
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
380 |
+
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
381 |
+
|
382 |
+
return temp_path
|
383 |
+
|
384 |
+
|
385 |
+
async def list_temp_files() -> List[str]:
|
386 |
+
"""List temporary audio files.
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
List of temp file names
|
390 |
+
"""
|
391 |
+
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
392 |
+
return []
|
393 |
+
|
394 |
+
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
395 |
+
return [entry.name for entry in entries if entry.is_file()]
|
396 |
+
|
397 |
+
|
398 |
+
async def get_temp_dir_size() -> int:
|
399 |
+
"""Get total size of temp directory in bytes.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
Size in bytes
|
403 |
+
"""
|
404 |
+
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
405 |
+
return 0
|
406 |
+
|
407 |
+
total = 0
|
408 |
+
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
409 |
+
for entry in entries:
|
410 |
+
if entry.is_file():
|
411 |
+
stat = await aiofiles.os.stat(entry.path)
|
412 |
+
total += stat.st_size
|
413 |
+
return total
|
api/src/inference/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Model inference package."""
|
2 |
+
|
3 |
+
from .base import BaseModelBackend
|
4 |
+
from .kokoro_v1 import KokoroV1
|
5 |
+
from .model_manager import ModelManager, get_manager
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"BaseModelBackend",
|
9 |
+
"ModelManager",
|
10 |
+
"get_manager",
|
11 |
+
"KokoroV1",
|
12 |
+
]
|
api/src/inference/base.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base interface for Kokoro inference."""
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class AudioChunk:
|
11 |
+
"""Class for audio chunks returned by model backends"""
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
audio: np.ndarray,
|
16 |
+
word_timestamps: Optional[List] = [],
|
17 |
+
output: Optional[Union[bytes, np.ndarray]] = b"",
|
18 |
+
):
|
19 |
+
self.audio = audio
|
20 |
+
self.word_timestamps = word_timestamps
|
21 |
+
self.output = output
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def combine(audio_chunk_list: List):
|
25 |
+
output = AudioChunk(
|
26 |
+
audio_chunk_list[0].audio, audio_chunk_list[0].word_timestamps
|
27 |
+
)
|
28 |
+
|
29 |
+
for audio_chunk in audio_chunk_list[1:]:
|
30 |
+
output.audio = np.concatenate(
|
31 |
+
(output.audio, audio_chunk.audio), dtype=np.int16
|
32 |
+
)
|
33 |
+
if output.word_timestamps is not None:
|
34 |
+
output.word_timestamps += audio_chunk.word_timestamps
|
35 |
+
|
36 |
+
return output
|
37 |
+
|
38 |
+
|
39 |
+
class ModelBackend(ABC):
|
40 |
+
"""Abstract base class for model inference backend."""
|
41 |
+
|
42 |
+
@abstractmethod
|
43 |
+
async def load_model(self, path: str) -> None:
|
44 |
+
"""Load model from path.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
path: Path to model file
|
48 |
+
|
49 |
+
Raises:
|
50 |
+
RuntimeError: If model loading fails
|
51 |
+
"""
|
52 |
+
pass
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
async def generate(
|
56 |
+
self,
|
57 |
+
text: str,
|
58 |
+
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
59 |
+
speed: float = 1.0,
|
60 |
+
) -> AsyncGenerator[AudioChunk, None]:
|
61 |
+
"""Generate audio from text.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
text: Input text to synthesize
|
65 |
+
voice: Either a voice path or tuple of (name, tensor/path)
|
66 |
+
speed: Speed multiplier
|
67 |
+
|
68 |
+
Yields:
|
69 |
+
Generated audio chunks
|
70 |
+
|
71 |
+
Raises:
|
72 |
+
RuntimeError: If generation fails
|
73 |
+
"""
|
74 |
+
pass
|
75 |
+
|
76 |
+
@abstractmethod
|
77 |
+
def unload(self) -> None:
|
78 |
+
"""Unload model and free resources."""
|
79 |
+
pass
|
80 |
+
|
81 |
+
@property
|
82 |
+
@abstractmethod
|
83 |
+
def is_loaded(self) -> bool:
|
84 |
+
"""Check if model is loaded.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
True if model is loaded, False otherwise
|
88 |
+
"""
|
89 |
+
pass
|
90 |
+
|
91 |
+
@property
|
92 |
+
@abstractmethod
|
93 |
+
def device(self) -> str:
|
94 |
+
"""Get device model is running on.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Device string ('cpu' or 'cuda')
|
98 |
+
"""
|
99 |
+
pass
|
100 |
+
|
101 |
+
|
102 |
+
class BaseModelBackend(ModelBackend):
|
103 |
+
"""Base implementation of model backend."""
|
104 |
+
|
105 |
+
def __init__(self):
|
106 |
+
"""Initialize base backend."""
|
107 |
+
self._model: Optional[torch.nn.Module] = None
|
108 |
+
self._device: str = "cpu"
|
109 |
+
|
110 |
+
@property
|
111 |
+
def is_loaded(self) -> bool:
|
112 |
+
"""Check if model is loaded."""
|
113 |
+
return self._model is not None
|
114 |
+
|
115 |
+
@property
|
116 |
+
def device(self) -> str:
|
117 |
+
"""Get device model is running on."""
|
118 |
+
return self._device
|
119 |
+
|
120 |
+
def unload(self) -> None:
|
121 |
+
"""Unload model and free resources."""
|
122 |
+
if self._model is not None:
|
123 |
+
del self._model
|
124 |
+
self._model = None
|
125 |
+
if torch.cuda.is_available():
|
126 |
+
torch.cuda.empty_cache()
|
127 |
+
torch.cuda.synchronize()
|
api/src/inference/kokoro_v1.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Clean Kokoro implementation with controlled resource management."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import AsyncGenerator, Dict, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from kokoro import KModel, KPipeline
|
9 |
+
from loguru import logger
|
10 |
+
|
11 |
+
from ..core import paths
|
12 |
+
from ..core.config import settings
|
13 |
+
from ..core.model_config import model_config
|
14 |
+
from ..structures.schemas import WordTimestamp
|
15 |
+
from .base import AudioChunk, BaseModelBackend
|
16 |
+
|
17 |
+
|
18 |
+
class KokoroV1(BaseModelBackend):
|
19 |
+
"""Kokoro backend with controlled resource management."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
"""Initialize backend with environment-based configuration."""
|
23 |
+
super().__init__()
|
24 |
+
# Strictly respect settings.use_gpu
|
25 |
+
self._device = settings.get_device()
|
26 |
+
self._model: Optional[KModel] = None
|
27 |
+
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
28 |
+
|
29 |
+
async def load_model(self, path: str) -> None:
|
30 |
+
"""Load pre-baked model.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
path: Path to model file
|
34 |
+
|
35 |
+
Raises:
|
36 |
+
RuntimeError: If model loading fails
|
37 |
+
"""
|
38 |
+
try:
|
39 |
+
# Get verified model path
|
40 |
+
model_path = await paths.get_model_path(path)
|
41 |
+
config_path = os.path.join(os.path.dirname(model_path), "config.json")
|
42 |
+
|
43 |
+
if not os.path.exists(config_path):
|
44 |
+
raise RuntimeError(f"Config file not found: {config_path}")
|
45 |
+
|
46 |
+
logger.info(f"Loading Kokoro model on {self._device}")
|
47 |
+
logger.info(f"Config path: {config_path}")
|
48 |
+
logger.info(f"Model path: {model_path}")
|
49 |
+
|
50 |
+
# Load model and let KModel handle device mapping
|
51 |
+
self._model = KModel(config=config_path, model=model_path).eval()
|
52 |
+
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
|
53 |
+
if self._device == "mps":
|
54 |
+
logger.info(
|
55 |
+
"Moving model to MPS device with CPU fallback for unsupported operations"
|
56 |
+
)
|
57 |
+
self._model = self._model.to(torch.device("mps"))
|
58 |
+
elif self._device == "cuda":
|
59 |
+
self._model = self._model.cuda()
|
60 |
+
else:
|
61 |
+
self._model = self._model.cpu()
|
62 |
+
|
63 |
+
except FileNotFoundError as e:
|
64 |
+
raise e
|
65 |
+
except Exception as e:
|
66 |
+
raise RuntimeError(f"Failed to load Kokoro model: {e}")
|
67 |
+
|
68 |
+
def _get_pipeline(self, lang_code: str) -> KPipeline:
|
69 |
+
"""Get or create pipeline for language code.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
lang_code: Language code to use
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
KPipeline instance for the language
|
76 |
+
"""
|
77 |
+
if not self._model:
|
78 |
+
raise RuntimeError("Model not loaded")
|
79 |
+
|
80 |
+
if lang_code not in self._pipelines:
|
81 |
+
logger.info(f"Creating new pipeline for language code: {lang_code}")
|
82 |
+
self._pipelines[lang_code] = KPipeline(
|
83 |
+
lang_code=lang_code, model=self._model, device=self._device
|
84 |
+
)
|
85 |
+
return self._pipelines[lang_code]
|
86 |
+
|
87 |
+
async def generate_from_tokens(
|
88 |
+
self,
|
89 |
+
tokens: str,
|
90 |
+
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
91 |
+
speed: float = 1.0,
|
92 |
+
lang_code: Optional[str] = None,
|
93 |
+
) -> AsyncGenerator[np.ndarray, None]:
|
94 |
+
"""Generate audio from phoneme tokens.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
tokens: Input phoneme tokens to synthesize
|
98 |
+
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
99 |
+
speed: Speed multiplier
|
100 |
+
lang_code: Optional language code override
|
101 |
+
|
102 |
+
Yields:
|
103 |
+
Generated audio chunks
|
104 |
+
|
105 |
+
Raises:
|
106 |
+
RuntimeError: If generation fails
|
107 |
+
"""
|
108 |
+
if not self.is_loaded:
|
109 |
+
raise RuntimeError("Model not loaded")
|
110 |
+
|
111 |
+
try:
|
112 |
+
# Memory management for GPU
|
113 |
+
if self._device == "cuda":
|
114 |
+
if self._check_memory():
|
115 |
+
self._clear_memory()
|
116 |
+
|
117 |
+
# Handle voice input
|
118 |
+
voice_path: str
|
119 |
+
voice_name: str
|
120 |
+
if isinstance(voice, tuple):
|
121 |
+
voice_name, voice_data = voice
|
122 |
+
if isinstance(voice_data, str):
|
123 |
+
voice_path = voice_data
|
124 |
+
else:
|
125 |
+
# Save tensor to temporary file
|
126 |
+
import tempfile
|
127 |
+
|
128 |
+
temp_dir = tempfile.gettempdir()
|
129 |
+
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
|
130 |
+
# Save tensor with CPU mapping for portability
|
131 |
+
torch.save(voice_data.cpu(), voice_path)
|
132 |
+
else:
|
133 |
+
voice_path = voice
|
134 |
+
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
|
135 |
+
|
136 |
+
# Load voice tensor with proper device mapping
|
137 |
+
voice_tensor = await paths.load_voice_tensor(
|
138 |
+
voice_path, device=self._device
|
139 |
+
)
|
140 |
+
# Save back to a temporary file with proper device mapping
|
141 |
+
import tempfile
|
142 |
+
|
143 |
+
temp_dir = tempfile.gettempdir()
|
144 |
+
temp_path = os.path.join(
|
145 |
+
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
|
146 |
+
)
|
147 |
+
await paths.save_voice_tensor(voice_tensor, temp_path)
|
148 |
+
voice_path = temp_path
|
149 |
+
|
150 |
+
# Use provided lang_code, settings voice code override, or first letter of voice name
|
151 |
+
if lang_code: # api is given priority
|
152 |
+
pipeline_lang_code = lang_code
|
153 |
+
elif settings.default_voice_code: # settings is next priority
|
154 |
+
pipeline_lang_code = settings.default_voice_code
|
155 |
+
else: # voice name is default/fallback
|
156 |
+
pipeline_lang_code = voice_name[0].lower()
|
157 |
+
|
158 |
+
pipeline = self._get_pipeline(pipeline_lang_code)
|
159 |
+
|
160 |
+
logger.debug(
|
161 |
+
f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}{'...' if len(tokens) > 100 else ''}'"
|
162 |
+
)
|
163 |
+
for result in pipeline.generate_from_tokens(
|
164 |
+
tokens=tokens, voice=voice_path, speed=speed, model=self._model
|
165 |
+
):
|
166 |
+
if result.audio is not None:
|
167 |
+
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
168 |
+
yield result.audio.numpy()
|
169 |
+
else:
|
170 |
+
logger.warning("No audio in chunk")
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
logger.error(f"Generation failed: {e}")
|
174 |
+
if (
|
175 |
+
self._device == "cuda"
|
176 |
+
and model_config.pytorch_gpu.retry_on_oom
|
177 |
+
and "out of memory" in str(e).lower()
|
178 |
+
):
|
179 |
+
self._clear_memory()
|
180 |
+
async for chunk in self.generate_from_tokens(
|
181 |
+
tokens, voice, speed, lang_code
|
182 |
+
):
|
183 |
+
yield chunk
|
184 |
+
raise
|
185 |
+
|
186 |
+
async def generate(
|
187 |
+
self,
|
188 |
+
text: str,
|
189 |
+
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
190 |
+
speed: float = 1.0,
|
191 |
+
lang_code: Optional[str] = None,
|
192 |
+
return_timestamps: Optional[bool] = False,
|
193 |
+
) -> AsyncGenerator[AudioChunk, None]:
|
194 |
+
"""Generate audio using model.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
text: Input text to synthesize
|
198 |
+
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
199 |
+
speed: Speed multiplier
|
200 |
+
lang_code: Optional language code override
|
201 |
+
|
202 |
+
Yields:
|
203 |
+
Generated audio chunks
|
204 |
+
|
205 |
+
Raises:
|
206 |
+
RuntimeError: If generation fails
|
207 |
+
"""
|
208 |
+
if not self.is_loaded:
|
209 |
+
raise RuntimeError("Model not loaded")
|
210 |
+
try:
|
211 |
+
# Memory management for GPU
|
212 |
+
if self._device == "cuda":
|
213 |
+
if self._check_memory():
|
214 |
+
self._clear_memory()
|
215 |
+
|
216 |
+
# Handle voice input
|
217 |
+
voice_path: str
|
218 |
+
voice_name: str
|
219 |
+
if isinstance(voice, tuple):
|
220 |
+
voice_name, voice_data = voice
|
221 |
+
if isinstance(voice_data, str):
|
222 |
+
voice_path = voice_data
|
223 |
+
else:
|
224 |
+
# Save tensor to temporary file
|
225 |
+
import tempfile
|
226 |
+
|
227 |
+
temp_dir = tempfile.gettempdir()
|
228 |
+
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
|
229 |
+
# Save tensor with CPU mapping for portability
|
230 |
+
torch.save(voice_data.cpu(), voice_path)
|
231 |
+
else:
|
232 |
+
voice_path = voice
|
233 |
+
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
|
234 |
+
|
235 |
+
# Load voice tensor with proper device mapping
|
236 |
+
voice_tensor = await paths.load_voice_tensor(
|
237 |
+
voice_path, device=self._device
|
238 |
+
)
|
239 |
+
# Save back to a temporary file with proper device mapping
|
240 |
+
import tempfile
|
241 |
+
|
242 |
+
temp_dir = tempfile.gettempdir()
|
243 |
+
temp_path = os.path.join(
|
244 |
+
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
|
245 |
+
)
|
246 |
+
await paths.save_voice_tensor(voice_tensor, temp_path)
|
247 |
+
voice_path = temp_path
|
248 |
+
|
249 |
+
# Use provided lang_code, settings voice code override, or first letter of voice name
|
250 |
+
pipeline_lang_code = (
|
251 |
+
lang_code
|
252 |
+
if lang_code
|
253 |
+
else (
|
254 |
+
settings.default_voice_code
|
255 |
+
if settings.default_voice_code
|
256 |
+
else voice_name[0].lower()
|
257 |
+
)
|
258 |
+
)
|
259 |
+
pipeline = self._get_pipeline(pipeline_lang_code)
|
260 |
+
|
261 |
+
logger.debug(
|
262 |
+
f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}{'...' if len(text) > 100 else ''}'"
|
263 |
+
)
|
264 |
+
for result in pipeline(
|
265 |
+
text, voice=voice_path, speed=speed, model=self._model
|
266 |
+
):
|
267 |
+
if result.audio is not None:
|
268 |
+
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
269 |
+
word_timestamps = None
|
270 |
+
if (
|
271 |
+
return_timestamps
|
272 |
+
and hasattr(result, "tokens")
|
273 |
+
and result.tokens
|
274 |
+
):
|
275 |
+
word_timestamps = []
|
276 |
+
current_offset = 0.0
|
277 |
+
logger.debug(
|
278 |
+
f"Processing chunk timestamps with {len(result.tokens)} tokens"
|
279 |
+
)
|
280 |
+
if result.pred_dur is not None:
|
281 |
+
try:
|
282 |
+
# Add timestamps with offset
|
283 |
+
for token in result.tokens:
|
284 |
+
if not all(
|
285 |
+
hasattr(token, attr)
|
286 |
+
for attr in [
|
287 |
+
"text",
|
288 |
+
"start_ts",
|
289 |
+
"end_ts",
|
290 |
+
]
|
291 |
+
):
|
292 |
+
continue
|
293 |
+
if not token.text or not token.text.strip():
|
294 |
+
continue
|
295 |
+
|
296 |
+
start_time = float(token.start_ts) + current_offset
|
297 |
+
end_time = float(token.end_ts) + current_offset
|
298 |
+
word_timestamps.append(
|
299 |
+
WordTimestamp(
|
300 |
+
word=str(token.text).strip(),
|
301 |
+
start_time=start_time,
|
302 |
+
end_time=end_time,
|
303 |
+
)
|
304 |
+
)
|
305 |
+
logger.debug(
|
306 |
+
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
|
307 |
+
)
|
308 |
+
|
309 |
+
except Exception as e:
|
310 |
+
logger.error(
|
311 |
+
f"Failed to process timestamps for chunk: {e}"
|
312 |
+
)
|
313 |
+
|
314 |
+
yield AudioChunk(
|
315 |
+
result.audio.numpy(), word_timestamps=word_timestamps
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
logger.warning("No audio in chunk")
|
319 |
+
|
320 |
+
except Exception as e:
|
321 |
+
logger.error(f"Generation failed: {e}")
|
322 |
+
if (
|
323 |
+
self._device == "cuda"
|
324 |
+
and model_config.pytorch_gpu.retry_on_oom
|
325 |
+
and "out of memory" in str(e).lower()
|
326 |
+
):
|
327 |
+
self._clear_memory()
|
328 |
+
async for chunk in self.generate(text, voice, speed, lang_code):
|
329 |
+
yield chunk
|
330 |
+
raise
|
331 |
+
|
332 |
+
def _check_memory(self) -> bool:
|
333 |
+
"""Check if memory usage is above threshold."""
|
334 |
+
if self._device == "cuda":
|
335 |
+
memory_gb = torch.cuda.memory_allocated() / 1e9
|
336 |
+
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
337 |
+
# MPS doesn't provide memory management APIs
|
338 |
+
return False
|
339 |
+
|
340 |
+
def _clear_memory(self) -> None:
|
341 |
+
"""Clear device memory."""
|
342 |
+
if self._device == "cuda":
|
343 |
+
torch.cuda.empty_cache()
|
344 |
+
torch.cuda.synchronize()
|
345 |
+
elif self._device == "mps":
|
346 |
+
# Empty cache if available (future-proofing)
|
347 |
+
if hasattr(torch.mps, "empty_cache"):
|
348 |
+
torch.mps.empty_cache()
|
349 |
+
|
350 |
+
def unload(self) -> None:
|
351 |
+
"""Unload model and free resources."""
|
352 |
+
if self._model is not None:
|
353 |
+
del self._model
|
354 |
+
self._model = None
|
355 |
+
for pipeline in self._pipelines.values():
|
356 |
+
del pipeline
|
357 |
+
self._pipelines.clear()
|
358 |
+
if torch.cuda.is_available():
|
359 |
+
torch.cuda.empty_cache()
|
360 |
+
torch.cuda.synchronize()
|
361 |
+
|
362 |
+
@property
|
363 |
+
def is_loaded(self) -> bool:
|
364 |
+
"""Check if model is loaded."""
|
365 |
+
return self._model is not None
|
366 |
+
|
367 |
+
@property
|
368 |
+
def device(self) -> str:
|
369 |
+
"""Get device model is running on."""
|
370 |
+
return self._device
|
api/src/inference/model_manager.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Kokoro V1 model management."""
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from ..core import paths
|
8 |
+
from ..core.config import settings
|
9 |
+
from ..core.model_config import ModelConfig, model_config
|
10 |
+
from .base import BaseModelBackend
|
11 |
+
from .kokoro_v1 import KokoroV1
|
12 |
+
|
13 |
+
|
14 |
+
class ModelManager:
|
15 |
+
"""Manages Kokoro V1 model loading and inference."""
|
16 |
+
|
17 |
+
# Singleton instance
|
18 |
+
_instance = None
|
19 |
+
|
20 |
+
def __init__(self, config: Optional[ModelConfig] = None):
|
21 |
+
"""Initialize manager.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
config: Optional model configuration override
|
25 |
+
"""
|
26 |
+
self._config = config or model_config
|
27 |
+
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
|
28 |
+
self._device: Optional[str] = None
|
29 |
+
|
30 |
+
def _determine_device(self) -> str:
|
31 |
+
"""Determine device based on settings."""
|
32 |
+
return "cuda" if settings.use_gpu else "cpu"
|
33 |
+
|
34 |
+
async def initialize(self) -> None:
|
35 |
+
"""Initialize Kokoro V1 backend."""
|
36 |
+
try:
|
37 |
+
self._device = self._determine_device()
|
38 |
+
logger.info(f"Initializing Kokoro V1 on {self._device}")
|
39 |
+
self._backend = KokoroV1()
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
|
43 |
+
|
44 |
+
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
|
45 |
+
"""Initialize and warm up model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
voice_manager: Voice manager instance for warmup
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Tuple of (device, backend type, voice count)
|
52 |
+
|
53 |
+
Raises:
|
54 |
+
RuntimeError: If initialization fails
|
55 |
+
"""
|
56 |
+
import time
|
57 |
+
|
58 |
+
start = time.perf_counter()
|
59 |
+
|
60 |
+
try:
|
61 |
+
# Initialize backend
|
62 |
+
await self.initialize()
|
63 |
+
|
64 |
+
# Load model
|
65 |
+
model_path = self._config.pytorch_kokoro_v1_file
|
66 |
+
await self.load_model(model_path)
|
67 |
+
|
68 |
+
# Use paths module to get voice path
|
69 |
+
try:
|
70 |
+
voices = await paths.list_voices()
|
71 |
+
voice_path = await paths.get_voice_path(settings.default_voice)
|
72 |
+
|
73 |
+
# Warm up with short text
|
74 |
+
warmup_text = "Warmup text for initialization."
|
75 |
+
# Use default voice name for warmup
|
76 |
+
voice_name = settings.default_voice
|
77 |
+
logger.debug(f"Using default voice '{voice_name}' for warmup")
|
78 |
+
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
|
79 |
+
pass
|
80 |
+
except Exception as e:
|
81 |
+
raise RuntimeError(f"Failed to get default voice: {e}")
|
82 |
+
|
83 |
+
ms = int((time.perf_counter() - start) * 1000)
|
84 |
+
logger.info(f"Warmup completed in {ms}ms")
|
85 |
+
|
86 |
+
return self._device, "kokoro_v1", len(voices)
|
87 |
+
except FileNotFoundError as e:
|
88 |
+
logger.error("""
|
89 |
+
Model files not found! You need to download the Kokoro V1 model:
|
90 |
+
|
91 |
+
1. Download model using the script:
|
92 |
+
python docker/scripts/download_model.py --output api/src/models/v1_0
|
93 |
+
|
94 |
+
2. Or set environment variable in docker-compose:
|
95 |
+
DOWNLOAD_MODEL=true
|
96 |
+
""")
|
97 |
+
exit(0)
|
98 |
+
except Exception as e:
|
99 |
+
raise RuntimeError(f"Warmup failed: {e}")
|
100 |
+
|
101 |
+
def get_backend(self) -> BaseModelBackend:
|
102 |
+
"""Get initialized backend.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
Initialized backend instance
|
106 |
+
|
107 |
+
Raises:
|
108 |
+
RuntimeError: If backend not initialized
|
109 |
+
"""
|
110 |
+
if not self._backend:
|
111 |
+
raise RuntimeError("Backend not initialized")
|
112 |
+
return self._backend
|
113 |
+
|
114 |
+
async def load_model(self, path: str) -> None:
|
115 |
+
"""Load model using initialized backend.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
path: Path to model file
|
119 |
+
|
120 |
+
Raises:
|
121 |
+
RuntimeError: If loading fails
|
122 |
+
"""
|
123 |
+
if not self._backend:
|
124 |
+
raise RuntimeError("Backend not initialized")
|
125 |
+
|
126 |
+
try:
|
127 |
+
await self._backend.load_model(path)
|
128 |
+
except FileNotFoundError as e:
|
129 |
+
raise e
|
130 |
+
except Exception as e:
|
131 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
132 |
+
|
133 |
+
async def generate(self, *args, **kwargs):
|
134 |
+
"""Generate audio using initialized backend.
|
135 |
+
|
136 |
+
Raises:
|
137 |
+
RuntimeError: If generation fails
|
138 |
+
"""
|
139 |
+
if not self._backend:
|
140 |
+
raise RuntimeError("Backend not initialized")
|
141 |
+
|
142 |
+
try:
|
143 |
+
async for chunk in self._backend.generate(*args, **kwargs):
|
144 |
+
yield chunk
|
145 |
+
except Exception as e:
|
146 |
+
raise RuntimeError(f"Generation failed: {e}")
|
147 |
+
|
148 |
+
def unload_all(self) -> None:
|
149 |
+
"""Unload model and free resources."""
|
150 |
+
if self._backend:
|
151 |
+
self._backend.unload()
|
152 |
+
self._backend = None
|
153 |
+
|
154 |
+
@property
|
155 |
+
def current_backend(self) -> str:
|
156 |
+
"""Get current backend type."""
|
157 |
+
return "kokoro_v1"
|
158 |
+
|
159 |
+
|
160 |
+
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
161 |
+
"""Get model manager instance.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
config: Optional configuration override
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
ModelManager instance
|
168 |
+
"""
|
169 |
+
if ModelManager._instance is None:
|
170 |
+
ModelManager._instance = ModelManager(config)
|
171 |
+
return ModelManager._instance
|
api/src/inference/voice_manager.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Voice management with controlled resource handling."""
|
2 |
+
|
3 |
+
from typing import Dict, List, Optional
|
4 |
+
|
5 |
+
import aiofiles
|
6 |
+
import torch
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from ..core import paths
|
10 |
+
from ..core.config import settings
|
11 |
+
|
12 |
+
|
13 |
+
class VoiceManager:
|
14 |
+
"""Manages voice loading and caching with controlled resource usage."""
|
15 |
+
|
16 |
+
# Singleton instance
|
17 |
+
_instance = None
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
"""Initialize voice manager."""
|
21 |
+
# Strictly respect settings.use_gpu
|
22 |
+
self._device = settings.get_device()
|
23 |
+
self._voices: Dict[str, torch.Tensor] = {}
|
24 |
+
|
25 |
+
async def get_voice_path(self, voice_name: str) -> str:
|
26 |
+
"""Get path to voice file.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
voice_name: Name of voice
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Path to voice file
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
RuntimeError: If voice not found
|
36 |
+
"""
|
37 |
+
return await paths.get_voice_path(voice_name)
|
38 |
+
|
39 |
+
async def load_voice(
|
40 |
+
self, voice_name: str, device: Optional[str] = None
|
41 |
+
) -> torch.Tensor:
|
42 |
+
"""Load voice tensor.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
voice_name: Name of voice to load
|
46 |
+
device: Optional override for target device
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Voice tensor
|
50 |
+
|
51 |
+
Raises:
|
52 |
+
RuntimeError: If voice not found
|
53 |
+
"""
|
54 |
+
try:
|
55 |
+
voice_path = await self.get_voice_path(voice_name)
|
56 |
+
target_device = device or self._device
|
57 |
+
voice = await paths.load_voice_tensor(voice_path, target_device)
|
58 |
+
self._voices[voice_name] = voice
|
59 |
+
return voice
|
60 |
+
except Exception as e:
|
61 |
+
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
62 |
+
|
63 |
+
async def combine_voices(
|
64 |
+
self, voices: List[str], device: Optional[str] = None
|
65 |
+
) -> torch.Tensor:
|
66 |
+
"""Combine multiple voices.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
voices: List of voice names to combine
|
70 |
+
device: Optional override for target device
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Combined voice tensor
|
74 |
+
|
75 |
+
Raises:
|
76 |
+
RuntimeError: If any voice not found
|
77 |
+
"""
|
78 |
+
if len(voices) < 2:
|
79 |
+
raise ValueError("Need at least 2 voices to combine")
|
80 |
+
|
81 |
+
target_device = device or self._device
|
82 |
+
voice_tensors = []
|
83 |
+
for name in voices:
|
84 |
+
voice = await self.load_voice(name, target_device)
|
85 |
+
voice_tensors.append(voice)
|
86 |
+
|
87 |
+
combined = torch.mean(torch.stack(voice_tensors), dim=0)
|
88 |
+
return combined
|
89 |
+
|
90 |
+
async def list_voices(self) -> List[str]:
|
91 |
+
"""List available voice names.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
List of voice names
|
95 |
+
"""
|
96 |
+
return await paths.list_voices()
|
97 |
+
|
98 |
+
def cache_info(self) -> Dict[str, int]:
|
99 |
+
"""Get cache statistics.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Dict with cache statistics
|
103 |
+
"""
|
104 |
+
return {"loaded_voices": len(self._voices), "device": self._device}
|
105 |
+
|
106 |
+
|
107 |
+
async def get_manager() -> VoiceManager:
|
108 |
+
"""Get voice manager instance.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
VoiceManager instance
|
112 |
+
"""
|
113 |
+
if VoiceManager._instance is None:
|
114 |
+
VoiceManager._instance = VoiceManager()
|
115 |
+
return VoiceManager._instance
|
api/src/main.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FastAPI OpenAI Compatible API
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
from contextlib import asynccontextmanager
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import uvicorn
|
12 |
+
from fastapi import FastAPI
|
13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
14 |
+
from loguru import logger
|
15 |
+
|
16 |
+
from .core.config import settings
|
17 |
+
from .routers.debug import router as debug_router
|
18 |
+
from .routers.development import router as dev_router
|
19 |
+
from .routers.openai_compatible import router as openai_router
|
20 |
+
from .routers.web_player import router as web_router
|
21 |
+
|
22 |
+
|
23 |
+
def setup_logger():
|
24 |
+
"""Configure loguru logger with custom formatting"""
|
25 |
+
config = {
|
26 |
+
"handlers": [
|
27 |
+
{
|
28 |
+
"sink": sys.stdout,
|
29 |
+
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
|
30 |
+
"{level: <8} | "
|
31 |
+
"<fg #4169E1>{module}:{line}</fg #4169E1> | "
|
32 |
+
"{message}",
|
33 |
+
"colorize": True,
|
34 |
+
"level": "DEBUG",
|
35 |
+
},
|
36 |
+
],
|
37 |
+
}
|
38 |
+
logger.remove()
|
39 |
+
logger.configure(**config)
|
40 |
+
logger.level("ERROR", color="<red>")
|
41 |
+
|
42 |
+
|
43 |
+
# Configure logger
|
44 |
+
setup_logger()
|
45 |
+
|
46 |
+
|
47 |
+
@asynccontextmanager
|
48 |
+
async def lifespan(app: FastAPI):
|
49 |
+
"""Lifespan context manager for model initialization"""
|
50 |
+
from .inference.model_manager import get_manager
|
51 |
+
from .inference.voice_manager import get_manager as get_voice_manager
|
52 |
+
from .services.temp_manager import cleanup_temp_files
|
53 |
+
|
54 |
+
# Clean old temp files on startup
|
55 |
+
await cleanup_temp_files()
|
56 |
+
|
57 |
+
logger.info("Loading TTS model and voice packs...")
|
58 |
+
|
59 |
+
try:
|
60 |
+
# Initialize managers
|
61 |
+
model_manager = await get_manager()
|
62 |
+
voice_manager = await get_voice_manager()
|
63 |
+
|
64 |
+
# Initialize model with warmup and get status
|
65 |
+
device, model, voicepack_count = await model_manager.initialize_with_warmup(
|
66 |
+
voice_manager
|
67 |
+
)
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(f"Failed to initialize model: {e}")
|
71 |
+
raise
|
72 |
+
|
73 |
+
boundary = "░" * 2 * 12
|
74 |
+
startup_msg = f"""
|
75 |
+
|
76 |
+
{boundary}
|
77 |
+
|
78 |
+
╔═╗┌─┐┌─┐┌┬┐
|
79 |
+
╠╣ ├─┤└─┐ │
|
80 |
+
╚ ┴ ┴└─┘ ┴
|
81 |
+
╦╔═┌─┐┬┌─┌─┐
|
82 |
+
╠╩╗│ │├┴┐│ │
|
83 |
+
╩ ╩└─┘┴ ┴└─┘
|
84 |
+
|
85 |
+
{boundary}
|
86 |
+
"""
|
87 |
+
startup_msg += f"\nModel warmed up on {device}: {model}"
|
88 |
+
if device == "mps":
|
89 |
+
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
|
90 |
+
elif device == "cuda":
|
91 |
+
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
|
92 |
+
else:
|
93 |
+
startup_msg += "\nRunning on CPU"
|
94 |
+
startup_msg += f"\n{voicepack_count} voice packs loaded"
|
95 |
+
|
96 |
+
# Add web player info if enabled
|
97 |
+
if settings.enable_web_player:
|
98 |
+
startup_msg += (
|
99 |
+
f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
|
100 |
+
)
|
101 |
+
startup_msg += f"\nor http://localhost:{settings.port}/web/"
|
102 |
+
else:
|
103 |
+
startup_msg += "\n\nWeb Player: disabled"
|
104 |
+
|
105 |
+
startup_msg += f"\n{boundary}\n"
|
106 |
+
logger.info(startup_msg)
|
107 |
+
|
108 |
+
yield
|
109 |
+
|
110 |
+
|
111 |
+
# Initialize FastAPI app
|
112 |
+
app = FastAPI(
|
113 |
+
title=settings.api_title,
|
114 |
+
description=settings.api_description,
|
115 |
+
version=settings.api_version,
|
116 |
+
lifespan=lifespan,
|
117 |
+
openapi_url="/openapi.json", # Explicitly enable OpenAPI schema
|
118 |
+
)
|
119 |
+
|
120 |
+
# Add CORS middleware if enabled
|
121 |
+
if settings.cors_enabled:
|
122 |
+
app.add_middleware(
|
123 |
+
CORSMiddleware,
|
124 |
+
allow_origins=settings.cors_origins,
|
125 |
+
allow_credentials=True,
|
126 |
+
allow_methods=["*"],
|
127 |
+
allow_headers=["*"],
|
128 |
+
)
|
129 |
+
|
130 |
+
# Include routers
|
131 |
+
app.include_router(openai_router, prefix="/v1")
|
132 |
+
app.include_router(dev_router) # Development endpoints
|
133 |
+
app.include_router(debug_router) # Debug endpoints
|
134 |
+
if settings.enable_web_player:
|
135 |
+
app.include_router(web_router, prefix="/web") # Web player static files
|
136 |
+
|
137 |
+
|
138 |
+
# Health check endpoint
|
139 |
+
@app.get("/health")
|
140 |
+
async def health_check():
|
141 |
+
"""Health check endpoint"""
|
142 |
+
return {"status": "healthy"}
|
143 |
+
|
144 |
+
|
145 |
+
@app.get("/v1/test")
|
146 |
+
async def test_endpoint():
|
147 |
+
"""Test endpoint to verify routing"""
|
148 |
+
return {"status": "ok"}
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True)
|
api/src/models/v1_0/config.json
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"istftnet": {
|
3 |
+
"upsample_kernel_sizes": [20, 12],
|
4 |
+
"upsample_rates": [10, 6],
|
5 |
+
"gen_istft_hop_size": 5,
|
6 |
+
"gen_istft_n_fft": 20,
|
7 |
+
"resblock_dilation_sizes": [
|
8 |
+
[1, 3, 5],
|
9 |
+
[1, 3, 5],
|
10 |
+
[1, 3, 5]
|
11 |
+
],
|
12 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
13 |
+
"upsample_initial_channel": 512
|
14 |
+
},
|
15 |
+
"dim_in": 64,
|
16 |
+
"dropout": 0.2,
|
17 |
+
"hidden_dim": 512,
|
18 |
+
"max_conv_dim": 512,
|
19 |
+
"max_dur": 50,
|
20 |
+
"multispeaker": true,
|
21 |
+
"n_layer": 3,
|
22 |
+
"n_mels": 80,
|
23 |
+
"n_token": 178,
|
24 |
+
"style_dim": 128,
|
25 |
+
"text_encoder_kernel_size": 5,
|
26 |
+
"plbert": {
|
27 |
+
"hidden_size": 768,
|
28 |
+
"num_attention_heads": 12,
|
29 |
+
"intermediate_size": 2048,
|
30 |
+
"max_position_embeddings": 512,
|
31 |
+
"num_hidden_layers": 12,
|
32 |
+
"dropout": 0.1
|
33 |
+
},
|
34 |
+
"vocab": {
|
35 |
+
";": 1,
|
36 |
+
":": 2,
|
37 |
+
",": 3,
|
38 |
+
".": 4,
|
39 |
+
"!": 5,
|
40 |
+
"?": 6,
|
41 |
+
"—": 9,
|
42 |
+
"…": 10,
|
43 |
+
"\"": 11,
|
44 |
+
"(": 12,
|
45 |
+
")": 13,
|
46 |
+
"“": 14,
|
47 |
+
"”": 15,
|
48 |
+
" ": 16,
|
49 |
+
"\u0303": 17,
|
50 |
+
"ʣ": 18,
|
51 |
+
"ʥ": 19,
|
52 |
+
"ʦ": 20,
|
53 |
+
"ʨ": 21,
|
54 |
+
"ᵝ": 22,
|
55 |
+
"\uAB67": 23,
|
56 |
+
"A": 24,
|
57 |
+
"I": 25,
|
58 |
+
"O": 31,
|
59 |
+
"Q": 33,
|
60 |
+
"S": 35,
|
61 |
+
"T": 36,
|
62 |
+
"W": 39,
|
63 |
+
"Y": 41,
|
64 |
+
"ᵊ": 42,
|
65 |
+
"a": 43,
|
66 |
+
"b": 44,
|
67 |
+
"c": 45,
|
68 |
+
"d": 46,
|
69 |
+
"e": 47,
|
70 |
+
"f": 48,
|
71 |
+
"h": 50,
|
72 |
+
"i": 51,
|
73 |
+
"j": 52,
|
74 |
+
"k": 53,
|
75 |
+
"l": 54,
|
76 |
+
"m": 55,
|
77 |
+
"n": 56,
|
78 |
+
"o": 57,
|
79 |
+
"p": 58,
|
80 |
+
"q": 59,
|
81 |
+
"r": 60,
|
82 |
+
"s": 61,
|
83 |
+
"t": 62,
|
84 |
+
"u": 63,
|
85 |
+
"v": 64,
|
86 |
+
"w": 65,
|
87 |
+
"x": 66,
|
88 |
+
"y": 67,
|
89 |
+
"z": 68,
|
90 |
+
"ɑ": 69,
|
91 |
+
"ɐ": 70,
|
92 |
+
"ɒ": 71,
|
93 |
+
"æ": 72,
|
94 |
+
"β": 75,
|
95 |
+
"ɔ": 76,
|
96 |
+
"ɕ": 77,
|
97 |
+
"ç": 78,
|
98 |
+
"ɖ": 80,
|
99 |
+
"ð": 81,
|
100 |
+
"ʤ": 82,
|
101 |
+
"ə": 83,
|
102 |
+
"ɚ": 85,
|
103 |
+
"ɛ": 86,
|
104 |
+
"ɜ": 87,
|
105 |
+
"ɟ": 90,
|
106 |
+
"ɡ": 92,
|
107 |
+
"ɥ": 99,
|
108 |
+
"ɨ": 101,
|
109 |
+
"ɪ": 102,
|
110 |
+
"ʝ": 103,
|
111 |
+
"ɯ": 110,
|
112 |
+
"ɰ": 111,
|
113 |
+
"ŋ": 112,
|
114 |
+
"ɳ": 113,
|
115 |
+
"ɲ": 114,
|
116 |
+
"ɴ": 115,
|
117 |
+
"ø": 116,
|
118 |
+
"ɸ": 118,
|
119 |
+
"θ": 119,
|
120 |
+
"œ": 120,
|
121 |
+
"ɹ": 123,
|
122 |
+
"ɾ": 125,
|
123 |
+
"ɻ": 126,
|
124 |
+
"ʁ": 128,
|
125 |
+
"ɽ": 129,
|
126 |
+
"ʂ": 130,
|
127 |
+
"ʃ": 131,
|
128 |
+
"ʈ": 132,
|
129 |
+
"ʧ": 133,
|
130 |
+
"ʊ": 135,
|
131 |
+
"ʋ": 136,
|
132 |
+
"ʌ": 138,
|
133 |
+
"ɣ": 139,
|
134 |
+
"ɤ": 140,
|
135 |
+
"χ": 142,
|
136 |
+
"ʎ": 143,
|
137 |
+
"ʒ": 147,
|
138 |
+
"ʔ": 148,
|
139 |
+
"ˈ": 156,
|
140 |
+
"ˌ": 157,
|
141 |
+
"ː": 158,
|
142 |
+
"ʰ": 162,
|
143 |
+
"ʲ": 164,
|
144 |
+
"↓": 169,
|
145 |
+
"→": 171,
|
146 |
+
"↗": 172,
|
147 |
+
"↘": 173,
|
148 |
+
"ᵻ": 177
|
149 |
+
}
|
150 |
+
}
|
api/src/routers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#
|
api/src/routers/debug.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
import psutil
|
6 |
+
import torch
|
7 |
+
from fastapi import APIRouter
|
8 |
+
|
9 |
+
try:
|
10 |
+
import GPUtil
|
11 |
+
|
12 |
+
GPU_AVAILABLE = True
|
13 |
+
except ImportError:
|
14 |
+
GPU_AVAILABLE = False
|
15 |
+
|
16 |
+
router = APIRouter(tags=["debug"])
|
17 |
+
|
18 |
+
|
19 |
+
@router.get("/debug/threads")
|
20 |
+
async def get_thread_info():
|
21 |
+
process = psutil.Process()
|
22 |
+
current_threads = threading.enumerate()
|
23 |
+
|
24 |
+
# Get per-thread CPU times
|
25 |
+
thread_details = []
|
26 |
+
for thread in current_threads:
|
27 |
+
thread_info = {
|
28 |
+
"name": thread.name,
|
29 |
+
"id": thread.ident,
|
30 |
+
"alive": thread.is_alive(),
|
31 |
+
"daemon": thread.daemon,
|
32 |
+
}
|
33 |
+
thread_details.append(thread_info)
|
34 |
+
|
35 |
+
return {
|
36 |
+
"total_threads": process.num_threads(),
|
37 |
+
"active_threads": len(current_threads),
|
38 |
+
"thread_names": [t.name for t in current_threads],
|
39 |
+
"thread_details": thread_details,
|
40 |
+
"memory_mb": process.memory_info().rss / 1024 / 1024,
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
@router.get("/debug/storage")
|
45 |
+
async def get_storage_info():
|
46 |
+
# Get disk partitions
|
47 |
+
partitions = psutil.disk_partitions()
|
48 |
+
storage_info = []
|
49 |
+
|
50 |
+
for partition in partitions:
|
51 |
+
try:
|
52 |
+
usage = psutil.disk_usage(partition.mountpoint)
|
53 |
+
storage_info.append(
|
54 |
+
{
|
55 |
+
"device": partition.device,
|
56 |
+
"mountpoint": partition.mountpoint,
|
57 |
+
"fstype": partition.fstype,
|
58 |
+
"total_gb": usage.total / (1024**3),
|
59 |
+
"used_gb": usage.used / (1024**3),
|
60 |
+
"free_gb": usage.free / (1024**3),
|
61 |
+
"percent_used": usage.percent,
|
62 |
+
}
|
63 |
+
)
|
64 |
+
except PermissionError:
|
65 |
+
continue
|
66 |
+
|
67 |
+
return {"storage_info": storage_info}
|
68 |
+
|
69 |
+
|
70 |
+
@router.get("/debug/system")
|
71 |
+
async def get_system_info():
|
72 |
+
process = psutil.Process()
|
73 |
+
|
74 |
+
# CPU Info
|
75 |
+
cpu_info = {
|
76 |
+
"cpu_count": psutil.cpu_count(),
|
77 |
+
"cpu_percent": psutil.cpu_percent(interval=1),
|
78 |
+
"per_cpu_percent": psutil.cpu_percent(interval=1, percpu=True),
|
79 |
+
"load_avg": psutil.getloadavg(),
|
80 |
+
}
|
81 |
+
|
82 |
+
# Memory Info
|
83 |
+
virtual_memory = psutil.virtual_memory()
|
84 |
+
swap_memory = psutil.swap_memory()
|
85 |
+
memory_info = {
|
86 |
+
"virtual": {
|
87 |
+
"total_gb": virtual_memory.total / (1024**3),
|
88 |
+
"available_gb": virtual_memory.available / (1024**3),
|
89 |
+
"used_gb": virtual_memory.used / (1024**3),
|
90 |
+
"percent": virtual_memory.percent,
|
91 |
+
},
|
92 |
+
"swap": {
|
93 |
+
"total_gb": swap_memory.total / (1024**3),
|
94 |
+
"used_gb": swap_memory.used / (1024**3),
|
95 |
+
"free_gb": swap_memory.free / (1024**3),
|
96 |
+
"percent": swap_memory.percent,
|
97 |
+
},
|
98 |
+
}
|
99 |
+
|
100 |
+
# Process Info
|
101 |
+
process_info = {
|
102 |
+
"pid": process.pid,
|
103 |
+
"status": process.status(),
|
104 |
+
"create_time": datetime.fromtimestamp(process.create_time()).isoformat(),
|
105 |
+
"cpu_percent": process.cpu_percent(),
|
106 |
+
"memory_percent": process.memory_percent(),
|
107 |
+
}
|
108 |
+
|
109 |
+
# Network Info
|
110 |
+
network_info = {
|
111 |
+
"connections": len(process.net_connections()),
|
112 |
+
"network_io": psutil.net_io_counters()._asdict(),
|
113 |
+
}
|
114 |
+
|
115 |
+
# GPU Info if available
|
116 |
+
gpu_info = None
|
117 |
+
if torch.backends.mps.is_available():
|
118 |
+
gpu_info = {
|
119 |
+
"type": "MPS",
|
120 |
+
"available": True,
|
121 |
+
"device": "Apple Silicon",
|
122 |
+
"backend": "Metal",
|
123 |
+
}
|
124 |
+
elif GPU_AVAILABLE:
|
125 |
+
try:
|
126 |
+
gpus = GPUtil.getGPUs()
|
127 |
+
gpu_info = [
|
128 |
+
{
|
129 |
+
"id": gpu.id,
|
130 |
+
"name": gpu.name,
|
131 |
+
"load": gpu.load,
|
132 |
+
"memory": {
|
133 |
+
"total": gpu.memoryTotal,
|
134 |
+
"used": gpu.memoryUsed,
|
135 |
+
"free": gpu.memoryFree,
|
136 |
+
"percent": (gpu.memoryUsed / gpu.memoryTotal) * 100,
|
137 |
+
},
|
138 |
+
"temperature": gpu.temperature,
|
139 |
+
}
|
140 |
+
for gpu in gpus
|
141 |
+
]
|
142 |
+
except Exception:
|
143 |
+
gpu_info = "GPU information unavailable"
|
144 |
+
|
145 |
+
return {
|
146 |
+
"cpu": cpu_info,
|
147 |
+
"memory": memory_info,
|
148 |
+
"process": process_info,
|
149 |
+
"network": network_info,
|
150 |
+
"gpu": gpu_info,
|
151 |
+
}
|
152 |
+
|
153 |
+
|
154 |
+
@router.get("/debug/session_pools")
|
155 |
+
async def get_session_pool_info():
|
156 |
+
"""Get information about ONNX session pools."""
|
157 |
+
from ..inference.model_manager import get_manager
|
158 |
+
|
159 |
+
manager = await get_manager()
|
160 |
+
pools = manager._session_pools
|
161 |
+
current_time = time.time()
|
162 |
+
|
163 |
+
pool_info = {}
|
164 |
+
|
165 |
+
# Get CPU pool info
|
166 |
+
if "onnx_cpu" in pools:
|
167 |
+
cpu_pool = pools["onnx_cpu"]
|
168 |
+
pool_info["cpu"] = {
|
169 |
+
"active_sessions": len(cpu_pool._sessions),
|
170 |
+
"max_sessions": cpu_pool._max_size,
|
171 |
+
"sessions": [
|
172 |
+
{"model": path, "age_seconds": current_time - info.last_used}
|
173 |
+
for path, info in cpu_pool._sessions.items()
|
174 |
+
],
|
175 |
+
}
|
176 |
+
|
177 |
+
# Get GPU pool info
|
178 |
+
if "onnx_gpu" in pools:
|
179 |
+
gpu_pool = pools["onnx_gpu"]
|
180 |
+
pool_info["gpu"] = {
|
181 |
+
"active_sessions": len(gpu_pool._sessions),
|
182 |
+
"max_streams": gpu_pool._max_size,
|
183 |
+
"available_streams": len(gpu_pool._available_streams),
|
184 |
+
"sessions": [
|
185 |
+
{
|
186 |
+
"model": path,
|
187 |
+
"age_seconds": current_time - info.last_used,
|
188 |
+
"stream_id": info.stream_id,
|
189 |
+
}
|
190 |
+
for path, info in gpu_pool._sessions.items()
|
191 |
+
],
|
192 |
+
}
|
193 |
+
|
194 |
+
# Add GPU memory info if available
|
195 |
+
if GPU_AVAILABLE:
|
196 |
+
try:
|
197 |
+
gpus = GPUtil.getGPUs()
|
198 |
+
if gpus:
|
199 |
+
gpu = gpus[0] # Assume first GPU
|
200 |
+
pool_info["gpu"]["memory"] = {
|
201 |
+
"total_mb": gpu.memoryTotal,
|
202 |
+
"used_mb": gpu.memoryUsed,
|
203 |
+
"free_mb": gpu.memoryFree,
|
204 |
+
"percent_used": (gpu.memoryUsed / gpu.memoryTotal) * 100,
|
205 |
+
}
|
206 |
+
except Exception:
|
207 |
+
pass
|
208 |
+
|
209 |
+
return pool_info
|
api/src/routers/development.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import AsyncGenerator, List, Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
11 |
+
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
12 |
+
from kokoro import KPipeline
|
13 |
+
from loguru import logger
|
14 |
+
|
15 |
+
from ..core.config import settings
|
16 |
+
from ..inference.base import AudioChunk
|
17 |
+
from ..services.audio import AudioNormalizer, AudioService
|
18 |
+
from ..services.streaming_audio_writer import StreamingAudioWriter
|
19 |
+
from ..services.temp_manager import TempFileWriter
|
20 |
+
from ..services.text_processing import smart_split
|
21 |
+
from ..services.tts_service import TTSService
|
22 |
+
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
23 |
+
from ..structures.custom_responses import JSONStreamingResponse
|
24 |
+
from ..structures.text_schemas import (
|
25 |
+
GenerateFromPhonemesRequest,
|
26 |
+
PhonemeRequest,
|
27 |
+
PhonemeResponse,
|
28 |
+
)
|
29 |
+
from .openai_compatible import process_and_validate_voices, stream_audio_chunks
|
30 |
+
|
31 |
+
router = APIRouter(tags=["text processing"])
|
32 |
+
|
33 |
+
|
34 |
+
async def get_tts_service() -> TTSService:
|
35 |
+
"""Dependency to get TTSService instance"""
|
36 |
+
return (
|
37 |
+
await TTSService.create()
|
38 |
+
) # Create service with properly initialized managers
|
39 |
+
|
40 |
+
|
41 |
+
@router.post("/dev/phonemize", response_model=PhonemeResponse)
|
42 |
+
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
43 |
+
"""Convert text to phonemes using Kokoro's quiet mode.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
request: Request containing text and language
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Phonemes and token IDs
|
50 |
+
"""
|
51 |
+
try:
|
52 |
+
if not request.text:
|
53 |
+
raise ValueError("Text cannot be empty")
|
54 |
+
|
55 |
+
# Initialize Kokoro pipeline in quiet mode (no model)
|
56 |
+
pipeline = KPipeline(lang_code=request.language, model=False)
|
57 |
+
|
58 |
+
# Get first result from pipeline (we only need one since we're not chunking)
|
59 |
+
for result in pipeline(request.text):
|
60 |
+
# result.graphemes = original text
|
61 |
+
# result.phonemes = phonemized text
|
62 |
+
# result.tokens = token objects (if available)
|
63 |
+
return PhonemeResponse(phonemes=result.phonemes, tokens=[])
|
64 |
+
|
65 |
+
raise ValueError("Failed to generate phonemes")
|
66 |
+
except ValueError as e:
|
67 |
+
logger.error(f"Error in phoneme generation: {str(e)}")
|
68 |
+
raise HTTPException(
|
69 |
+
status_code=500, detail={"error": "Server error", "message": str(e)}
|
70 |
+
)
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f"Error in phoneme generation: {str(e)}")
|
73 |
+
raise HTTPException(
|
74 |
+
status_code=500, detail={"error": "Server error", "message": str(e)}
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
@router.post("/dev/generate_from_phonemes")
|
79 |
+
async def generate_from_phonemes(
|
80 |
+
request: GenerateFromPhonemesRequest,
|
81 |
+
client_request: Request,
|
82 |
+
tts_service: TTSService = Depends(get_tts_service),
|
83 |
+
) -> StreamingResponse:
|
84 |
+
"""Generate audio directly from phonemes using Kokoro's phoneme format"""
|
85 |
+
try:
|
86 |
+
# Basic validation
|
87 |
+
if not isinstance(request.phonemes, str):
|
88 |
+
raise ValueError("Phonemes must be a string")
|
89 |
+
if not request.phonemes:
|
90 |
+
raise ValueError("Phonemes cannot be empty")
|
91 |
+
|
92 |
+
# Create streaming audio writer and normalizer
|
93 |
+
writer = StreamingAudioWriter(format="wav", sample_rate=24000, channels=1)
|
94 |
+
normalizer = AudioNormalizer()
|
95 |
+
|
96 |
+
async def generate_chunks():
|
97 |
+
try:
|
98 |
+
# Generate audio from phonemes
|
99 |
+
chunk_audio, _ = await tts_service.generate_from_phonemes(
|
100 |
+
phonemes=request.phonemes, # Pass complete phoneme string
|
101 |
+
voice=request.voice,
|
102 |
+
speed=1.0,
|
103 |
+
)
|
104 |
+
|
105 |
+
if chunk_audio is not None:
|
106 |
+
# Normalize audio before writing
|
107 |
+
normalized_audio = await normalizer.normalize(chunk_audio)
|
108 |
+
# Write chunk and yield bytes
|
109 |
+
chunk_bytes = writer.write_chunk(normalized_audio)
|
110 |
+
if chunk_bytes:
|
111 |
+
yield chunk_bytes
|
112 |
+
|
113 |
+
# Finalize and yield remaining bytes
|
114 |
+
final_bytes = writer.write_chunk(finalize=True)
|
115 |
+
if final_bytes:
|
116 |
+
yield final_bytes
|
117 |
+
else:
|
118 |
+
raise ValueError("Failed to generate audio data")
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
logger.error(f"Error in audio generation: {str(e)}")
|
122 |
+
# Clean up writer on error
|
123 |
+
writer.close()
|
124 |
+
# Re-raise the original exception
|
125 |
+
raise
|
126 |
+
|
127 |
+
return StreamingResponse(
|
128 |
+
generate_chunks(),
|
129 |
+
media_type="audio/wav",
|
130 |
+
headers={
|
131 |
+
"Content-Disposition": "attachment; filename=speech.wav",
|
132 |
+
"X-Accel-Buffering": "no",
|
133 |
+
"Cache-Control": "no-cache",
|
134 |
+
"Transfer-Encoding": "chunked",
|
135 |
+
},
|
136 |
+
)
|
137 |
+
|
138 |
+
except ValueError as e:
|
139 |
+
logger.error(f"Error generating audio: {str(e)}")
|
140 |
+
raise HTTPException(
|
141 |
+
status_code=400,
|
142 |
+
detail={
|
143 |
+
"error": "validation_error",
|
144 |
+
"message": str(e),
|
145 |
+
"type": "invalid_request_error",
|
146 |
+
},
|
147 |
+
)
|
148 |
+
except Exception as e:
|
149 |
+
logger.error(f"Error generating audio: {str(e)}")
|
150 |
+
raise HTTPException(
|
151 |
+
status_code=500,
|
152 |
+
detail={
|
153 |
+
"error": "processing_error",
|
154 |
+
"message": str(e),
|
155 |
+
"type": "server_error",
|
156 |
+
},
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
@router.post("/dev/captioned_speech")
|
161 |
+
async def create_captioned_speech(
|
162 |
+
request: CaptionedSpeechRequest,
|
163 |
+
client_request: Request,
|
164 |
+
x_raw_response: str = Header(None, alias="x-raw-response"),
|
165 |
+
tts_service: TTSService = Depends(get_tts_service),
|
166 |
+
):
|
167 |
+
"""Generate audio with word-level timestamps using streaming approach"""
|
168 |
+
|
169 |
+
try:
|
170 |
+
# model_name = get_model_name(request.model)
|
171 |
+
tts_service = await get_tts_service()
|
172 |
+
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
173 |
+
|
174 |
+
# Set content type based on format
|
175 |
+
content_type = {
|
176 |
+
"mp3": "audio/mpeg",
|
177 |
+
"opus": "audio/opus",
|
178 |
+
"m4a": "audio/mp4",
|
179 |
+
"flac": "audio/flac",
|
180 |
+
"wav": "audio/wav",
|
181 |
+
"pcm": "audio/pcm",
|
182 |
+
}.get(request.response_format, f"audio/{request.response_format}")
|
183 |
+
|
184 |
+
writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
|
185 |
+
# Check if streaming is requested (default for OpenAI client)
|
186 |
+
if request.stream:
|
187 |
+
# Create generator but don't start it yet
|
188 |
+
generator = stream_audio_chunks(
|
189 |
+
tts_service, request, client_request, writer
|
190 |
+
)
|
191 |
+
|
192 |
+
# If download link requested, wrap generator with temp file writer
|
193 |
+
if request.return_download_link:
|
194 |
+
from ..services.temp_manager import TempFileWriter
|
195 |
+
|
196 |
+
temp_writer = TempFileWriter(request.response_format)
|
197 |
+
await temp_writer.__aenter__() # Initialize temp file
|
198 |
+
|
199 |
+
# Get download path immediately after temp file creation
|
200 |
+
download_path = temp_writer.download_path
|
201 |
+
|
202 |
+
# Create response headers with download path
|
203 |
+
headers = {
|
204 |
+
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
205 |
+
"X-Accel-Buffering": "no",
|
206 |
+
"Cache-Control": "no-cache",
|
207 |
+
"Transfer-Encoding": "chunked",
|
208 |
+
"X-Download-Path": download_path,
|
209 |
+
}
|
210 |
+
|
211 |
+
# Create async generator for streaming
|
212 |
+
async def dual_output():
|
213 |
+
try:
|
214 |
+
# Write chunks to temp file and stream
|
215 |
+
async for chunk_data in generator:
|
216 |
+
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
|
217 |
+
timestamp_acumulator = []
|
218 |
+
|
219 |
+
if chunk_data.output: # Skip empty chunks
|
220 |
+
await temp_writer.write(chunk_data.output)
|
221 |
+
base64_chunk = base64.b64encode(
|
222 |
+
chunk_data.output
|
223 |
+
).decode("utf-8")
|
224 |
+
|
225 |
+
# Add any chunks that may be in the acumulator into the return word_timestamps
|
226 |
+
chunk_data.word_timestamps = (
|
227 |
+
timestamp_acumulator + chunk_data.word_timestamps
|
228 |
+
)
|
229 |
+
timestamp_acumulator = []
|
230 |
+
|
231 |
+
yield CaptionedSpeechResponse(
|
232 |
+
audio=base64_chunk,
|
233 |
+
audio_format=content_type,
|
234 |
+
timestamps=chunk_data.word_timestamps,
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
if (
|
238 |
+
chunk_data.word_timestamps is not None
|
239 |
+
and len(chunk_data.word_timestamps) > 0
|
240 |
+
):
|
241 |
+
timestamp_acumulator += chunk_data.word_timestamps
|
242 |
+
|
243 |
+
# Finalize the temp file
|
244 |
+
await temp_writer.finalize()
|
245 |
+
except Exception as e:
|
246 |
+
logger.error(f"Error in dual output streaming: {e}")
|
247 |
+
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
248 |
+
raise
|
249 |
+
finally:
|
250 |
+
# Ensure temp writer is closed
|
251 |
+
if not temp_writer._finalized:
|
252 |
+
await temp_writer.__aexit__(None, None, None)
|
253 |
+
writer.close()
|
254 |
+
|
255 |
+
# Stream with temp file writing
|
256 |
+
return JSONStreamingResponse(
|
257 |
+
dual_output(), media_type="application/json", headers=headers
|
258 |
+
)
|
259 |
+
|
260 |
+
async def single_output():
|
261 |
+
try:
|
262 |
+
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
|
263 |
+
timestamp_acumulator = []
|
264 |
+
|
265 |
+
# Stream chunks
|
266 |
+
async for chunk_data in generator:
|
267 |
+
if chunk_data.output: # Skip empty chunks
|
268 |
+
# Encode the chunk bytes into base 64
|
269 |
+
base64_chunk = base64.b64encode(chunk_data.output).decode(
|
270 |
+
"utf-8"
|
271 |
+
)
|
272 |
+
|
273 |
+
# Add any chunks that may be in the acumulator into the return word_timestamps
|
274 |
+
if chunk_data.word_timestamps != None:
|
275 |
+
chunk_data.word_timestamps = (
|
276 |
+
timestamp_acumulator + chunk_data.word_timestamps
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
chunk_data.word_timestamps = []
|
280 |
+
timestamp_acumulator = []
|
281 |
+
|
282 |
+
yield CaptionedSpeechResponse(
|
283 |
+
audio=base64_chunk,
|
284 |
+
audio_format=content_type,
|
285 |
+
timestamps=chunk_data.word_timestamps,
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
if (
|
289 |
+
chunk_data.word_timestamps is not None
|
290 |
+
and len(chunk_data.word_timestamps) > 0
|
291 |
+
):
|
292 |
+
timestamp_acumulator += chunk_data.word_timestamps
|
293 |
+
|
294 |
+
except Exception as e:
|
295 |
+
logger.error(f"Error in single output streaming: {e}")
|
296 |
+
writer.close()
|
297 |
+
raise
|
298 |
+
|
299 |
+
# Standard streaming without download link
|
300 |
+
return JSONStreamingResponse(
|
301 |
+
single_output(),
|
302 |
+
media_type="application/json",
|
303 |
+
headers={
|
304 |
+
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
305 |
+
"X-Accel-Buffering": "no",
|
306 |
+
"Cache-Control": "no-cache",
|
307 |
+
"Transfer-Encoding": "chunked",
|
308 |
+
},
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
# Generate complete audio using public interface
|
312 |
+
audio_data = await tts_service.generate_audio(
|
313 |
+
text=request.input,
|
314 |
+
voice=voice_name,
|
315 |
+
writer=writer,
|
316 |
+
speed=request.speed,
|
317 |
+
return_timestamps=request.return_timestamps,
|
318 |
+
normalization_options=request.normalization_options,
|
319 |
+
lang_code=request.lang_code,
|
320 |
+
)
|
321 |
+
|
322 |
+
audio_data = await AudioService.convert_audio(
|
323 |
+
audio_data,
|
324 |
+
request.response_format,
|
325 |
+
writer,
|
326 |
+
is_last_chunk=False,
|
327 |
+
trim_audio=False,
|
328 |
+
)
|
329 |
+
|
330 |
+
# Convert to requested format with proper finalization
|
331 |
+
final = await AudioService.convert_audio(
|
332 |
+
AudioChunk(np.array([], dtype=np.int16)),
|
333 |
+
request.response_format,
|
334 |
+
writer,
|
335 |
+
is_last_chunk=True,
|
336 |
+
)
|
337 |
+
output = audio_data.output + final.output
|
338 |
+
|
339 |
+
base64_output = base64.b64encode(output).decode("utf-8")
|
340 |
+
|
341 |
+
content = CaptionedSpeechResponse(
|
342 |
+
audio=base64_output,
|
343 |
+
audio_format=content_type,
|
344 |
+
timestamps=audio_data.word_timestamps,
|
345 |
+
).model_dump()
|
346 |
+
|
347 |
+
writer.close()
|
348 |
+
|
349 |
+
return JSONResponse(
|
350 |
+
content=content,
|
351 |
+
media_type="application/json",
|
352 |
+
headers={
|
353 |
+
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
354 |
+
"Cache-Control": "no-cache", # Prevent caching
|
355 |
+
},
|
356 |
+
)
|
357 |
+
|
358 |
+
except ValueError as e:
|
359 |
+
# Handle validation errors
|
360 |
+
logger.warning(f"Invalid request: {str(e)}")
|
361 |
+
|
362 |
+
try:
|
363 |
+
writer.close()
|
364 |
+
except:
|
365 |
+
pass
|
366 |
+
|
367 |
+
raise HTTPException(
|
368 |
+
status_code=400,
|
369 |
+
detail={
|
370 |
+
"error": "validation_error",
|
371 |
+
"message": str(e),
|
372 |
+
"type": "invalid_request_error",
|
373 |
+
},
|
374 |
+
)
|
375 |
+
except RuntimeError as e:
|
376 |
+
# Handle runtime/processing errors
|
377 |
+
logger.error(f"Processing error: {str(e)}")
|
378 |
+
|
379 |
+
try:
|
380 |
+
writer.close()
|
381 |
+
except:
|
382 |
+
pass
|
383 |
+
|
384 |
+
raise HTTPException(
|
385 |
+
status_code=500,
|
386 |
+
detail={
|
387 |
+
"error": "processing_error",
|
388 |
+
"message": str(e),
|
389 |
+
"type": "server_error",
|
390 |
+
},
|
391 |
+
)
|
392 |
+
except Exception as e:
|
393 |
+
# Handle unexpected errors
|
394 |
+
logger.error(f"Unexpected error in captioned speech generation: {str(e)}")
|
395 |
+
|
396 |
+
try:
|
397 |
+
writer.close()
|
398 |
+
except:
|
399 |
+
pass
|
400 |
+
|
401 |
+
raise HTTPException(
|
402 |
+
status_code=500,
|
403 |
+
detail={
|
404 |
+
"error": "processing_error",
|
405 |
+
"message": str(e),
|
406 |
+
"type": "server_error",
|
407 |
+
},
|
408 |
+
)
|
api/src/routers/openai_compatible.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""OpenAI-compatible router for text-to-speech"""
|
2 |
+
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import tempfile
|
8 |
+
from typing import AsyncGenerator, Dict, List, Tuple, Union
|
9 |
+
from urllib import response
|
10 |
+
|
11 |
+
import aiofiles
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
15 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
16 |
+
from loguru import logger
|
17 |
+
|
18 |
+
from ..core.config import settings
|
19 |
+
from ..inference.base import AudioChunk
|
20 |
+
from ..services.audio import AudioService
|
21 |
+
from ..services.streaming_audio_writer import StreamingAudioWriter
|
22 |
+
from ..services.tts_service import TTSService
|
23 |
+
from ..structures import OpenAISpeechRequest
|
24 |
+
from ..structures.schemas import CaptionedSpeechRequest
|
25 |
+
|
26 |
+
|
27 |
+
# Load OpenAI mappings
|
28 |
+
def load_openai_mappings() -> Dict:
|
29 |
+
"""Load OpenAI voice and model mappings from JSON"""
|
30 |
+
api_dir = os.path.dirname(os.path.dirname(__file__))
|
31 |
+
mapping_path = os.path.join(api_dir, "core", "openai_mappings.json")
|
32 |
+
try:
|
33 |
+
with open(mapping_path, "r") as f:
|
34 |
+
return json.load(f)
|
35 |
+
except Exception as e:
|
36 |
+
logger.error(f"Failed to load OpenAI mappings: {e}")
|
37 |
+
return {"models": {}, "voices": {}}
|
38 |
+
|
39 |
+
|
40 |
+
# Global mappings
|
41 |
+
_openai_mappings = load_openai_mappings()
|
42 |
+
|
43 |
+
|
44 |
+
router = APIRouter(
|
45 |
+
tags=["OpenAI Compatible TTS"],
|
46 |
+
responses={404: {"description": "Not found"}},
|
47 |
+
)
|
48 |
+
|
49 |
+
# Global TTSService instance with lock
|
50 |
+
_tts_service = None
|
51 |
+
_init_lock = None
|
52 |
+
|
53 |
+
|
54 |
+
async def get_tts_service() -> TTSService:
|
55 |
+
"""Get global TTSService instance"""
|
56 |
+
global _tts_service, _init_lock
|
57 |
+
|
58 |
+
# Create lock if needed
|
59 |
+
if _init_lock is None:
|
60 |
+
import asyncio
|
61 |
+
|
62 |
+
_init_lock = asyncio.Lock()
|
63 |
+
|
64 |
+
# Initialize service if needed
|
65 |
+
if _tts_service is None:
|
66 |
+
async with _init_lock:
|
67 |
+
# Double check pattern
|
68 |
+
if _tts_service is None:
|
69 |
+
_tts_service = await TTSService.create()
|
70 |
+
logger.info("Created global TTSService instance")
|
71 |
+
|
72 |
+
return _tts_service
|
73 |
+
|
74 |
+
|
75 |
+
def get_model_name(model: str) -> str:
|
76 |
+
"""Get internal model name from OpenAI model name"""
|
77 |
+
base_name = _openai_mappings["models"].get(model)
|
78 |
+
if not base_name:
|
79 |
+
raise ValueError(f"Unsupported model: {model}")
|
80 |
+
return base_name + ".pth"
|
81 |
+
|
82 |
+
|
83 |
+
async def process_and_validate_voices(
|
84 |
+
voice_input: Union[str, List[str]], tts_service: TTSService
|
85 |
+
) -> str:
|
86 |
+
"""Process voice input, handling both string and list formats
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Voice name to use (with weights if specified)
|
90 |
+
"""
|
91 |
+
voices = []
|
92 |
+
# Convert input to list of voices
|
93 |
+
if isinstance(voice_input, str):
|
94 |
+
voice_input = voice_input.replace(" ", "").strip()
|
95 |
+
|
96 |
+
if voice_input[-1] in "+-" or voice_input[0] in "+-":
|
97 |
+
raise ValueError(f"Voice combination contains empty combine items")
|
98 |
+
|
99 |
+
if re.search(r"[+-]{2,}", voice_input) is not None:
|
100 |
+
raise ValueError(f"Voice combination contains empty combine items")
|
101 |
+
voices = re.split(r"([-+])", voice_input)
|
102 |
+
else:
|
103 |
+
voices = [[item, "+"] for item in voice_input][:-1]
|
104 |
+
|
105 |
+
available_voices = await tts_service.list_voices()
|
106 |
+
|
107 |
+
for voice_index in range(0, len(voices), 2):
|
108 |
+
mapped_voice = voices[voice_index].split("(")
|
109 |
+
mapped_voice = list(map(str.strip, mapped_voice))
|
110 |
+
|
111 |
+
if len(mapped_voice) > 2:
|
112 |
+
raise ValueError(
|
113 |
+
f"Voice '{voices[voice_index]}' contains too many weight items"
|
114 |
+
)
|
115 |
+
|
116 |
+
if mapped_voice.count(")") > 1:
|
117 |
+
raise ValueError(
|
118 |
+
f"Voice '{voices[voice_index]}' contains too many weight items"
|
119 |
+
)
|
120 |
+
|
121 |
+
mapped_voice[0] = _openai_mappings["voices"].get(
|
122 |
+
mapped_voice[0], mapped_voice[0]
|
123 |
+
)
|
124 |
+
|
125 |
+
if mapped_voice[0] not in available_voices:
|
126 |
+
raise ValueError(
|
127 |
+
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
128 |
+
)
|
129 |
+
|
130 |
+
voices[voice_index] = "(".join(mapped_voice)
|
131 |
+
|
132 |
+
return "".join(voices)
|
133 |
+
|
134 |
+
|
135 |
+
async def stream_audio_chunks(
|
136 |
+
tts_service: TTSService,
|
137 |
+
request: Union[OpenAISpeechRequest, CaptionedSpeechRequest],
|
138 |
+
client_request: Request,
|
139 |
+
writer: StreamingAudioWriter,
|
140 |
+
) -> AsyncGenerator[AudioChunk, None]:
|
141 |
+
"""Stream audio chunks as they're generated with client disconnect handling"""
|
142 |
+
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
143 |
+
unique_properties = {"return_timestamps": False}
|
144 |
+
if hasattr(request, "return_timestamps"):
|
145 |
+
unique_properties["return_timestamps"] = request.return_timestamps
|
146 |
+
|
147 |
+
try:
|
148 |
+
async for chunk_data in tts_service.generate_audio_stream(
|
149 |
+
text=request.input,
|
150 |
+
voice=voice_name,
|
151 |
+
writer=writer,
|
152 |
+
speed=request.speed,
|
153 |
+
output_format=request.response_format,
|
154 |
+
lang_code=request.lang_code,
|
155 |
+
normalization_options=request.normalization_options,
|
156 |
+
return_timestamps=unique_properties["return_timestamps"],
|
157 |
+
):
|
158 |
+
# Check if client is still connected
|
159 |
+
is_disconnected = client_request.is_disconnected
|
160 |
+
if callable(is_disconnected):
|
161 |
+
is_disconnected = await is_disconnected()
|
162 |
+
if is_disconnected:
|
163 |
+
logger.info("Client disconnected, stopping audio generation")
|
164 |
+
break
|
165 |
+
|
166 |
+
yield chunk_data
|
167 |
+
except Exception as e:
|
168 |
+
logger.error(f"Error in audio streaming: {str(e)}")
|
169 |
+
# Let the exception propagate to trigger cleanup
|
170 |
+
raise
|
171 |
+
|
172 |
+
|
173 |
+
@router.post("/audio/speech")
|
174 |
+
async def create_speech(
|
175 |
+
request: OpenAISpeechRequest,
|
176 |
+
client_request: Request,
|
177 |
+
x_raw_response: str = Header(None, alias="x-raw-response"),
|
178 |
+
):
|
179 |
+
"""OpenAI-compatible endpoint for text-to-speech"""
|
180 |
+
# Validate model before processing request
|
181 |
+
if request.model not in _openai_mappings["models"]:
|
182 |
+
raise HTTPException(
|
183 |
+
status_code=400,
|
184 |
+
detail={
|
185 |
+
"error": "invalid_model",
|
186 |
+
"message": f"Unsupported model: {request.model}",
|
187 |
+
"type": "invalid_request_error",
|
188 |
+
},
|
189 |
+
)
|
190 |
+
|
191 |
+
try:
|
192 |
+
# model_name = get_model_name(request.model)
|
193 |
+
tts_service = await get_tts_service()
|
194 |
+
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
195 |
+
|
196 |
+
# Set content type based on format
|
197 |
+
content_type = {
|
198 |
+
"mp3": "audio/mpeg",
|
199 |
+
"opus": "audio/opus",
|
200 |
+
"aac": "audio/aac",
|
201 |
+
"flac": "audio/flac",
|
202 |
+
"wav": "audio/wav",
|
203 |
+
"pcm": "audio/pcm",
|
204 |
+
}.get(request.response_format, f"audio/{request.response_format}")
|
205 |
+
|
206 |
+
writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
|
207 |
+
|
208 |
+
# Check if streaming is requested (default for OpenAI client)
|
209 |
+
if request.stream:
|
210 |
+
# Create generator but don't start it yet
|
211 |
+
generator = stream_audio_chunks(
|
212 |
+
tts_service, request, client_request, writer
|
213 |
+
)
|
214 |
+
|
215 |
+
# If download link requested, wrap generator with temp file writer
|
216 |
+
if request.return_download_link:
|
217 |
+
from ..services.temp_manager import TempFileWriter
|
218 |
+
|
219 |
+
# Use download_format if specified, otherwise use response_format
|
220 |
+
output_format = request.download_format or request.response_format
|
221 |
+
temp_writer = TempFileWriter(output_format)
|
222 |
+
await temp_writer.__aenter__() # Initialize temp file
|
223 |
+
|
224 |
+
# Get download path immediately after temp file creation
|
225 |
+
download_path = temp_writer.download_path
|
226 |
+
|
227 |
+
# Create response headers with download path
|
228 |
+
headers = {
|
229 |
+
"Content-Disposition": f"attachment; filename=speech.{output_format}",
|
230 |
+
"X-Accel-Buffering": "no",
|
231 |
+
"Cache-Control": "no-cache",
|
232 |
+
"Transfer-Encoding": "chunked",
|
233 |
+
"X-Download-Path": download_path,
|
234 |
+
}
|
235 |
+
|
236 |
+
# Add header to indicate if temp file writing is available
|
237 |
+
if temp_writer._write_error:
|
238 |
+
headers["X-Download-Status"] = "unavailable"
|
239 |
+
|
240 |
+
# Create async generator for streaming
|
241 |
+
async def dual_output():
|
242 |
+
try:
|
243 |
+
# Write chunks to temp file and stream
|
244 |
+
async for chunk_data in generator:
|
245 |
+
if chunk_data.output: # Skip empty chunks
|
246 |
+
await temp_writer.write(chunk_data.output)
|
247 |
+
# if return_json:
|
248 |
+
# yield chunk, chunk_data
|
249 |
+
# else:
|
250 |
+
yield chunk_data.output
|
251 |
+
|
252 |
+
# Finalize the temp file
|
253 |
+
await temp_writer.finalize()
|
254 |
+
except Exception as e:
|
255 |
+
logger.error(f"Error in dual output streaming: {e}")
|
256 |
+
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
257 |
+
raise
|
258 |
+
finally:
|
259 |
+
# Ensure temp writer is closed
|
260 |
+
if not temp_writer._finalized:
|
261 |
+
await temp_writer.__aexit__(None, None, None)
|
262 |
+
writer.close()
|
263 |
+
|
264 |
+
# Stream with temp file writing
|
265 |
+
return StreamingResponse(
|
266 |
+
dual_output(), media_type=content_type, headers=headers
|
267 |
+
)
|
268 |
+
|
269 |
+
async def single_output():
|
270 |
+
try:
|
271 |
+
# Stream chunks
|
272 |
+
async for chunk_data in generator:
|
273 |
+
if chunk_data.output: # Skip empty chunks
|
274 |
+
yield chunk_data.output
|
275 |
+
except Exception as e:
|
276 |
+
logger.error(f"Error in single output streaming: {e}")
|
277 |
+
writer.close()
|
278 |
+
raise
|
279 |
+
|
280 |
+
# Standard streaming without download link
|
281 |
+
return StreamingResponse(
|
282 |
+
single_output(),
|
283 |
+
media_type=content_type,
|
284 |
+
headers={
|
285 |
+
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
286 |
+
"X-Accel-Buffering": "no",
|
287 |
+
"Cache-Control": "no-cache",
|
288 |
+
"Transfer-Encoding": "chunked",
|
289 |
+
},
|
290 |
+
)
|
291 |
+
else:
|
292 |
+
headers = {
|
293 |
+
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
294 |
+
"Cache-Control": "no-cache", # Prevent caching
|
295 |
+
}
|
296 |
+
|
297 |
+
# Generate complete audio using public interface
|
298 |
+
audio_data = await tts_service.generate_audio(
|
299 |
+
text=request.input,
|
300 |
+
voice=voice_name,
|
301 |
+
writer=writer,
|
302 |
+
speed=request.speed,
|
303 |
+
normalization_options=request.normalization_options,
|
304 |
+
lang_code=request.lang_code,
|
305 |
+
)
|
306 |
+
|
307 |
+
audio_data = await AudioService.convert_audio(
|
308 |
+
audio_data,
|
309 |
+
request.response_format,
|
310 |
+
writer,
|
311 |
+
is_last_chunk=False,
|
312 |
+
trim_audio=False,
|
313 |
+
)
|
314 |
+
|
315 |
+
# Convert to requested format with proper finalization
|
316 |
+
final = await AudioService.convert_audio(
|
317 |
+
AudioChunk(np.array([], dtype=np.int16)),
|
318 |
+
request.response_format,
|
319 |
+
writer,
|
320 |
+
is_last_chunk=True,
|
321 |
+
)
|
322 |
+
output = audio_data.output + final.output
|
323 |
+
|
324 |
+
if request.return_download_link:
|
325 |
+
from ..services.temp_manager import TempFileWriter
|
326 |
+
|
327 |
+
# Use download_format if specified, otherwise use response_format
|
328 |
+
output_format = request.download_format or request.response_format
|
329 |
+
temp_writer = TempFileWriter(output_format)
|
330 |
+
await temp_writer.__aenter__() # Initialize temp file
|
331 |
+
|
332 |
+
# Get download path immediately after temp file creation
|
333 |
+
download_path = temp_writer.download_path
|
334 |
+
headers["X-Download-Path"] = download_path
|
335 |
+
|
336 |
+
try:
|
337 |
+
# Write chunks to temp file
|
338 |
+
logger.info("Writing chunks to tempory file for download")
|
339 |
+
await temp_writer.write(output)
|
340 |
+
# Finalize the temp file
|
341 |
+
await temp_writer.finalize()
|
342 |
+
|
343 |
+
except Exception as e:
|
344 |
+
logger.error(f"Error in dual output: {e}")
|
345 |
+
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
346 |
+
raise
|
347 |
+
finally:
|
348 |
+
# Ensure temp writer is closed
|
349 |
+
if not temp_writer._finalized:
|
350 |
+
await temp_writer.__aexit__(None, None, None)
|
351 |
+
writer.close()
|
352 |
+
|
353 |
+
return Response(
|
354 |
+
content=output,
|
355 |
+
media_type=content_type,
|
356 |
+
headers=headers,
|
357 |
+
)
|
358 |
+
|
359 |
+
except ValueError as e:
|
360 |
+
# Handle validation errors
|
361 |
+
logger.warning(f"Invalid request: {str(e)}")
|
362 |
+
|
363 |
+
try:
|
364 |
+
writer.close()
|
365 |
+
except:
|
366 |
+
pass
|
367 |
+
|
368 |
+
raise HTTPException(
|
369 |
+
status_code=400,
|
370 |
+
detail={
|
371 |
+
"error": "validation_error",
|
372 |
+
"message": str(e),
|
373 |
+
"type": "invalid_request_error",
|
374 |
+
},
|
375 |
+
)
|
376 |
+
except RuntimeError as e:
|
377 |
+
# Handle runtime/processing errors
|
378 |
+
logger.error(f"Processing error: {str(e)}")
|
379 |
+
|
380 |
+
try:
|
381 |
+
writer.close()
|
382 |
+
except:
|
383 |
+
pass
|
384 |
+
|
385 |
+
raise HTTPException(
|
386 |
+
status_code=500,
|
387 |
+
detail={
|
388 |
+
"error": "processing_error",
|
389 |
+
"message": str(e),
|
390 |
+
"type": "server_error",
|
391 |
+
},
|
392 |
+
)
|
393 |
+
except Exception as e:
|
394 |
+
# Handle unexpected errors
|
395 |
+
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
396 |
+
|
397 |
+
try:
|
398 |
+
writer.close()
|
399 |
+
except:
|
400 |
+
pass
|
401 |
+
|
402 |
+
raise HTTPException(
|
403 |
+
status_code=500,
|
404 |
+
detail={
|
405 |
+
"error": "processing_error",
|
406 |
+
"message": str(e),
|
407 |
+
"type": "server_error",
|
408 |
+
},
|
409 |
+
)
|
410 |
+
|
411 |
+
|
412 |
+
@router.get("/download/{filename}")
|
413 |
+
async def download_audio_file(filename: str):
|
414 |
+
"""Download a generated audio file from temp storage"""
|
415 |
+
try:
|
416 |
+
from ..core.paths import _find_file, get_content_type
|
417 |
+
|
418 |
+
# Search for file in temp directory
|
419 |
+
file_path = await _find_file(
|
420 |
+
filename=filename, search_paths=[settings.temp_file_dir]
|
421 |
+
)
|
422 |
+
|
423 |
+
# Get content type from path helper
|
424 |
+
content_type = await get_content_type(file_path)
|
425 |
+
|
426 |
+
return FileResponse(
|
427 |
+
file_path,
|
428 |
+
media_type=content_type,
|
429 |
+
filename=filename,
|
430 |
+
headers={
|
431 |
+
"Cache-Control": "no-cache",
|
432 |
+
"Content-Disposition": f"attachment; filename={filename}",
|
433 |
+
},
|
434 |
+
)
|
435 |
+
|
436 |
+
except Exception as e:
|
437 |
+
logger.error(f"Error serving download file {filename}: {e}")
|
438 |
+
raise HTTPException(
|
439 |
+
status_code=500,
|
440 |
+
detail={
|
441 |
+
"error": "server_error",
|
442 |
+
"message": "Failed to serve audio file",
|
443 |
+
"type": "server_error",
|
444 |
+
},
|
445 |
+
)
|
446 |
+
|
447 |
+
|
448 |
+
@router.get("/models")
|
449 |
+
async def list_models():
|
450 |
+
"""List all available models"""
|
451 |
+
try:
|
452 |
+
# Create standard model list
|
453 |
+
models = [
|
454 |
+
{
|
455 |
+
"id": "tts-1",
|
456 |
+
"object": "model",
|
457 |
+
"created": 1686935002,
|
458 |
+
"owned_by": "kokoro",
|
459 |
+
},
|
460 |
+
{
|
461 |
+
"id": "tts-1-hd",
|
462 |
+
"object": "model",
|
463 |
+
"created": 1686935002,
|
464 |
+
"owned_by": "kokoro",
|
465 |
+
},
|
466 |
+
{
|
467 |
+
"id": "kokoro",
|
468 |
+
"object": "model",
|
469 |
+
"created": 1686935002,
|
470 |
+
"owned_by": "kokoro",
|
471 |
+
},
|
472 |
+
]
|
473 |
+
|
474 |
+
return {"object": "list", "data": models}
|
475 |
+
except Exception as e:
|
476 |
+
logger.error(f"Error listing models: {str(e)}")
|
477 |
+
raise HTTPException(
|
478 |
+
status_code=500,
|
479 |
+
detail={
|
480 |
+
"error": "server_error",
|
481 |
+
"message": "Failed to retrieve model list",
|
482 |
+
"type": "server_error",
|
483 |
+
},
|
484 |
+
)
|
485 |
+
|
486 |
+
|
487 |
+
@router.get("/models/{model}")
|
488 |
+
async def retrieve_model(model: str):
|
489 |
+
"""Retrieve a specific model"""
|
490 |
+
try:
|
491 |
+
# Define available models
|
492 |
+
models = {
|
493 |
+
"tts-1": {
|
494 |
+
"id": "tts-1",
|
495 |
+
"object": "model",
|
496 |
+
"created": 1686935002,
|
497 |
+
"owned_by": "kokoro",
|
498 |
+
},
|
499 |
+
"tts-1-hd": {
|
500 |
+
"id": "tts-1-hd",
|
501 |
+
"object": "model",
|
502 |
+
"created": 1686935002,
|
503 |
+
"owned_by": "kokoro",
|
504 |
+
},
|
505 |
+
"kokoro": {
|
506 |
+
"id": "kokoro",
|
507 |
+
"object": "model",
|
508 |
+
"created": 1686935002,
|
509 |
+
"owned_by": "kokoro",
|
510 |
+
},
|
511 |
+
}
|
512 |
+
|
513 |
+
# Check if requested model exists
|
514 |
+
if model not in models:
|
515 |
+
raise HTTPException(
|
516 |
+
status_code=404,
|
517 |
+
detail={
|
518 |
+
"error": "model_not_found",
|
519 |
+
"message": f"Model '{model}' not found",
|
520 |
+
"type": "invalid_request_error",
|
521 |
+
},
|
522 |
+
)
|
523 |
+
|
524 |
+
# Return the specific model
|
525 |
+
return models[model]
|
526 |
+
except HTTPException:
|
527 |
+
raise
|
528 |
+
except Exception as e:
|
529 |
+
logger.error(f"Error retrieving model {model}: {str(e)}")
|
530 |
+
raise HTTPException(
|
531 |
+
status_code=500,
|
532 |
+
detail={
|
533 |
+
"error": "server_error",
|
534 |
+
"message": "Failed to retrieve model information",
|
535 |
+
"type": "server_error",
|
536 |
+
},
|
537 |
+
)
|
538 |
+
|
539 |
+
|
540 |
+
@router.get("/audio/voices")
|
541 |
+
async def list_voices():
|
542 |
+
"""List all available voices for text-to-speech"""
|
543 |
+
try:
|
544 |
+
tts_service = await get_tts_service()
|
545 |
+
voices = await tts_service.list_voices()
|
546 |
+
return {"voices": voices}
|
547 |
+
except Exception as e:
|
548 |
+
logger.error(f"Error listing voices: {str(e)}")
|
549 |
+
raise HTTPException(
|
550 |
+
status_code=500,
|
551 |
+
detail={
|
552 |
+
"error": "server_error",
|
553 |
+
"message": "Failed to retrieve voice list",
|
554 |
+
"type": "server_error",
|
555 |
+
},
|
556 |
+
)
|
557 |
+
|
558 |
+
|
559 |
+
@router.post("/audio/voices/combine")
|
560 |
+
async def combine_voices(request: Union[str, List[str]]):
|
561 |
+
"""Combine multiple voices into a new voice and return the .pt file.
|
562 |
+
|
563 |
+
Args:
|
564 |
+
request: Either a string with voices separated by + (e.g. "voice1+voice2")
|
565 |
+
or a list of voice names to combine
|
566 |
+
|
567 |
+
Returns:
|
568 |
+
FileResponse with the combined voice .pt file
|
569 |
+
|
570 |
+
Raises:
|
571 |
+
HTTPException:
|
572 |
+
- 400: Invalid request (wrong number of voices, voice not found)
|
573 |
+
- 500: Server error (file system issues, combination failed)
|
574 |
+
"""
|
575 |
+
# Check if local voice saving is allowed
|
576 |
+
if not settings.allow_local_voice_saving:
|
577 |
+
raise HTTPException(
|
578 |
+
status_code=403,
|
579 |
+
detail={
|
580 |
+
"error": "permission_denied",
|
581 |
+
"message": "Local voice saving is disabled",
|
582 |
+
"type": "permission_error",
|
583 |
+
},
|
584 |
+
)
|
585 |
+
|
586 |
+
try:
|
587 |
+
# Convert input to list of voices
|
588 |
+
if isinstance(request, str):
|
589 |
+
# Check if it's an OpenAI voice name
|
590 |
+
mapped_voice = _openai_mappings["voices"].get(request)
|
591 |
+
if mapped_voice:
|
592 |
+
request = mapped_voice
|
593 |
+
voices = [v.strip() for v in request.split("+") if v.strip()]
|
594 |
+
else:
|
595 |
+
# For list input, map each voice if it's an OpenAI voice name
|
596 |
+
voices = [_openai_mappings["voices"].get(v, v) for v in request]
|
597 |
+
voices = [v.strip() for v in voices if v.strip()]
|
598 |
+
|
599 |
+
if not voices:
|
600 |
+
raise ValueError("No voices provided")
|
601 |
+
|
602 |
+
# For multiple voices, validate base voices exist
|
603 |
+
tts_service = await get_tts_service()
|
604 |
+
available_voices = await tts_service.list_voices()
|
605 |
+
for voice in voices:
|
606 |
+
if voice not in available_voices:
|
607 |
+
raise ValueError(
|
608 |
+
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
609 |
+
)
|
610 |
+
|
611 |
+
# Combine voices
|
612 |
+
combined_tensor = await tts_service.combine_voices(voices=voices)
|
613 |
+
combined_name = "+".join(voices)
|
614 |
+
|
615 |
+
# Save to temp file
|
616 |
+
temp_dir = tempfile.gettempdir()
|
617 |
+
voice_path = os.path.join(temp_dir, f"{combined_name}.pt")
|
618 |
+
buffer = io.BytesIO()
|
619 |
+
torch.save(combined_tensor, buffer)
|
620 |
+
async with aiofiles.open(voice_path, "wb") as f:
|
621 |
+
await f.write(buffer.getvalue())
|
622 |
+
|
623 |
+
return FileResponse(
|
624 |
+
voice_path,
|
625 |
+
media_type="application/octet-stream",
|
626 |
+
filename=f"{combined_name}.pt",
|
627 |
+
headers={
|
628 |
+
"Content-Disposition": f"attachment; filename={combined_name}.pt",
|
629 |
+
"Cache-Control": "no-cache",
|
630 |
+
},
|
631 |
+
)
|
632 |
+
|
633 |
+
except ValueError as e:
|
634 |
+
logger.warning(f"Invalid voice combination request: {str(e)}")
|
635 |
+
raise HTTPException(
|
636 |
+
status_code=400,
|
637 |
+
detail={
|
638 |
+
"error": "validation_error",
|
639 |
+
"message": str(e),
|
640 |
+
"type": "invalid_request_error",
|
641 |
+
},
|
642 |
+
)
|
643 |
+
except RuntimeError as e:
|
644 |
+
logger.error(f"Voice combination processing error: {str(e)}")
|
645 |
+
raise HTTPException(
|
646 |
+
status_code=500,
|
647 |
+
detail={
|
648 |
+
"error": "processing_error",
|
649 |
+
"message": "Failed to process voice combination request",
|
650 |
+
"type": "server_error",
|
651 |
+
},
|
652 |
+
)
|
653 |
+
except Exception as e:
|
654 |
+
logger.error(f"Unexpected error in voice combination: {str(e)}")
|
655 |
+
raise HTTPException(
|
656 |
+
status_code=500,
|
657 |
+
detail={
|
658 |
+
"error": "server_error",
|
659 |
+
"message": "An unexpected error occurred",
|
660 |
+
"type": "server_error",
|
661 |
+
},
|
662 |
+
)
|
api/src/routers/web_player.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Web player router with async file serving."""
|
2 |
+
|
3 |
+
from fastapi import APIRouter, HTTPException
|
4 |
+
from fastapi.responses import Response
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from ..core.config import settings
|
8 |
+
from ..core.paths import get_content_type, get_web_file_path, read_bytes
|
9 |
+
|
10 |
+
router = APIRouter(
|
11 |
+
tags=["Web Player"],
|
12 |
+
responses={404: {"description": "Not found"}},
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
@router.get("/{filename:path}")
|
17 |
+
async def serve_web_file(filename: str):
|
18 |
+
"""Serve web player static files asynchronously."""
|
19 |
+
if not settings.enable_web_player:
|
20 |
+
raise HTTPException(status_code=404, detail="Web player is disabled")
|
21 |
+
|
22 |
+
try:
|
23 |
+
# Default to index.html for root path
|
24 |
+
if filename == "" or filename == "/":
|
25 |
+
filename = "index.html"
|
26 |
+
|
27 |
+
# Get file path
|
28 |
+
file_path = await get_web_file_path(filename)
|
29 |
+
|
30 |
+
# Read file content
|
31 |
+
content = await read_bytes(file_path)
|
32 |
+
|
33 |
+
# Get content type
|
34 |
+
content_type = await get_content_type(file_path)
|
35 |
+
|
36 |
+
return Response(
|
37 |
+
content=content,
|
38 |
+
media_type=content_type,
|
39 |
+
headers={
|
40 |
+
"Cache-Control": "no-cache", # Prevent caching during development
|
41 |
+
},
|
42 |
+
)
|
43 |
+
|
44 |
+
except RuntimeError as e:
|
45 |
+
logger.warning(f"Web file not found: {filename}")
|
46 |
+
raise HTTPException(status_code=404, detail=str(e))
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f"Error serving web file {filename}: {e}")
|
49 |
+
raise HTTPException(status_code=500, detail="Internal server error")
|
api/src/services/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .tts_service import TTSService
|
2 |
+
|
3 |
+
__all__ = ["TTSService"]
|
api/src/services/audio.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Audio conversion service"""
|
2 |
+
|
3 |
+
import math
|
4 |
+
import struct
|
5 |
+
import time
|
6 |
+
from io import BytesIO
|
7 |
+
from typing import Tuple
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import scipy.io.wavfile as wavfile
|
11 |
+
import soundfile as sf
|
12 |
+
from loguru import logger
|
13 |
+
from pydub import AudioSegment
|
14 |
+
from torch import norm
|
15 |
+
|
16 |
+
from ..core.config import settings
|
17 |
+
from ..inference.base import AudioChunk
|
18 |
+
from .streaming_audio_writer import StreamingAudioWriter
|
19 |
+
|
20 |
+
|
21 |
+
class AudioNormalizer:
|
22 |
+
"""Handles audio normalization state for a single stream"""
|
23 |
+
|
24 |
+
def __init__(self):
|
25 |
+
self.chunk_trim_ms = settings.gap_trim_ms
|
26 |
+
self.sample_rate = 24000 # Sample rate of the audio
|
27 |
+
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
28 |
+
self.samples_to_pad_start = int(50 * self.sample_rate / 1000)
|
29 |
+
|
30 |
+
def find_first_last_non_silent(
|
31 |
+
self,
|
32 |
+
audio_data: np.ndarray,
|
33 |
+
chunk_text: str,
|
34 |
+
speed: float,
|
35 |
+
silence_threshold_db: int = -45,
|
36 |
+
is_last_chunk: bool = False,
|
37 |
+
) -> tuple[int, int]:
|
38 |
+
"""Finds the indices of the first and last non-silent samples in audio data.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
audio_data: Input audio data as numpy array
|
42 |
+
chunk_text: The text sent to the model to generate the resulting speech
|
43 |
+
speed: The speaking speed of the voice
|
44 |
+
silence_threshold_db: How quiet audio has to be to be conssidered silent
|
45 |
+
is_last_chunk: Whether this is the last chunk
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
A tuple with the start of the non silent portion and with the end of the non silent portion
|
49 |
+
"""
|
50 |
+
|
51 |
+
pad_multiplier = 1
|
52 |
+
split_character = chunk_text.strip()
|
53 |
+
if len(split_character) > 0:
|
54 |
+
split_character = split_character[-1]
|
55 |
+
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
|
56 |
+
pad_multiplier = settings.dynamic_gap_trim_padding_char_multiplier[
|
57 |
+
split_character
|
58 |
+
]
|
59 |
+
|
60 |
+
if not is_last_chunk:
|
61 |
+
samples_to_pad_end = max(
|
62 |
+
int(
|
63 |
+
(
|
64 |
+
settings.dynamic_gap_trim_padding_ms
|
65 |
+
* self.sample_rate
|
66 |
+
* pad_multiplier
|
67 |
+
)
|
68 |
+
/ 1000
|
69 |
+
)
|
70 |
+
- self.samples_to_pad_start,
|
71 |
+
0,
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
samples_to_pad_end = self.samples_to_pad_start
|
75 |
+
# Convert dBFS threshold to amplitude
|
76 |
+
amplitude_threshold = np.iinfo(audio_data.dtype).max * (
|
77 |
+
10 ** (silence_threshold_db / 20)
|
78 |
+
)
|
79 |
+
# Find the first samples above the silence threshold at the start and end of the audio
|
80 |
+
non_silent_index_start, non_silent_index_end = None, None
|
81 |
+
|
82 |
+
for X in range(0, len(audio_data)):
|
83 |
+
if audio_data[X] > amplitude_threshold:
|
84 |
+
non_silent_index_start = X
|
85 |
+
break
|
86 |
+
|
87 |
+
for X in range(len(audio_data) - 1, -1, -1):
|
88 |
+
if audio_data[X] > amplitude_threshold:
|
89 |
+
non_silent_index_end = X
|
90 |
+
break
|
91 |
+
|
92 |
+
# Handle the case where the entire audio is silent
|
93 |
+
if non_silent_index_start == None or non_silent_index_end == None:
|
94 |
+
return 0, len(audio_data)
|
95 |
+
|
96 |
+
return max(non_silent_index_start - self.samples_to_pad_start, 0), min(
|
97 |
+
non_silent_index_end + math.ceil(samples_to_pad_end / speed),
|
98 |
+
len(audio_data),
|
99 |
+
)
|
100 |
+
|
101 |
+
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
102 |
+
"""Convert audio data to int16 range
|
103 |
+
|
104 |
+
Args:
|
105 |
+
audio_data: Input audio data as numpy array
|
106 |
+
Returns:
|
107 |
+
Normalized audio data
|
108 |
+
"""
|
109 |
+
if audio_data.dtype != np.int16:
|
110 |
+
# Scale directly to int16 range with clipping
|
111 |
+
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
|
112 |
+
return audio_data
|
113 |
+
|
114 |
+
|
115 |
+
class AudioService:
|
116 |
+
"""Service for audio format conversions with streaming support"""
|
117 |
+
|
118 |
+
# Supported formats
|
119 |
+
SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm"}
|
120 |
+
|
121 |
+
# Default audio format settings balanced for speed and compression
|
122 |
+
DEFAULT_SETTINGS = {
|
123 |
+
"mp3": {
|
124 |
+
"bitrate_mode": "CONSTANT", # Faster than variable bitrate
|
125 |
+
"compression_level": 0.0, # Balanced compression
|
126 |
+
},
|
127 |
+
"opus": {
|
128 |
+
"compression_level": 0.0, # Good balance for speech
|
129 |
+
},
|
130 |
+
"flac": {
|
131 |
+
"compression_level": 0.0, # Light compression, still fast
|
132 |
+
},
|
133 |
+
"aac": {
|
134 |
+
"bitrate": "192k", # Default AAC bitrate
|
135 |
+
},
|
136 |
+
}
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
async def convert_audio(
|
140 |
+
audio_chunk: AudioChunk,
|
141 |
+
output_format: str,
|
142 |
+
writer: StreamingAudioWriter,
|
143 |
+
speed: float = 1,
|
144 |
+
chunk_text: str = "",
|
145 |
+
is_last_chunk: bool = False,
|
146 |
+
trim_audio: bool = True,
|
147 |
+
normalizer: AudioNormalizer = None,
|
148 |
+
) -> AudioChunk:
|
149 |
+
"""Convert audio data to specified format with streaming support
|
150 |
+
|
151 |
+
Args:
|
152 |
+
audio_data: Numpy array of audio samples
|
153 |
+
output_format: Target format (wav, mp3, ogg, pcm)
|
154 |
+
writer: The StreamingAudioWriter to use
|
155 |
+
speed: The speaking speed of the voice
|
156 |
+
chunk_text: The text sent to the model to generate the resulting speech
|
157 |
+
is_last_chunk: Whether this is the last chunk
|
158 |
+
trim_audio: Whether audio should be trimmed
|
159 |
+
normalizer: Optional AudioNormalizer instance for consistent normalization
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Bytes of the converted audio chunk
|
163 |
+
"""
|
164 |
+
|
165 |
+
try:
|
166 |
+
# Validate format
|
167 |
+
if output_format not in AudioService.SUPPORTED_FORMATS:
|
168 |
+
raise ValueError(f"Format {output_format} not supported")
|
169 |
+
|
170 |
+
# Always normalize audio to ensure proper amplitude scaling
|
171 |
+
if normalizer is None:
|
172 |
+
normalizer = AudioNormalizer()
|
173 |
+
|
174 |
+
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
|
175 |
+
|
176 |
+
if trim_audio == True:
|
177 |
+
audio_chunk = AudioService.trim_audio(
|
178 |
+
audio_chunk, chunk_text, speed, is_last_chunk, normalizer
|
179 |
+
)
|
180 |
+
|
181 |
+
# Write audio data first
|
182 |
+
if len(audio_chunk.audio) > 0:
|
183 |
+
chunk_data = writer.write_chunk(audio_chunk.audio)
|
184 |
+
|
185 |
+
# Then finalize if this is the last chunk
|
186 |
+
if is_last_chunk:
|
187 |
+
final_data = writer.write_chunk(finalize=True)
|
188 |
+
|
189 |
+
if final_data:
|
190 |
+
audio_chunk.output = final_data
|
191 |
+
return audio_chunk
|
192 |
+
|
193 |
+
if chunk_data:
|
194 |
+
audio_chunk.output = chunk_data
|
195 |
+
return audio_chunk
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
|
199 |
+
raise ValueError(
|
200 |
+
f"Failed to convert audio stream to {output_format}: {str(e)}"
|
201 |
+
)
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def trim_audio(
|
205 |
+
audio_chunk: AudioChunk,
|
206 |
+
chunk_text: str = "",
|
207 |
+
speed: float = 1,
|
208 |
+
is_last_chunk: bool = False,
|
209 |
+
normalizer: AudioNormalizer = None,
|
210 |
+
) -> AudioChunk:
|
211 |
+
"""Trim silence from start and end
|
212 |
+
|
213 |
+
Args:
|
214 |
+
audio_data: Input audio data as numpy array
|
215 |
+
chunk_text: The text sent to the model to generate the resulting speech
|
216 |
+
speed: The speaking speed of the voice
|
217 |
+
is_last_chunk: Whether this is the last chunk
|
218 |
+
normalizer: Optional AudioNormalizer instance for consistent normalization
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
Trimmed audio data
|
222 |
+
"""
|
223 |
+
if normalizer is None:
|
224 |
+
normalizer = AudioNormalizer()
|
225 |
+
|
226 |
+
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
|
227 |
+
|
228 |
+
trimed_samples = 0
|
229 |
+
# Trim start and end if enough samples
|
230 |
+
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
231 |
+
audio_chunk.audio = audio_chunk.audio[
|
232 |
+
normalizer.samples_to_trim : -normalizer.samples_to_trim
|
233 |
+
]
|
234 |
+
trimed_samples += normalizer.samples_to_trim
|
235 |
+
|
236 |
+
# Find non silent portion and trim
|
237 |
+
start_index, end_index = normalizer.find_first_last_non_silent(
|
238 |
+
audio_chunk.audio, chunk_text, speed, is_last_chunk=is_last_chunk
|
239 |
+
)
|
240 |
+
|
241 |
+
audio_chunk.audio = audio_chunk.audio[start_index:end_index]
|
242 |
+
trimed_samples += start_index
|
243 |
+
|
244 |
+
if audio_chunk.word_timestamps is not None:
|
245 |
+
for timestamp in audio_chunk.word_timestamps:
|
246 |
+
timestamp.start_time -= trimed_samples / 24000
|
247 |
+
timestamp.end_time -= trimed_samples / 24000
|
248 |
+
return audio_chunk
|
api/src/services/streaming_audio_writer.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Audio conversion service with proper streaming support"""
|
2 |
+
|
3 |
+
import struct
|
4 |
+
from io import BytesIO
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import av
|
8 |
+
import numpy as np
|
9 |
+
import soundfile as sf
|
10 |
+
from loguru import logger
|
11 |
+
from pydub import AudioSegment
|
12 |
+
|
13 |
+
|
14 |
+
class StreamingAudioWriter:
|
15 |
+
"""Handles streaming audio format conversions"""
|
16 |
+
|
17 |
+
def __init__(self, format: str, sample_rate: int, channels: int = 1):
|
18 |
+
self.format = format.lower()
|
19 |
+
self.sample_rate = sample_rate
|
20 |
+
self.channels = channels
|
21 |
+
self.bytes_written = 0
|
22 |
+
self.pts = 0
|
23 |
+
|
24 |
+
codec_map = {
|
25 |
+
"wav": "pcm_s16le",
|
26 |
+
"mp3": "mp3",
|
27 |
+
"opus": "libopus",
|
28 |
+
"flac": "flac",
|
29 |
+
"aac": "aac",
|
30 |
+
}
|
31 |
+
# Format-specific setup
|
32 |
+
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
|
33 |
+
if self.format != "pcm":
|
34 |
+
self.output_buffer = BytesIO()
|
35 |
+
self.container = av.open(
|
36 |
+
self.output_buffer,
|
37 |
+
mode="w",
|
38 |
+
format=self.format if self.format != "aac" else "adts",
|
39 |
+
)
|
40 |
+
self.stream = self.container.add_stream(
|
41 |
+
codec_map[self.format],
|
42 |
+
sample_rate=self.sample_rate,
|
43 |
+
layout="mono" if self.channels == 1 else "stereo",
|
44 |
+
)
|
45 |
+
self.stream.bit_rate = 128000
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Unsupported format: {format}")
|
48 |
+
|
49 |
+
def close(self):
|
50 |
+
if hasattr(self, "container"):
|
51 |
+
self.container.close()
|
52 |
+
|
53 |
+
if hasattr(self, "output_buffer"):
|
54 |
+
self.output_buffer.close()
|
55 |
+
|
56 |
+
def write_chunk(
|
57 |
+
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
|
58 |
+
) -> bytes:
|
59 |
+
"""Write a chunk of audio data and return bytes in the target format.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
audio_data: Audio data to write, or None if finalizing
|
63 |
+
finalize: Whether this is the final write to close the stream
|
64 |
+
"""
|
65 |
+
|
66 |
+
if finalize:
|
67 |
+
if self.format != "pcm":
|
68 |
+
packets = self.stream.encode(None)
|
69 |
+
for packet in packets:
|
70 |
+
self.container.mux(packet)
|
71 |
+
|
72 |
+
data = self.output_buffer.getvalue()
|
73 |
+
self.close()
|
74 |
+
return data
|
75 |
+
|
76 |
+
if audio_data is None or len(audio_data) == 0:
|
77 |
+
return b""
|
78 |
+
|
79 |
+
if self.format == "pcm":
|
80 |
+
# Write raw bytes
|
81 |
+
return audio_data.tobytes()
|
82 |
+
else:
|
83 |
+
frame = av.AudioFrame.from_ndarray(
|
84 |
+
audio_data.reshape(1, -1),
|
85 |
+
format="s16",
|
86 |
+
layout="mono" if self.channels == 1 else "stereo",
|
87 |
+
)
|
88 |
+
frame.sample_rate = self.sample_rate
|
89 |
+
|
90 |
+
frame.pts = self.pts
|
91 |
+
self.pts += frame.samples
|
92 |
+
|
93 |
+
packets = self.stream.encode(frame)
|
94 |
+
for packet in packets:
|
95 |
+
self.container.mux(packet)
|
96 |
+
|
97 |
+
data = self.output_buffer.getvalue()
|
98 |
+
self.output_buffer.seek(0)
|
99 |
+
self.output_buffer.truncate(0)
|
100 |
+
return data
|
api/src/services/temp_manager.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Temporary file writer for audio downloads"""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import tempfile
|
5 |
+
from typing import List, Optional
|
6 |
+
|
7 |
+
import aiofiles
|
8 |
+
from fastapi import HTTPException
|
9 |
+
from loguru import logger
|
10 |
+
|
11 |
+
from ..core.config import settings
|
12 |
+
|
13 |
+
|
14 |
+
async def cleanup_temp_files() -> None:
|
15 |
+
"""Clean up old temp files"""
|
16 |
+
try:
|
17 |
+
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
18 |
+
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
19 |
+
return
|
20 |
+
|
21 |
+
# Get all temp files with stats
|
22 |
+
files = []
|
23 |
+
total_size = 0
|
24 |
+
|
25 |
+
# Use os.scandir for sync iteration, but aiofiles.os.stat for async stats
|
26 |
+
for entry in os.scandir(settings.temp_file_dir):
|
27 |
+
if entry.is_file():
|
28 |
+
stat = await aiofiles.os.stat(entry.path)
|
29 |
+
files.append((entry.path, stat.st_mtime, stat.st_size))
|
30 |
+
total_size += stat.st_size
|
31 |
+
|
32 |
+
# Sort by modification time (oldest first)
|
33 |
+
files.sort(key=lambda x: x[1])
|
34 |
+
|
35 |
+
# Remove files if:
|
36 |
+
# 1. They're too old
|
37 |
+
# 2. We have too many files
|
38 |
+
# 3. Directory is too large
|
39 |
+
current_time = (await aiofiles.os.stat(settings.temp_file_dir)).st_mtime
|
40 |
+
max_age = settings.max_temp_dir_age_hours * 3600
|
41 |
+
|
42 |
+
for path, mtime, size in files:
|
43 |
+
should_delete = False
|
44 |
+
|
45 |
+
# Check age
|
46 |
+
if current_time - mtime > max_age:
|
47 |
+
should_delete = True
|
48 |
+
logger.info(f"Deleting old temp file: {path}")
|
49 |
+
|
50 |
+
# Check count limit
|
51 |
+
elif len(files) > settings.max_temp_dir_count:
|
52 |
+
should_delete = True
|
53 |
+
logger.info(f"Deleting excess temp file: {path}")
|
54 |
+
|
55 |
+
# Check size limit
|
56 |
+
elif total_size > settings.max_temp_dir_size_mb * 1024 * 1024:
|
57 |
+
should_delete = True
|
58 |
+
logger.info(f"Deleting to reduce directory size: {path}")
|
59 |
+
|
60 |
+
if should_delete:
|
61 |
+
try:
|
62 |
+
await aiofiles.os.remove(path)
|
63 |
+
total_size -= size
|
64 |
+
logger.info(f"Deleted temp file: {path}")
|
65 |
+
except Exception as e:
|
66 |
+
logger.warning(f"Failed to delete temp file {path}: {e}")
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
logger.warning(f"Error during temp file cleanup: {e}")
|
70 |
+
|
71 |
+
|
72 |
+
class TempFileWriter:
|
73 |
+
"""Handles writing audio chunks to a temp file"""
|
74 |
+
|
75 |
+
def __init__(self, format: str):
|
76 |
+
"""Initialize temp file writer
|
77 |
+
|
78 |
+
Args:
|
79 |
+
format: Audio format extension (mp3, wav, etc)
|
80 |
+
"""
|
81 |
+
self.format = format
|
82 |
+
self.temp_file = None
|
83 |
+
self._finalized = False
|
84 |
+
self._write_error = False # Flag to track if we've had a write error
|
85 |
+
|
86 |
+
async def __aenter__(self):
|
87 |
+
"""Async context manager entry"""
|
88 |
+
try:
|
89 |
+
# Clean up old files first
|
90 |
+
await cleanup_temp_files()
|
91 |
+
|
92 |
+
# Create temp file with proper extension
|
93 |
+
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
94 |
+
temp = tempfile.NamedTemporaryFile(
|
95 |
+
dir=settings.temp_file_dir,
|
96 |
+
delete=False,
|
97 |
+
suffix=f".{self.format}",
|
98 |
+
mode="wb",
|
99 |
+
)
|
100 |
+
self.temp_file = await aiofiles.open(temp.name, mode="wb")
|
101 |
+
self.temp_path = temp.name
|
102 |
+
temp.close() # Close sync file, we'll use async version
|
103 |
+
|
104 |
+
# Generate download path immediately
|
105 |
+
self.download_path = f"/download/{os.path.basename(self.temp_path)}"
|
106 |
+
except Exception as e:
|
107 |
+
# Handle permission issues or other errors gracefully
|
108 |
+
logger.error(f"Failed to create temp file: {e}")
|
109 |
+
self._write_error = True
|
110 |
+
# Set a placeholder path so the API can still function
|
111 |
+
self.temp_path = f"unavailable_{self.format}"
|
112 |
+
self.download_path = f"/download/{self.temp_path}"
|
113 |
+
|
114 |
+
return self
|
115 |
+
|
116 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
117 |
+
"""Async context manager exit"""
|
118 |
+
try:
|
119 |
+
if self.temp_file and not self._finalized:
|
120 |
+
await self.temp_file.close()
|
121 |
+
self._finalized = True
|
122 |
+
except Exception as e:
|
123 |
+
logger.error(f"Error closing temp file: {e}")
|
124 |
+
self._write_error = True
|
125 |
+
|
126 |
+
async def write(self, chunk: bytes) -> None:
|
127 |
+
"""Write a chunk of audio data
|
128 |
+
|
129 |
+
Args:
|
130 |
+
chunk: Audio data bytes to write
|
131 |
+
"""
|
132 |
+
if self._finalized:
|
133 |
+
raise RuntimeError("Cannot write to finalized temp file")
|
134 |
+
|
135 |
+
# Skip writing if we've already encountered an error
|
136 |
+
if self._write_error or not self.temp_file:
|
137 |
+
return
|
138 |
+
|
139 |
+
try:
|
140 |
+
await self.temp_file.write(chunk)
|
141 |
+
await self.temp_file.flush()
|
142 |
+
except Exception as e:
|
143 |
+
# Handle permission issues or other errors gracefully
|
144 |
+
logger.error(f"Failed to write to temp file: {e}")
|
145 |
+
self._write_error = True
|
146 |
+
|
147 |
+
async def finalize(self) -> str:
|
148 |
+
"""Close temp file and return download path
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Path to use for downloading the temp file
|
152 |
+
"""
|
153 |
+
if self._finalized:
|
154 |
+
raise RuntimeError("Temp file already finalized")
|
155 |
+
|
156 |
+
# Skip finalizing if we've already encountered an error
|
157 |
+
if self._write_error or not self.temp_file:
|
158 |
+
self._finalized = True
|
159 |
+
return self.download_path
|
160 |
+
|
161 |
+
try:
|
162 |
+
await self.temp_file.close()
|
163 |
+
self._finalized = True
|
164 |
+
except Exception as e:
|
165 |
+
# Handle permission issues or other errors gracefully
|
166 |
+
logger.error(f"Failed to finalize temp file: {e}")
|
167 |
+
self._write_error = True
|
168 |
+
self._finalized = True
|
169 |
+
|
170 |
+
return self.download_path
|
api/src/services/text_processing/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Text processing pipeline."""
|
2 |
+
|
3 |
+
from .normalizer import normalize_text
|
4 |
+
from .phonemizer import phonemize
|
5 |
+
from .text_processor import process_text_chunk, smart_split
|
6 |
+
from .vocabulary import tokenize
|
7 |
+
|
8 |
+
|
9 |
+
def process_text(text: str) -> list[int]:
|
10 |
+
"""Process text into token IDs (for backward compatibility)."""
|
11 |
+
return process_text_chunk(text)
|
12 |
+
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
"normalize_text",
|
16 |
+
"phonemize",
|
17 |
+
"tokenize",
|
18 |
+
"process_text",
|
19 |
+
"process_text_chunk",
|
20 |
+
"smart_split",
|
21 |
+
]
|
api/src/services/text_processing/normalizer.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Text normalization module for TTS processing.
|
3 |
+
Handles various text formats including URLs, emails, numbers, money, and special characters.
|
4 |
+
Converts them into a format suitable for text-to-speech processing.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import re
|
8 |
+
from functools import lru_cache
|
9 |
+
|
10 |
+
import inflect
|
11 |
+
from numpy import number
|
12 |
+
from text_to_num import text2num
|
13 |
+
from torch import mul
|
14 |
+
|
15 |
+
from ...structures.schemas import NormalizationOptions
|
16 |
+
|
17 |
+
# Constants
|
18 |
+
VALID_TLDS = [
|
19 |
+
"com",
|
20 |
+
"org",
|
21 |
+
"net",
|
22 |
+
"edu",
|
23 |
+
"gov",
|
24 |
+
"mil",
|
25 |
+
"int",
|
26 |
+
"biz",
|
27 |
+
"info",
|
28 |
+
"name",
|
29 |
+
"pro",
|
30 |
+
"coop",
|
31 |
+
"museum",
|
32 |
+
"travel",
|
33 |
+
"jobs",
|
34 |
+
"mobi",
|
35 |
+
"tel",
|
36 |
+
"asia",
|
37 |
+
"cat",
|
38 |
+
"xxx",
|
39 |
+
"aero",
|
40 |
+
"arpa",
|
41 |
+
"bg",
|
42 |
+
"br",
|
43 |
+
"ca",
|
44 |
+
"cn",
|
45 |
+
"de",
|
46 |
+
"es",
|
47 |
+
"eu",
|
48 |
+
"fr",
|
49 |
+
"in",
|
50 |
+
"it",
|
51 |
+
"jp",
|
52 |
+
"mx",
|
53 |
+
"nl",
|
54 |
+
"ru",
|
55 |
+
"uk",
|
56 |
+
"us",
|
57 |
+
"io",
|
58 |
+
"co",
|
59 |
+
]
|
60 |
+
|
61 |
+
VALID_UNITS = {
|
62 |
+
"m": "meter",
|
63 |
+
"cm": "centimeter",
|
64 |
+
"mm": "millimeter",
|
65 |
+
"km": "kilometer",
|
66 |
+
"in": "inch",
|
67 |
+
"ft": "foot",
|
68 |
+
"yd": "yard",
|
69 |
+
"mi": "mile", # Length
|
70 |
+
"g": "gram",
|
71 |
+
"kg": "kilogram",
|
72 |
+
"mg": "milligram", # Mass
|
73 |
+
"s": "second",
|
74 |
+
"ms": "millisecond",
|
75 |
+
"min": "minutes",
|
76 |
+
"h": "hour", # Time
|
77 |
+
"l": "liter",
|
78 |
+
"ml": "mililiter",
|
79 |
+
"cl": "centiliter",
|
80 |
+
"dl": "deciliter", # Volume
|
81 |
+
"kph": "kilometer per hour",
|
82 |
+
"mph": "mile per hour",
|
83 |
+
"mi/h": "mile per hour",
|
84 |
+
"m/s": "meter per second",
|
85 |
+
"km/h": "kilometer per hour",
|
86 |
+
"mm/s": "milimeter per second",
|
87 |
+
"cm/s": "centimeter per second",
|
88 |
+
"ft/s": "feet per second",
|
89 |
+
"cm/h": "centimeter per day", # Speed
|
90 |
+
"°c": "degree celsius",
|
91 |
+
"c": "degree celsius",
|
92 |
+
"°f": "degree fahrenheit",
|
93 |
+
"f": "degree fahrenheit",
|
94 |
+
"k": "kelvin", # Temperature
|
95 |
+
"pa": "pascal",
|
96 |
+
"kpa": "kilopascal",
|
97 |
+
"mpa": "megapascal",
|
98 |
+
"atm": "atmosphere", # Pressure
|
99 |
+
"hz": "hertz",
|
100 |
+
"khz": "kilohertz",
|
101 |
+
"mhz": "megahertz",
|
102 |
+
"ghz": "gigahertz", # Frequency
|
103 |
+
"v": "volt",
|
104 |
+
"kv": "kilovolt",
|
105 |
+
"mv": "mergavolt", # Voltage
|
106 |
+
"a": "amp",
|
107 |
+
"ma": "megaamp",
|
108 |
+
"ka": "kiloamp", # Current
|
109 |
+
"w": "watt",
|
110 |
+
"kw": "kilowatt",
|
111 |
+
"mw": "megawatt", # Power
|
112 |
+
"j": "joule",
|
113 |
+
"kj": "kilojoule",
|
114 |
+
"mj": "megajoule", # Energy
|
115 |
+
"Ω": "ohm",
|
116 |
+
"kΩ": "kiloohm",
|
117 |
+
"mΩ": "megaohm", # Resistance (Ohm)
|
118 |
+
"f": "farad",
|
119 |
+
"µf": "microfarad",
|
120 |
+
"nf": "nanofarad",
|
121 |
+
"pf": "picofarad", # Capacitance
|
122 |
+
"b": "bit",
|
123 |
+
"kb": "kilobit",
|
124 |
+
"mb": "megabit",
|
125 |
+
"gb": "gigabit",
|
126 |
+
"tb": "terabit",
|
127 |
+
"pb": "petabit", # Data size
|
128 |
+
"kbps": "kilobit per second",
|
129 |
+
"mbps": "megabit per second",
|
130 |
+
"gbps": "gigabit per second",
|
131 |
+
"tbps": "terabit per second",
|
132 |
+
"px": "pixel", # CSS units
|
133 |
+
}
|
134 |
+
|
135 |
+
|
136 |
+
# Pre-compiled regex patterns for performance
|
137 |
+
EMAIL_PATTERN = re.compile(
|
138 |
+
r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE
|
139 |
+
)
|
140 |
+
URL_PATTERN = re.compile(
|
141 |
+
r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:"
|
142 |
+
+ "|".join(VALID_TLDS)
|
143 |
+
+ "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?",
|
144 |
+
re.IGNORECASE,
|
145 |
+
)
|
146 |
+
|
147 |
+
UNIT_PATTERN = re.compile(
|
148 |
+
r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*("
|
149 |
+
+ "|".join(sorted(list(VALID_UNITS.keys()), reverse=True))
|
150 |
+
+ r"""){1}(?=[^\w\d]{1}|\b)""",
|
151 |
+
re.IGNORECASE,
|
152 |
+
)
|
153 |
+
|
154 |
+
TIME_PATTERN = re.compile(
|
155 |
+
r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
|
156 |
+
)
|
157 |
+
|
158 |
+
INFLECT_ENGINE = inflect.engine()
|
159 |
+
|
160 |
+
|
161 |
+
def split_num(num: re.Match[str]) -> str:
|
162 |
+
"""Handle number splitting for various formats"""
|
163 |
+
num = num.group()
|
164 |
+
if "." in num:
|
165 |
+
return num
|
166 |
+
elif ":" in num:
|
167 |
+
h, m = [int(n) for n in num.split(":")]
|
168 |
+
if m == 0:
|
169 |
+
return f"{h} o'clock"
|
170 |
+
elif m < 10:
|
171 |
+
return f"{h} oh {m}"
|
172 |
+
return f"{h} {m}"
|
173 |
+
year = int(num[:4])
|
174 |
+
if year < 1100 or year % 1000 < 10:
|
175 |
+
return num
|
176 |
+
left, right = num[:2], int(num[2:4])
|
177 |
+
s = "s" if num.endswith("s") else ""
|
178 |
+
if 100 <= year % 1000 <= 999:
|
179 |
+
if right == 0:
|
180 |
+
return f"{left} hundred{s}"
|
181 |
+
elif right < 10:
|
182 |
+
return f"{left} oh {right}{s}"
|
183 |
+
return f"{left} {right}{s}"
|
184 |
+
|
185 |
+
|
186 |
+
def handle_units(u: re.Match[str]) -> str:
|
187 |
+
"""Converts units to their full form"""
|
188 |
+
unit_string = u.group(6).strip()
|
189 |
+
unit = unit_string
|
190 |
+
|
191 |
+
if unit_string.lower() in VALID_UNITS:
|
192 |
+
unit = VALID_UNITS[unit_string.lower()].split(" ")
|
193 |
+
|
194 |
+
# Handles the B vs b case
|
195 |
+
if unit[0].endswith("bit"):
|
196 |
+
b_case = unit_string[min(1, len(unit_string) - 1)]
|
197 |
+
if b_case == "B":
|
198 |
+
unit[0] = unit[0][:-3] + "byte"
|
199 |
+
|
200 |
+
number = u.group(1).strip()
|
201 |
+
unit[0] = INFLECT_ENGINE.no(unit[0], number)
|
202 |
+
return " ".join(unit)
|
203 |
+
|
204 |
+
|
205 |
+
def conditional_int(number: float, threshold: float = 0.00001):
|
206 |
+
if abs(round(number) - number) < threshold:
|
207 |
+
return int(round(number))
|
208 |
+
return number
|
209 |
+
|
210 |
+
|
211 |
+
def handle_money(m: re.Match[str]) -> str:
|
212 |
+
"""Convert money expressions to spoken form"""
|
213 |
+
|
214 |
+
bill = "dollar" if m.group(2) == "$" else "pound"
|
215 |
+
coin = "cent" if m.group(2) == "$" else "pence"
|
216 |
+
number = m.group(3)
|
217 |
+
|
218 |
+
multiplier = m.group(4)
|
219 |
+
try:
|
220 |
+
number = float(number)
|
221 |
+
except:
|
222 |
+
return m.group()
|
223 |
+
|
224 |
+
if m.group(1) == "-":
|
225 |
+
number *= -1
|
226 |
+
|
227 |
+
if number % 1 == 0 or multiplier != "":
|
228 |
+
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
|
229 |
+
else:
|
230 |
+
sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
|
231 |
+
|
232 |
+
text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
|
233 |
+
|
234 |
+
return text_number
|
235 |
+
|
236 |
+
|
237 |
+
def handle_decimal(num: re.Match[str]) -> str:
|
238 |
+
"""Convert decimal numbers to spoken form"""
|
239 |
+
a, b = num.group().split(".")
|
240 |
+
return " point ".join([a, " ".join(b)])
|
241 |
+
|
242 |
+
|
243 |
+
def handle_email(m: re.Match[str]) -> str:
|
244 |
+
"""Convert email addresses into speakable format"""
|
245 |
+
email = m.group(0)
|
246 |
+
parts = email.split("@")
|
247 |
+
if len(parts) == 2:
|
248 |
+
user, domain = parts
|
249 |
+
domain = domain.replace(".", " dot ")
|
250 |
+
return f"{user} at {domain}"
|
251 |
+
return email
|
252 |
+
|
253 |
+
|
254 |
+
def handle_url(u: re.Match[str]) -> str:
|
255 |
+
"""Make URLs speakable by converting special characters to spoken words"""
|
256 |
+
if not u:
|
257 |
+
return ""
|
258 |
+
|
259 |
+
url = u.group(0).strip()
|
260 |
+
|
261 |
+
# Handle protocol first
|
262 |
+
url = re.sub(
|
263 |
+
r"^https?://",
|
264 |
+
lambda a: "https " if "https" in a.group() else "http ",
|
265 |
+
url,
|
266 |
+
flags=re.IGNORECASE,
|
267 |
+
)
|
268 |
+
url = re.sub(r"^www\.", "www ", url, flags=re.IGNORECASE)
|
269 |
+
|
270 |
+
# Handle port numbers before other replacements
|
271 |
+
url = re.sub(r":(\d+)(?=/|$)", lambda m: f" colon {m.group(1)}", url)
|
272 |
+
|
273 |
+
# Split into domain and path
|
274 |
+
parts = url.split("/", 1)
|
275 |
+
domain = parts[0]
|
276 |
+
path = parts[1] if len(parts) > 1 else ""
|
277 |
+
|
278 |
+
# Handle dots in domain
|
279 |
+
domain = domain.replace(".", " dot ")
|
280 |
+
|
281 |
+
# Reconstruct URL
|
282 |
+
if path:
|
283 |
+
url = f"{domain} slash {path}"
|
284 |
+
else:
|
285 |
+
url = domain
|
286 |
+
|
287 |
+
# Replace remaining symbols with words
|
288 |
+
url = url.replace("-", " dash ")
|
289 |
+
url = url.replace("_", " underscore ")
|
290 |
+
url = url.replace("?", " question-mark ")
|
291 |
+
url = url.replace("=", " equals ")
|
292 |
+
url = url.replace("&", " ampersand ")
|
293 |
+
url = url.replace("%", " percent ")
|
294 |
+
url = url.replace(":", " colon ") # Handle any remaining colons
|
295 |
+
url = url.replace("/", " slash ") # Handle any remaining slashes
|
296 |
+
|
297 |
+
# Clean up extra spaces
|
298 |
+
return re.sub(r"\s+", " ", url).strip()
|
299 |
+
|
300 |
+
|
301 |
+
def handle_phone_number(p: re.Match[str]) -> str:
|
302 |
+
p = list(p.groups())
|
303 |
+
|
304 |
+
country_code = ""
|
305 |
+
if p[0] is not None:
|
306 |
+
p[0] = p[0].replace("+", "")
|
307 |
+
country_code += INFLECT_ENGINE.number_to_words(p[0])
|
308 |
+
|
309 |
+
area_code = INFLECT_ENGINE.number_to_words(
|
310 |
+
p[2].replace("(", "").replace(")", ""), group=1, comma=""
|
311 |
+
)
|
312 |
+
|
313 |
+
telephone_prefix = INFLECT_ENGINE.number_to_words(p[3], group=1, comma="")
|
314 |
+
|
315 |
+
line_number = INFLECT_ENGINE.number_to_words(p[4], group=1, comma="")
|
316 |
+
|
317 |
+
return ",".join([country_code, area_code, telephone_prefix, line_number])
|
318 |
+
|
319 |
+
|
320 |
+
def handle_time(t: re.Match[str]) -> str:
|
321 |
+
t = t.groups()
|
322 |
+
|
323 |
+
numbers = " ".join(
|
324 |
+
[INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")]
|
325 |
+
)
|
326 |
+
|
327 |
+
half = ""
|
328 |
+
if t[2] is not None:
|
329 |
+
half = t[2].strip()
|
330 |
+
|
331 |
+
return numbers + half
|
332 |
+
|
333 |
+
|
334 |
+
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
|
335 |
+
"""Normalize text for TTS processing"""
|
336 |
+
# Handle email addresses first if enabled
|
337 |
+
if normalization_options.email_normalization:
|
338 |
+
text = EMAIL_PATTERN.sub(handle_email, text)
|
339 |
+
|
340 |
+
# Handle URLs if enabled
|
341 |
+
if normalization_options.url_normalization:
|
342 |
+
text = URL_PATTERN.sub(handle_url, text)
|
343 |
+
|
344 |
+
# Pre-process numbers with units if enabled
|
345 |
+
if normalization_options.unit_normalization:
|
346 |
+
text = UNIT_PATTERN.sub(handle_units, text)
|
347 |
+
|
348 |
+
# Replace optional pluralization
|
349 |
+
if normalization_options.optional_pluralization_normalization:
|
350 |
+
text = re.sub(r"\(s\)", "s", text)
|
351 |
+
|
352 |
+
# Replace phone numbers:
|
353 |
+
if normalization_options.phone_normalization:
|
354 |
+
text = re.sub(
|
355 |
+
r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",
|
356 |
+
handle_phone_number,
|
357 |
+
text,
|
358 |
+
)
|
359 |
+
|
360 |
+
# Replace quotes and brackets
|
361 |
+
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
362 |
+
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
363 |
+
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
364 |
+
|
365 |
+
# Handle CJK punctuation and some non standard chars
|
366 |
+
for a, b in zip("、。!,:;?–", ",.!,:;?-"):
|
367 |
+
text = text.replace(a, b + " ")
|
368 |
+
|
369 |
+
# Handle simple time in the format of HH:MM:SS
|
370 |
+
text = TIME_PATTERN.sub(
|
371 |
+
handle_time,
|
372 |
+
text,
|
373 |
+
)
|
374 |
+
|
375 |
+
# Clean up whitespace
|
376 |
+
text = re.sub(r"[^\S \n]", " ", text)
|
377 |
+
text = re.sub(r" +", " ", text)
|
378 |
+
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
379 |
+
|
380 |
+
# Handle titles and abbreviations
|
381 |
+
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
382 |
+
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
383 |
+
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
|
384 |
+
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
|
385 |
+
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
|
386 |
+
|
387 |
+
# Handle common words
|
388 |
+
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
389 |
+
|
390 |
+
# Handle numbers and money
|
391 |
+
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
392 |
+
|
393 |
+
text = re.sub(
|
394 |
+
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
|
395 |
+
handle_money,
|
396 |
+
text,
|
397 |
+
)
|
398 |
+
|
399 |
+
text = re.sub(
|
400 |
+
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
401 |
+
)
|
402 |
+
|
403 |
+
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
404 |
+
|
405 |
+
# Handle various formatting
|
406 |
+
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
407 |
+
text = re.sub(r"(?<=\d)S", " S", text)
|
408 |
+
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
409 |
+
text = re.sub(r"(?<=X')S\b", "s", text)
|
410 |
+
text = re.sub(
|
411 |
+
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
|
412 |
+
)
|
413 |
+
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
414 |
+
|
415 |
+
return text.strip()
|
api/src/services/text_processing/phonemizer.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
import phonemizer
|
5 |
+
|
6 |
+
from .normalizer import normalize_text
|
7 |
+
|
8 |
+
phonemizers = {}
|
9 |
+
|
10 |
+
|
11 |
+
class PhonemizerBackend(ABC):
|
12 |
+
"""Abstract base class for phonemization backends"""
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def phonemize(self, text: str) -> str:
|
16 |
+
"""Convert text to phonemes
|
17 |
+
|
18 |
+
Args:
|
19 |
+
text: Text to convert to phonemes
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Phonemized text
|
23 |
+
"""
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
class EspeakBackend(PhonemizerBackend):
|
28 |
+
"""Espeak-based phonemizer implementation"""
|
29 |
+
|
30 |
+
def __init__(self, language: str):
|
31 |
+
"""Initialize espeak backend
|
32 |
+
|
33 |
+
Args:
|
34 |
+
language: Language code ('en-us' or 'en-gb')
|
35 |
+
"""
|
36 |
+
self.backend = phonemizer.backend.EspeakBackend(
|
37 |
+
language=language, preserve_punctuation=True, with_stress=True
|
38 |
+
)
|
39 |
+
|
40 |
+
self.language = language
|
41 |
+
|
42 |
+
def phonemize(self, text: str) -> str:
|
43 |
+
"""Convert text to phonemes using espeak
|
44 |
+
|
45 |
+
Args:
|
46 |
+
text: Text to convert to phonemes
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Phonemized text
|
50 |
+
"""
|
51 |
+
# Phonemize text
|
52 |
+
ps = self.backend.phonemize([text])
|
53 |
+
ps = ps[0] if ps else ""
|
54 |
+
|
55 |
+
# Handle special cases
|
56 |
+
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
|
57 |
+
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
|
58 |
+
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
|
59 |
+
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»"" ]|$)', "z", ps)
|
60 |
+
|
61 |
+
# Language-specific rules
|
62 |
+
if self.language == "en-us":
|
63 |
+
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
|
64 |
+
|
65 |
+
return ps.strip()
|
66 |
+
|
67 |
+
|
68 |
+
def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
69 |
+
"""Factory function to create phonemizer backend
|
70 |
+
|
71 |
+
Args:
|
72 |
+
language: Language code ('a' for US English, 'b' for British English)
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Phonemizer backend instance
|
76 |
+
"""
|
77 |
+
# Map language codes to espeak language codes
|
78 |
+
lang_map = {"a": "en-us", "b": "en-gb"}
|
79 |
+
|
80 |
+
if language not in lang_map:
|
81 |
+
raise ValueError(f"Unsupported language code: {language}")
|
82 |
+
|
83 |
+
return EspeakBackend(lang_map[language])
|
84 |
+
|
85 |
+
|
86 |
+
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
87 |
+
"""Convert text to phonemes
|
88 |
+
|
89 |
+
Args:
|
90 |
+
text: Text to convert to phonemes
|
91 |
+
language: Language code ('a' for US English, 'b' for British English)
|
92 |
+
normalize: Whether to normalize text before phonemization
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Phonemized text
|
96 |
+
"""
|
97 |
+
global phonemizers
|
98 |
+
if normalize:
|
99 |
+
text = normalize_text(text)
|
100 |
+
if language not in phonemizers:
|
101 |
+
phonemizers[language] = create_phonemizer(language)
|
102 |
+
return phonemizers[language].phonemize(text)
|
api/src/services/text_processing/text_processor.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Unified text processing for TTS with smart chunking."""
|
2 |
+
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
from typing import AsyncGenerator, Dict, List, Tuple
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from ...core.config import settings
|
10 |
+
from ...structures.schemas import NormalizationOptions
|
11 |
+
from .normalizer import normalize_text
|
12 |
+
from .phonemizer import phonemize
|
13 |
+
from .vocabulary import tokenize
|
14 |
+
|
15 |
+
# Pre-compiled regex patterns for performance
|
16 |
+
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
17 |
+
|
18 |
+
|
19 |
+
def process_text_chunk(
|
20 |
+
text: str, language: str = "a", skip_phonemize: bool = False
|
21 |
+
) -> List[int]:
|
22 |
+
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
text: Text chunk to process
|
26 |
+
language: Language code for phonemization
|
27 |
+
skip_phonemize: If True, treat input as phonemes and skip normalization/phonemization
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
List of token IDs
|
31 |
+
"""
|
32 |
+
start_time = time.time()
|
33 |
+
|
34 |
+
if skip_phonemize:
|
35 |
+
# Input is already phonemes, just tokenize
|
36 |
+
t0 = time.time()
|
37 |
+
tokens = tokenize(text)
|
38 |
+
t1 = time.time()
|
39 |
+
else:
|
40 |
+
# Normal text processing pipeline
|
41 |
+
t0 = time.time()
|
42 |
+
t1 = time.time()
|
43 |
+
|
44 |
+
t0 = time.time()
|
45 |
+
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
46 |
+
t1 = time.time()
|
47 |
+
|
48 |
+
t0 = time.time()
|
49 |
+
tokens = tokenize(phonemes)
|
50 |
+
t1 = time.time()
|
51 |
+
|
52 |
+
total_time = time.time() - start_time
|
53 |
+
logger.debug(
|
54 |
+
f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'"
|
55 |
+
)
|
56 |
+
|
57 |
+
return tokens
|
58 |
+
|
59 |
+
|
60 |
+
async def yield_chunk(
|
61 |
+
text: str, tokens: List[int], chunk_count: int
|
62 |
+
) -> Tuple[str, List[int]]:
|
63 |
+
"""Yield a chunk with consistent logging."""
|
64 |
+
logger.debug(
|
65 |
+
f"Yielding chunk {chunk_count}: '{text[:50]}{'...' if len(text) > 50 else ''}' ({len(tokens)} tokens)"
|
66 |
+
)
|
67 |
+
return text, tokens
|
68 |
+
|
69 |
+
|
70 |
+
def process_text(text: str, language: str = "a") -> List[int]:
|
71 |
+
"""Process text into token IDs.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
text: Text to process
|
75 |
+
language: Language code for phonemization
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
List of token IDs
|
79 |
+
"""
|
80 |
+
if not isinstance(text, str):
|
81 |
+
text = str(text) if text is not None else ""
|
82 |
+
|
83 |
+
text = text.strip()
|
84 |
+
if not text:
|
85 |
+
return []
|
86 |
+
|
87 |
+
return process_text_chunk(text, language)
|
88 |
+
|
89 |
+
|
90 |
+
def get_sentence_info(
|
91 |
+
text: str, custom_phenomes_list: Dict[str, str]
|
92 |
+
) -> List[Tuple[str, List[int], int]]:
|
93 |
+
"""Process all sentences and return info."""
|
94 |
+
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
95 |
+
phoneme_length, min_value = len(custom_phenomes_list), 0
|
96 |
+
|
97 |
+
results = []
|
98 |
+
for i in range(0, len(sentences), 2):
|
99 |
+
sentence = sentences[i].strip()
|
100 |
+
for replaced in range(min_value, phoneme_length):
|
101 |
+
current_id = f"</|custom_phonemes_{replaced}|/>"
|
102 |
+
if current_id in sentence:
|
103 |
+
sentence = sentence.replace(
|
104 |
+
current_id, custom_phenomes_list.pop(current_id)
|
105 |
+
)
|
106 |
+
min_value += 1
|
107 |
+
|
108 |
+
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
109 |
+
|
110 |
+
if not sentence:
|
111 |
+
continue
|
112 |
+
|
113 |
+
full = sentence + punct
|
114 |
+
tokens = process_text_chunk(full)
|
115 |
+
results.append((full, tokens, len(tokens)))
|
116 |
+
|
117 |
+
return results
|
118 |
+
|
119 |
+
|
120 |
+
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
|
121 |
+
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
|
122 |
+
phenomes_list[latest_id] = s.group(0).strip()
|
123 |
+
return latest_id
|
124 |
+
|
125 |
+
|
126 |
+
async def smart_split(
|
127 |
+
text: str,
|
128 |
+
max_tokens: int = settings.absolute_max_tokens,
|
129 |
+
lang_code: str = "a",
|
130 |
+
normalization_options: NormalizationOptions = NormalizationOptions(),
|
131 |
+
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
132 |
+
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
133 |
+
start_time = time.time()
|
134 |
+
chunk_count = 0
|
135 |
+
logger.info(f"Starting smart split for {len(text)} chars")
|
136 |
+
|
137 |
+
custom_phoneme_list = {}
|
138 |
+
|
139 |
+
# Normalize text
|
140 |
+
if settings.advanced_text_normalization and normalization_options.normalize:
|
141 |
+
print(lang_code)
|
142 |
+
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
143 |
+
text = CUSTOM_PHONEMES.sub(
|
144 |
+
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
|
145 |
+
)
|
146 |
+
text = normalize_text(text, normalization_options)
|
147 |
+
else:
|
148 |
+
logger.info(
|
149 |
+
"Skipping text normalization as it is only supported for english"
|
150 |
+
)
|
151 |
+
|
152 |
+
# Process all sentences
|
153 |
+
sentences = get_sentence_info(text, custom_phoneme_list)
|
154 |
+
|
155 |
+
current_chunk = []
|
156 |
+
current_tokens = []
|
157 |
+
current_count = 0
|
158 |
+
|
159 |
+
for sentence, tokens, count in sentences:
|
160 |
+
# Handle sentences that exceed max tokens
|
161 |
+
if count > max_tokens:
|
162 |
+
# Yield current chunk if any
|
163 |
+
if current_chunk:
|
164 |
+
chunk_text = " ".join(current_chunk)
|
165 |
+
chunk_count += 1
|
166 |
+
logger.debug(
|
167 |
+
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
168 |
+
)
|
169 |
+
yield chunk_text, current_tokens
|
170 |
+
current_chunk = []
|
171 |
+
current_tokens = []
|
172 |
+
current_count = 0
|
173 |
+
|
174 |
+
# Split long sentence on commas
|
175 |
+
clauses = re.split(r"([,])", sentence)
|
176 |
+
clause_chunk = []
|
177 |
+
clause_tokens = []
|
178 |
+
clause_count = 0
|
179 |
+
|
180 |
+
for j in range(0, len(clauses), 2):
|
181 |
+
clause = clauses[j].strip()
|
182 |
+
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
183 |
+
|
184 |
+
if not clause:
|
185 |
+
continue
|
186 |
+
|
187 |
+
full_clause = clause + comma
|
188 |
+
|
189 |
+
tokens = process_text_chunk(full_clause)
|
190 |
+
count = len(tokens)
|
191 |
+
|
192 |
+
# If adding clause keeps us under max and not optimal yet
|
193 |
+
if (
|
194 |
+
clause_count + count <= max_tokens
|
195 |
+
and clause_count + count <= settings.target_max_tokens
|
196 |
+
):
|
197 |
+
clause_chunk.append(full_clause)
|
198 |
+
clause_tokens.extend(tokens)
|
199 |
+
clause_count += count
|
200 |
+
else:
|
201 |
+
# Yield clause chunk if we have one
|
202 |
+
if clause_chunk:
|
203 |
+
chunk_text = " ".join(clause_chunk)
|
204 |
+
chunk_count += 1
|
205 |
+
logger.debug(
|
206 |
+
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
207 |
+
)
|
208 |
+
yield chunk_text, clause_tokens
|
209 |
+
clause_chunk = [full_clause]
|
210 |
+
clause_tokens = tokens
|
211 |
+
clause_count = count
|
212 |
+
|
213 |
+
# Don't forget last clause chunk
|
214 |
+
if clause_chunk:
|
215 |
+
chunk_text = " ".join(clause_chunk)
|
216 |
+
chunk_count += 1
|
217 |
+
logger.debug(
|
218 |
+
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
219 |
+
)
|
220 |
+
yield chunk_text, clause_tokens
|
221 |
+
|
222 |
+
# Regular sentence handling
|
223 |
+
elif (
|
224 |
+
current_count >= settings.target_min_tokens
|
225 |
+
and current_count + count > settings.target_max_tokens
|
226 |
+
):
|
227 |
+
# If we have a good sized chunk and adding next sentence exceeds target,
|
228 |
+
# yield current chunk and start new one
|
229 |
+
chunk_text = " ".join(current_chunk)
|
230 |
+
chunk_count += 1
|
231 |
+
logger.info(
|
232 |
+
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
233 |
+
)
|
234 |
+
yield chunk_text, current_tokens
|
235 |
+
current_chunk = [sentence]
|
236 |
+
current_tokens = tokens
|
237 |
+
current_count = count
|
238 |
+
elif current_count + count <= settings.target_max_tokens:
|
239 |
+
# Keep building chunk while under target max
|
240 |
+
current_chunk.append(sentence)
|
241 |
+
current_tokens.extend(tokens)
|
242 |
+
current_count += count
|
243 |
+
elif (
|
244 |
+
current_count + count <= max_tokens
|
245 |
+
and current_count < settings.target_min_tokens
|
246 |
+
):
|
247 |
+
# Only exceed target max if we haven't reached minimum size yet
|
248 |
+
current_chunk.append(sentence)
|
249 |
+
current_tokens.extend(tokens)
|
250 |
+
current_count += count
|
251 |
+
else:
|
252 |
+
# Yield current chunk and start new one
|
253 |
+
if current_chunk:
|
254 |
+
chunk_text = " ".join(current_chunk)
|
255 |
+
chunk_count += 1
|
256 |
+
logger.info(
|
257 |
+
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
258 |
+
)
|
259 |
+
yield chunk_text, current_tokens
|
260 |
+
current_chunk = [sentence]
|
261 |
+
current_tokens = tokens
|
262 |
+
current_count = count
|
263 |
+
|
264 |
+
# Don't forget the last chunk
|
265 |
+
if current_chunk:
|
266 |
+
chunk_text = " ".join(current_chunk)
|
267 |
+
chunk_count += 1
|
268 |
+
logger.info(
|
269 |
+
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
270 |
+
)
|
271 |
+
yield chunk_text, current_tokens
|
272 |
+
|
273 |
+
total_time = time.time() - start_time
|
274 |
+
logger.info(
|
275 |
+
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
|
276 |
+
)
|
api/src/services/text_processing/vocabulary.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_vocab():
|
2 |
+
"""Get the vocabulary dictionary mapping characters to token IDs"""
|
3 |
+
_pad = "$"
|
4 |
+
_punctuation = ';:,.!?¡¿—…"«»"" '
|
5 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
6 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
7 |
+
|
8 |
+
# Create vocabulary dictionary
|
9 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
10 |
+
return {symbol: i for i, symbol in enumerate(symbols)}
|
11 |
+
|
12 |
+
|
13 |
+
# Initialize vocabulary
|
14 |
+
VOCAB = get_vocab()
|
15 |
+
|
16 |
+
|
17 |
+
def tokenize(phonemes: str) -> list[int]:
|
18 |
+
"""Convert phonemes string to token IDs
|
19 |
+
|
20 |
+
Args:
|
21 |
+
phonemes: String of phonemes to tokenize
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
List of token IDs
|
25 |
+
"""
|
26 |
+
return [i for i in map(VOCAB.get, phonemes) if i is not None]
|
27 |
+
|
28 |
+
|
29 |
+
def decode_tokens(tokens: list[int]) -> str:
|
30 |
+
"""Convert token IDs back to phonemes string
|
31 |
+
|
32 |
+
Args:
|
33 |
+
tokens: List of token IDs
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
String of phonemes
|
37 |
+
"""
|
38 |
+
# Create reverse mapping
|
39 |
+
id_to_symbol = {i: s for s, i in VOCAB.items()}
|
40 |
+
return "".join(id_to_symbol[t] for t in tokens)
|
api/src/services/tts_service.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""TTS service using model and voice managers."""
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import tempfile
|
7 |
+
import time
|
8 |
+
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from kokoro import KPipeline
|
13 |
+
from loguru import logger
|
14 |
+
|
15 |
+
from ..core.config import settings
|
16 |
+
from ..inference.base import AudioChunk
|
17 |
+
from ..inference.kokoro_v1 import KokoroV1
|
18 |
+
from ..inference.model_manager import get_manager as get_model_manager
|
19 |
+
from ..inference.voice_manager import get_manager as get_voice_manager
|
20 |
+
from ..structures.schemas import NormalizationOptions
|
21 |
+
from .audio import AudioNormalizer, AudioService
|
22 |
+
from .streaming_audio_writer import StreamingAudioWriter
|
23 |
+
from .text_processing import tokenize
|
24 |
+
from .text_processing.text_processor import process_text_chunk, smart_split
|
25 |
+
|
26 |
+
|
27 |
+
class TTSService:
|
28 |
+
"""Text-to-speech service."""
|
29 |
+
|
30 |
+
# Limit concurrent chunk processing
|
31 |
+
_chunk_semaphore = asyncio.Semaphore(4)
|
32 |
+
|
33 |
+
def __init__(self, output_dir: str = None):
|
34 |
+
"""Initialize service."""
|
35 |
+
self.output_dir = output_dir
|
36 |
+
self.model_manager = None
|
37 |
+
self._voice_manager = None
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
async def create(cls, output_dir: str = None) -> "TTSService":
|
41 |
+
"""Create and initialize TTSService instance."""
|
42 |
+
service = cls(output_dir)
|
43 |
+
service.model_manager = await get_model_manager()
|
44 |
+
service._voice_manager = await get_voice_manager()
|
45 |
+
return service
|
46 |
+
|
47 |
+
async def _process_chunk(
|
48 |
+
self,
|
49 |
+
chunk_text: str,
|
50 |
+
tokens: List[int],
|
51 |
+
voice_name: str,
|
52 |
+
voice_path: str,
|
53 |
+
speed: float,
|
54 |
+
writer: StreamingAudioWriter,
|
55 |
+
output_format: Optional[str] = None,
|
56 |
+
is_first: bool = False,
|
57 |
+
is_last: bool = False,
|
58 |
+
normalizer: Optional[AudioNormalizer] = None,
|
59 |
+
lang_code: Optional[str] = None,
|
60 |
+
return_timestamps: Optional[bool] = False,
|
61 |
+
) -> AsyncGenerator[AudioChunk, None]:
|
62 |
+
"""Process tokens into audio."""
|
63 |
+
async with self._chunk_semaphore:
|
64 |
+
try:
|
65 |
+
# Handle stream finalization
|
66 |
+
if is_last:
|
67 |
+
# Skip format conversion for raw audio mode
|
68 |
+
if not output_format:
|
69 |
+
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
|
70 |
+
return
|
71 |
+
chunk_data = await AudioService.convert_audio(
|
72 |
+
AudioChunk(
|
73 |
+
np.array([], dtype=np.float32)
|
74 |
+
), # Dummy data for type checking
|
75 |
+
output_format,
|
76 |
+
writer,
|
77 |
+
speed,
|
78 |
+
"",
|
79 |
+
normalizer=normalizer,
|
80 |
+
is_last_chunk=True,
|
81 |
+
)
|
82 |
+
yield chunk_data
|
83 |
+
return
|
84 |
+
|
85 |
+
# Skip empty chunks
|
86 |
+
if not tokens and not chunk_text:
|
87 |
+
return
|
88 |
+
|
89 |
+
# Get backend
|
90 |
+
backend = self.model_manager.get_backend()
|
91 |
+
|
92 |
+
# Generate audio using pre-warmed model
|
93 |
+
if isinstance(backend, KokoroV1):
|
94 |
+
chunk_index = 0
|
95 |
+
# For Kokoro V1, pass text and voice info with lang_code
|
96 |
+
async for chunk_data in self.model_manager.generate(
|
97 |
+
chunk_text,
|
98 |
+
(voice_name, voice_path),
|
99 |
+
speed=speed,
|
100 |
+
lang_code=lang_code,
|
101 |
+
return_timestamps=return_timestamps,
|
102 |
+
):
|
103 |
+
# For streaming, convert to bytes
|
104 |
+
if output_format:
|
105 |
+
try:
|
106 |
+
chunk_data = await AudioService.convert_audio(
|
107 |
+
chunk_data,
|
108 |
+
output_format,
|
109 |
+
writer,
|
110 |
+
speed,
|
111 |
+
chunk_text,
|
112 |
+
is_last_chunk=is_last,
|
113 |
+
normalizer=normalizer,
|
114 |
+
)
|
115 |
+
yield chunk_data
|
116 |
+
except Exception as e:
|
117 |
+
logger.error(f"Failed to convert audio: {str(e)}")
|
118 |
+
else:
|
119 |
+
chunk_data = AudioService.trim_audio(
|
120 |
+
chunk_data, chunk_text, speed, is_last, normalizer
|
121 |
+
)
|
122 |
+
yield chunk_data
|
123 |
+
chunk_index += 1
|
124 |
+
else:
|
125 |
+
# For legacy backends, load voice tensor
|
126 |
+
voice_tensor = await self._voice_manager.load_voice(
|
127 |
+
voice_name, device=backend.device
|
128 |
+
)
|
129 |
+
chunk_data = await self.model_manager.generate(
|
130 |
+
tokens,
|
131 |
+
voice_tensor,
|
132 |
+
speed=speed,
|
133 |
+
return_timestamps=return_timestamps,
|
134 |
+
)
|
135 |
+
|
136 |
+
if chunk_data.audio is None:
|
137 |
+
logger.error("Model generated None for audio chunk")
|
138 |
+
return
|
139 |
+
|
140 |
+
if len(chunk_data.audio) == 0:
|
141 |
+
logger.error("Model generated empty audio chunk")
|
142 |
+
return
|
143 |
+
|
144 |
+
# For streaming, convert to bytes
|
145 |
+
if output_format:
|
146 |
+
try:
|
147 |
+
chunk_data = await AudioService.convert_audio(
|
148 |
+
chunk_data,
|
149 |
+
output_format,
|
150 |
+
writer,
|
151 |
+
speed,
|
152 |
+
chunk_text,
|
153 |
+
normalizer=normalizer,
|
154 |
+
is_last_chunk=is_last,
|
155 |
+
)
|
156 |
+
yield chunk_data
|
157 |
+
except Exception as e:
|
158 |
+
logger.error(f"Failed to convert audio: {str(e)}")
|
159 |
+
else:
|
160 |
+
trimmed = AudioService.trim_audio(
|
161 |
+
chunk_data, chunk_text, speed, is_last, normalizer
|
162 |
+
)
|
163 |
+
yield trimmed
|
164 |
+
except Exception as e:
|
165 |
+
logger.error(f"Failed to process tokens: {str(e)}")
|
166 |
+
|
167 |
+
async def _load_voice_from_path(self, path: str, weight: float):
|
168 |
+
# Check if the path is None and raise a ValueError if it is not
|
169 |
+
if not path:
|
170 |
+
raise ValueError(f"Voice not found at path: {path}")
|
171 |
+
|
172 |
+
logger.debug(f"Loading voice tensor from path: {path}")
|
173 |
+
return torch.load(path, map_location="cpu") * weight
|
174 |
+
|
175 |
+
async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
|
176 |
+
"""Get voice path, handling combined voices.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
Tuple of (voice name to use, voice path to use)
|
183 |
+
|
184 |
+
Raises:
|
185 |
+
RuntimeError: If voice not found
|
186 |
+
"""
|
187 |
+
try:
|
188 |
+
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
|
189 |
+
split_voice = re.split(r"([-+])", voice)
|
190 |
+
|
191 |
+
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
192 |
+
if len(split_voice) == 1:
|
193 |
+
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
194 |
+
if (
|
195 |
+
"(" not in voice and ")" not in voice
|
196 |
+
) or settings.voice_weight_normalization == True:
|
197 |
+
path = await self._voice_manager.get_voice_path(voice)
|
198 |
+
if not path:
|
199 |
+
raise RuntimeError(f"Voice not found: {voice}")
|
200 |
+
logger.debug(f"Using single voice path: {path}")
|
201 |
+
return voice, path
|
202 |
+
|
203 |
+
total_weight = 0
|
204 |
+
|
205 |
+
for voice_index in range(0, len(split_voice), 2):
|
206 |
+
voice_object = split_voice[voice_index]
|
207 |
+
|
208 |
+
if "(" in voice_object and ")" in voice_object:
|
209 |
+
voice_name = voice_object.split("(")[0].strip()
|
210 |
+
voice_weight = float(voice_object.split("(")[1].split(")")[0])
|
211 |
+
else:
|
212 |
+
voice_name = voice_object
|
213 |
+
voice_weight = 1
|
214 |
+
|
215 |
+
total_weight += voice_weight
|
216 |
+
split_voice[voice_index] = (voice_name, voice_weight)
|
217 |
+
|
218 |
+
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
|
219 |
+
if settings.voice_weight_normalization == False:
|
220 |
+
total_weight = 1
|
221 |
+
|
222 |
+
# Load the first voice as the starting point for voices to be combined onto
|
223 |
+
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
224 |
+
combined_tensor = await self._load_voice_from_path(
|
225 |
+
path, split_voice[0][1] / total_weight
|
226 |
+
)
|
227 |
+
|
228 |
+
# Loop through each + or - in split_voice so they can be applied to combined voice
|
229 |
+
for operation_index in range(1, len(split_voice) - 1, 2):
|
230 |
+
# Get the voice path of the voice 1 index ahead of the operator
|
231 |
+
path = await self._voice_manager.get_voice_path(
|
232 |
+
split_voice[operation_index + 1][0]
|
233 |
+
)
|
234 |
+
voice_tensor = await self._load_voice_from_path(
|
235 |
+
path, split_voice[operation_index + 1][1] / total_weight
|
236 |
+
)
|
237 |
+
|
238 |
+
# Either add or subtract the voice from the current combined voice
|
239 |
+
if split_voice[operation_index] == "+":
|
240 |
+
combined_tensor += voice_tensor
|
241 |
+
else:
|
242 |
+
combined_tensor -= voice_tensor
|
243 |
+
|
244 |
+
# Save the new combined voice so it can be loaded latter
|
245 |
+
temp_dir = tempfile.gettempdir()
|
246 |
+
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
247 |
+
logger.debug(f"Saving combined voice to: {combined_path}")
|
248 |
+
torch.save(combined_tensor, combined_path)
|
249 |
+
return voice, combined_path
|
250 |
+
except Exception as e:
|
251 |
+
logger.error(f"Failed to get voice path: {e}")
|
252 |
+
raise
|
253 |
+
|
254 |
+
async def generate_audio_stream(
|
255 |
+
self,
|
256 |
+
text: str,
|
257 |
+
voice: str,
|
258 |
+
writer: StreamingAudioWriter,
|
259 |
+
speed: float = 1.0,
|
260 |
+
output_format: str = "wav",
|
261 |
+
lang_code: Optional[str] = None,
|
262 |
+
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
263 |
+
return_timestamps: Optional[bool] = False,
|
264 |
+
) -> AsyncGenerator[AudioChunk, None]:
|
265 |
+
"""Generate and stream audio chunks."""
|
266 |
+
stream_normalizer = AudioNormalizer()
|
267 |
+
chunk_index = 0
|
268 |
+
current_offset = 0.0
|
269 |
+
try:
|
270 |
+
# Get backend
|
271 |
+
backend = self.model_manager.get_backend()
|
272 |
+
|
273 |
+
# Get voice path, handling combined voices
|
274 |
+
voice_name, voice_path = await self._get_voices_path(voice)
|
275 |
+
logger.debug(f"Using voice path: {voice_path}")
|
276 |
+
|
277 |
+
# Use provided lang_code or determine from voice name
|
278 |
+
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
279 |
+
logger.info(
|
280 |
+
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
281 |
+
)
|
282 |
+
|
283 |
+
# Process text in chunks with smart splitting
|
284 |
+
async for chunk_text, tokens in smart_split(
|
285 |
+
text,
|
286 |
+
lang_code=pipeline_lang_code,
|
287 |
+
normalization_options=normalization_options,
|
288 |
+
):
|
289 |
+
try:
|
290 |
+
# Process audio for chunk
|
291 |
+
async for chunk_data in self._process_chunk(
|
292 |
+
chunk_text, # Pass text for Kokoro V1
|
293 |
+
tokens, # Pass tokens for legacy backends
|
294 |
+
voice_name, # Pass voice name
|
295 |
+
voice_path, # Pass voice path
|
296 |
+
speed,
|
297 |
+
writer,
|
298 |
+
output_format,
|
299 |
+
is_first=(chunk_index == 0),
|
300 |
+
is_last=False, # We'll update the last chunk later
|
301 |
+
normalizer=stream_normalizer,
|
302 |
+
lang_code=pipeline_lang_code, # Pass lang_code
|
303 |
+
return_timestamps=return_timestamps,
|
304 |
+
):
|
305 |
+
if chunk_data.word_timestamps is not None:
|
306 |
+
for timestamp in chunk_data.word_timestamps:
|
307 |
+
timestamp.start_time += current_offset
|
308 |
+
timestamp.end_time += current_offset
|
309 |
+
|
310 |
+
current_offset += len(chunk_data.audio) / 24000
|
311 |
+
|
312 |
+
if chunk_data.output is not None:
|
313 |
+
yield chunk_data
|
314 |
+
|
315 |
+
else:
|
316 |
+
logger.warning(
|
317 |
+
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
318 |
+
)
|
319 |
+
chunk_index += 1
|
320 |
+
except Exception as e:
|
321 |
+
logger.error(
|
322 |
+
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
323 |
+
)
|
324 |
+
continue
|
325 |
+
|
326 |
+
# Only finalize if we successfully processed at least one chunk
|
327 |
+
if chunk_index > 0:
|
328 |
+
try:
|
329 |
+
# Empty tokens list to finalize audio
|
330 |
+
async for chunk_data in self._process_chunk(
|
331 |
+
"", # Empty text
|
332 |
+
[], # Empty tokens
|
333 |
+
voice_name,
|
334 |
+
voice_path,
|
335 |
+
speed,
|
336 |
+
writer,
|
337 |
+
output_format,
|
338 |
+
is_first=False,
|
339 |
+
is_last=True, # Signal this is the last chunk
|
340 |
+
normalizer=stream_normalizer,
|
341 |
+
lang_code=pipeline_lang_code, # Pass lang_code
|
342 |
+
):
|
343 |
+
if chunk_data.output is not None:
|
344 |
+
yield chunk_data
|
345 |
+
except Exception as e:
|
346 |
+
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
347 |
+
|
348 |
+
except Exception as e:
|
349 |
+
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
350 |
+
raise e
|
351 |
+
|
352 |
+
async def generate_audio(
|
353 |
+
self,
|
354 |
+
text: str,
|
355 |
+
voice: str,
|
356 |
+
writer: StreamingAudioWriter,
|
357 |
+
speed: float = 1.0,
|
358 |
+
return_timestamps: bool = False,
|
359 |
+
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
360 |
+
lang_code: Optional[str] = None,
|
361 |
+
) -> AudioChunk:
|
362 |
+
"""Generate complete audio for text using streaming internally."""
|
363 |
+
audio_data_chunks = []
|
364 |
+
|
365 |
+
try:
|
366 |
+
async for audio_stream_data in self.generate_audio_stream(
|
367 |
+
text,
|
368 |
+
voice,
|
369 |
+
writer,
|
370 |
+
speed=speed,
|
371 |
+
normalization_options=normalization_options,
|
372 |
+
return_timestamps=return_timestamps,
|
373 |
+
lang_code=lang_code,
|
374 |
+
output_format=None,
|
375 |
+
):
|
376 |
+
if len(audio_stream_data.audio) > 0:
|
377 |
+
audio_data_chunks.append(audio_stream_data)
|
378 |
+
|
379 |
+
combined_audio_data = AudioChunk.combine(audio_data_chunks)
|
380 |
+
return combined_audio_data
|
381 |
+
except Exception as e:
|
382 |
+
logger.error(f"Error in audio generation: {str(e)}")
|
383 |
+
raise
|
384 |
+
|
385 |
+
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
386 |
+
"""Combine multiple voices.
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
Combined voice tensor
|
390 |
+
"""
|
391 |
+
|
392 |
+
return await self._voice_manager.combine_voices(voices)
|
393 |
+
|
394 |
+
async def list_voices(self) -> List[str]:
|
395 |
+
"""List available voices."""
|
396 |
+
return await self._voice_manager.list_voices()
|
397 |
+
|
398 |
+
async def generate_from_phonemes(
|
399 |
+
self,
|
400 |
+
phonemes: str,
|
401 |
+
voice: str,
|
402 |
+
speed: float = 1.0,
|
403 |
+
lang_code: Optional[str] = None,
|
404 |
+
) -> Tuple[np.ndarray, float]:
|
405 |
+
"""Generate audio directly from phonemes.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
phonemes: Phonemes in Kokoro format
|
409 |
+
voice: Voice name
|
410 |
+
speed: Speed multiplier
|
411 |
+
lang_code: Optional language code override
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
Tuple of (audio array, processing time)
|
415 |
+
"""
|
416 |
+
start_time = time.time()
|
417 |
+
try:
|
418 |
+
# Get backend and voice path
|
419 |
+
backend = self.model_manager.get_backend()
|
420 |
+
voice_name, voice_path = await self._get_voices_path(voice)
|
421 |
+
|
422 |
+
if isinstance(backend, KokoroV1):
|
423 |
+
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
424 |
+
result = None
|
425 |
+
# Use provided lang_code or determine from voice name
|
426 |
+
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
427 |
+
logger.info(
|
428 |
+
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
429 |
+
)
|
430 |
+
|
431 |
+
try:
|
432 |
+
# Use backend's pipeline management
|
433 |
+
for r in backend._get_pipeline(
|
434 |
+
pipeline_lang_code
|
435 |
+
).generate_from_tokens(
|
436 |
+
tokens=phonemes, # Pass raw phonemes string
|
437 |
+
voice=voice_path,
|
438 |
+
speed=speed,
|
439 |
+
):
|
440 |
+
if r.audio is not None:
|
441 |
+
result = r
|
442 |
+
break
|
443 |
+
except Exception as e:
|
444 |
+
logger.error(f"Failed to generate from phonemes: {e}")
|
445 |
+
raise RuntimeError(f"Phoneme generation failed: {e}")
|
446 |
+
|
447 |
+
if result is None or result.audio is None:
|
448 |
+
raise ValueError("No audio generated")
|
449 |
+
|
450 |
+
processing_time = time.time() - start_time
|
451 |
+
return result.audio.numpy(), processing_time
|
452 |
+
else:
|
453 |
+
raise ValueError(
|
454 |
+
"Phoneme generation only supported with Kokoro V1 backend"
|
455 |
+
)
|
456 |
+
|
457 |
+
except Exception as e:
|
458 |
+
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
459 |
+
raise
|
api/src/structures/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .schemas import (
|
2 |
+
CaptionedSpeechRequest,
|
3 |
+
CaptionedSpeechResponse,
|
4 |
+
OpenAISpeechRequest,
|
5 |
+
TTSStatus,
|
6 |
+
VoiceCombineRequest,
|
7 |
+
WordTimestamp,
|
8 |
+
)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"OpenAISpeechRequest",
|
12 |
+
"CaptionedSpeechRequest",
|
13 |
+
"CaptionedSpeechResponse",
|
14 |
+
"WordTimestamp",
|
15 |
+
"TTSStatus",
|
16 |
+
"VoiceCombineRequest",
|
17 |
+
]
|
api/src/structures/custom_responses.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import typing
|
3 |
+
from collections.abc import AsyncIterable, Iterable
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from starlette.background import BackgroundTask
|
7 |
+
from starlette.concurrency import iterate_in_threadpool
|
8 |
+
from starlette.responses import JSONResponse, StreamingResponse
|
9 |
+
|
10 |
+
|
11 |
+
class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
12 |
+
"""StreamingResponse that also render with JSON."""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
content: Iterable | AsyncIterable,
|
17 |
+
status_code: int = 200,
|
18 |
+
headers: dict[str, str] | None = None,
|
19 |
+
media_type: str | None = None,
|
20 |
+
background: BackgroundTask | None = None,
|
21 |
+
) -> None:
|
22 |
+
if isinstance(content, AsyncIterable):
|
23 |
+
self._content_iterable: AsyncIterable = content
|
24 |
+
else:
|
25 |
+
self._content_iterable = iterate_in_threadpool(content)
|
26 |
+
|
27 |
+
async def body_iterator() -> AsyncIterable[bytes]:
|
28 |
+
async for content_ in self._content_iterable:
|
29 |
+
if isinstance(content_, BaseModel):
|
30 |
+
content_ = content_.model_dump()
|
31 |
+
yield self.render(content_)
|
32 |
+
|
33 |
+
self.body_iterator = body_iterator()
|
34 |
+
self.status_code = status_code
|
35 |
+
if media_type is not None:
|
36 |
+
self.media_type = media_type
|
37 |
+
self.background = background
|
38 |
+
self.init_headers(headers)
|
39 |
+
|
40 |
+
def render(self, content: typing.Any) -> bytes:
|
41 |
+
return (
|
42 |
+
json.dumps(
|
43 |
+
content,
|
44 |
+
ensure_ascii=False,
|
45 |
+
allow_nan=False,
|
46 |
+
indent=None,
|
47 |
+
separators=(",", ":"),
|
48 |
+
)
|
49 |
+
+ "\n"
|
50 |
+
).encode("utf-8")
|
api/src/structures/model_schemas.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Voice configuration schemas."""
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
|
5 |
+
|
6 |
+
class VoiceConfig(BaseModel):
|
7 |
+
"""Voice configuration."""
|
8 |
+
|
9 |
+
use_cache: bool = Field(True, description="Whether to cache loaded voices")
|
10 |
+
cache_size: int = Field(3, description="Number of voices to cache")
|
11 |
+
validate_on_load: bool = Field(
|
12 |
+
True, description="Whether to validate voices when loading"
|
13 |
+
)
|
14 |
+
|
15 |
+
class Config:
|
16 |
+
frozen = True # Make config immutable
|
api/src/structures/schemas.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from typing import List, Literal, Optional, Union
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
|
7 |
+
class VoiceCombineRequest(BaseModel):
|
8 |
+
"""Request schema for voice combination endpoint that accepts either a string with + or a list"""
|
9 |
+
|
10 |
+
voices: Union[str, List[str]] = Field(
|
11 |
+
...,
|
12 |
+
description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine",
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
class TTSStatus(str, Enum):
|
17 |
+
PENDING = "pending"
|
18 |
+
PROCESSING = "processing"
|
19 |
+
COMPLETED = "completed"
|
20 |
+
FAILED = "failed"
|
21 |
+
DELETED = "deleted" # For files removed by cleanup
|
22 |
+
|
23 |
+
|
24 |
+
# OpenAI-compatible schemas
|
25 |
+
class WordTimestamp(BaseModel):
|
26 |
+
"""Word-level timestamp information"""
|
27 |
+
|
28 |
+
word: str = Field(..., description="The word or token")
|
29 |
+
start_time: float = Field(..., description="Start time in seconds")
|
30 |
+
end_time: float = Field(..., description="End time in seconds")
|
31 |
+
|
32 |
+
|
33 |
+
class CaptionedSpeechResponse(BaseModel):
|
34 |
+
"""Response schema for captioned speech endpoint"""
|
35 |
+
|
36 |
+
audio: str = Field(..., description="The generated audio data encoded in base 64")
|
37 |
+
audio_format: str = Field(..., description="The format of the output audio")
|
38 |
+
timestamps: Optional[List[WordTimestamp]] = Field(
|
39 |
+
..., description="Word-level timestamps"
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class NormalizationOptions(BaseModel):
|
44 |
+
"""Options for the normalization system"""
|
45 |
+
|
46 |
+
normalize: bool = Field(
|
47 |
+
default=True,
|
48 |
+
description="Normalizes input text to make it easier for the model to say",
|
49 |
+
)
|
50 |
+
unit_normalization: bool = Field(
|
51 |
+
default=False, description="Transforms units like 10KB to 10 kilobytes"
|
52 |
+
)
|
53 |
+
url_normalization: bool = Field(
|
54 |
+
default=True,
|
55 |
+
description="Changes urls so they can be properly pronounced by kokoro",
|
56 |
+
)
|
57 |
+
email_normalization: bool = Field(
|
58 |
+
default=True,
|
59 |
+
description="Changes emails so they can be properly pronouced by kokoro",
|
60 |
+
)
|
61 |
+
optional_pluralization_normalization: bool = Field(
|
62 |
+
default=True,
|
63 |
+
description="Replaces (s) with s so some words get pronounced correctly",
|
64 |
+
)
|
65 |
+
phone_normalization: bool = Field(
|
66 |
+
default=True,
|
67 |
+
description="Changes phone numbers so they can be properly pronouced by kokoro",
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
class OpenAISpeechRequest(BaseModel):
|
72 |
+
"""Request schema for OpenAI-compatible speech endpoint"""
|
73 |
+
|
74 |
+
model: str = Field(
|
75 |
+
default="kokoro",
|
76 |
+
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro",
|
77 |
+
)
|
78 |
+
input: str = Field(..., description="The text to generate audio for")
|
79 |
+
voice: str = Field(
|
80 |
+
default="af_heart",
|
81 |
+
description="The voice to use for generation. Can be a base voice or a combined voice name.",
|
82 |
+
)
|
83 |
+
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
84 |
+
default="mp3",
|
85 |
+
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
|
86 |
+
)
|
87 |
+
download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = (
|
88 |
+
Field(
|
89 |
+
default=None,
|
90 |
+
description="Optional different format for the final download. If not provided, uses response_format.",
|
91 |
+
)
|
92 |
+
)
|
93 |
+
speed: float = Field(
|
94 |
+
default=1.0,
|
95 |
+
ge=0.25,
|
96 |
+
le=4.0,
|
97 |
+
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
|
98 |
+
)
|
99 |
+
stream: bool = Field(
|
100 |
+
default=True, # Default to streaming for OpenAI compatibility
|
101 |
+
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
|
102 |
+
)
|
103 |
+
return_download_link: bool = Field(
|
104 |
+
default=False,
|
105 |
+
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
106 |
+
)
|
107 |
+
lang_code: Optional[str] = Field(
|
108 |
+
default=None,
|
109 |
+
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
110 |
+
)
|
111 |
+
normalization_options: Optional[NormalizationOptions] = Field(
|
112 |
+
default=NormalizationOptions(),
|
113 |
+
description="Options for the normalization system",
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
class CaptionedSpeechRequest(BaseModel):
|
118 |
+
"""Request schema for captioned speech endpoint"""
|
119 |
+
|
120 |
+
model: str = Field(
|
121 |
+
default="kokoro",
|
122 |
+
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro",
|
123 |
+
)
|
124 |
+
input: str = Field(..., description="The text to generate audio for")
|
125 |
+
voice: str = Field(
|
126 |
+
default="af_heart",
|
127 |
+
description="The voice to use for generation. Can be a base voice or a combined voice name.",
|
128 |
+
)
|
129 |
+
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
130 |
+
default="mp3",
|
131 |
+
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
|
132 |
+
)
|
133 |
+
speed: float = Field(
|
134 |
+
default=1.0,
|
135 |
+
ge=0.25,
|
136 |
+
le=4.0,
|
137 |
+
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
|
138 |
+
)
|
139 |
+
stream: bool = Field(
|
140 |
+
default=True, # Default to streaming for OpenAI compatibility
|
141 |
+
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
|
142 |
+
)
|
143 |
+
return_timestamps: bool = Field(
|
144 |
+
default=True,
|
145 |
+
description="If true (default), returns word-level timestamps in the response",
|
146 |
+
)
|
147 |
+
return_download_link: bool = Field(
|
148 |
+
default=False,
|
149 |
+
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
150 |
+
)
|
151 |
+
lang_code: Optional[str] = Field(
|
152 |
+
default=None,
|
153 |
+
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
154 |
+
)
|
155 |
+
normalization_options: Optional[NormalizationOptions] = Field(
|
156 |
+
default=NormalizationOptions(),
|
157 |
+
description="Options for the normalization system",
|
158 |
+
)
|
api/src/structures/text_schemas.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field, field_validator
|
4 |
+
|
5 |
+
|
6 |
+
class PhonemeRequest(BaseModel):
|
7 |
+
text: str
|
8 |
+
language: str = "a" # Default to American English
|
9 |
+
|
10 |
+
|
11 |
+
class PhonemeResponse(BaseModel):
|
12 |
+
phonemes: str
|
13 |
+
tokens: list[int]
|
14 |
+
|
15 |
+
|
16 |
+
class StitchOptions(BaseModel):
|
17 |
+
"""Options for stitching audio chunks together"""
|
18 |
+
|
19 |
+
gap_method: str = Field(
|
20 |
+
default="static_trim",
|
21 |
+
description="Method to handle gaps between chunks. Currently only 'static_trim' supported.",
|
22 |
+
)
|
23 |
+
trim_ms: int = Field(
|
24 |
+
default=0,
|
25 |
+
ge=0,
|
26 |
+
description="Milliseconds to trim from chunk boundaries when using static_trim",
|
27 |
+
)
|
28 |
+
|
29 |
+
@field_validator("gap_method")
|
30 |
+
@classmethod
|
31 |
+
def validate_gap_method(cls, v: str) -> str:
|
32 |
+
if v != "static_trim":
|
33 |
+
raise ValueError("Currently only 'static_trim' gap method is supported")
|
34 |
+
return v
|
35 |
+
|
36 |
+
|
37 |
+
class GenerateFromPhonemesRequest(BaseModel):
|
38 |
+
"""Simple request for phoneme-to-speech generation"""
|
39 |
+
|
40 |
+
phonemes: str = Field(..., description="Phoneme string to synthesize")
|
41 |
+
voice: str = Field(..., description="Voice ID to use for generation")
|
api/tests/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Make tests directory a Python package
|
api/tests/conftest.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pytest
|
7 |
+
import pytest_asyncio
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from api.src.inference.model_manager import ModelManager
|
11 |
+
from api.src.inference.voice_manager import VoiceManager
|
12 |
+
from api.src.services.tts_service import TTSService
|
13 |
+
from api.src.structures.model_schemas import VoiceConfig
|
14 |
+
|
15 |
+
|
16 |
+
@pytest.fixture
|
17 |
+
def mock_voice_tensor():
|
18 |
+
"""Load a real voice tensor for testing."""
|
19 |
+
voice_path = os.path.join(
|
20 |
+
os.path.dirname(os.path.dirname(__file__)), "src/voices/af_bella.pt"
|
21 |
+
)
|
22 |
+
return torch.load(voice_path, map_location="cpu", weights_only=False)
|
23 |
+
|
24 |
+
|
25 |
+
@pytest.fixture
|
26 |
+
def mock_audio_output():
|
27 |
+
"""Load pre-generated test audio for consistent testing."""
|
28 |
+
test_audio_path = os.path.join(
|
29 |
+
os.path.dirname(__file__), "test_data/test_audio.npy"
|
30 |
+
)
|
31 |
+
return np.load(test_audio_path) # Return as numpy array instead of bytes
|
32 |
+
|
33 |
+
|
34 |
+
@pytest_asyncio.fixture
|
35 |
+
async def mock_model_manager(mock_audio_output):
|
36 |
+
"""Mock model manager for testing."""
|
37 |
+
manager = AsyncMock(spec=ModelManager)
|
38 |
+
manager.get_backend = MagicMock()
|
39 |
+
|
40 |
+
async def mock_generate(*args, **kwargs):
|
41 |
+
# Simulate successful audio generation
|
42 |
+
return np.random.rand(24000).astype(np.float32) # 1 second of random audio data
|
43 |
+
|
44 |
+
manager.generate = AsyncMock(side_effect=mock_generate)
|
45 |
+
return manager
|
46 |
+
|
47 |
+
|
48 |
+
@pytest_asyncio.fixture
|
49 |
+
async def mock_voice_manager(mock_voice_tensor):
|
50 |
+
"""Mock voice manager for testing."""
|
51 |
+
manager = AsyncMock(spec=VoiceManager)
|
52 |
+
manager.get_voice_path = MagicMock(return_value="/mock/path/voice.pt")
|
53 |
+
manager.load_voice = AsyncMock(return_value=mock_voice_tensor)
|
54 |
+
manager.list_voices = AsyncMock(return_value=["voice1", "voice2"])
|
55 |
+
manager.combine_voices = AsyncMock(return_value="voice1_voice2")
|
56 |
+
return manager
|
57 |
+
|
58 |
+
|
59 |
+
@pytest_asyncio.fixture
|
60 |
+
async def tts_service(mock_model_manager, mock_voice_manager):
|
61 |
+
"""Get mocked TTS service instance."""
|
62 |
+
service = TTSService()
|
63 |
+
service.model_manager = mock_model_manager
|
64 |
+
service._voice_manager = mock_voice_manager
|
65 |
+
return service
|
66 |
+
|
67 |
+
|
68 |
+
@pytest.fixture
|
69 |
+
def test_voice():
|
70 |
+
"""Return a test voice name."""
|
71 |
+
return "voice1"
|
api/tests/test_audio_service.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tests for AudioService"""
|
2 |
+
|
3 |
+
from unittest.mock import patch
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pytest
|
7 |
+
|
8 |
+
from api.src.inference.base import AudioChunk
|
9 |
+
from api.src.services.audio import AudioNormalizer, AudioService
|
10 |
+
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
11 |
+
|
12 |
+
|
13 |
+
@pytest.fixture(autouse=True)
|
14 |
+
def mock_settings():
|
15 |
+
"""Mock settings for all tests"""
|
16 |
+
with patch("api.src.services.audio.settings") as mock_settings:
|
17 |
+
mock_settings.gap_trim_ms = 250
|
18 |
+
yield mock_settings
|
19 |
+
|
20 |
+
|
21 |
+
@pytest.fixture
|
22 |
+
def sample_audio():
|
23 |
+
"""Generate a simple sine wave for testing"""
|
24 |
+
sample_rate = 24000
|
25 |
+
duration = 0.1 # 100ms
|
26 |
+
t = np.linspace(0, duration, int(sample_rate * duration))
|
27 |
+
frequency = 440 # A4 note
|
28 |
+
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
|
29 |
+
|
30 |
+
|
31 |
+
@pytest.mark.asyncio
|
32 |
+
async def test_convert_to_wav(sample_audio):
|
33 |
+
"""Test converting to WAV format"""
|
34 |
+
audio_data, sample_rate = sample_audio
|
35 |
+
|
36 |
+
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
37 |
+
# Write and finalize in one step for WAV
|
38 |
+
audio_chunk = await AudioService.convert_audio(
|
39 |
+
AudioChunk(audio_data), "wav", writer, is_last_chunk=False
|
40 |
+
)
|
41 |
+
|
42 |
+
writer.close()
|
43 |
+
|
44 |
+
assert isinstance(audio_chunk.output, bytes)
|
45 |
+
assert isinstance(audio_chunk, AudioChunk)
|
46 |
+
assert len(audio_chunk.output) > 0
|
47 |
+
# Check WAV header
|
48 |
+
assert audio_chunk.output.startswith(b"RIFF")
|
49 |
+
assert b"WAVE" in audio_chunk.output[:12]
|
50 |
+
|
51 |
+
|
52 |
+
@pytest.mark.asyncio
|
53 |
+
async def test_convert_to_mp3(sample_audio):
|
54 |
+
"""Test converting to MP3 format"""
|
55 |
+
audio_data, sample_rate = sample_audio
|
56 |
+
|
57 |
+
writer = StreamingAudioWriter("mp3", sample_rate=24000)
|
58 |
+
|
59 |
+
audio_chunk = await AudioService.convert_audio(
|
60 |
+
AudioChunk(audio_data), "mp3", writer
|
61 |
+
)
|
62 |
+
|
63 |
+
writer.close()
|
64 |
+
|
65 |
+
assert isinstance(audio_chunk.output, bytes)
|
66 |
+
assert isinstance(audio_chunk, AudioChunk)
|
67 |
+
assert len(audio_chunk.output) > 0
|
68 |
+
# Check MP3 header (ID3 or MPEG frame sync)
|
69 |
+
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(
|
70 |
+
b"\xff\xfb"
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
@pytest.mark.asyncio
|
75 |
+
async def test_convert_to_opus(sample_audio):
|
76 |
+
"""Test converting to Opus format"""
|
77 |
+
|
78 |
+
audio_data, sample_rate = sample_audio
|
79 |
+
|
80 |
+
writer = StreamingAudioWriter("opus", sample_rate=24000)
|
81 |
+
|
82 |
+
audio_chunk = await AudioService.convert_audio(
|
83 |
+
AudioChunk(audio_data), "opus", writer
|
84 |
+
)
|
85 |
+
|
86 |
+
writer.close()
|
87 |
+
|
88 |
+
assert isinstance(audio_chunk.output, bytes)
|
89 |
+
assert isinstance(audio_chunk, AudioChunk)
|
90 |
+
assert len(audio_chunk.output) > 0
|
91 |
+
# Check OGG header
|
92 |
+
assert audio_chunk.output.startswith(b"OggS")
|
93 |
+
|
94 |
+
|
95 |
+
@pytest.mark.asyncio
|
96 |
+
async def test_convert_to_flac(sample_audio):
|
97 |
+
"""Test converting to FLAC format"""
|
98 |
+
audio_data, sample_rate = sample_audio
|
99 |
+
|
100 |
+
writer = StreamingAudioWriter("flac", sample_rate=24000)
|
101 |
+
|
102 |
+
audio_chunk = await AudioService.convert_audio(
|
103 |
+
AudioChunk(audio_data), "flac", writer
|
104 |
+
)
|
105 |
+
|
106 |
+
writer.close()
|
107 |
+
|
108 |
+
assert isinstance(audio_chunk.output, bytes)
|
109 |
+
assert isinstance(audio_chunk, AudioChunk)
|
110 |
+
assert len(audio_chunk.output) > 0
|
111 |
+
# Check FLAC header
|
112 |
+
assert audio_chunk.output.startswith(b"fLaC")
|
113 |
+
|
114 |
+
|
115 |
+
@pytest.mark.asyncio
|
116 |
+
async def test_convert_to_aac(sample_audio):
|
117 |
+
"""Test converting to M4A format"""
|
118 |
+
audio_data, sample_rate = sample_audio
|
119 |
+
|
120 |
+
writer = StreamingAudioWriter("aac", sample_rate=24000)
|
121 |
+
|
122 |
+
audio_chunk = await AudioService.convert_audio(
|
123 |
+
AudioChunk(audio_data), "aac", writer
|
124 |
+
)
|
125 |
+
|
126 |
+
writer.close()
|
127 |
+
|
128 |
+
assert isinstance(audio_chunk.output, bytes)
|
129 |
+
assert isinstance(audio_chunk, AudioChunk)
|
130 |
+
assert len(audio_chunk.output) > 0
|
131 |
+
# Check ADTS header (AAC)
|
132 |
+
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(
|
133 |
+
b"\xff\xf1"
|
134 |
+
)
|
135 |
+
|
136 |
+
|
137 |
+
@pytest.mark.asyncio
|
138 |
+
async def test_convert_to_pcm(sample_audio):
|
139 |
+
"""Test converting to PCM format"""
|
140 |
+
audio_data, sample_rate = sample_audio
|
141 |
+
|
142 |
+
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
143 |
+
|
144 |
+
audio_chunk = await AudioService.convert_audio(
|
145 |
+
AudioChunk(audio_data), "pcm", writer
|
146 |
+
)
|
147 |
+
|
148 |
+
writer.close()
|
149 |
+
|
150 |
+
assert isinstance(audio_chunk.output, bytes)
|
151 |
+
assert isinstance(audio_chunk, AudioChunk)
|
152 |
+
assert len(audio_chunk.output) > 0
|
153 |
+
# PCM is raw bytes, so no header to check
|
154 |
+
|
155 |
+
|
156 |
+
@pytest.mark.asyncio
|
157 |
+
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
158 |
+
"""Test that converting to an invalid format raises an error"""
|
159 |
+
# audio_data, sample_rate = sample_audio
|
160 |
+
with pytest.raises(ValueError, match="Unsupported format: invalid"):
|
161 |
+
writer = StreamingAudioWriter("invalid", sample_rate=24000)
|
162 |
+
|
163 |
+
|
164 |
+
@pytest.mark.asyncio
|
165 |
+
async def test_normalization_wav(sample_audio):
|
166 |
+
"""Test that WAV output is properly normalized to int16 range"""
|
167 |
+
audio_data, sample_rate = sample_audio
|
168 |
+
|
169 |
+
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
170 |
+
|
171 |
+
# Create audio data outside int16 range
|
172 |
+
large_audio = audio_data * 1e5
|
173 |
+
# Write and finalize in one step for WAV
|
174 |
+
audio_chunk = await AudioService.convert_audio(
|
175 |
+
AudioChunk(large_audio), "wav", writer
|
176 |
+
)
|
177 |
+
|
178 |
+
writer.close()
|
179 |
+
|
180 |
+
assert isinstance(audio_chunk.output, bytes)
|
181 |
+
assert isinstance(audio_chunk, AudioChunk)
|
182 |
+
assert len(audio_chunk.output) > 0
|
183 |
+
|
184 |
+
|
185 |
+
@pytest.mark.asyncio
|
186 |
+
async def test_normalization_pcm(sample_audio):
|
187 |
+
"""Test that PCM output is properly normalized to int16 range"""
|
188 |
+
audio_data, sample_rate = sample_audio
|
189 |
+
|
190 |
+
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
191 |
+
|
192 |
+
# Create audio data outside int16 range
|
193 |
+
large_audio = audio_data * 1e5
|
194 |
+
audio_chunk = await AudioService.convert_audio(
|
195 |
+
AudioChunk(large_audio), "pcm", writer
|
196 |
+
)
|
197 |
+
assert isinstance(audio_chunk.output, bytes)
|
198 |
+
assert isinstance(audio_chunk, AudioChunk)
|
199 |
+
assert len(audio_chunk.output) > 0
|
200 |
+
|
201 |
+
|
202 |
+
@pytest.mark.asyncio
|
203 |
+
async def test_invalid_audio_data():
|
204 |
+
"""Test handling of invalid audio data"""
|
205 |
+
invalid_audio = np.array([]) # Empty array
|
206 |
+
sample_rate = 24000
|
207 |
+
|
208 |
+
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
209 |
+
|
210 |
+
with pytest.raises(ValueError):
|
211 |
+
await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer)
|
212 |
+
|
213 |
+
|
214 |
+
@pytest.mark.asyncio
|
215 |
+
async def test_different_sample_rates(sample_audio):
|
216 |
+
"""Test converting audio with different sample rates"""
|
217 |
+
audio_data, _ = sample_audio
|
218 |
+
sample_rates = [8000, 16000, 44100, 48000]
|
219 |
+
|
220 |
+
for rate in sample_rates:
|
221 |
+
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
222 |
+
|
223 |
+
audio_chunk = await AudioService.convert_audio(
|
224 |
+
AudioChunk(audio_data), "wav", writer
|
225 |
+
)
|
226 |
+
|
227 |
+
writer.close()
|
228 |
+
|
229 |
+
assert isinstance(audio_chunk.output, bytes)
|
230 |
+
assert isinstance(audio_chunk, AudioChunk)
|
231 |
+
assert len(audio_chunk.output) > 0
|
232 |
+
|
233 |
+
|
234 |
+
@pytest.mark.asyncio
|
235 |
+
async def test_buffer_position_after_conversion(sample_audio):
|
236 |
+
"""Test that buffer position is reset after writing"""
|
237 |
+
audio_data, sample_rate = sample_audio
|
238 |
+
|
239 |
+
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
240 |
+
|
241 |
+
# Write and finalize in one step for first conversion
|
242 |
+
audio_chunk1 = await AudioService.convert_audio(
|
243 |
+
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
|
244 |
+
)
|
245 |
+
assert isinstance(audio_chunk1.output, bytes)
|
246 |
+
assert isinstance(audio_chunk1, AudioChunk)
|
247 |
+
# Convert again to ensure buffer was properly reset
|
248 |
+
|
249 |
+
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
250 |
+
|
251 |
+
audio_chunk2 = await AudioService.convert_audio(
|
252 |
+
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
|
253 |
+
)
|
254 |
+
assert isinstance(audio_chunk2.output, bytes)
|
255 |
+
assert isinstance(audio_chunk2, AudioChunk)
|
256 |
+
assert len(audio_chunk1.output) == len(audio_chunk2.output)
|
api/tests/test_data/generate_test_data.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def generate_test_audio():
|
7 |
+
"""Generate test audio data - 1 second of 440Hz tone"""
|
8 |
+
# Create 1 second of silence at 24kHz
|
9 |
+
audio = np.zeros(24000, dtype=np.float32)
|
10 |
+
|
11 |
+
# Add a simple sine wave to make it non-zero
|
12 |
+
t = np.linspace(0, 1, 24000)
|
13 |
+
audio += 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone at half amplitude
|
14 |
+
|
15 |
+
# Create test_data directory if it doesn't exist
|
16 |
+
os.makedirs("api/tests/test_data", exist_ok=True)
|
17 |
+
|
18 |
+
# Save the test audio
|
19 |
+
np.save("api/tests/test_data/test_audio.npy", audio)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
generate_test_audio()
|
api/tests/test_data/test_audio.npy
ADDED
Binary file (96.1 kB). View file
|
|
api/tests/test_development.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
from unittest.mock import MagicMock, patch
|
4 |
+
|
5 |
+
import pytest
|
6 |
+
import requests
|
7 |
+
|
8 |
+
|
9 |
+
def test_generate_captioned_speech():
|
10 |
+
"""Test the generate_captioned_speech function with mocked responses"""
|
11 |
+
# Mock the API responses
|
12 |
+
mock_audio_response = MagicMock()
|
13 |
+
mock_audio_response.status_code = 200
|
14 |
+
|
15 |
+
mock_timestamps_response = MagicMock()
|
16 |
+
mock_timestamps_response.status_code = 200
|
17 |
+
mock_timestamps_response.content = json.dumps(
|
18 |
+
{
|
19 |
+
"audio": base64.b64encode(b"mock audio data").decode("utf-8"),
|
20 |
+
"timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}],
|
21 |
+
}
|
22 |
+
)
|
23 |
+
|
24 |
+
# Patch the HTTP requests
|
25 |
+
with patch("requests.post", return_value=mock_timestamps_response):
|
26 |
+
# Import here to avoid module-level import issues
|
27 |
+
from examples.captioned_speech_example import generate_captioned_speech
|
28 |
+
|
29 |
+
# Test the function
|
30 |
+
audio, timestamps = generate_captioned_speech("test text")
|
31 |
+
|
32 |
+
# Verify we got both audio and timestamps
|
33 |
+
assert audio == b"mock audio data"
|
34 |
+
assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]
|
api/tests/test_kokoro_v1.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest.mock import ANY, MagicMock, patch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from api.src.inference.kokoro_v1 import KokoroV1
|
8 |
+
|
9 |
+
|
10 |
+
@pytest.fixture
|
11 |
+
def kokoro_backend():
|
12 |
+
"""Create a KokoroV1 instance for testing."""
|
13 |
+
return KokoroV1()
|
14 |
+
|
15 |
+
|
16 |
+
def test_initial_state(kokoro_backend):
|
17 |
+
"""Test initial state of KokoroV1."""
|
18 |
+
assert not kokoro_backend.is_loaded
|
19 |
+
assert kokoro_backend._model is None
|
20 |
+
assert kokoro_backend._pipelines == {} # Now using dict of pipelines
|
21 |
+
# Device should be set based on settings
|
22 |
+
assert kokoro_backend.device in ["cuda", "cpu"]
|
23 |
+
|
24 |
+
|
25 |
+
@patch("torch.cuda.is_available", return_value=True)
|
26 |
+
@patch("torch.cuda.memory_allocated", return_value=5e9)
|
27 |
+
def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
|
28 |
+
"""Test GPU memory management functions."""
|
29 |
+
# Patch backend so it thinks we have cuda
|
30 |
+
with patch.object(kokoro_backend, "_device", "cuda"):
|
31 |
+
# Test memory check
|
32 |
+
with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
|
33 |
+
mock_config.pytorch_gpu.memory_threshold = 4
|
34 |
+
assert kokoro_backend._check_memory() == True
|
35 |
+
|
36 |
+
mock_config.pytorch_gpu.memory_threshold = 6
|
37 |
+
assert kokoro_backend._check_memory() == False
|
38 |
+
|
39 |
+
|
40 |
+
@patch("torch.cuda.empty_cache")
|
41 |
+
@patch("torch.cuda.synchronize")
|
42 |
+
def test_clear_memory(mock_sync, mock_clear, kokoro_backend):
|
43 |
+
"""Test memory clearing."""
|
44 |
+
with patch.object(kokoro_backend, "_device", "cuda"):
|
45 |
+
kokoro_backend._clear_memory()
|
46 |
+
mock_clear.assert_called_once()
|
47 |
+
mock_sync.assert_called_once()
|
48 |
+
|
49 |
+
|
50 |
+
@pytest.mark.asyncio
|
51 |
+
async def test_load_model_validation(kokoro_backend):
|
52 |
+
"""Test model loading validation."""
|
53 |
+
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
|
54 |
+
await kokoro_backend.load_model("nonexistent_model.pth")
|
55 |
+
|
56 |
+
|
57 |
+
def test_unload_with_pipelines(kokoro_backend):
|
58 |
+
"""Test model unloading with multiple pipelines."""
|
59 |
+
# Mock loaded state with multiple pipelines
|
60 |
+
kokoro_backend._model = MagicMock()
|
61 |
+
pipeline_a = MagicMock()
|
62 |
+
pipeline_e = MagicMock()
|
63 |
+
kokoro_backend._pipelines = {"a": pipeline_a, "e": pipeline_e}
|
64 |
+
assert kokoro_backend.is_loaded
|
65 |
+
|
66 |
+
# Test unload
|
67 |
+
kokoro_backend.unload()
|
68 |
+
assert not kokoro_backend.is_loaded
|
69 |
+
assert kokoro_backend._model is None
|
70 |
+
assert kokoro_backend._pipelines == {} # All pipelines should be cleared
|
71 |
+
|
72 |
+
|
73 |
+
@pytest.mark.asyncio
|
74 |
+
async def test_generate_validation(kokoro_backend):
|
75 |
+
"""Test generation validation."""
|
76 |
+
with pytest.raises(RuntimeError, match="Model not loaded"):
|
77 |
+
async for _ in kokoro_backend.generate("test", "voice"):
|
78 |
+
pass
|
79 |
+
|
80 |
+
|
81 |
+
@pytest.mark.asyncio
|
82 |
+
async def test_generate_from_tokens_validation(kokoro_backend):
|
83 |
+
"""Test token generation validation."""
|
84 |
+
with pytest.raises(RuntimeError, match="Model not loaded"):
|
85 |
+
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
|
86 |
+
pass
|
87 |
+
|
88 |
+
|
89 |
+
def test_get_pipeline_creates_new(kokoro_backend):
|
90 |
+
"""Test that _get_pipeline creates new pipeline for new language code."""
|
91 |
+
# Mock loaded state
|
92 |
+
kokoro_backend._model = MagicMock()
|
93 |
+
|
94 |
+
# Mock KPipeline
|
95 |
+
mock_pipeline = MagicMock()
|
96 |
+
with patch(
|
97 |
+
"api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline
|
98 |
+
) as mock_kpipeline:
|
99 |
+
# Get pipeline for Spanish
|
100 |
+
pipeline_e = kokoro_backend._get_pipeline("e")
|
101 |
+
|
102 |
+
# Should create new pipeline with correct params
|
103 |
+
mock_kpipeline.assert_called_once_with(
|
104 |
+
lang_code="e", model=kokoro_backend._model, device=kokoro_backend._device
|
105 |
+
)
|
106 |
+
assert pipeline_e == mock_pipeline
|
107 |
+
assert kokoro_backend._pipelines["e"] == mock_pipeline
|
108 |
+
|
109 |
+
|
110 |
+
def test_get_pipeline_reuses_existing(kokoro_backend):
|
111 |
+
"""Test that _get_pipeline reuses existing pipeline for same language code."""
|
112 |
+
# Mock loaded state
|
113 |
+
kokoro_backend._model = MagicMock()
|
114 |
+
|
115 |
+
# Mock KPipeline
|
116 |
+
mock_pipeline = MagicMock()
|
117 |
+
with patch(
|
118 |
+
"api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline
|
119 |
+
) as mock_kpipeline:
|
120 |
+
# Get pipeline twice for same language
|
121 |
+
pipeline1 = kokoro_backend._get_pipeline("e")
|
122 |
+
pipeline2 = kokoro_backend._get_pipeline("e")
|
123 |
+
|
124 |
+
# Should only create pipeline once
|
125 |
+
mock_kpipeline.assert_called_once()
|
126 |
+
assert pipeline1 == pipeline2
|
127 |
+
assert kokoro_backend._pipelines["e"] == mock_pipeline
|
128 |
+
|
129 |
+
|
130 |
+
@pytest.mark.asyncio
|
131 |
+
async def test_generate_uses_correct_pipeline(kokoro_backend):
|
132 |
+
"""Test that generate uses correct pipeline for language code."""
|
133 |
+
# Mock loaded state
|
134 |
+
kokoro_backend._model = MagicMock()
|
135 |
+
|
136 |
+
# Mock voice path handling
|
137 |
+
with (
|
138 |
+
patch("api.src.core.paths.load_voice_tensor") as mock_load_voice,
|
139 |
+
patch("api.src.core.paths.save_voice_tensor"),
|
140 |
+
patch("tempfile.gettempdir") as mock_tempdir,
|
141 |
+
):
|
142 |
+
mock_load_voice.return_value = torch.ones(1)
|
143 |
+
mock_tempdir.return_value = "/tmp"
|
144 |
+
|
145 |
+
# Mock KPipeline
|
146 |
+
mock_pipeline = MagicMock()
|
147 |
+
mock_pipeline.return_value = iter([]) # Empty generator for testing
|
148 |
+
with patch("api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline):
|
149 |
+
# Generate with Spanish voice and explicit lang_code
|
150 |
+
async for _ in kokoro_backend.generate("test", "ef_voice", lang_code="e"):
|
151 |
+
pass
|
152 |
+
|
153 |
+
# Should create pipeline with Spanish lang_code
|
154 |
+
assert "e" in kokoro_backend._pipelines
|
155 |
+
# Use ANY to match the temp file path since it's dynamic
|
156 |
+
mock_pipeline.assert_called_with(
|
157 |
+
"test",
|
158 |
+
voice=ANY, # Don't check exact path since it's dynamic
|
159 |
+
speed=1.0,
|
160 |
+
model=kokoro_backend._model,
|
161 |
+
)
|
162 |
+
# Verify the voice path is a temp file path
|
163 |
+
call_args = mock_pipeline.call_args
|
164 |
+
assert isinstance(call_args[1]["voice"], str)
|
165 |
+
assert call_args[1]["voice"].startswith("/tmp/temp_voice_")
|
api/tests/test_normalizer.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tests for text normalization service"""
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
|
5 |
+
from api.src.services.text_processing.normalizer import normalize_text
|
6 |
+
from api.src.structures.schemas import NormalizationOptions
|
7 |
+
|
8 |
+
|
9 |
+
def test_url_protocols():
|
10 |
+
"""Test URL protocol handling"""
|
11 |
+
assert (
|
12 |
+
normalize_text(
|
13 |
+
"Check out https://example.com",
|
14 |
+
normalization_options=NormalizationOptions(),
|
15 |
+
)
|
16 |
+
== "Check out https example dot com"
|
17 |
+
)
|
18 |
+
assert (
|
19 |
+
normalize_text(
|
20 |
+
"Visit http://site.com", normalization_options=NormalizationOptions()
|
21 |
+
)
|
22 |
+
== "Visit http site dot com"
|
23 |
+
)
|
24 |
+
assert (
|
25 |
+
normalize_text(
|
26 |
+
"Go to https://test.org/path", normalization_options=NormalizationOptions()
|
27 |
+
)
|
28 |
+
== "Go to https test dot org slash path"
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def test_url_www():
|
33 |
+
"""Test www prefix handling"""
|
34 |
+
assert (
|
35 |
+
normalize_text(
|
36 |
+
"Go to www.example.com", normalization_options=NormalizationOptions()
|
37 |
+
)
|
38 |
+
== "Go to www example dot com"
|
39 |
+
)
|
40 |
+
assert (
|
41 |
+
normalize_text(
|
42 |
+
"Visit www.test.org/docs", normalization_options=NormalizationOptions()
|
43 |
+
)
|
44 |
+
== "Visit www test dot org slash docs"
|
45 |
+
)
|
46 |
+
assert (
|
47 |
+
normalize_text(
|
48 |
+
"Check www.site.com?q=test", normalization_options=NormalizationOptions()
|
49 |
+
)
|
50 |
+
== "Check www site dot com question-mark q equals test"
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def test_url_localhost():
|
55 |
+
"""Test localhost URL handling"""
|
56 |
+
assert (
|
57 |
+
normalize_text(
|
58 |
+
"Running on localhost:7860", normalization_options=NormalizationOptions()
|
59 |
+
)
|
60 |
+
== "Running on localhost colon 78 60"
|
61 |
+
)
|
62 |
+
assert (
|
63 |
+
normalize_text(
|
64 |
+
"Server at localhost:8080/api", normalization_options=NormalizationOptions()
|
65 |
+
)
|
66 |
+
== "Server at localhost colon 80 80 slash api"
|
67 |
+
)
|
68 |
+
assert (
|
69 |
+
normalize_text(
|
70 |
+
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
|
71 |
+
)
|
72 |
+
== "Test localhost colon 3000 slash test question-mark v equals 1"
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def test_url_ip_addresses():
|
77 |
+
"""Test IP address URL handling"""
|
78 |
+
assert (
|
79 |
+
normalize_text(
|
80 |
+
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
|
81 |
+
)
|
82 |
+
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
83 |
+
)
|
84 |
+
assert (
|
85 |
+
normalize_text(
|
86 |
+
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
|
87 |
+
)
|
88 |
+
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
89 |
+
)
|
90 |
+
assert (
|
91 |
+
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
|
92 |
+
== "Server 127 dot 0 dot 0 dot 1"
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def test_url_raw_domains():
|
97 |
+
"""Test raw domain handling"""
|
98 |
+
assert (
|
99 |
+
normalize_text(
|
100 |
+
"Visit google.com/search", normalization_options=NormalizationOptions()
|
101 |
+
)
|
102 |
+
== "Visit google dot com slash search"
|
103 |
+
)
|
104 |
+
assert (
|
105 |
+
normalize_text(
|
106 |
+
"Go to example.com/path?q=test",
|
107 |
+
normalization_options=NormalizationOptions(),
|
108 |
+
)
|
109 |
+
== "Go to example dot com slash path question-mark q equals test"
|
110 |
+
)
|
111 |
+
assert (
|
112 |
+
normalize_text(
|
113 |
+
"Check docs.test.com", normalization_options=NormalizationOptions()
|
114 |
+
)
|
115 |
+
== "Check docs dot test dot com"
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
def test_url_email_addresses():
|
120 |
+
"""Test email address handling"""
|
121 |
+
assert (
|
122 |
+
normalize_text(
|
123 |
+
"Email me at [email protected]", normalization_options=NormalizationOptions()
|
124 |
+
)
|
125 |
+
== "Email me at user at example dot com"
|
126 |
+
)
|
127 |
+
assert (
|
128 |
+
normalize_text(
|
129 |
+
"Contact [email protected]", normalization_options=NormalizationOptions()
|
130 |
+
)
|
131 |
+
== "Contact admin at test dot org"
|
132 |
+
)
|
133 |
+
assert (
|
134 |
+
normalize_text(
|
135 |
+
"Send to [email protected]", normalization_options=NormalizationOptions()
|
136 |
+
)
|
137 |
+
== "Send to test dot user at site dot com"
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def test_money():
|
142 |
+
"""Test that money text is normalized correctly"""
|
143 |
+
assert (
|
144 |
+
normalize_text(
|
145 |
+
"He lost $5.3 thousand.", normalization_options=NormalizationOptions()
|
146 |
+
)
|
147 |
+
== "He lost five point three thousand dollars."
|
148 |
+
)
|
149 |
+
assert (
|
150 |
+
normalize_text(
|
151 |
+
"To put it weirdly -$6.9 million",
|
152 |
+
normalization_options=NormalizationOptions(),
|
153 |
+
)
|
154 |
+
== "To put it weirdly minus six point nine million dollars"
|
155 |
+
)
|
156 |
+
assert (
|
157 |
+
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
|
158 |
+
== "It costs fifty dollars and thirty cents."
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
def test_non_url_text():
|
163 |
+
"""Test that non-URL text is unaffected"""
|
164 |
+
assert (
|
165 |
+
normalize_text(
|
166 |
+
"This is not.a.url text", normalization_options=NormalizationOptions()
|
167 |
+
)
|
168 |
+
== "This is not-a-url text"
|
169 |
+
)
|
170 |
+
assert (
|
171 |
+
normalize_text(
|
172 |
+
"Hello, how are you today?", normalization_options=NormalizationOptions()
|
173 |
+
)
|
174 |
+
== "Hello, how are you today?"
|
175 |
+
)
|
176 |
+
assert (
|
177 |
+
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
178 |
+
== "It costs fifty dollars."
|
179 |
+
)
|
api/tests/test_openai_endpoints.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import AsyncGenerator, Tuple
|
5 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pytest
|
9 |
+
from fastapi.testclient import TestClient
|
10 |
+
|
11 |
+
from api.src.core.config import settings
|
12 |
+
from api.src.inference.base import AudioChunk
|
13 |
+
from api.src.main import app
|
14 |
+
from api.src.routers.openai_compatible import (
|
15 |
+
get_tts_service,
|
16 |
+
load_openai_mappings,
|
17 |
+
stream_audio_chunks,
|
18 |
+
)
|
19 |
+
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
20 |
+
from api.src.services.tts_service import TTSService
|
21 |
+
from api.src.structures.schemas import OpenAISpeechRequest
|
22 |
+
|
23 |
+
client = TestClient(app)
|
24 |
+
|
25 |
+
|
26 |
+
@pytest.fixture
|
27 |
+
def test_voice():
|
28 |
+
"""Fixture providing a test voice name."""
|
29 |
+
return "test_voice"
|
30 |
+
|
31 |
+
|
32 |
+
@pytest.fixture
|
33 |
+
def mock_openai_mappings():
|
34 |
+
"""Mock OpenAI mappings for testing."""
|
35 |
+
with patch(
|
36 |
+
"api.src.routers.openai_compatible._openai_mappings",
|
37 |
+
{
|
38 |
+
"models": {"tts-1": "kokoro-v1_0", "tts-1-hd": "kokoro-v1_0"},
|
39 |
+
"voices": {"alloy": "am_adam", "nova": "bf_isabella"},
|
40 |
+
},
|
41 |
+
):
|
42 |
+
yield
|
43 |
+
|
44 |
+
|
45 |
+
@pytest.fixture
|
46 |
+
def mock_json_file(tmp_path):
|
47 |
+
"""Create a temporary mock JSON file."""
|
48 |
+
content = {
|
49 |
+
"models": {"test-model": "test-kokoro"},
|
50 |
+
"voices": {"test-voice": "test-internal"},
|
51 |
+
}
|
52 |
+
json_file = tmp_path / "test_mappings.json"
|
53 |
+
json_file.write_text(json.dumps(content))
|
54 |
+
return json_file
|
55 |
+
|
56 |
+
|
57 |
+
def test_load_openai_mappings(mock_json_file):
|
58 |
+
"""Test loading OpenAI mappings from JSON file"""
|
59 |
+
with patch("os.path.join", return_value=str(mock_json_file)):
|
60 |
+
mappings = load_openai_mappings()
|
61 |
+
assert "models" in mappings
|
62 |
+
assert "voices" in mappings
|
63 |
+
assert mappings["models"]["test-model"] == "test-kokoro"
|
64 |
+
assert mappings["voices"]["test-voice"] == "test-internal"
|
65 |
+
|
66 |
+
|
67 |
+
def test_load_openai_mappings_file_not_found():
|
68 |
+
"""Test handling of missing mappings file"""
|
69 |
+
with patch("os.path.join", return_value="/nonexistent/path"):
|
70 |
+
mappings = load_openai_mappings()
|
71 |
+
assert mappings == {"models": {}, "voices": {}}
|
72 |
+
|
73 |
+
|
74 |
+
def test_list_models(mock_openai_mappings):
|
75 |
+
"""Test listing available models endpoint"""
|
76 |
+
response = client.get("/v1/models")
|
77 |
+
assert response.status_code == 200
|
78 |
+
data = response.json()
|
79 |
+
assert data["object"] == "list"
|
80 |
+
assert isinstance(data["data"], list)
|
81 |
+
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
|
82 |
+
|
83 |
+
# Verify all expected models are present
|
84 |
+
model_ids = [model["id"] for model in data["data"]]
|
85 |
+
assert "tts-1" in model_ids
|
86 |
+
assert "tts-1-hd" in model_ids
|
87 |
+
assert "kokoro" in model_ids
|
88 |
+
|
89 |
+
# Verify model format
|
90 |
+
for model in data["data"]:
|
91 |
+
assert model["object"] == "model"
|
92 |
+
assert "created" in model
|
93 |
+
assert model["owned_by"] == "kokoro"
|
94 |
+
|
95 |
+
|
96 |
+
def test_retrieve_model(mock_openai_mappings):
|
97 |
+
"""Test retrieving a specific model endpoint"""
|
98 |
+
# Test successful model retrieval
|
99 |
+
response = client.get("/v1/models/tts-1")
|
100 |
+
assert response.status_code == 200
|
101 |
+
data = response.json()
|
102 |
+
assert data["id"] == "tts-1"
|
103 |
+
assert data["object"] == "model"
|
104 |
+
assert data["owned_by"] == "kokoro"
|
105 |
+
assert "created" in data
|
106 |
+
|
107 |
+
# Test non-existent model
|
108 |
+
response = client.get("/v1/models/nonexistent-model")
|
109 |
+
assert response.status_code == 404
|
110 |
+
error = response.json()
|
111 |
+
assert error["detail"]["error"] == "model_not_found"
|
112 |
+
assert "not found" in error["detail"]["message"]
|
113 |
+
assert error["detail"]["type"] == "invalid_request_error"
|
114 |
+
|
115 |
+
|
116 |
+
@pytest.mark.asyncio
|
117 |
+
async def test_get_tts_service_initialization():
|
118 |
+
"""Test TTSService initialization"""
|
119 |
+
with patch("api.src.routers.openai_compatible._tts_service", None):
|
120 |
+
with patch("api.src.routers.openai_compatible._init_lock", None):
|
121 |
+
with patch("api.src.services.tts_service.TTSService.create") as mock_create:
|
122 |
+
mock_service = AsyncMock()
|
123 |
+
mock_create.return_value = mock_service
|
124 |
+
|
125 |
+
# Test concurrent access
|
126 |
+
async def get_service():
|
127 |
+
return await get_tts_service()
|
128 |
+
|
129 |
+
# Create multiple concurrent requests
|
130 |
+
tasks = [get_service() for _ in range(5)]
|
131 |
+
results = await asyncio.gather(*tasks)
|
132 |
+
|
133 |
+
# Verify service was created only once
|
134 |
+
mock_create.assert_called_once()
|
135 |
+
assert all(r == mock_service for r in results)
|
136 |
+
|
137 |
+
|
138 |
+
@pytest.mark.asyncio
|
139 |
+
async def test_stream_audio_chunks_client_disconnect():
|
140 |
+
"""Test handling of client disconnect during streaming"""
|
141 |
+
mock_request = MagicMock()
|
142 |
+
mock_request.is_disconnected = AsyncMock(return_value=True)
|
143 |
+
|
144 |
+
mock_service = AsyncMock()
|
145 |
+
|
146 |
+
async def mock_stream(*args, **kwargs):
|
147 |
+
for i in range(5):
|
148 |
+
yield AudioChunk(np.ndarray([], np.int16), output=b"chunk")
|
149 |
+
|
150 |
+
mock_service.generate_audio_stream = mock_stream
|
151 |
+
mock_service.list_voices.return_value = ["test_voice"]
|
152 |
+
|
153 |
+
request = OpenAISpeechRequest(
|
154 |
+
model="kokoro",
|
155 |
+
input="Test text",
|
156 |
+
voice="test_voice",
|
157 |
+
response_format="mp3",
|
158 |
+
stream=True,
|
159 |
+
speed=1.0,
|
160 |
+
)
|
161 |
+
|
162 |
+
writer = StreamingAudioWriter("mp3", 24000)
|
163 |
+
|
164 |
+
chunks = []
|
165 |
+
async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
|
166 |
+
chunks.append(chunk)
|
167 |
+
|
168 |
+
writer.close()
|
169 |
+
|
170 |
+
assert len(chunks) == 0 # Should stop immediately due to disconnect
|
171 |
+
|
172 |
+
|
173 |
+
def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings):
|
174 |
+
"""Test OpenAI voice name mapping"""
|
175 |
+
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
|
176 |
+
|
177 |
+
response = client.post(
|
178 |
+
"/v1/audio/speech",
|
179 |
+
json={
|
180 |
+
"model": "tts-1",
|
181 |
+
"input": "Hello world",
|
182 |
+
"voice": "alloy", # OpenAI voice name
|
183 |
+
"response_format": "mp3",
|
184 |
+
"stream": False,
|
185 |
+
},
|
186 |
+
)
|
187 |
+
assert response.status_code == 200
|
188 |
+
mock_tts_service.generate_audio.assert_called_once()
|
189 |
+
assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam"
|
190 |
+
|
191 |
+
|
192 |
+
def test_openai_voice_mapping_streaming(
|
193 |
+
mock_tts_service, mock_openai_mappings, mock_audio_bytes
|
194 |
+
):
|
195 |
+
"""Test OpenAI voice mapping in streaming mode"""
|
196 |
+
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
|
197 |
+
|
198 |
+
response = client.post(
|
199 |
+
"/v1/audio/speech",
|
200 |
+
json={
|
201 |
+
"model": "tts-1-hd",
|
202 |
+
"input": "Hello world",
|
203 |
+
"voice": "nova", # OpenAI voice name
|
204 |
+
"response_format": "mp3",
|
205 |
+
"stream": True,
|
206 |
+
},
|
207 |
+
)
|
208 |
+
assert response.status_code == 200
|
209 |
+
content = b""
|
210 |
+
for chunk in response.iter_bytes():
|
211 |
+
content += chunk
|
212 |
+
assert content == mock_audio_bytes
|
213 |
+
|
214 |
+
|
215 |
+
def test_invalid_openai_model(mock_tts_service, mock_openai_mappings):
|
216 |
+
"""Test error handling for invalid OpenAI model"""
|
217 |
+
response = client.post(
|
218 |
+
"/v1/audio/speech",
|
219 |
+
json={
|
220 |
+
"model": "invalid-model",
|
221 |
+
"input": "Hello world",
|
222 |
+
"voice": "alloy",
|
223 |
+
"response_format": "mp3",
|
224 |
+
"stream": False,
|
225 |
+
},
|
226 |
+
)
|
227 |
+
assert response.status_code == 400
|
228 |
+
error_response = response.json()
|
229 |
+
assert error_response["detail"]["error"] == "invalid_model"
|
230 |
+
assert "Unsupported model" in error_response["detail"]["message"]
|
231 |
+
|
232 |
+
|
233 |
+
@pytest.fixture
|
234 |
+
def mock_audio_bytes():
|
235 |
+
"""Mock audio bytes for testing."""
|
236 |
+
return b"mock audio data"
|
237 |
+
|
238 |
+
|
239 |
+
@pytest.fixture
|
240 |
+
def mock_tts_service(mock_audio_bytes):
|
241 |
+
"""Mock TTS service for testing."""
|
242 |
+
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
243 |
+
service = AsyncMock(spec=TTSService)
|
244 |
+
service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
|
245 |
+
|
246 |
+
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
|
247 |
+
yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes)
|
248 |
+
|
249 |
+
service.generate_audio_stream = mock_stream
|
250 |
+
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
251 |
+
service.combine_voices.return_value = "voice1_voice2"
|
252 |
+
|
253 |
+
mock_get.return_value = service
|
254 |
+
mock_get.side_effect = None
|
255 |
+
yield service
|
256 |
+
|
257 |
+
|
258 |
+
@patch("api.src.services.audio.AudioService.convert_audio")
|
259 |
+
def test_openai_speech_endpoint(
|
260 |
+
mock_convert, mock_tts_service, test_voice, mock_audio_bytes
|
261 |
+
):
|
262 |
+
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
263 |
+
# Configure mocks
|
264 |
+
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
|
265 |
+
mock_convert.return_value = AudioChunk(
|
266 |
+
np.zeros(1000, np.int16), output=mock_audio_bytes
|
267 |
+
)
|
268 |
+
|
269 |
+
response = client.post(
|
270 |
+
"/v1/audio/speech",
|
271 |
+
json={
|
272 |
+
"model": "kokoro",
|
273 |
+
"input": "Hello world",
|
274 |
+
"voice": test_voice,
|
275 |
+
"response_format": "mp3",
|
276 |
+
"stream": False,
|
277 |
+
},
|
278 |
+
)
|
279 |
+
assert response.status_code == 200
|
280 |
+
assert response.headers["content-type"] == "audio/mpeg"
|
281 |
+
assert len(response.content) > 0
|
282 |
+
assert response.content == mock_audio_bytes + mock_audio_bytes
|
283 |
+
|
284 |
+
mock_tts_service.generate_audio.assert_called_once()
|
285 |
+
assert mock_convert.call_count == 2
|
286 |
+
|
287 |
+
|
288 |
+
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
289 |
+
"""Test the OpenAI-compatible speech endpoint with streaming"""
|
290 |
+
response = client.post(
|
291 |
+
"/v1/audio/speech",
|
292 |
+
json={
|
293 |
+
"model": "kokoro",
|
294 |
+
"input": "Hello world",
|
295 |
+
"voice": test_voice,
|
296 |
+
"response_format": "mp3",
|
297 |
+
"stream": True,
|
298 |
+
},
|
299 |
+
)
|
300 |
+
assert response.status_code == 200
|
301 |
+
assert response.headers["content-type"] == "audio/mpeg"
|
302 |
+
assert "Transfer-Encoding" in response.headers
|
303 |
+
assert response.headers["Transfer-Encoding"] == "chunked"
|
304 |
+
|
305 |
+
content = b""
|
306 |
+
for chunk in response.iter_bytes():
|
307 |
+
content += chunk
|
308 |
+
assert content == mock_audio_bytes
|
309 |
+
|
310 |
+
|
311 |
+
def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
312 |
+
"""Test PCM streaming format"""
|
313 |
+
response = client.post(
|
314 |
+
"/v1/audio/speech",
|
315 |
+
json={
|
316 |
+
"model": "kokoro",
|
317 |
+
"input": "Hello world",
|
318 |
+
"voice": test_voice,
|
319 |
+
"response_format": "pcm",
|
320 |
+
"stream": True,
|
321 |
+
},
|
322 |
+
)
|
323 |
+
assert response.status_code == 200
|
324 |
+
assert response.headers["content-type"] == "audio/pcm"
|
325 |
+
|
326 |
+
content = b""
|
327 |
+
for chunk in response.iter_bytes():
|
328 |
+
content += chunk
|
329 |
+
assert content == mock_audio_bytes
|
330 |
+
|
331 |
+
|
332 |
+
def test_openai_speech_invalid_voice(mock_tts_service):
|
333 |
+
"""Test error handling for invalid voice"""
|
334 |
+
mock_tts_service.generate_audio.side_effect = ValueError(
|
335 |
+
"Voice 'invalid_voice' not found"
|
336 |
+
)
|
337 |
+
|
338 |
+
response = client.post(
|
339 |
+
"/v1/audio/speech",
|
340 |
+
json={
|
341 |
+
"model": "kokoro",
|
342 |
+
"input": "Hello world",
|
343 |
+
"voice": "invalid_voice",
|
344 |
+
"response_format": "mp3",
|
345 |
+
"stream": False,
|
346 |
+
},
|
347 |
+
)
|
348 |
+
assert response.status_code == 400
|
349 |
+
error_response = response.json()
|
350 |
+
assert error_response["detail"]["error"] == "validation_error"
|
351 |
+
assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"]
|
352 |
+
assert error_response["detail"]["type"] == "invalid_request_error"
|
353 |
+
|
354 |
+
|
355 |
+
def test_openai_speech_empty_text(mock_tts_service, test_voice):
|
356 |
+
"""Test error handling for empty text"""
|
357 |
+
|
358 |
+
async def mock_error_stream(*args, **kwargs):
|
359 |
+
raise ValueError("Text is empty after preprocessing")
|
360 |
+
|
361 |
+
mock_tts_service.generate_audio = mock_error_stream
|
362 |
+
mock_tts_service.list_voices.return_value = ["test_voice"]
|
363 |
+
|
364 |
+
response = client.post(
|
365 |
+
"/v1/audio/speech",
|
366 |
+
json={
|
367 |
+
"model": "kokoro",
|
368 |
+
"input": "",
|
369 |
+
"voice": test_voice,
|
370 |
+
"response_format": "mp3",
|
371 |
+
"stream": False,
|
372 |
+
},
|
373 |
+
)
|
374 |
+
assert response.status_code == 400
|
375 |
+
error_response = response.json()
|
376 |
+
assert error_response["detail"]["error"] == "validation_error"
|
377 |
+
assert "Text is empty after preprocessing" in error_response["detail"]["message"]
|
378 |
+
assert error_response["detail"]["type"] == "invalid_request_error"
|
379 |
+
|
380 |
+
|
381 |
+
def test_openai_speech_invalid_format(mock_tts_service, test_voice):
|
382 |
+
"""Test error handling for invalid format"""
|
383 |
+
response = client.post(
|
384 |
+
"/v1/audio/speech",
|
385 |
+
json={
|
386 |
+
"model": "kokoro",
|
387 |
+
"input": "Hello world",
|
388 |
+
"voice": test_voice,
|
389 |
+
"response_format": "invalid_format",
|
390 |
+
"stream": False,
|
391 |
+
},
|
392 |
+
)
|
393 |
+
assert response.status_code == 422 # Validation error from Pydantic
|
394 |
+
|
395 |
+
|
396 |
+
def test_list_voices(mock_tts_service):
|
397 |
+
"""Test listing available voices"""
|
398 |
+
# Override the mock for this specific test
|
399 |
+
mock_tts_service.list_voices.return_value = ["voice1", "voice2"]
|
400 |
+
|
401 |
+
response = client.get("/v1/audio/voices")
|
402 |
+
assert response.status_code == 200
|
403 |
+
data = response.json()
|
404 |
+
assert "voices" in data
|
405 |
+
assert len(data["voices"]) == 2
|
406 |
+
assert "voice1" in data["voices"]
|
407 |
+
assert "voice2" in data["voices"]
|
408 |
+
|
409 |
+
|
410 |
+
@patch("api.src.routers.openai_compatible.settings")
|
411 |
+
def test_combine_voices(mock_settings, mock_tts_service):
|
412 |
+
"""Test combining voices endpoint"""
|
413 |
+
# Enable local voice saving for this test
|
414 |
+
mock_settings.allow_local_voice_saving = True
|
415 |
+
|
416 |
+
response = client.post("/v1/audio/voices/combine", json="voice1+voice2")
|
417 |
+
assert response.status_code == 200
|
418 |
+
assert response.headers["content-type"] == "application/octet-stream"
|
419 |
+
assert "voice1+voice2.pt" in response.headers["content-disposition"]
|
420 |
+
|
421 |
+
|
422 |
+
def test_server_error(mock_tts_service, test_voice):
|
423 |
+
"""Test handling of server errors"""
|
424 |
+
|
425 |
+
async def mock_error_stream(*args, **kwargs):
|
426 |
+
raise RuntimeError("Internal server error")
|
427 |
+
|
428 |
+
mock_tts_service.generate_audio = mock_error_stream
|
429 |
+
mock_tts_service.list_voices.return_value = ["test_voice"]
|
430 |
+
|
431 |
+
response = client.post(
|
432 |
+
"/v1/audio/speech",
|
433 |
+
json={
|
434 |
+
"model": "kokoro",
|
435 |
+
"input": "Hello world",
|
436 |
+
"voice": test_voice,
|
437 |
+
"response_format": "mp3",
|
438 |
+
"stream": False,
|
439 |
+
},
|
440 |
+
)
|
441 |
+
assert response.status_code == 500
|
442 |
+
error_response = response.json()
|
443 |
+
assert error_response["detail"]["error"] == "processing_error"
|
444 |
+
assert error_response["detail"]["type"] == "server_error"
|
445 |
+
|
446 |
+
|
447 |
+
def test_streaming_error(mock_tts_service, test_voice):
|
448 |
+
"""Test handling streaming errors"""
|
449 |
+
# Mock process_voices to raise the error
|
450 |
+
mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
|
451 |
+
|
452 |
+
response = client.post(
|
453 |
+
"/v1/audio/speech",
|
454 |
+
json={
|
455 |
+
"model": "kokoro",
|
456 |
+
"input": "Hello world",
|
457 |
+
"voice": test_voice,
|
458 |
+
"response_format": "mp3",
|
459 |
+
"stream": True,
|
460 |
+
},
|
461 |
+
)
|
462 |
+
|
463 |
+
assert response.status_code == 500
|
464 |
+
error_data = response.json()
|
465 |
+
assert error_data["detail"]["error"] == "processing_error"
|
466 |
+
assert error_data["detail"]["type"] == "server_error"
|
467 |
+
assert "Streaming failed" in error_data["detail"]["message"]
|
468 |
+
|
469 |
+
|
470 |
+
@pytest.mark.asyncio
|
471 |
+
async def test_streaming_initialization_error():
|
472 |
+
"""Test handling of streaming initialization errors"""
|
473 |
+
mock_service = AsyncMock()
|
474 |
+
|
475 |
+
async def mock_error_stream(*args, **kwargs):
|
476 |
+
if False: # This makes it a proper generator
|
477 |
+
yield b""
|
478 |
+
raise RuntimeError("Failed to initialize stream")
|
479 |
+
|
480 |
+
mock_service.generate_audio_stream = mock_error_stream
|
481 |
+
mock_service.list_voices.return_value = ["test_voice"]
|
482 |
+
|
483 |
+
request = OpenAISpeechRequest(
|
484 |
+
model="kokoro",
|
485 |
+
input="Test text",
|
486 |
+
voice="test_voice",
|
487 |
+
response_format="mp3",
|
488 |
+
stream=True,
|
489 |
+
speed=1.0,
|
490 |
+
)
|
491 |
+
|
492 |
+
writer = StreamingAudioWriter("mp3", 24000)
|
493 |
+
|
494 |
+
with pytest.raises(RuntimeError) as exc:
|
495 |
+
async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
|
496 |
+
pass
|
497 |
+
|
498 |
+
writer.close()
|
499 |
+
assert "Failed to initialize stream" in str(exc.value)
|
api/tests/test_paths.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from unittest.mock import patch
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from api.src.core.paths import (
|
7 |
+
_find_file,
|
8 |
+
_scan_directories,
|
9 |
+
get_content_type,
|
10 |
+
get_temp_dir_size,
|
11 |
+
get_temp_file_path,
|
12 |
+
list_temp_files,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
@pytest.mark.asyncio
|
17 |
+
async def test_find_file_exists():
|
18 |
+
"""Test finding existing file."""
|
19 |
+
with patch("aiofiles.os.path.exists") as mock_exists:
|
20 |
+
mock_exists.return_value = True
|
21 |
+
path = await _find_file("test.txt", ["/test/path"])
|
22 |
+
assert path == "/test/path/test.txt"
|
23 |
+
|
24 |
+
|
25 |
+
@pytest.mark.asyncio
|
26 |
+
async def test_find_file_not_exists():
|
27 |
+
"""Test finding non-existent file."""
|
28 |
+
with patch("aiofiles.os.path.exists") as mock_exists:
|
29 |
+
mock_exists.return_value = False
|
30 |
+
with pytest.raises(FileNotFoundError, match="File not found"):
|
31 |
+
await _find_file("test.txt", ["/test/path"])
|
32 |
+
|
33 |
+
|
34 |
+
@pytest.mark.asyncio
|
35 |
+
async def test_find_file_with_filter():
|
36 |
+
"""Test finding file with filter function."""
|
37 |
+
with patch("aiofiles.os.path.exists") as mock_exists:
|
38 |
+
mock_exists.return_value = True
|
39 |
+
filter_fn = lambda p: p.endswith(".txt")
|
40 |
+
path = await _find_file("test.txt", ["/test/path"], filter_fn)
|
41 |
+
assert path == "/test/path/test.txt"
|
42 |
+
|
43 |
+
|
44 |
+
@pytest.mark.asyncio
|
45 |
+
async def test_scan_directories():
|
46 |
+
"""Test scanning directories."""
|
47 |
+
mock_entry = type("MockEntry", (), {"name": "test.txt"})()
|
48 |
+
|
49 |
+
with (
|
50 |
+
patch("aiofiles.os.path.exists") as mock_exists,
|
51 |
+
patch("aiofiles.os.scandir") as mock_scandir,
|
52 |
+
):
|
53 |
+
mock_exists.return_value = True
|
54 |
+
mock_scandir.return_value = [mock_entry]
|
55 |
+
|
56 |
+
files = await _scan_directories(["/test/path"])
|
57 |
+
assert "test.txt" in files
|
58 |
+
|
59 |
+
|
60 |
+
@pytest.mark.asyncio
|
61 |
+
async def test_get_content_type():
|
62 |
+
"""Test content type detection."""
|
63 |
+
test_cases = [
|
64 |
+
("test.html", "text/html"),
|
65 |
+
("test.js", "application/javascript"),
|
66 |
+
("test.css", "text/css"),
|
67 |
+
("test.png", "image/png"),
|
68 |
+
("test.unknown", "application/octet-stream"),
|
69 |
+
]
|
70 |
+
|
71 |
+
for filename, expected in test_cases:
|
72 |
+
content_type = await get_content_type(filename)
|
73 |
+
assert content_type == expected
|
74 |
+
|
75 |
+
|
76 |
+
@pytest.mark.asyncio
|
77 |
+
async def test_get_temp_file_path():
|
78 |
+
"""Test temp file path generation."""
|
79 |
+
with (
|
80 |
+
patch("aiofiles.os.path.exists") as mock_exists,
|
81 |
+
patch("aiofiles.os.makedirs") as mock_makedirs,
|
82 |
+
):
|
83 |
+
mock_exists.return_value = False
|
84 |
+
|
85 |
+
path = await get_temp_file_path("test.wav")
|
86 |
+
assert "test.wav" in path
|
87 |
+
mock_makedirs.assert_called_once()
|
88 |
+
|
89 |
+
|
90 |
+
@pytest.mark.asyncio
|
91 |
+
async def test_list_temp_files():
|
92 |
+
"""Test listing temp files."""
|
93 |
+
|
94 |
+
class MockEntry:
|
95 |
+
def __init__(self, name):
|
96 |
+
self.name = name
|
97 |
+
|
98 |
+
def is_file(self):
|
99 |
+
return True
|
100 |
+
|
101 |
+
mock_entry = MockEntry("test.wav")
|
102 |
+
|
103 |
+
with (
|
104 |
+
patch("aiofiles.os.path.exists") as mock_exists,
|
105 |
+
patch("aiofiles.os.scandir") as mock_scandir,
|
106 |
+
):
|
107 |
+
mock_exists.return_value = True
|
108 |
+
mock_scandir.return_value = [mock_entry]
|
109 |
+
|
110 |
+
files = await list_temp_files()
|
111 |
+
assert "test.wav" in files
|
112 |
+
|
113 |
+
|
114 |
+
@pytest.mark.asyncio
|
115 |
+
async def test_get_temp_dir_size():
|
116 |
+
"""Test getting temp directory size."""
|
117 |
+
|
118 |
+
class MockEntry:
|
119 |
+
def __init__(self, path):
|
120 |
+
self.path = path
|
121 |
+
|
122 |
+
def is_file(self):
|
123 |
+
return True
|
124 |
+
|
125 |
+
mock_entry = MockEntry("/tmp/test.wav")
|
126 |
+
mock_stat = type("MockStat", (), {"st_size": 1024})()
|
127 |
+
|
128 |
+
with (
|
129 |
+
patch("aiofiles.os.path.exists") as mock_exists,
|
130 |
+
patch("aiofiles.os.scandir") as mock_scandir,
|
131 |
+
patch("aiofiles.os.stat") as mock_stat_fn,
|
132 |
+
):
|
133 |
+
mock_exists.return_value = True
|
134 |
+
mock_scandir.return_value = [mock_entry]
|
135 |
+
mock_stat_fn.return_value = mock_stat
|
136 |
+
|
137 |
+
size = await get_temp_dir_size()
|
138 |
+
assert size == 1024
|
api/tests/test_text_processor.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
|
3 |
+
from api.src.services.text_processing.text_processor import (
|
4 |
+
get_sentence_info,
|
5 |
+
process_text_chunk,
|
6 |
+
smart_split,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
def test_process_text_chunk_basic():
|
11 |
+
"""Test basic text chunk processing."""
|
12 |
+
text = "Hello world"
|
13 |
+
tokens = process_text_chunk(text)
|
14 |
+
assert isinstance(tokens, list)
|
15 |
+
assert len(tokens) > 0
|
16 |
+
|
17 |
+
|
18 |
+
def test_process_text_chunk_empty():
|
19 |
+
"""Test processing empty text."""
|
20 |
+
text = ""
|
21 |
+
tokens = process_text_chunk(text)
|
22 |
+
assert isinstance(tokens, list)
|
23 |
+
assert len(tokens) == 0
|
24 |
+
|
25 |
+
|
26 |
+
def test_process_text_chunk_phonemes():
|
27 |
+
"""Test processing with skip_phonemize."""
|
28 |
+
phonemes = "h @ l @U" # Example phoneme sequence
|
29 |
+
tokens = process_text_chunk(phonemes, skip_phonemize=True)
|
30 |
+
assert isinstance(tokens, list)
|
31 |
+
assert len(tokens) > 0
|
32 |
+
|
33 |
+
|
34 |
+
def test_get_sentence_info():
|
35 |
+
"""Test sentence splitting and info extraction."""
|
36 |
+
text = "This is sentence one. This is sentence two! What about three?"
|
37 |
+
results = get_sentence_info(text, {})
|
38 |
+
|
39 |
+
assert len(results) == 3
|
40 |
+
for sentence, tokens, count in results:
|
41 |
+
assert isinstance(sentence, str)
|
42 |
+
assert isinstance(tokens, list)
|
43 |
+
assert isinstance(count, int)
|
44 |
+
assert count == len(tokens)
|
45 |
+
assert count > 0
|
46 |
+
|
47 |
+
|
48 |
+
def test_get_sentence_info_phenomoes():
|
49 |
+
"""Test sentence splitting and info extraction."""
|
50 |
+
text = (
|
51 |
+
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
52 |
+
)
|
53 |
+
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
54 |
+
|
55 |
+
assert len(results) == 3
|
56 |
+
assert "sˈɛntᵊns" in results[1][0]
|
57 |
+
for sentence, tokens, count in results:
|
58 |
+
assert isinstance(sentence, str)
|
59 |
+
assert isinstance(tokens, list)
|
60 |
+
assert isinstance(count, int)
|
61 |
+
assert count == len(tokens)
|
62 |
+
assert count > 0
|
63 |
+
|
64 |
+
|
65 |
+
@pytest.mark.asyncio
|
66 |
+
async def test_smart_split_short_text():
|
67 |
+
"""Test smart splitting with text under max tokens."""
|
68 |
+
text = "This is a short test sentence."
|
69 |
+
chunks = []
|
70 |
+
async for chunk_text, chunk_tokens in smart_split(text):
|
71 |
+
chunks.append((chunk_text, chunk_tokens))
|
72 |
+
|
73 |
+
assert len(chunks) == 1
|
74 |
+
assert isinstance(chunks[0][0], str)
|
75 |
+
assert isinstance(chunks[0][1], list)
|
76 |
+
|
77 |
+
|
78 |
+
@pytest.mark.asyncio
|
79 |
+
async def test_smart_split_long_text():
|
80 |
+
"""Test smart splitting with longer text."""
|
81 |
+
# Create text that should split into multiple chunks
|
82 |
+
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
|
83 |
+
|
84 |
+
chunks = []
|
85 |
+
async for chunk_text, chunk_tokens in smart_split(text):
|
86 |
+
chunks.append((chunk_text, chunk_tokens))
|
87 |
+
|
88 |
+
assert len(chunks) > 1
|
89 |
+
for chunk_text, chunk_tokens in chunks:
|
90 |
+
assert isinstance(chunk_text, str)
|
91 |
+
assert isinstance(chunk_tokens, list)
|
92 |
+
assert len(chunk_tokens) > 0
|
93 |
+
|
94 |
+
|
95 |
+
@pytest.mark.asyncio
|
96 |
+
async def test_smart_split_with_punctuation():
|
97 |
+
"""Test smart splitting handles punctuation correctly."""
|
98 |
+
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
|
99 |
+
|
100 |
+
chunks = []
|
101 |
+
async for chunk_text, chunk_tokens in smart_split(text):
|
102 |
+
chunks.append(chunk_text)
|
103 |
+
|
104 |
+
# Verify punctuation is preserved
|
105 |
+
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
|
api/tests/test_tts_service.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from api.src.services.tts_service import TTSService
|
8 |
+
|
9 |
+
|
10 |
+
@pytest.fixture
|
11 |
+
def mock_managers():
|
12 |
+
"""Mock model and voice managers."""
|
13 |
+
|
14 |
+
async def _mock_managers():
|
15 |
+
model_manager = AsyncMock()
|
16 |
+
model_manager.get_backend.return_value = MagicMock()
|
17 |
+
|
18 |
+
voice_manager = AsyncMock()
|
19 |
+
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
|
20 |
+
voice_manager.list_voices.return_value = ["voice1", "voice2"]
|
21 |
+
|
22 |
+
with (
|
23 |
+
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
24 |
+
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
25 |
+
):
|
26 |
+
mock_get_model.return_value = model_manager
|
27 |
+
mock_get_voice.return_value = voice_manager
|
28 |
+
return model_manager, voice_manager
|
29 |
+
|
30 |
+
return _mock_managers()
|
31 |
+
|
32 |
+
|
33 |
+
@pytest.fixture
|
34 |
+
def tts_service(mock_managers):
|
35 |
+
"""Create TTSService instance with mocked dependencies."""
|
36 |
+
|
37 |
+
async def _create_service():
|
38 |
+
return await TTSService.create("test_output")
|
39 |
+
|
40 |
+
return _create_service()
|
41 |
+
|
42 |
+
|
43 |
+
@pytest.mark.asyncio
|
44 |
+
async def test_service_creation():
|
45 |
+
"""Test service creation and initialization."""
|
46 |
+
model_manager = AsyncMock()
|
47 |
+
voice_manager = AsyncMock()
|
48 |
+
|
49 |
+
with (
|
50 |
+
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
51 |
+
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
52 |
+
):
|
53 |
+
mock_get_model.return_value = model_manager
|
54 |
+
mock_get_voice.return_value = voice_manager
|
55 |
+
|
56 |
+
service = await TTSService.create("test_output")
|
57 |
+
assert service.output_dir == "test_output"
|
58 |
+
assert service.model_manager is model_manager
|
59 |
+
assert service._voice_manager is voice_manager
|
60 |
+
|
61 |
+
|
62 |
+
@pytest.mark.asyncio
|
63 |
+
async def test_get_voice_path_single():
|
64 |
+
"""Test getting path for single voice."""
|
65 |
+
model_manager = AsyncMock()
|
66 |
+
voice_manager = AsyncMock()
|
67 |
+
voice_manager.get_voice_path.return_value = "/path/to/voice1.pt"
|
68 |
+
|
69 |
+
with (
|
70 |
+
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
71 |
+
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
72 |
+
):
|
73 |
+
mock_get_model.return_value = model_manager
|
74 |
+
mock_get_voice.return_value = voice_manager
|
75 |
+
|
76 |
+
service = await TTSService.create("test_output")
|
77 |
+
name, path = await service._get_voices_path("voice1")
|
78 |
+
assert name == "voice1"
|
79 |
+
assert path == "/path/to/voice1.pt"
|
80 |
+
voice_manager.get_voice_path.assert_called_once_with("voice1")
|
81 |
+
|
82 |
+
|
83 |
+
@pytest.mark.asyncio
|
84 |
+
async def test_get_voice_path_combined():
|
85 |
+
"""Test getting path for combined voices."""
|
86 |
+
model_manager = AsyncMock()
|
87 |
+
voice_manager = AsyncMock()
|
88 |
+
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
|
89 |
+
|
90 |
+
with (
|
91 |
+
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
92 |
+
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
93 |
+
patch("torch.load") as mock_load,
|
94 |
+
patch("torch.save") as mock_save,
|
95 |
+
patch("tempfile.gettempdir") as mock_temp,
|
96 |
+
):
|
97 |
+
mock_get_model.return_value = model_manager
|
98 |
+
mock_get_voice.return_value = voice_manager
|
99 |
+
mock_temp.return_value = "/tmp"
|
100 |
+
mock_load.return_value = torch.ones(10)
|
101 |
+
|
102 |
+
service = await TTSService.create("test_output")
|
103 |
+
name, path = await service._get_voices_path("voice1+voice2")
|
104 |
+
assert name == "voice1+voice2"
|
105 |
+
assert path.endswith("voice1+voice2.pt")
|
106 |
+
mock_save.assert_called_once()
|
107 |
+
|
108 |
+
|
109 |
+
@pytest.mark.asyncio
|
110 |
+
async def test_list_voices():
|
111 |
+
"""Test listing available voices."""
|
112 |
+
model_manager = AsyncMock()
|
113 |
+
voice_manager = AsyncMock()
|
114 |
+
voice_manager.list_voices.return_value = ["voice1", "voice2"]
|
115 |
+
|
116 |
+
with (
|
117 |
+
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
118 |
+
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
119 |
+
):
|
120 |
+
mock_get_model.return_value = model_manager
|
121 |
+
mock_get_voice.return_value = voice_manager
|
122 |
+
|
123 |
+
service = await TTSService.create("test_output")
|
124 |
+
voices = await service.list_voices()
|
125 |
+
assert voices == ["voice1", "voice2"]
|
126 |
+
voice_manager.list_voices.assert_called_once()
|
charts/kokoro-fastapi/.helmignore
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Patterns to ignore when building packages.
|
2 |
+
# This supports shell glob matching, relative path matching, and
|
3 |
+
# negation (prefixed with !). Only one pattern per line.
|
4 |
+
.DS_Store
|
5 |
+
# Common VCS dirs
|
6 |
+
.git/
|
7 |
+
.gitignore
|
8 |
+
.bzr/
|
9 |
+
.bzrignore
|
10 |
+
.hg/
|
11 |
+
.hgignore
|
12 |
+
.svn/
|
13 |
+
# Common backup files
|
14 |
+
*.swp
|
15 |
+
*.bak
|
16 |
+
*.tmp
|
17 |
+
*.orig
|
18 |
+
*~
|
19 |
+
# Various IDEs
|
20 |
+
.project
|
21 |
+
.idea/
|
22 |
+
*.tmproj
|
23 |
+
.vscode/
|
charts/kokoro-fastapi/Chart.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
apiVersion: v2
|
2 |
+
name: kokoro-fastapi
|
3 |
+
description: A Helm chart for deploying the Kokoro FastAPI TTS service to Kubernetes
|
4 |
+
type: application
|
5 |
+
version: 0.3.0
|
6 |
+
appVersion: "0.3.0"
|
7 |
+
|
8 |
+
keywords:
|
9 |
+
- tts
|
10 |
+
- fastapi
|
11 |
+
- gpu
|
12 |
+
- kokoro
|