Tomtom84 commited on
Commit
55145d2
·
verified ·
1 Parent(s): 641d199

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -58
app.py CHANGED
@@ -24,43 +24,37 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  if HF_TOKEN:
26
  print("🔑 Logging in to Hugging Face Hub...")
27
- # Consider adding error handling for login failure if necessary
28
  login(HF_TOKEN)
29
 
30
  # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug
31
 
32
  # 1) Konstanten -------------------------------------------------------
33
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
 
34
  START_TOKEN = 128259
35
- NEW_BLOCK = 128257 # Token indicating start of audio generation
36
- EOS_TOKEN = 128258 # End Of Speech token
37
- AUDIO_BASE = 128266 # Base ID for audio tokens
38
- AUDIO_SPAN = 4096 * 7 # 7 codebooks * 4096 codes per book = 28672 possible audio tokens
39
  # Create AUDIO_IDS on the correct device later in load_models
40
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
41
 
42
  # 2) Logit‑Mask -------------------------------------------------------
43
  class AudioMask(LogitsProcessor):
44
- """
45
- Manages allowed tokens during generation.
46
- - Initially allows NEW_BLOCK and AUDIO tokens.
47
- - Allows EOS_TOKEN only after at least one audio block has been sent.
48
- """
49
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
50
  super().__init__()
51
  # Allow NEW_BLOCK and all valid audio tokens initially
52
- self.allow_initial = torch.cat([
53
- torch.tensor([new_block_token_id], device=audio_ids.device),
54
  audio_ids
55
  ], dim=0)
56
- self.eos = torch.tensor([eos_token_id], device=audio_ids.device)
57
- # Precompute combined tensor for allowing audio, NEW_BLOCK, and EOS
58
- self.allow_with_eos = torch.cat([self.allow_initial, self.eos], dim=0)
59
  self.sent_blocks = 0 # State: Number of audio blocks sent
60
 
61
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
62
  # Determine which tokens are allowed based on whether blocks have been sent
63
- current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow_initial
64
 
65
  # Create a mask initialized to negative infinity
66
  mask = torch.full_like(scores, float("-inf"))
@@ -81,25 +75,20 @@ class EosStoppingCriteria(StoppingCriteria):
81
 
82
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
83
  # Check if the *last* generated token is the EOS token
84
- # Check input_ids shape to prevent index error on first token
85
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
86
- print("StoppingCriteria: EOS detected.")
87
  return True
88
  return False
89
 
90
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
91
  class AudioStreamer(BaseStreamer):
92
- """
93
- Custom streamer to process audio tokens, decode them using SNAC,
94
- and send audio bytes over a WebSocket.
95
- """
96
- # Added target_device parameter
97
  def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
98
  self.ws = ws
99
  self.snac = snac_decoder
100
  self.masker = audio_mask # Reference to the mask to update sent_blocks
101
  self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
102
- # Use the passed target_device
103
  self.device = target_device
104
  self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
105
  self.tasks = set() # Keep track of pending send tasks
@@ -108,46 +97,46 @@ class AudioStreamer(BaseStreamer):
108
  """
109
  Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
110
  NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
111
- is CRITICAL and based on the structure used by the specific model.
112
- This implementation uses the mapping derived from the user's previous code.
113
- If audio is distorted, try the alternative mapping commented out below.
114
  """
115
  if len(block7) != 7:
116
  print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
117
  return b"" # Return empty bytes if block is incomplete
118
 
119
- # --- Mapping based on user's previous version ---
 
120
  try:
121
  l1 = [block7[0]] # Index 0
122
  l2 = [block7[1], block7[4]] # Indices 1, 4
123
  l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6
124
- # --- Alternative Hypothesis Mapping (Try if above fails) ---
125
- # l1 = [block7[0], block7[3], block7[6]] # Indices 0, 3, 6
126
- # l2 = [block7[1], block7[4]] # Indices 1, 4
127
- # l3 = [block7[2], block7[5]] # Indices 2, 5
128
- except IndexError as e:
129
- print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}, Error: {e}")
130
  return b""
131
 
 
 
 
 
 
 
 
 
 
 
 
132
  # Convert lists to tensors on the correct device
133
- try:
134
- codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
135
- codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
136
- codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
137
- codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
138
- except Exception as e:
139
- print(f"Streamer Error: Failed converting lists to tensors. Error: {e}")
140
- return b""
141
 
142
  # Decode using SNAC
143
- try:
144
- with torch.no_grad():
145
- # Ensure snac_decoder is on the correct device already (done via .to(device))
146
- audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
147
- except Exception as e:
148
- print(f"Streamer Error: snac.decode failed. Input shapes: {[c.shape for c in codes]}. Error: {e}")
149
- return b""
150
-
151
 
152
  # Squeeze, move to CPU, convert to numpy
153
  audio_np = audio.squeeze().detach().cpu().numpy()
@@ -162,7 +151,7 @@ class AudioStreamer(BaseStreamer):
162
  return
163
  try:
164
  await self.ws.send_bytes(data)
165
- # print(f"Streamer: Sent {len(data)} audio bytes.") # Optional: Debug log
166
  except WebSocketDisconnect:
167
  print("Streamer: WebSocket disconnected during send.")
168
  except Exception as e:
@@ -176,12 +165,9 @@ class AudioStreamer(BaseStreamer):
176
  # Ensure value is on CPU and flatten to a list of ints
177
  if value.numel() == 0:
178
  return
179
- # Handle potential shape issues, ensure it's iterable
180
- try:
181
- new_token_ids = value.view(-1).tolist()
182
- except Exception as e:
183
- print(f"Streamer Error: Could not process incoming tensor: {value}, Error: {e}")
184
- return
185
 
186
  for t in new_token_ids:
187
  if t == EOS_TOKEN:
@@ -189,4 +175,226 @@ class AudioStreamer(BaseStreamer):
189
  # EOS is handled by StoppingCriteria, no action needed here except maybe logging.
190
  break # Stop processing this batch if EOS is found
191
 
192
- if t == NEW_BLOCK:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  if HF_TOKEN:
26
  print("🔑 Logging in to Hugging Face Hub...")
 
27
  login(HF_TOKEN)
28
 
29
  # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug
30
 
31
  # 1) Konstanten -------------------------------------------------------
32
  REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
33
+ # CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
34
  START_TOKEN = 128259
35
+ NEW_BLOCK = 128257
36
+ EOS_TOKEN = 128258
37
+ AUDIO_BASE = 128266
38
+ AUDIO_SPAN = 4096 * 7 # 28672 Codes
39
  # Create AUDIO_IDS on the correct device later in load_models
40
  AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
41
 
42
  # 2) Logit‑Mask -------------------------------------------------------
43
  class AudioMask(LogitsProcessor):
 
 
 
 
 
44
  def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
45
  super().__init__()
46
  # Allow NEW_BLOCK and all valid audio tokens initially
47
+ self.allow = torch.cat([
48
+ torch.tensor([new_block_token_id], device=audio_ids.device), # Add NEW_BLOCK token ID
49
  audio_ids
50
  ], dim=0)
51
+ self.eos = torch.tensor([eos_token_id], device=audio_ids.device) # Store EOS token ID as tensor
52
+ self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor
 
53
  self.sent_blocks = 0 # State: Number of audio blocks sent
54
 
55
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56
  # Determine which tokens are allowed based on whether blocks have been sent
57
+ current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
58
 
59
  # Create a mask initialized to negative infinity
60
  mask = torch.full_like(scores, float("-inf"))
 
75
 
76
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
77
  # Check if the *last* generated token is the EOS token
 
78
  if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
79
+ # print("StoppingCriteria: EOS detected.")
80
  return True
81
  return False
82
 
83
  # 4) Benutzerdefinierter AudioStreamer -------------------------------
84
  class AudioStreamer(BaseStreamer):
85
+ # --- Updated __init__ to accept target_device ---
 
 
 
 
86
  def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
87
  self.ws = ws
88
  self.snac = snac_decoder
89
  self.masker = audio_mask # Reference to the mask to update sent_blocks
90
  self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
91
+ # --- Use the passed target_device ---
92
  self.device = target_device
93
  self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
94
  self.tasks = set() # Keep track of pending send tasks
 
97
  """
98
  Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
99
  NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
100
+ is based on the structure found in the previous while-loop version.
101
+ If audio is distorted, this mapping is the primary suspect.
102
+ Ensure this mapping is correct for the specific model!
103
  """
104
  if len(block7) != 7:
105
  print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
106
  return b"" # Return empty bytes if block is incomplete
107
 
108
+ # --- Mapping derived from previous user version (indices [0], [1,4], [2,3,5,6]) ---
109
+ # This seems more likely to be correct for Kartoffel_Orpheus if the previous version worked.
110
  try:
111
  l1 = [block7[0]] # Index 0
112
  l2 = [block7[1], block7[4]] # Indices 1, 4
113
  l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6
114
+ except IndexError:
115
+ print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
 
 
 
 
116
  return b""
117
 
118
+ # --- Alternative Hypothesis (commented out): Interleaving mapping ---
119
+ # try:
120
+ # l1 = [block7[0], block7[3], block7[6]] # Codebook 1 indices: 0, 3, 6
121
+ # l2 = [block7[1], block7[4]] # Codebook 2 indices: 1, 4
122
+ # l3 = [block7[2], block7[5]] # Codebook 3 indices: 2, 5
123
+ # except IndexError:
124
+ # print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
125
+ # return b""
126
+ # --- End Alternative Hypothesis ---
127
+
128
+
129
  # Convert lists to tensors on the correct device
130
+ # Use self.device which was set correctly in __init__
131
+ codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
132
+ codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
133
+ codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
134
+ codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
 
 
 
135
 
136
  # Decode using SNAC
137
+ with torch.no_grad():
138
+ # self.snac should already be on self.device from load_models_startup
139
+ audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
 
 
 
 
 
140
 
141
  # Squeeze, move to CPU, convert to numpy
142
  audio_np = audio.squeeze().detach().cpu().numpy()
 
151
  return
152
  try:
153
  await self.ws.send_bytes(data)
154
+ # print(f"Streamer: Sent {len(data)} audio bytes.")
155
  except WebSocketDisconnect:
156
  print("Streamer: WebSocket disconnected during send.")
157
  except Exception as e:
 
165
  # Ensure value is on CPU and flatten to a list of ints
166
  if value.numel() == 0:
167
  return
168
+ new_token_ids = value.squeeze().tolist()
169
+ if isinstance(new_token_ids, int): # Handle single token case
170
+ new_token_ids = [new_token_ids]
 
 
 
171
 
172
  for t in new_token_ids:
173
  if t == EOS_TOKEN:
 
175
  # EOS is handled by StoppingCriteria, no action needed here except maybe logging.
176
  break # Stop processing this batch if EOS is found
177
 
178
+ if t == NEW_BLOCK:
179
+ # print("Streamer: NEW_BLOCK token encountered.")
180
+ # NEW_BLOCK indicates the start of audio, might reset buffer if needed
181
+ self.buf.clear()
182
+ continue # Move to the next token
183
+
184
+ # Check if token is within the expected audio range
185
+ if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
186
+ # Store value relative to base (IMPORTANT for _decode_block)
187
+ self.buf.append(t - AUDIO_BASE)
188
+ else:
189
+ # Log unexpected tokens (like START_TOKEN or others if generation goes wrong)
190
+ # print(f"Streamer Warning: Ignoring unexpected token {t}")
191
+ pass # Ignore tokens outside the audio range
192
+
193
+ # If buffer has 7 tokens, decode and send
194
+ if len(self.buf) == 7:
195
+ audio_bytes = self._decode_block(self.buf)
196
+ self.buf.clear() # Clear buffer after processing
197
+
198
+ if audio_bytes: # Only send if decoding was successful
199
+ # Schedule the async send function to run on the main event loop
200
+ future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
201
+ self.tasks.add(future)
202
+ # Optional: Remove completed tasks to prevent memory leak if generation is very long
203
+ future.add_done_callback(self.tasks.discard)
204
+
205
+ # Allow EOS only after the first full block has been processed and scheduled for sending
206
+ if self.masker.sent_blocks == 0:
207
+ # print("Streamer: First audio block processed, allowing EOS.")
208
+ self.masker.sent_blocks = 1 # Update state in the mask
209
+
210
+ # Note: No need to explicitly wait for tasks here. put() should return quickly.
211
+
212
+ def end(self):
213
+ """Called by generate() when generation finishes."""
214
+ # Handle any remaining tokens in the buffer (optional, here we discard them)
215
+ if len(self.buf) > 0:
216
+ print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.")
217
+ self.buf.clear()
218
+
219
+ # Optional: Wait briefly for any outstanding send tasks to complete?
220
+ # This is tricky because end() is sync. A robust solution might involve
221
+ # signaling the WebSocket handler to wait before closing.
222
+ # For simplicity, we rely on FastAPI/Uvicorn's graceful shutdown handling.
223
+ # print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}")
224
+ pass
225
+
226
+ # 5) FastAPI App ------------------------------------------------------
227
+ app = FastAPI()
228
+
229
+ @app.on_event("startup")
230
+ async def load_models_startup(): # Make startup async if needed for future async loads
231
+ global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
232
+
233
+ print(f"🚀 Starting up on device: {device}")
234
+ print("⏳ Lade Modelle …", flush=True)
235
+
236
+ tok = AutoTokenizer.from_pretrained(REPO)
237
+ print("Tokenizer loaded.")
238
+
239
+ # Load SNAC first (usually smaller)
240
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
241
+ # --- FIXED Print statement ---
242
+ print(f"SNAC loaded to {device}.") # Use the global device variable
243
+
244
+ # Load the main model
245
+ # Determine appropriate dtype based on device and support
246
+ model_dtype = torch.float32 # Default to float32 for CPU
247
+ if device == "cuda":
248
+ if torch.cuda.is_bf16_supported():
249
+ model_dtype = torch.bfloat16
250
+ print("Using bfloat16 for model.")
251
+ else:
252
+ model_dtype = torch.float16 # Fallback to float16 if bfloat16 not supported
253
+ print("Using float16 for model.")
254
+
255
+ model = AutoModelForCausalLM.from_pretrained(
256
+ REPO,
257
+ device_map={"": 0} if device == "cuda" else None, # Assign to GPU 0 if cuda
258
+ torch_dtype=model_dtype,
259
+ low_cpu_mem_usage=True, # Good practice for large models
260
+ )
261
+ model.config.pad_token_id = model.config.eos_token_id # Set pad token
262
+ print(f"Model loaded to {model.device} with dtype {model.dtype}.")
263
+
264
+ # Ensure model is in evaluation mode
265
+ model.eval()
266
+
267
+ # Initialize AudioMask (needs AUDIO_IDS on the correct device)
268
+ audio_ids_device = AUDIO_IDS_CPU.to(device)
269
+ masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
270
+ print("AudioMask initialized.")
271
+
272
+ # Initialize StoppingCriteria
273
+ # IMPORTANT: Create the list and add the criteria instance
274
+ stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
275
+ print("StoppingCriteria initialized.")
276
+
277
+ print("✅ Modelle geladen und bereit!", flush=True)
278
+
279
+ @app.get("/")
280
+ def hello():
281
+ return {"status": "ok", "message": "TTS Service is running"}
282
+
283
+ # 6) Helper zum Prompt Bauen -------------------------------------------
284
+ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
285
+ """Builds the input_ids and attention_mask for the model."""
286
+ # Format: <START> <VOICE>: <TEXT> <NEW_BLOCK>
287
+ prompt_text = f"{voice}: {text}"
288
+ prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device)
289
+
290
+ # Construct input_ids tensor
291
+ input_ids = torch.cat([
292
+ torch.tensor([[START_TOKEN]], device=device, dtype=torch.long), # Start token
293
+ prompt_ids, # Encoded prompt
294
+ torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long) # New block token to trigger audio
295
+ ], dim=1)
296
+
297
+ # Create attention mask (all ones)
298
+ attention_mask = torch.ones_like(input_ids)
299
+ return input_ids, attention_mask
300
+
301
+ # 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
302
+ @app.websocket("/ws/tts")
303
+ async def tts(ws: WebSocket):
304
+ await ws.accept()
305
+ print("🔌 Client connected")
306
+ streamer = None # Initialize for finally block
307
+ main_loop = asyncio.get_running_loop() # Get the current event loop
308
+
309
+ try:
310
+ # Receive configuration
311
+ req_text = await ws.receive_text()
312
+ print(f"Received request: {req_text}")
313
+ req = json.loads(req_text)
314
+ text = req.get("text", "Hallo Welt, wie geht es dir heute?") # Default text
315
+ voice = req.get("voice", "Jakob") # Default voice
316
+
317
+ if not text:
318
+ print("⚠️ Request text is empty.")
319
+ await ws.close(code=1003, reason="Text cannot be empty") # 1003 = Cannot accept data type
320
+ return
321
+
322
+ print(f"Generating audio for: '{text}' with voice '{voice}'")
323
+
324
+ # Prepare prompt
325
+ ids, attn = build_prompt(text, voice)
326
+
327
+ # --- Reset stateful components ---
328
+ masker.reset() # CRITICAL: Reset the mask state for the new request
329
+
330
+ # --- Create Streamer Instance ---
331
+ # --- Pass the global 'device' variable ---
332
+ streamer = AudioStreamer(ws, snac, masker, main_loop, device)
333
+
334
+ # --- Run model.generate in a separate thread ---
335
+ # This prevents blocking the main FastAPI event loop
336
+ print("Starting generation in background thread...")
337
+ await asyncio.to_thread(
338
+ model.generate,
339
+ input_ids=ids,
340
+ attention_mask=attn,
341
+ max_new_tokens=1500, # Limit generation length (adjust as needed)
342
+ logits_processor=[masker],
343
+ stopping_criteria=stopping_criteria,
344
+ do_sample=False, # Use greedy decoding for potentially more stable audio
345
+ # do_sample=True, temperature=0.7, top_p=0.95, # Or use sampling
346
+ use_cache=True,
347
+ streamer=streamer # Pass the custom streamer
348
+ # No need to manage past_key_values manually
349
+ )
350
+ print("Generation thread finished.")
351
+
352
+ except WebSocketDisconnect:
353
+ print("🔌 Client disconnected.")
354
+ except json.JSONDecodeError:
355
+ print("❌ Invalid JSON received.")
356
+ if ws.client_state.name == "CONNECTED":
357
+ await ws.close(code=1003, reason="Invalid JSON format")
358
+ except Exception as e:
359
+ error_details = traceback.format_exc()
360
+ print(f"❌ WS‑Error: {e}\n{error_details}", flush=True)
361
+ # Try to send an error message before closing, if possible
362
+ error_payload = json.dumps({"error": str(e)})
363
+ try:
364
+ if ws.client_state.name == "CONNECTED":
365
+ await ws.send_text(error_payload) # Send error as text/json
366
+ except Exception:
367
+ pass # Ignore error during error reporting
368
+ # Close with internal server error code
369
+ if ws.client_state.name == "CONNECTED":
370
+ await ws.close(code=1011) # 1011 = Internal Server Error
371
+ finally:
372
+ # Ensure streamer's end method is called if it exists
373
+ if streamer:
374
+ try:
375
+ # print("Calling streamer.end()")
376
+ streamer.end()
377
+ except Exception as e_end:
378
+ print(f"Error during streamer.end(): {e_end}")
379
+
380
+ # Ensure WebSocket is closed
381
+ print("Closing connection.")
382
+ if ws.client_state.name == "CONNECTED":
383
+ try:
384
+ await ws.close(code=1000) # 1000 = Normal Closure
385
+ except RuntimeError as e_close:
386
+ # Can happen if connection is already closing/closed
387
+ print(f"Runtime error closing websocket: {e_close}")
388
+ except Exception as e_close_final:
389
+ print(f"Error closing websocket: {e_close_final}")
390
+ elif ws.client_state.name != "DISCONNECTED":
391
+ print(f"WebSocket final state: {ws.client_state.name}")
392
+ print("Connection closed.")
393
+
394
+ # 8) Dev‑Start --------------------------------------------------------
395
+ if __name__ == "__main__":
396
+ import uvicorn
397
+ print("Starting Uvicorn server...")
398
+ # Use reload=True only for development, remove for production
399
+ # Consider adding --workers 1 if you experience issues with multiple workers and global state/GPU memory
400
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") #, reload=True)