da03 commited on
Commit
c0820d5
·
1 Parent(s): 14aa927
Files changed (1) hide show
  1. main.py +51 -2
main.py CHANGED
@@ -96,6 +96,7 @@ app = FastAPI()
96
  app.mount("/static", StaticFiles(directory="static"), name="static")
97
 
98
  # Add this at the top with other global variables
 
99
 
100
  # Create a thread pool executor
101
  thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
@@ -226,7 +227,9 @@ async def get():
226
  # WebSocket endpoint for continuous user interaction
227
  @app.websocket("/ws")
228
  async def websocket_endpoint(websocket: WebSocket):
229
- client_id = id(websocket) # Use a unique identifier for each connection
 
 
230
  print(f"New WebSocket connection: {client_id}")
231
  await websocket.accept()
232
 
@@ -350,6 +353,9 @@ async def websocket_endpoint(websocket: WebSocket):
350
  print(f"[{time.perf_counter():.3f}] Sending image to client...")
351
  await websocket.send_json({"image": img_str})
352
  print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
 
 
 
353
  finally:
354
  is_processing = False
355
  print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
@@ -457,7 +463,9 @@ async def websocket_endpoint(websocket: WebSocket):
457
  print("WebSocket connection timed out")
458
 
459
  except WebSocketDisconnect:
460
- print("WebSocket disconnected")
 
 
461
  break
462
 
463
  except Exception as e:
@@ -475,3 +483,44 @@ async def websocket_endpoint(websocket: WebSocket):
475
  print(f" Average FPS: {frame_count/total_time:.2f}")
476
 
477
  print(f"WebSocket connection closed: {client_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  app.mount("/static", StaticFiles(directory="static"), name="static")
97
 
98
  # Add this at the top with other global variables
99
+ connection_counter = 0
100
 
101
  # Create a thread pool executor
102
  thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
 
227
  # WebSocket endpoint for continuous user interaction
228
  @app.websocket("/ws")
229
  async def websocket_endpoint(websocket: WebSocket):
230
+ global connection_counter
231
+ connection_counter += 1
232
+ client_id = f"{int(time.time())}_{connection_counter}"
233
  print(f"New WebSocket connection: {client_id}")
234
  await websocket.accept()
235
 
 
353
  print(f"[{time.perf_counter():.3f}] Sending image to client...")
354
  await websocket.send_json({"image": img_str})
355
  print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
356
+
357
+ # Log the input
358
+ log_interaction(client_id, data, generated_frame=sample_img)
359
  finally:
360
  is_processing = False
361
  print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
 
463
  print("WebSocket connection timed out")
464
 
465
  except WebSocketDisconnect:
466
+ # Log final EOS entry
467
+ log_interaction(client_id, {}, is_end_of_session=True)
468
+ print(f"WebSocket disconnected: {client_id}")
469
  break
470
 
471
  except Exception as e:
 
483
  print(f" Average FPS: {frame_count/total_time:.2f}")
484
 
485
  print(f"WebSocket connection closed: {client_id}")
486
+
487
+ def log_interaction(client_id, data, generated_frame=None, is_end_of_session=False):
488
+ """Log user interaction and optionally the generated frame."""
489
+ timestamp = time.time()
490
+
491
+ # Create directory structure if it doesn't exist
492
+ os.makedirs("interaction_logs", exist_ok=True)
493
+
494
+ # Structure the log entry
495
+ log_entry = {
496
+ "timestamp": timestamp,
497
+ "client_id": client_id,
498
+ "is_eos": is_end_of_session
499
+ }
500
+
501
+ # Only include input data if this isn't an EOS token or if data is provided
502
+ if not is_end_of_session or data:
503
+ log_entry["inputs"] = {
504
+ "x": data.get("x"),
505
+ "y": data.get("y"),
506
+ "is_left_click": data.get("is_left_click"),
507
+ "is_right_click": data.get("is_right_click"),
508
+ "keys_down": data.get("keys_down", []),
509
+ "keys_up": data.get("keys_up", [])
510
+ }
511
+ else:
512
+ # For EOS records with empty data, just include minimal info
513
+ log_entry["inputs"] = None
514
+
515
+ # Save to a file (one file per session)
516
+ session_file = f"interaction_logs/session_{client_id}.jsonl"
517
+ with open(session_file, "a") as f:
518
+ f.write(json.dumps(log_entry) + "\n")
519
+
520
+ # Optionally save the frame if provided
521
+ if generated_frame is not None and not is_end_of_session:
522
+ frame_dir = f"interaction_logs/frames_{client_id}"
523
+ os.makedirs(frame_dir, exist_ok=True)
524
+ frame_file = f"{frame_dir}/{timestamp:.6f}.png"
525
+ # Save the frame as PNG
526
+ Image.fromarray(generated_frame).save(frame_file)