da03 commited on
Commit
0adb69d
·
1 Parent(s): 406c8c8
Files changed (1) hide show
  1. main.py +35 -13
main.py CHANGED
@@ -180,13 +180,14 @@ async def websocket_endpoint(websocket: WebSocket):
180
  nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing
181
 
182
  try:
 
 
183
  is_processing = True
184
  frame_num += 1
185
  frame_count += 1 # Increment total frame counter
186
- start_frame = time.perf_counter()
187
 
188
  # Calculate global FPS
189
- total_elapsed = start_frame - connection_start_time
190
  global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
191
 
192
  x = data.get("x")
@@ -195,7 +196,7 @@ async def websocket_endpoint(websocket: WebSocket):
195
  is_right_click = data.get("is_right_click")
196
  keys_down_list = data.get("keys_down", []) # Get as list
197
  keys_up_list = data.get("keys_up", [])
198
- print(f'x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}')
199
 
200
  # Update the set based on the received data
201
  for key in keys_down_list:
@@ -205,9 +206,11 @@ async def websocket_endpoint(websocket: WebSocket):
205
  keys_down.remove(key)
206
 
207
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
 
208
  previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
209
- timing_info['full_frame'] = time.perf_counter() - start_frame
210
 
 
211
  # Use the provided function to print timing statistics
212
  print_timing_stats(timing_info, frame_num)
213
 
@@ -220,48 +223,63 @@ async def websocket_endpoint(websocket: WebSocket):
220
  img_str = base64.b64encode(buffered.getvalue()).decode()
221
 
222
  # Send the generated frame back to the client
 
223
  await websocket.send_json({"image": img_str})
 
224
  finally:
225
  is_processing = False
 
226
  # Check if we have more inputs to process after this one
227
  process_next_input()
228
 
229
  def process_next_input():
230
  nonlocal input_queue
231
 
232
- if not input_queue or is_processing:
 
 
233
  return
234
 
 
 
 
 
 
 
235
  # Find the most recent interesting input (click or key event)
236
  interesting_indices = [i for i, data in enumerate(input_queue)
237
  if data.get("is_left_click") or
238
  data.get("is_right_click") or
239
- data.get("keys_down") or
240
- data.get("keys_up")]
241
 
242
  if interesting_indices:
243
  # There are interesting events - take the most recent one
244
  idx = interesting_indices[-1]
245
  next_input = input_queue[idx]
 
246
 
247
  # Clear all inputs up to and including this one
248
  input_queue = input_queue[idx+1:]
249
 
250
- print(f"Processing interesting input (skipped {idx} events)")
251
  else:
252
  # No interesting events - just take the most recent movement
 
253
  next_input = input_queue[-1]
254
- skipped_count = len(input_queue) - 1 # We're processing one, so skipped = total - 1
255
  input_queue = []
256
- print(f"Processing latest movement (skipped {skipped_count} events)")
257
 
258
  # Process the selected input asynchronously
 
259
  asyncio.create_task(process_input(next_input))
260
 
261
  while True:
262
  try:
263
  # Receive user input
 
264
  data = await websocket.receive_json()
 
265
 
266
  if data.get("type") == "heartbeat":
267
  await websocket.send_json({"type": "heartbeat_response"})
@@ -269,21 +287,26 @@ async def websocket_endpoint(websocket: WebSocket):
269
 
270
  # Add the input to our queue
271
  input_queue.append(data)
272
- print (f"Input queue length: {len(input_queue)}")
273
 
274
  # If we're not currently processing, start processing this input
275
  if not is_processing:
 
276
  process_next_input()
 
 
277
 
278
  except asyncio.TimeoutError:
279
  print("WebSocket connection timed out")
280
 
281
  except WebSocketDisconnect:
282
  print("WebSocket disconnected")
283
- #break # Exit the loop on disconnect
284
 
285
  except Exception as e:
286
  print(f"Error in WebSocket connection {client_id}: {e}")
 
 
287
 
288
  finally:
289
  # Print final FPS statistics when connection ends
@@ -295,4 +318,3 @@ async def websocket_endpoint(websocket: WebSocket):
295
  print(f" Average FPS: {frame_count/total_time:.2f}")
296
 
297
  print(f"WebSocket connection closed: {client_id}")
298
- #await websocket.close() # Ensure the WebSocket is closed
 
180
  nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing
181
 
182
  try:
183
+ process_start_time = time.perf_counter()
184
+ print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {len(input_queue)}")
185
  is_processing = True
186
  frame_num += 1
187
  frame_count += 1 # Increment total frame counter
 
188
 
189
  # Calculate global FPS
190
+ total_elapsed = process_start_time - connection_start_time
191
  global_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
192
 
193
  x = data.get("x")
 
196
  is_right_click = data.get("is_right_click")
197
  keys_down_list = data.get("keys_down", []) # Get as list
198
  keys_up_list = data.get("keys_up", [])
199
+ print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}')
200
 
201
  # Update the set based on the received data
202
  for key in keys_down_list:
 
206
  keys_down.remove(key)
207
 
208
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
209
+ print(f"[{time.perf_counter():.3f}] Starting model inference...")
210
  previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
211
+ timing_info['full_frame'] = time.perf_counter() - process_start_time
212
 
213
+ print(f"[{time.perf_counter():.3f}] Model inference complete. Queue size now: {len(input_queue)}")
214
  # Use the provided function to print timing statistics
215
  print_timing_stats(timing_info, frame_num)
216
 
 
223
  img_str = base64.b64encode(buffered.getvalue()).decode()
224
 
225
  # Send the generated frame back to the client
226
+ print(f"[{time.perf_counter():.3f}] Sending image to client...")
227
  await websocket.send_json({"image": img_str})
228
+ print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {len(input_queue)}")
229
  finally:
230
  is_processing = False
231
+ print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {len(input_queue)}")
232
  # Check if we have more inputs to process after this one
233
  process_next_input()
234
 
235
  def process_next_input():
236
  nonlocal input_queue
237
 
238
+ current_time = time.perf_counter()
239
+ if not input_queue:
240
+ print(f"[{current_time:.3f}] No inputs to process. Queue is empty.")
241
  return
242
 
243
+ if is_processing:
244
+ print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
245
+ return
246
+
247
+ print(f"[{current_time:.3f}] Processing next input. Queue size: {len(input_queue)}")
248
+
249
  # Find the most recent interesting input (click or key event)
250
  interesting_indices = [i for i, data in enumerate(input_queue)
251
  if data.get("is_left_click") or
252
  data.get("is_right_click") or
253
+ (data.get("keys_down") and len(data.get("keys_down")) > 0) or
254
+ (data.get("keys_up") and len(data.get("keys_up")) > 0)]
255
 
256
  if interesting_indices:
257
  # There are interesting events - take the most recent one
258
  idx = interesting_indices[-1]
259
  next_input = input_queue[idx]
260
+ skipped = idx # Number of events we're skipping
261
 
262
  # Clear all inputs up to and including this one
263
  input_queue = input_queue[idx+1:]
264
 
265
+ print(f"[{current_time:.3f}] Processing interesting input (skipped {skipped} events). Queue size now: {len(input_queue)}")
266
  else:
267
  # No interesting events - just take the most recent movement
268
+ skipped = len(input_queue) - 1 # We're processing one, so skipped = total - 1
269
  next_input = input_queue[-1]
 
270
  input_queue = []
271
+ print(f"[{current_time:.3f}] Processing latest movement (skipped {skipped} events). Queue now empty.")
272
 
273
  # Process the selected input asynchronously
274
+ print(f"[{current_time:.3f}] Creating task to process input...")
275
  asyncio.create_task(process_input(next_input))
276
 
277
  while True:
278
  try:
279
  # Receive user input
280
+ print(f"[{time.perf_counter():.3f}] Waiting for input... Queue size: {len(input_queue)}, is_processing: {is_processing}")
281
  data = await websocket.receive_json()
282
+ receive_time = time.perf_counter()
283
 
284
  if data.get("type") == "heartbeat":
285
  await websocket.send_json({"type": "heartbeat_response"})
 
287
 
288
  # Add the input to our queue
289
  input_queue.append(data)
290
+ print(f"[{receive_time:.3f}] Received input. Queue size now: {len(input_queue)}")
291
 
292
  # If we're not currently processing, start processing this input
293
  if not is_processing:
294
+ print(f"[{receive_time:.3f}] Not currently processing, will call process_next_input()")
295
  process_next_input()
296
+ else:
297
+ print(f"[{receive_time:.3f}] Currently processing, new input queued for later")
298
 
299
  except asyncio.TimeoutError:
300
  print("WebSocket connection timed out")
301
 
302
  except WebSocketDisconnect:
303
  print("WebSocket disconnected")
304
+ break
305
 
306
  except Exception as e:
307
  print(f"Error in WebSocket connection {client_id}: {e}")
308
+ import traceback
309
+ traceback.print_exc()
310
 
311
  finally:
312
  # Print final FPS statistics when connection ends
 
318
  print(f" Average FPS: {frame_count/total_time:.2f}")
319
 
320
  print(f"WebSocket connection closed: {client_id}")