Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
c0820d5
1
Parent(s):
14aa927
main.py
CHANGED
@@ -96,6 +96,7 @@ app = FastAPI()
|
|
96 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
97 |
|
98 |
# Add this at the top with other global variables
|
|
|
99 |
|
100 |
# Create a thread pool executor
|
101 |
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
@@ -226,7 +227,9 @@ async def get():
|
|
226 |
# WebSocket endpoint for continuous user interaction
|
227 |
@app.websocket("/ws")
|
228 |
async def websocket_endpoint(websocket: WebSocket):
|
229 |
-
|
|
|
|
|
230 |
print(f"New WebSocket connection: {client_id}")
|
231 |
await websocket.accept()
|
232 |
|
@@ -350,6 +353,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
350 |
print(f"[{time.perf_counter():.3f}] Sending image to client...")
|
351 |
await websocket.send_json({"image": img_str})
|
352 |
print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
|
|
|
|
|
|
|
353 |
finally:
|
354 |
is_processing = False
|
355 |
print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
|
@@ -457,7 +463,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
457 |
print("WebSocket connection timed out")
|
458 |
|
459 |
except WebSocketDisconnect:
|
460 |
-
|
|
|
|
|
461 |
break
|
462 |
|
463 |
except Exception as e:
|
@@ -475,3 +483,44 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
475 |
print(f" Average FPS: {frame_count/total_time:.2f}")
|
476 |
|
477 |
print(f"WebSocket connection closed: {client_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
97 |
|
98 |
# Add this at the top with other global variables
|
99 |
+
connection_counter = 0
|
100 |
|
101 |
# Create a thread pool executor
|
102 |
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
|
227 |
# WebSocket endpoint for continuous user interaction
|
228 |
@app.websocket("/ws")
|
229 |
async def websocket_endpoint(websocket: WebSocket):
|
230 |
+
global connection_counter
|
231 |
+
connection_counter += 1
|
232 |
+
client_id = f"{int(time.time())}_{connection_counter}"
|
233 |
print(f"New WebSocket connection: {client_id}")
|
234 |
await websocket.accept()
|
235 |
|
|
|
353 |
print(f"[{time.perf_counter():.3f}] Sending image to client...")
|
354 |
await websocket.send_json({"image": img_str})
|
355 |
print(f"[{time.perf_counter():.3f}] Image sent. Queue size before next_input: {input_queue.qsize()}")
|
356 |
+
|
357 |
+
# Log the input
|
358 |
+
log_interaction(client_id, data, generated_frame=sample_img)
|
359 |
finally:
|
360 |
is_processing = False
|
361 |
print(f"[{time.perf_counter():.3f}] Processing complete. Queue size before checking next input: {input_queue.qsize()}")
|
|
|
463 |
print("WebSocket connection timed out")
|
464 |
|
465 |
except WebSocketDisconnect:
|
466 |
+
# Log final EOS entry
|
467 |
+
log_interaction(client_id, {}, is_end_of_session=True)
|
468 |
+
print(f"WebSocket disconnected: {client_id}")
|
469 |
break
|
470 |
|
471 |
except Exception as e:
|
|
|
483 |
print(f" Average FPS: {frame_count/total_time:.2f}")
|
484 |
|
485 |
print(f"WebSocket connection closed: {client_id}")
|
486 |
+
|
487 |
+
def log_interaction(client_id, data, generated_frame=None, is_end_of_session=False):
|
488 |
+
"""Log user interaction and optionally the generated frame."""
|
489 |
+
timestamp = time.time()
|
490 |
+
|
491 |
+
# Create directory structure if it doesn't exist
|
492 |
+
os.makedirs("interaction_logs", exist_ok=True)
|
493 |
+
|
494 |
+
# Structure the log entry
|
495 |
+
log_entry = {
|
496 |
+
"timestamp": timestamp,
|
497 |
+
"client_id": client_id,
|
498 |
+
"is_eos": is_end_of_session
|
499 |
+
}
|
500 |
+
|
501 |
+
# Only include input data if this isn't an EOS token or if data is provided
|
502 |
+
if not is_end_of_session or data:
|
503 |
+
log_entry["inputs"] = {
|
504 |
+
"x": data.get("x"),
|
505 |
+
"y": data.get("y"),
|
506 |
+
"is_left_click": data.get("is_left_click"),
|
507 |
+
"is_right_click": data.get("is_right_click"),
|
508 |
+
"keys_down": data.get("keys_down", []),
|
509 |
+
"keys_up": data.get("keys_up", [])
|
510 |
+
}
|
511 |
+
else:
|
512 |
+
# For EOS records with empty data, just include minimal info
|
513 |
+
log_entry["inputs"] = None
|
514 |
+
|
515 |
+
# Save to a file (one file per session)
|
516 |
+
session_file = f"interaction_logs/session_{client_id}.jsonl"
|
517 |
+
with open(session_file, "a") as f:
|
518 |
+
f.write(json.dumps(log_entry) + "\n")
|
519 |
+
|
520 |
+
# Optionally save the frame if provided
|
521 |
+
if generated_frame is not None and not is_end_of_session:
|
522 |
+
frame_dir = f"interaction_logs/frames_{client_id}"
|
523 |
+
os.makedirs(frame_dir, exist_ok=True)
|
524 |
+
frame_file = f"{frame_dir}/{timestamp:.6f}.png"
|
525 |
+
# Save the frame as PNG
|
526 |
+
Image.fromarray(generated_frame).save(frame_file)
|