Tomtom84 commited on
Commit
4715aa2
Β·
verified Β·
1 Parent(s): be98391

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +381 -378
app.py CHANGED
@@ -1,398 +1,401 @@
1
- # app.py ──────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
- import json
4
- import torch
5
- import asyncio
6
- import traceback # Import traceback for better error logging
7
-
8
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
- from huggingface_hub import login
10
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
11
- # Import BaseStreamer for the interface
12
- from transformers.generation.streamers import BaseStreamer
13
- from snac import SNAC # Ensure you have 'pip install snac'
14
-
15
- # --- Globals (populated in load_models) ---
16
- tok = None
17
- model = None
18
- snac = None
19
- masker = None
20
- stopping_criteria = None
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
-
23
- # 0) Login + Device ---------------------------------------------------
24
- HF_TOKEN = os.getenv("HF_TOKEN")
25
- if HF_TOKEN:
26
- print("πŸ”‘ Logging in to Hugging Face Hub...")
27
- login(HF_TOKEN)
28
-
29
- # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug
30
-
31
- # 1) Konstanten -------------------------------------------------------
32
- REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
33
- START_TOKEN = 128259
34
- NEW_BLOCK = 128257
35
- EOS_TOKEN = 128258 # Ensure this is correct for the model
36
- AUDIO_BASE = 128266
37
- AUDIO_SPAN = 4096 * 7 # 28672 Codes
38
- CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
39
- AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
40
-
41
- # 2) Logit‑Mask -------------------------------------------------------
42
- class AudioMask(LogitsProcessor):
43
- def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
44
- super().__init__()
45
- self.allow = torch.cat([
46
- torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long),
47
- audio_ids
48
- ], dim=0)
49
- self.eos = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
50
- self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
51
- self.sent_blocks = 0 # State: Number of audio blocks sent
52
-
53
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
54
- current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
55
- mask = torch.full_like(scores, float("-inf"))
56
- mask[:, current_allow] = 0
57
- return scores + mask
58
-
59
- def reset(self):
60
- self.sent_blocks = 0
61
-
62
- # 3) StoppingCriteria fΓΌr EOS ---------------------------------------
63
- class EosStoppingCriteria(StoppingCriteria):
64
- def __init__(self, eos_token_id: int):
65
- self.eos_token_id = eos_token_id
66
-
67
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
68
- if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
69
- # print("StoppingCriteria: EOS detected.") # Optional: Uncomment for debugging
70
- return True
71
- return False
72
-
73
- # 4) Benutzerdefinierter AudioStreamer -------------------------------
74
- class AudioStreamer(BaseStreamer):
75
- def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
76
- self.ws = ws
77
- self.snac = snac_decoder
78
- self.masker = audio_mask
79
- self.loop = loop
80
- self.device = target_device
81
- self.buf: list[int] = []
82
- self.tasks = set()
83
-
84
- def _decode_block(self, block7: list[int]) -> bytes:
85
- """
86
- Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
87
- Uses modulo to extract base code value (0-4095).
88
- Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
89
- """
90
- if len(block7) != 7:
91
- print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
92
- return b""
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  try:
95
- # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
96
- code_val_0 = block7[0] % CODEBOOK_SIZE
97
- code_val_1 = block7[1] % CODEBOOK_SIZE
98
- code_val_2 = block7[2] % CODEBOOK_SIZE
99
- code_val_3 = block7[3] % CODEBOOK_SIZE
100
- code_val_4 = block7[4] % CODEBOOK_SIZE
101
- code_val_5 = block7[5] % CODEBOOK_SIZE
102
- code_val_6 = block7[6] % CODEBOOK_SIZE
103
-
104
- # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) ---
105
- l1 = [code_val_0]
106
- l2 = [code_val_1, code_val_4]
107
- l3 = [code_val_2, code_val_3, code_val_5, code_val_6]
108
-
109
- except IndexError:
110
- print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
111
- return b""
112
- except Exception as e_map:
113
- print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
114
- return b""
115
-
116
- # --- Convert lists to tensors on the correct device ---
117
- try:
118
- codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
119
- codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
120
- codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
121
- codes = [codes_l1, codes_l2, codes_l3]
122
- except Exception as e_tensor:
123
- print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}")
124
- return b""
125
-
126
- # --- Decode using SNAC ---
127
- try:
128
- with torch.no_grad():
129
- audio = self.snac.decode(codes)[0]
130
- except Exception as e_decode:
131
- print(f"Streamer Error: Exception during snac.decode: {e_decode}")
132
- print(f"Input codes shapes: {[c.shape for c in codes]}")
133
- print(f"Input codes dtypes: {[c.dtype for c in codes]}")
134
- print(f"Input codes devices: {[c.device for c in codes]}")
135
- print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
136
- return b""
137
-
138
- # --- Post-processing ---
139
- try:
140
- audio_np = audio.squeeze().detach().cpu().numpy()
141
- audio_bytes = (audio_np * 32767).astype("int16").tobytes()
142
- return audio_bytes
143
- except Exception as e_post:
144
- print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}")
145
- return b""
146
-
147
- async def _send_audio_bytes(self, data: bytes):
148
- """Coroutine to send bytes over WebSocket."""
149
- if not data:
150
- return
151
- try:
152
- await self.ws.send_bytes(data)
153
- except WebSocketDisconnect:
154
- print("Streamer: WebSocket disconnected during send.")
155
  except Exception as e:
156
- # Log errors other than expected disconnects more visibly maybe
157
- if "Cannot call \"send\" once a close message has been sent" not in str(e):
158
- print(f"Streamer: Error sending bytes: {e}")
159
- # else: # Optionally print disconnect errors quietly
160
- # print("Streamer: Attempted send after close.")
161
- pass # Avoid flooding logs if client disconnects early
162
-
163
- def put(self, value: torch.LongTensor):
164
- """
165
- Receives new token IDs (Tensor) from generate().
166
- Processes tokens, decodes full blocks, and schedules sending.
167
- """
168
- if value.numel() == 0:
169
- return
170
- # Ensure value is on CPU and flatten to a list of ints
171
- new_token_ids = value.squeeze().cpu().tolist() # Move to CPU before list conversion
172
- if isinstance(new_token_ids, int):
173
- new_token_ids = [new_token_ids]
174
-
175
- for t in new_token_ids:
176
- # --- DEBUGGING PRINT ---
177
- # Log every token ID received from the model
178
- print(f"Streamer received token ID: {t}")
179
- # --- END DEBUGGING ---
180
-
181
- if t == EOS_TOKEN:
182
- # print("Streamer: EOS token encountered.") # Optional debugging
183
- break # Stop processing this batch if EOS is found
184
-
185
- if t == NEW_BLOCK:
186
- # print("Streamer: NEW_BLOCK token encountered.") # Optional debugging
187
- self.buf.clear()
188
- continue # Move to the next token
189
-
190
- # Check if token is within the expected audio range
191
- if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
192
- self.buf.append(t - AUDIO_BASE) # Store value relative to base
193
- # else: # Log unexpected tokens if needed
194
- # print(f"Streamer Warning: Ignoring unexpected token {t} (outside audio range [{AUDIO_BASE}, {AUDIO_BASE + AUDIO_SPAN}))")
195
- pass
196
-
197
- # If buffer has 7 tokens, decode and send
198
- if len(self.buf) == 7:
199
- audio_bytes = self._decode_block(self.buf)
200
- self.buf.clear() # Clear buffer after processing
201
-
202
- if audio_bytes: # Only send if decoding was successful
203
- # Schedule the async send function to run on the main event loop
204
- future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
205
- self.tasks.add(future)
206
- # Optional: Remove completed tasks to prevent memory leak if generation is very long
207
- future.add_done_callback(self.tasks.discard)
208
-
209
- # Allow EOS only after the first full block has been processed and scheduled for sending
210
- if self.masker.sent_blocks == 0:
211
- # print("Streamer: First audio block processed, allowing EOS.")
212
- self.masker.sent_blocks = 1 # Update state in the mask
213
-
214
- def end(self):
215
- """Called by generate() when generation finishes."""
216
- if len(self.buf) > 0:
217
- print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
218
- self.buf.clear()
219
- # print(f"Streamer: Generation finished.") # Optional debugging
220
- pass
221
-
222
- # 5) FastAPI App ------------------------------------------------------
223
- app = FastAPI()
224
 
225
- @app.on_event("startup")
226
- async def load_models_startup():
227
- global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, EOS_TOKEN
228
-
229
- print(f"πŸš€ Starting up on device: {device}")
230
- print("⏳ Lade Modelle …", flush=True)
231
-
232
- tok = AutoTokenizer.from_pretrained(REPO)
233
- print("Tokenizer loaded.")
234
-
235
- snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
236
- print(f"SNAC loaded to {device}.")
237
-
238
- model_dtype = torch.float32
239
- if device == "cuda":
240
- if torch.cuda.is_bf16_supported():
241
- model_dtype = torch.bfloat16
242
- print("Using bfloat16 for model.")
243
- else:
244
- model_dtype = torch.float16
245
- print("Using float16 for model.")
246
-
247
- model = AutoModelForCausalLM.from_pretrained(
248
- REPO,
249
- device_map={"": 0} if device == "cuda" else None,
250
- torch_dtype=model_dtype,
251
- low_cpu_mem_usage=True,
252
- )
253
 
254
- # --- Verify EOS Token ---
255
- # Use the actual EOS token ID from the loaded model/tokenizer config
256
- config_eos_id = model.config.eos_token_id
257
- tokenizer_eos_id = tok.eos_token_id
258
-
259
- if config_eos_id is None:
260
- print("🚨 WARNING: model.config.eos_token_id is None!")
261
- # Fallback or default? Let's use the constant for now, but this needs checking.
262
- final_eos_token_id = EOS_TOKEN
263
- elif tokenizer_eos_id is not None and config_eos_id != tokenizer_eos_id:
264
- print(f"⚠️ WARNING: Mismatch! model.config.eos_token_id ({config_eos_id}) != tok.eos_token_id ({tokenizer_eos_id}). Using model config ID.")
265
- final_eos_token_id = config_eos_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  else:
267
- final_eos_token_id = config_eos_id
 
 
 
 
268
 
269
- # Update the global constant if it differs or wasn't set properly by config
270
- if final_eos_token_id != EOS_TOKEN:
271
- print(f"πŸ”„ Updating EOS_TOKEN constant from {EOS_TOKEN} to {final_eos_token_id}")
272
- EOS_TOKEN = final_eos_token_id # Update the global constant
273
 
274
- # Set pad_token_id to the determined EOS token ID
275
- model.config.pad_token_id = EOS_TOKEN
276
- print(f"Using EOS Token ID: {EOS_TOKEN}")
277
- # --- End Verify EOS Token ---
 
 
278
 
279
 
280
- print(f"Model loaded to {model.device} with dtype {model.dtype}.")
281
- model.eval()
 
 
282
 
283
- audio_ids_device = AUDIO_IDS_CPU.to(device)
284
- masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN) # Use updated EOS_TOKEN
285
- print("AudioMask initialized.")
286
 
287
- stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)]) # Use updated EOS_TOKEN
288
- print("StoppingCriteria initialized.")
289
 
290
- print("βœ… Modelle geladen und bereit!", flush=True)
 
 
291
 
292
- @app.get("/")
293
- def hello():
294
- return {"status": "ok", "message": "TTS Service is running"}
295
-
296
- # 6) Helper zum Prompt Bauen -------------------------------------------
297
- def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
298
- """Builds the input_ids and attention_mask for the model."""
299
- prompt_text = f"{voice}: {text}"
300
- prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)
301
-
302
- input_ids = torch.cat([
303
- torch.tensor([[START_TOKEN]], device=device, dtype=torch.long),
304
- prompt_ids,
305
- torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long)
306
- ], dim=1)
307
-
308
- attention_mask = torch.ones_like(input_ids)
309
- return input_ids, attention_mask
310
-
311
- # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
312
- @app.websocket("/ws/tts")
313
- async def tts(ws: WebSocket):
314
- await ws.accept()
315
- print("πŸ”Œ Client connected")
316
- streamer = None
317
- main_loop = asyncio.get_running_loop()
318
 
319
- try:
320
- req_text = await ws.receive_text()
321
- print(f"Received request: {req_text}")
322
- req = json.loads(req_text)
323
- text = req.get("text", "Hallo Welt, wie geht es dir heute?")
324
- voice = req.get("voice", "Jakob")
325
-
326
- if not text:
327
- print("⚠️ Request text is empty.")
328
- await ws.close(code=1003, reason="Text cannot be empty")
329
- return
330
-
331
- print(f"Generating audio for: '{text}' with voice '{voice}'")
332
- ids, attn = build_prompt(text, voice)
333
- masker.reset()
334
- streamer = AudioStreamer(ws, snac, masker, main_loop, device)
335
-
336
- print("Starting generation in background thread...")
337
- # --- DEBUGGING: Adjusted Generation Parameters ---
338
- await asyncio.to_thread(
339
- model.generate,
340
- input_ids=ids,
341
- attention_mask=attn,
342
- max_new_tokens=1500, # Keep lower for faster debugging cycles initially
343
- logits_processor=[masker],
344
- stopping_criteria=stopping_criteria,
345
- # --- Adjusted Parameters for Debugging Repetition ---
346
- do_sample=True,
347
- temperature=0.7, # Slightly higher temperature
348
- # top_p=0.9, # Commented out top_p for simpler testing
349
- repetition_penalty=1.2, # Slightly stronger penalty
350
- # --- End Adjusted Parameters ---
351
- use_cache=True,
352
- streamer=streamer
353
  )
354
- print("Generation thread finished.")
355
-
356
- except WebSocketDisconnect:
357
- print("πŸ”Œ Client disconnected.")
358
- except json.JSONDecodeError:
359
- print("❌ Invalid JSON received.")
360
- if ws.client_state.name == "CONNECTED":
361
- await ws.close(code=1003, reason="Invalid JSON format")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  except Exception as e:
363
- error_details = traceback.format_exc()
364
- print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
365
- error_payload = json.dumps({"error": str(e)})
366
- try:
367
- if ws.client_state.name == "CONNECTED":
368
- await ws.send_text(error_payload)
369
- except Exception:
370
- pass
371
- if ws.client_state.name == "CONNECTED":
372
- await ws.close(code=1011)
373
- finally:
374
- if streamer:
375
- try:
376
- streamer.end()
377
- except Exception as e_end:
378
- print(f"Error during streamer.end(): {e_end}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
- print("Closing connection.")
381
- if ws.client_state.name == "CONNECTED":
382
- try:
383
- await ws.close(code=1000)
384
- except RuntimeError as e_close:
385
- print(f"Runtime error closing websocket: {e_close}")
386
- except Exception as e_close_final:
387
- print(f"Error closing websocket: {e_close_final}")
388
- elif ws.client_state.name != "DISCONNECTED":
389
- print(f"WebSocket final state: {ws.client_state.name}")
390
- print("Connection closed.")
391
-
392
- # 8) Dev‑Start --------------------------------------------------------
393
  if __name__ == "__main__":
394
- import uvicorn
395
- print("Starting Uvicorn server...")
396
- # Note: Consider running with --workers 1 if you face issues with globals/GPU memory
397
- # uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", workers=1)
398
- uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ print("Starting server")
3
+ import logging
4
+
5
+ # Enable or disable debug logging
6
+ DEBUG_LOGGING = False
7
+
8
+ if DEBUG_LOGGING:
9
+ logging.basicConfig(level=logging.DEBUG)
10
+ else:
11
+ logging.basicConfig(level=logging.WARNING)
12
+
13
+
14
+ from RealtimeTTS import (
15
+ TextToAudioStream,
16
+ AzureEngine,
17
+ ElevenlabsEngine,
18
+ SystemEngine,
19
+ CoquiEngine,
20
+ OpenAIEngine,
21
+ KokoroEngine
22
+ )
23
+
24
+ from RealtimeTTS import register_engine
25
+
26
+ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
27
+ from fastapi.middleware.cors import CORSMiddleware
28
+ from fastapi import FastAPI, Query, Request
29
+ from fastapi.staticfiles import StaticFiles
30
+
31
+ from queue import Queue
32
+ import threading
33
+ import logging
34
+ import uvicorn
35
+ import wave
36
+ import io
37
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ PORT = int(os.environ.get("TTS_FASTAPI_PORT", 8000))
40
+
41
+ register_engine("orpheus", OrpheusEngine)
42
+
43
+ SUPPORTED_ENGINES = [
44
+ "azure",
45
+ "openai",
46
+ "elevenlabs",
47
+ "system",
48
+ # "coqui", #multiple queries are not supported on coqui engine right now, comment coqui out for tests where you need server start often,
49
+ "kokoro"
50
+ ]
51
+
52
+ # change start engine by moving engine name
53
+ # to the first position in SUPPORTED_ENGINES
54
+ START_ENGINE = SUPPORTED_ENGINES[0]
55
+
56
+ BROWSER_IDENTIFIERS = [
57
+ "mozilla",
58
+ "chrome",
59
+ "safari",
60
+ "firefox",
61
+ "edge",
62
+ "opera",
63
+ "msie",
64
+ "trident",
65
+ ]
66
+
67
+ origins = [
68
+ "http://localhost",
69
+ f"http://localhost:{PORT}",
70
+ "http://127.0.0.1",
71
+ f"http://127.0.0.1:{PORT}",
72
+ "https://localhost",
73
+ f"https://localhost:{PORT}",
74
+ "https://127.0.0.1",
75
+ f"https://127.0.0.1:{PORT}",
76
+ ]
77
+
78
+ play_text_to_speech_semaphore = threading.Semaphore(1)
79
+ engines = {}
80
+ voices = {}
81
+ current_engine = None
82
+ speaking_lock = threading.Lock()
83
+ tts_lock = threading.Lock()
84
+ gen_lock = threading.Lock()
85
+
86
+
87
+ class TTSRequestHandler:
88
+ def __init__(self, engine):
89
+ self.engine = engine
90
+ self.audio_queue = Queue()
91
+ self.stream = TextToAudioStream(
92
+ engine, on_audio_stream_stop=self.on_audio_stream_stop, muted=True
93
+ )
94
+ self.speaking = False
95
+
96
+ def on_audio_chunk(self, chunk):
97
+ self.audio_queue.put(chunk)
98
+
99
+ def on_audio_stream_stop(self):
100
+ self.audio_queue.put(None)
101
+ self.speaking = False
102
+
103
+ def play_text_to_speech(self, text):
104
+ self.speaking = True
105
+ self.stream.feed(text)
106
+ logging.debug(f"Playing audio for text: {text}")
107
+ print(f'Synthesizing: "{text}"')
108
+ self.stream.play_async(on_audio_chunk=self.on_audio_chunk, muted=True)
109
+
110
+ def audio_chunk_generator(self, send_wave_headers):
111
+ first_chunk = False
112
  try:
113
+ while True:
114
+ chunk = self.audio_queue.get()
115
+ if chunk is None:
116
+ print("Terminating stream")
117
+ break
118
+ if not first_chunk:
119
+ if send_wave_headers:
120
+ print("Sending wave header")
121
+ yield create_wave_header_for_engine(self.engine)
122
+ first_chunk = True
123
+ yield chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  except Exception as e:
125
+ print(f"Error during streaming: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ app = FastAPI()
129
+ app.mount("/static", StaticFiles(directory="static"), name="static")
130
+ app.add_middleware(
131
+ CORSMiddleware,
132
+ allow_origins=origins,
133
+ allow_credentials=True,
134
+ allow_methods=["*"],
135
+ allow_headers=["*"],
136
+ )
137
+
138
+ # Define a CSP that allows 'self' for script sources for firefox
139
+ csp = {
140
+ "default-src": "'self'",
141
+ "script-src": "'self'",
142
+ "style-src": "'self' 'unsafe-inline'",
143
+ "img-src": "'self' data:",
144
+ "font-src": "'self' data:",
145
+ "media-src": "'self' blob:",
146
+ }
147
+ csp_string = "; ".join(f"{key} {value}" for key, value in csp.items())
148
+
149
+
150
+ @app.middleware("http")
151
+ async def add_security_headers(request: Request, call_next):
152
+ response = await call_next(request)
153
+ response.headers["Content-Security-Policy"] = csp_string
154
+ return response
155
+
156
+
157
+ @app.get("/favicon.ico")
158
+ async def favicon():
159
+ return FileResponse("static/favicon.ico")
160
+
161
+
162
+ def _set_engine(engine_name):
163
+ global current_engine, stream
164
+ if current_engine is None:
165
+ current_engine = engines[engine_name]
166
  else:
167
+ current_engine = engines[engine_name]
168
+
169
+ if voices[engine_name]:
170
+ engines[engine_name].set_voice(voices[engine_name][0].name)
171
+
172
 
173
+ @app.get("/set_engine")
174
+ def set_engine(request: Request, engine_name: str = Query(...)):
175
+ if engine_name not in engines:
176
+ return {"error": "Engine not supported"}
177
 
178
+ try:
179
+ _set_engine(engine_name)
180
+ return {"message": f"Switched to {engine_name} engine"}
181
+ except Exception as e:
182
+ logging.error(f"Error switching engine: {str(e)}")
183
+ return {"error": "Failed to switch engine"}
184
 
185
 
186
+ def is_browser_request(request):
187
+ user_agent = request.headers.get("user-agent", "").lower()
188
+ is_browser = any(browser_id in user_agent for browser_id in BROWSER_IDENTIFIERS)
189
+ return is_browser
190
 
 
 
 
191
 
192
+ def create_wave_header_for_engine(engine):
193
+ _, _, sample_rate = engine.get_stream_info()
194
 
195
+ num_channels = 1
196
+ sample_width = 2
197
+ frame_rate = sample_rate
198
 
199
+ wav_header = io.BytesIO()
200
+ with wave.open(wav_header, "wb") as wav_file:
201
+ wav_file.setnchannels(num_channels)
202
+ wav_file.setsampwidth(sample_width)
203
+ wav_file.setframerate(frame_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ wav_header.seek(0)
206
+ wave_header_bytes = wav_header.read()
207
+ wav_header.close()
208
+
209
+ # Create a new BytesIO with the correct MIME type for Firefox
210
+ final_wave_header = io.BytesIO()
211
+ final_wave_header.write(wave_header_bytes)
212
+ final_wave_header.seek(0)
213
+
214
+ return final_wave_header.getvalue()
215
+
216
+
217
+ @app.get("/tts")
218
+ async def tts(request: Request, text: str = Query(...)):
219
+ with tts_lock:
220
+ request_handler = TTSRequestHandler(current_engine)
221
+ browser_request = is_browser_request(request)
222
+
223
+ if play_text_to_speech_semaphore.acquire(blocking=False):
224
+ try:
225
+ threading.Thread(
226
+ target=request_handler.play_text_to_speech,
227
+ args=(text,),
228
+ daemon=True,
229
+ ).start()
230
+ finally:
231
+ play_text_to_speech_semaphore.release()
232
+
233
+ return StreamingResponse(
234
+ request_handler.audio_chunk_generator(browser_request),
235
+ media_type="audio/wav",
 
 
 
236
  )
237
+
238
+
239
+ @app.get("/engines")
240
+ def get_engines():
241
+ return list(engines.keys())
242
+
243
+
244
+ @app.get("/voices")
245
+ def get_voices():
246
+ voices_list = []
247
+ for voice in voices[current_engine.engine_name]:
248
+ voices_list.append(voice.name)
249
+ return voices_list
250
+
251
+
252
+ @app.get("/setvoice")
253
+ def set_voice(request: Request, voice_name: str = Query(...)):
254
+ print(f"Getting request: {voice_name}")
255
+ if not current_engine:
256
+ print("No engine is currently selected")
257
+ return {"error": "No engine is currently selected"}
258
+
259
+ try:
260
+ print(f"Setting voice to {voice_name}")
261
+ current_engine.set_voice(voice_name)
262
+ return {"message": f"Voice set to {voice_name} successfully"}
263
  except Exception as e:
264
+ print(f"Error setting voice: {str(e)}")
265
+ logging.error(f"Error setting voice: {str(e)}")
266
+ return {"error": "Failed to set voice"}
267
+
268
+
269
+ @app.get("/")
270
+ def root_page():
271
+ engines_options = "".join(
272
+ [
273
+ f'<option value="{engine}">{engine.title()}</option>'
274
+ for engine in engines.keys()
275
+ ]
276
+ )
277
+ content = f"""
278
+ <!DOCTYPE html>
279
+ <html>
280
+ <head>
281
+ <title>Text-To-Speech</title>
282
+ <style>
283
+ body {{
284
+ font-family: Arial, sans-serif;
285
+ background-color: #f0f0f0;
286
+ margin: 0;
287
+ padding: 0;
288
+ }}
289
+ h2 {{
290
+ color: #333;
291
+ text-align: center;
292
+ }}
293
+ #container {{
294
+ width: 80%;
295
+ margin: 50px auto;
296
+ background-color: #fff;
297
+ border-radius: 10px;
298
+ padding: 20px;
299
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
300
+ }}
301
+ label {{
302
+ font-weight: bold;
303
+ }}
304
+ select, textarea {{
305
+ width: 100%;
306
+ padding: 10px;
307
+ margin: 10px 0;
308
+ border: 1px solid #ccc;
309
+ border-radius: 5px;
310
+ box-sizing: border-box;
311
+ font-size: 16px;
312
+ }}
313
+ button {{
314
+ display: block;
315
+ width: 100%;
316
+ padding: 15px;
317
+ background-color: #007bff;
318
+ border: none;
319
+ border-radius: 5px;
320
+ color: #fff;
321
+ font-size: 16px;
322
+ cursor: pointer;
323
+ transition: background-color 0.3s;
324
+ }}
325
+ button:hover {{
326
+ background-color: #0056b3;
327
+ }}
328
+ audio {{
329
+ width: 80%;
330
+ margin: 10px auto;
331
+ display: block;
332
+ }}
333
+ </style>
334
+ </head>
335
+ <body>
336
+ <div id="container">
337
+ <h2>Text to Speech</h2>
338
+ <label for="engine">Select Engine:</label>
339
+ <select id="engine">
340
+ {engines_options}
341
+ </select>
342
+ <label for="voice">Select Voice:</label>
343
+ <select id="voice">
344
+ <!-- Options will be dynamically populated by JavaScript -->
345
+ </select>
346
+ <textarea id="text" rows="4" cols="50" placeholder="Enter text here..."></textarea>
347
+ <button id="speakButton">Speak</button>
348
+ <audio id="audio" controls></audio> <!-- Hidden audio player -->
349
+ </div>
350
+ <script src="/static/tts.js"></script>
351
+ </body>
352
+ </html>
353
+ """
354
+ return HTMLResponse(content=content)
355
+
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  if __name__ == "__main__":
358
+ print("Initializing TTS Engines")
359
+
360
+ for engine_name in SUPPORTED_ENGINES:
361
+ if "azure" == engine_name:
362
+ azure_api_key = os.environ.get("AZURE_SPEECH_KEY")
363
+ azure_region = os.environ.get("AZURE_SPEECH_REGION")
364
+ if azure_api_key and azure_region:
365
+ print("Initializing azure engine")
366
+ engines["azure"] = AzureEngine(azure_api_key, azure_region)
367
+
368
+ if "elevenlabs" == engine_name:
369
+ elevenlabs_api_key = os.environ.get("ELEVENLABS_API_KEY")
370
+ if elevenlabs_api_key:
371
+ print("Initializing elevenlabs engine")
372
+ engines["elevenlabs"] = ElevenlabsEngine(elevenlabs_api_key)
373
+
374
+ if "system" == engine_name:
375
+ print("Initializing system engine")
376
+ engines["system"] = SystemEngine()
377
+
378
+ if "coqui" == engine_name:
379
+ print("Initializing coqui engine")
380
+ engines["coqui"] = CoquiEngine()
381
+
382
+ if "kokoro" == engine_name:
383
+ print("Initializing kokoro engine")
384
+ engines["kokoro"] = KokoroEngine()
385
+
386
+ if "openai" == engine_name:
387
+ print("Initializing openai engine")
388
+ engines["openai"] = OpenAIEngine()
389
+
390
+ for _engine in engines.keys():
391
+ print(f"Retrieving voices for TTS Engine {_engine}")
392
+ try:
393
+ voices[_engine] = engines[_engine].get_voices()
394
+ except Exception as e:
395
+ voices[_engine] = []
396
+ logging.error(f"Error retrieving voices for {_engine}: {str(e)}")
397
+
398
+ _set_engine(START_ENGINE)
399
+
400
+ print("Server ready")
401
+ uvicorn.run(app, host="0.0.0.0", port=PORT)