Tomtom84 commited on
Commit
b45ed35
·
verified ·
1 Parent(s): d229a19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -376
app.py CHANGED
@@ -1,404 +1,223 @@
1
- if __name__ == "__main__":
2
- print("Starting server")
3
- import logging
4
-
5
- # Enable or disable debug logging
6
- DEBUG_LOGGING = False
7
-
8
- if DEBUG_LOGGING:
9
- logging.basicConfig(level=logging.DEBUG)
10
- else:
11
- logging.basicConfig(level=logging.WARNING)
12
-
13
-
14
- from RealtimeTTS import (
15
- TextToAudioStream,
16
- )
17
- from engines.orpheus_engine import OrpheusEngine
18
-
19
- from huggingface_hub import login
20
-
21
- from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
22
  from fastapi.middleware.cors import CORSMiddleware
23
- from fastapi import FastAPI, Query, Request
24
- from fastapi.staticfiles import StaticFiles
25
-
26
- from queue import Queue
27
- import threading
28
- import logging
29
- import uvicorn
30
- import wave
31
- import io
32
  import os
 
 
 
33
 
34
- HF_TOKEN = os.getenv("HF_TOKEN")
35
- if HF_TOKEN:
36
- print("🔑 Logging in to Hugging Face Hub...")
37
- login(HF_TOKEN)
38
-
39
- engines: dict[str, "BaseEngine"] = {}
40
- voices: dict[str, list] = {}
41
- current_engine = None
42
-
43
- play_text_to_speech_semaphore = threading.Semaphore(1)
44
-
45
- PORT = int(os.getenv("TTS_FASTAPI_PORT", 7860)) # Zahl kongruent halten
46
- SUPPORTED_ENGINES = ["orpheus"]
47
- engines["orpheus"] = OrpheusEngine(
48
- api_url=os.getenv("ORPHEUS_API_URL"), # http://127.0.0.1:1234/v1/completions
49
- model=os.getenv("ORPHEUS_MODEL") # Kartoffel-ID
50
- )
51
-
52
- voices["orpheus"] = engines["orpheus"].get_voices()
53
-
54
- current_engine = engines["orpheus"]
55
-
56
- # change start engine by moving engine name
57
- # to the first position in SUPPORTED_ENGINES
58
- # START_ENGINE = SUPPORTED_ENGINES[0]
59
- START_ENGINE = "orpheus"
60
-
61
- BROWSER_IDENTIFIERS = [
62
- "mozilla",
63
- "chrome",
64
- "safari",
65
- "firefox",
66
- "edge",
67
- "opera",
68
- "msie",
69
- "trident",
70
- ]
71
-
72
- origins = [
73
- "http://localhost",
74
- f"http://localhost:{PORT}",
75
- "http://127.0.0.1",
76
- f"http://127.0.0.1:{PORT}",
77
- "https://localhost",
78
- f"https://localhost:{PORT}",
79
- "https://127.0.0.1",
80
- f"https://127.0.0.1:{PORT}",
81
- ]
82
-
83
-
84
- speaking_lock = threading.Lock()
85
- tts_lock = threading.Lock()
86
- gen_lock = threading.Lock()
87
-
88
-
89
- class TTSRequestHandler:
90
- def __init__(self, engine):
91
- self.engine = engine
92
- self.audio_queue = Queue()
93
- self.stream = TextToAudioStream(
94
- engine, on_audio_stream_stop=self.on_audio_stream_stop, muted=True
95
- )
96
- self.speaking = False
97
-
98
- def on_audio_chunk(self, chunk):
99
- self.audio_queue.put(chunk)
100
-
101
- def on_audio_stream_stop(self):
102
- self.audio_queue.put(None)
103
- self.speaking = False
104
 
105
- def play_text_to_speech(self, text):
106
- self.speaking = True
107
- self.stream.feed(text)
108
- logging.debug(f"Playing audio for text: {text}")
109
- print(f'Synthesizing: "{text}"')
110
- self.stream.play_async(on_audio_chunk=self.on_audio_chunk, muted=True)
111
 
112
- def audio_chunk_generator(self, send_wave_headers):
113
- first_chunk = False
114
- try:
115
- while True:
116
- chunk = self.audio_queue.get()
117
- if chunk is None:
118
- print("Terminating stream")
119
- break
120
- if not first_chunk:
121
- if send_wave_headers:
122
- print("Sending wave header")
123
- yield create_wave_header_for_engine(self.engine)
124
- first_chunk = True
125
- yield chunk
126
- except Exception as e:
127
- print(f"Error during streaming: {str(e)}")
128
 
 
129
 
130
- app = FastAPI()
131
- app.mount("/static", StaticFiles(directory="static"), name="static")
132
  app.add_middleware(
133
  CORSMiddleware,
134
- allow_origins=origins,
135
  allow_credentials=True,
136
  allow_methods=["*"],
137
  allow_headers=["*"],
138
  )
139
 
140
- # Define a CSP that allows 'self' for script sources for firefox
141
- csp = {
142
- "default-src": "'self'",
143
- "script-src": "'self'",
144
- "style-src": "'self' 'unsafe-inline'",
145
- "img-src": "'self' data:",
146
- "font-src": "'self' data:",
147
- "media-src": "'self' blob:",
148
- }
149
- csp_string = "; ".join(f"{key} {value}" for key, value in csp.items())
150
-
151
-
152
- @app.middleware("http")
153
- async def add_security_headers(request: Request, call_next):
154
- response = await call_next(request)
155
- response.headers["Content-Security-Policy"] = csp_string
156
- return response
157
-
158
-
159
- @app.get("/favicon.ico")
160
- async def favicon():
161
- return FileResponse("static/favicon.ico")
162
-
163
-
164
- def _set_engine(engine_name):
165
- global current_engine, stream
166
- if current_engine is None:
167
- current_engine = engines[engine_name]
168
- else:
169
- current_engine = engines[engine_name]
170
-
171
- if voices[engine_name]:
172
- engines[engine_name].set_voice(voices[engine_name][0].name)
173
-
174
-
175
- @app.get("/set_engine")
176
- def set_engine(request: Request, engine_name: str = Query(...)):
177
- if engine_name not in engines:
178
- return {"error": "Engine not supported"}
179
 
 
 
 
180
  try:
181
- _set_engine(engine_name)
182
- return {"message": f"Switched to {engine_name} engine"}
 
 
 
183
  except Exception as e:
184
- logging.error(f"Error switching engine: {str(e)}")
185
- return {"error": "Failed to switch engine"}
186
-
187
-
188
- def is_browser_request(request):
189
- user_agent = request.headers.get("user-agent", "").lower()
190
- is_browser = any(browser_id in user_agent for browser_id in BROWSER_IDENTIFIERS)
191
- return is_browser
192
-
193
-
194
- def create_wave_header_for_engine(engine):
195
- _, _, sample_rate = engine.get_stream_info()
196
-
197
- num_channels = 1
198
- sample_width = 2
199
- frame_rate = sample_rate
200
-
201
- wav_header = io.BytesIO()
202
- with wave.open(wav_header, "wb") as wav_file:
203
- wav_file.setnchannels(num_channels)
204
- wav_file.setsampwidth(sample_width)
205
- wav_file.setframerate(frame_rate)
206
-
207
- wav_header.seek(0)
208
- wave_header_bytes = wav_header.read()
209
- wav_header.close()
210
-
211
- # Create a new BytesIO with the correct MIME type for Firefox
212
- final_wave_header = io.BytesIO()
213
- final_wave_header.write(wave_header_bytes)
214
- final_wave_header.seek(0)
215
-
216
- return final_wave_header.getvalue()
217
 
 
 
 
 
218
 
219
  @app.get("/tts")
220
- async def tts(request: Request, text: str = Query(...)):
221
- with tts_lock:
222
- request_handler = TTSRequestHandler(current_engine)
223
- browser_request = is_browser_request(request)
224
-
225
- if play_text_to_speech_semaphore.acquire(blocking=False):
226
- try:
227
- threading.Thread(
228
- target=request_handler.play_text_to_speech,
229
- args=(text,),
230
- daemon=True,
231
- ).start()
232
- finally:
233
- play_text_to_speech_semaphore.release()
234
-
235
- return StreamingResponse(
236
- request_handler.audio_chunk_generator(browser_request),
237
- media_type="audio/wav",
238
- )
239
-
240
-
241
- @app.get("/engines")
242
- def get_engines():
243
- return list(engines.keys())
244
-
245
- @app.get("/voices")
246
- def get_voices():
247
- # falls noch keine Engine gewählt/initialisiert ist
248
- if engine is None: # <-- dein zentrales Engine-Objekt
249
- return []
250
-
251
- # OrpheusEngine.get_voices() liefert eine Liste von OrpheusVoice-Objekten
252
- return [v.__dict__ for v in engine.get_voices()]
253
-
254
-
255
- @app.get("/setvoice")
256
- def set_voice(request: Request, voice_name: str = Query(...)):
257
- print(f"Getting request: {voice_name}")
258
- if not current_engine:
259
- print("No engine is currently selected")
260
- return {"error": "No engine is currently selected"}
 
 
 
261
 
 
 
 
 
 
 
 
 
 
 
262
  try:
263
- print(f"Setting voice to {voice_name}")
264
- current_engine.set_voice(voice_name)
265
- return {"message": f"Voice set to {voice_name} successfully"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  except Exception as e:
267
- print(f"Error setting voice: {str(e)}")
268
- logging.error(f"Error setting voice: {str(e)}")
269
- return {"error": "Failed to set voice"}
270
 
 
 
 
 
 
 
 
271
 
272
  @app.get("/")
273
- def root_page():
274
- engines_options = "".join(
275
- [
276
- f'<option value="{engine}">{engine.title()}</option>'
277
- for engine in engines.keys()
278
- ]
279
- )
280
- content = f"""
281
- <!DOCTYPE html>
282
- <html>
283
- <head>
284
- <title>Text-To-Speech</title>
285
- <style>
286
- body {{
287
- font-family: Arial, sans-serif;
288
- background-color: #f0f0f0;
289
- margin: 0;
290
- padding: 0;
291
- }}
292
- h2 {{
293
- color: #333;
294
- text-align: center;
295
- }}
296
- #container {{
297
- width: 80%;
298
- margin: 50px auto;
299
- background-color: #fff;
300
- border-radius: 10px;
301
- padding: 20px;
302
- box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
303
- }}
304
- label {{
305
- font-weight: bold;
306
- }}
307
- select, textarea {{
308
- width: 100%;
309
- padding: 10px;
310
- margin: 10px 0;
311
- border: 1px solid #ccc;
312
- border-radius: 5px;
313
- box-sizing: border-box;
314
- font-size: 16px;
315
- }}
316
- button {{
317
- display: block;
318
- width: 100%;
319
- padding: 15px;
320
- background-color: #007bff;
321
- border: none;
322
- border-radius: 5px;
323
- color: #fff;
324
- font-size: 16px;
325
- cursor: pointer;
326
- transition: background-color 0.3s;
327
- }}
328
- button:hover {{
329
- background-color: #0056b3;
330
- }}
331
- audio {{
332
- width: 80%;
333
- margin: 10px auto;
334
- display: block;
335
- }}
336
- </style>
337
- </head>
338
- <body>
339
- <div id="container">
340
- <h2>Text to Speech</h2>
341
- <label for="engine">Select Engine:</label>
342
- <select id="engine">
343
- {engines_options}
344
- </select>
345
- <label for="voice">Select Voice:</label>
346
- <select id="voice">
347
- <!-- Options will be dynamically populated by JavaScript -->
348
- </select>
349
- <textarea id="text" rows="4" cols="50" placeholder="Enter text here..."></textarea>
350
- <button id="speakButton">Speak</button>
351
- <audio id="audio" controls></audio> <!-- Hidden audio player -->
352
- </div>
353
- <script src="/static/tts.js"></script>
354
- </body>
355
- </html>
356
- """
357
- return HTMLResponse(content=content)
358
-
359
 
360
  if __name__ == "__main__":
361
- print("Initializing TTS Engines")
362
-
363
- for engine_name in SUPPORTED_ENGINES:
364
- if "azure" == engine_name:
365
- azure_api_key = os.environ.get("AZURE_SPEECH_KEY")
366
- azure_region = os.environ.get("AZURE_SPEECH_REGION")
367
- if azure_api_key and azure_region:
368
- print("Initializing azure engine")
369
- engines["azure"] = AzureEngine(azure_api_key, azure_region)
370
-
371
- if "elevenlabs" == engine_name:
372
- elevenlabs_api_key = os.environ.get("ELEVENLABS_API_KEY")
373
- if elevenlabs_api_key:
374
- print("Initializing elevenlabs engine")
375
- engines["elevenlabs"] = ElevenlabsEngine(elevenlabs_api_key)
376
-
377
- if "system" == engine_name:
378
- print("Initializing system engine")
379
- engines["system"] = SystemEngine()
380
-
381
- if "coqui" == engine_name:
382
- print("Initializing coqui engine")
383
- engines["coqui"] = CoquiEngine()
384
-
385
- if "kokoro" == engine_name:
386
- print("Initializing kokoro engine")
387
- engines["kokoro"] = KokoroEngine()
388
-
389
- if "openai" == engine_name:
390
- print("Initializing openai engine")
391
- engines["openai"] = OpenAIEngine()
392
-
393
- for _engine in engines.keys():
394
- print(f"Retrieving voices for TTS Engine {_engine}")
395
- try:
396
- voices[_engine] = engines[_engine].get_voices()
397
- except Exception as e:
398
- voices[_engine] = []
399
- logging.error(f"Error retrieving voices for {_engine}: {str(e)}")
400
-
401
- _set_engine(START_ENGINE)
402
-
403
- print("Server ready")
404
- uvicorn.run(app, host="0.0.0.0", port=PORT)
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query
2
+ from fastapi.responses import StreamingResponse, JSONResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ import struct
5
+ import sys
 
 
 
 
 
 
 
6
  import os
7
+ import json
8
+ import asyncio
9
+ import logging
10
 
11
+ # Add the orpheus-tts module to the path
12
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'orpheus-tts'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ try:
15
+ from orpheus_tts.engine_class import OrpheusModel
16
+ except ImportError:
17
+ from engine_class import OrpheusModel
 
 
18
 
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ app = FastAPI(title="Orpheus TTS Server", version="1.0.0")
24
 
25
+ # Add CORS middleware for web clients
 
26
  app.add_middleware(
27
  CORSMiddleware,
28
+ allow_origins=["*"],
29
  allow_credentials=True,
30
  allow_methods=["*"],
31
  allow_headers=["*"],
32
  )
33
 
34
+ # Initialize the Orpheus model
35
+ engine = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ @app.on_event("startup")
38
+ async def startup_event():
39
+ global engine
40
  try:
41
+ engine = OrpheusModel(
42
+ model_name="SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1",
43
+ tokenizer="SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
44
+ )
45
+ logger.info("Orpheus model loaded successfully")
46
  except Exception as e:
47
+ logger.error(f"Error loading Orpheus model: {e}")
48
+ raise e
49
+
50
+ def create_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
51
+ """Create WAV header for audio streaming"""
52
+ byte_rate = sample_rate * channels * bits_per_sample // 8
53
+ block_align = channels * bits_per_sample // 8
54
+ data_size = 0
55
+
56
+ header = struct.pack(
57
+ '<4sI4s4sIHHIIHH4sI',
58
+ b'RIFF',
59
+ 36 + data_size,
60
+ b'WAVE',
61
+ b'fmt ',
62
+ 16,
63
+ 1,
64
+ channels,
65
+ sample_rate,
66
+ byte_rate,
67
+ block_align,
68
+ bits_per_sample,
69
+ b'data',
70
+ data_size
71
+ )
72
+ return header
 
 
 
 
 
 
 
73
 
74
+ @app.get("/health")
75
+ async def health_check():
76
+ """Health check endpoint"""
77
+ return {"status": "healthy", "model_loaded": engine is not None}
78
 
79
  @app.get("/tts")
80
+ async def tts_stream(
81
+ prompt: str = Query(..., description="Text to synthesize"),
82
+ voice: str = Query("Jakob", description="Voice to use"),
83
+ temperature: float = Query(0.4, description="Temperature for generation"),
84
+ top_p: float = Query(0.9, description="Top-p for generation"),
85
+ max_tokens: int = Query(2000, description="Maximum tokens"),
86
+ repetition_penalty: float = Query(1.1, description="Repetition penalty")
87
+ ):
88
+ """HTTP endpoint for TTS streaming"""
89
+ if engine is None:
90
+ raise HTTPException(status_code=503, detail="Model not loaded")
91
+
92
+ def generate_audio_stream():
93
+ try:
94
+ # Send WAV header first
95
+ yield create_wav_header()
96
+
97
+ # Generate speech tokens
98
+ syn_tokens = engine.generate_speech(
99
+ prompt=prompt,
100
+ voice=voice,
101
+ repetition_penalty=repetition_penalty,
102
+ stop_token_ids=[128258],
103
+ max_tokens=max_tokens,
104
+ temperature=temperature,
105
+ top_p=top_p
106
+ )
107
+
108
+ # Stream audio chunks
109
+ for chunk in syn_tokens:
110
+ yield chunk
111
+
112
+ except Exception as e:
113
+ logger.error(f"Error in TTS generation: {e}")
114
+ raise HTTPException(status_code=500, detail=str(e))
115
+
116
+ return StreamingResponse(
117
+ generate_audio_stream(),
118
+ media_type='audio/wav',
119
+ headers={
120
+ "Cache-Control": "no-cache",
121
+ "Connection": "keep-alive",
122
+ }
123
+ )
124
 
125
+ @app.websocket("/ws/tts")
126
+ async def websocket_tts(websocket: WebSocket):
127
+ """WebSocket endpoint for real-time TTS streaming"""
128
+ await websocket.accept()
129
+
130
+ if engine is None:
131
+ await websocket.send_json({"error": "Model not loaded"})
132
+ await websocket.close()
133
+ return
134
+
135
  try:
136
+ while True:
137
+ # Receive request from client
138
+ data = await websocket.receive_text()
139
+ request = json.loads(data)
140
+
141
+ prompt = request.get("prompt", "")
142
+ voice = request.get("voice", "Jakob")
143
+ temperature = request.get("temperature", 0.4)
144
+ top_p = request.get("top_p", 0.9)
145
+ max_tokens = request.get("max_tokens", 2000)
146
+ repetition_penalty = request.get("repetition_penalty", 1.1)
147
+
148
+ if not prompt:
149
+ await websocket.send_json({"error": "No prompt provided"})
150
+ continue
151
+
152
+ # Send status update
153
+ await websocket.send_json({"status": "generating", "prompt": prompt})
154
+
155
+ try:
156
+ # Send WAV header
157
+ wav_header = create_wav_header()
158
+ await websocket.send_bytes(wav_header)
159
+
160
+ # Generate and stream audio
161
+ syn_tokens = engine.generate_speech(
162
+ prompt=prompt,
163
+ voice=voice,
164
+ repetition_penalty=repetition_penalty,
165
+ stop_token_ids=[128258],
166
+ max_tokens=max_tokens,
167
+ temperature=temperature,
168
+ top_p=top_p
169
+ )
170
+
171
+ chunk_count = 0
172
+ for chunk in syn_tokens:
173
+ await websocket.send_bytes(chunk)
174
+ chunk_count += 1
175
+
176
+ # Send periodic status updates
177
+ if chunk_count % 10 == 0:
178
+ await websocket.send_json({
179
+ "status": "streaming",
180
+ "chunks_sent": chunk_count
181
+ })
182
+
183
+ # Send completion status
184
+ await websocket.send_json({
185
+ "status": "completed",
186
+ "total_chunks": chunk_count
187
+ })
188
+
189
+ except Exception as e:
190
+ logger.error(f"Error in WebSocket TTS generation: {e}")
191
+ await websocket.send_json({"error": str(e)})
192
+
193
+ except WebSocketDisconnect:
194
+ logger.info("WebSocket client disconnected")
195
  except Exception as e:
196
+ logger.error(f"WebSocket error: {e}")
197
+ await websocket.close()
 
198
 
199
+ @app.get("/voices")
200
+ async def get_available_voices():
201
+ """Get list of available voices"""
202
+ if engine is None:
203
+ raise HTTPException(status_code=503, detail="Model not loaded")
204
+
205
+ return {"voices": engine.available_voices}
206
 
207
  @app.get("/")
208
+ async def root():
209
+ """Root endpoint with API information"""
210
+ return {
211
+ "message": "Orpheus TTS Server",
212
+ "endpoints": {
213
+ "health": "/health",
214
+ "tts_http": "/tts?prompt=your_text&voice=Jakob",
215
+ "tts_websocket": "/ws/tts",
216
+ "voices": "/voices"
217
+ },
218
+ "model_loaded": engine is not None
219
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  if __name__ == "__main__":
222
+ import uvicorn
223
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")