Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
21002d0
1
Parent(s):
dc7c204
- main.py +67 -0
- static/index.html +73 -1
main.py
CHANGED
@@ -131,6 +131,10 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
131 |
# Add this at the top with other global variables
|
132 |
connection_counter = 0
|
133 |
|
|
|
|
|
|
|
|
|
134 |
# Create a thread pool executor
|
135 |
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
136 |
|
@@ -289,6 +293,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
289 |
"sampling_steps": NUM_SAMPLING_STEPS # Start with default global value
|
290 |
}
|
291 |
|
|
|
|
|
|
|
|
|
|
|
292 |
# Start timing for global FPS calculation
|
293 |
connection_start_time = time.perf_counter()
|
294 |
frame_count = 0
|
@@ -373,6 +382,49 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
373 |
# Send confirmation to client
|
374 |
await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn})
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
async def process_input(data):
|
377 |
nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing
|
378 |
|
@@ -401,6 +453,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
401 |
is_auto_input = data.get("is_auto_input", False)
|
402 |
if is_auto_input:
|
403 |
print (f'[{time.perf_counter():.3f}] Auto-input detected')
|
|
|
|
|
|
|
404 |
print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}')
|
405 |
|
406 |
# Update the set based on the received data
|
@@ -542,24 +597,28 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
542 |
# Handle reset command
|
543 |
if data.get("type") == "reset":
|
544 |
print(f"[{receive_time:.3f}] Received reset command")
|
|
|
545 |
await reset_simulation()
|
546 |
continue
|
547 |
|
548 |
# Handle sampling steps update
|
549 |
if data.get("type") == "update_sampling_steps":
|
550 |
print(f"[{receive_time:.3f}] Received request to update sampling steps")
|
|
|
551 |
await update_sampling_steps(data.get("steps", 32))
|
552 |
continue
|
553 |
|
554 |
# Handle USE_RNN update
|
555 |
if data.get("type") == "update_use_rnn":
|
556 |
print(f"[{receive_time:.3f}] Received request to update USE_RNN")
|
|
|
557 |
await update_use_rnn(data.get("use_rnn", False))
|
558 |
continue
|
559 |
|
560 |
# Handle settings request
|
561 |
if data.get("type") == "get_settings":
|
562 |
print(f"[{receive_time:.3f}] Received request for current settings")
|
|
|
563 |
await websocket.send_json({
|
564 |
"type": "settings",
|
565 |
"sampling_steps": client_settings["sampling_steps"],
|
@@ -594,6 +653,14 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
594 |
traceback.print_exc()
|
595 |
|
596 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
# Print final FPS statistics when connection ends
|
598 |
if frame_num >= 0: # Only if we processed at least one frame
|
599 |
total_time = time.perf_counter() - connection_start_time
|
|
|
131 |
# Add this at the top with other global variables
|
132 |
connection_counter = 0
|
133 |
|
134 |
+
# Connection timeout settings
|
135 |
+
CONNECTION_TIMEOUT = 60 # 1 minute timeout
|
136 |
+
WARNING_TIME = 30 # 30 seconds warning before timeout
|
137 |
+
|
138 |
# Create a thread pool executor
|
139 |
thread_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
140 |
|
|
|
293 |
"sampling_steps": NUM_SAMPLING_STEPS # Start with default global value
|
294 |
}
|
295 |
|
296 |
+
# Connection timeout tracking
|
297 |
+
last_user_activity_time = time.perf_counter()
|
298 |
+
timeout_warning_sent = False
|
299 |
+
timeout_task = None
|
300 |
+
|
301 |
# Start timing for global FPS calculation
|
302 |
connection_start_time = time.perf_counter()
|
303 |
frame_count = 0
|
|
|
382 |
# Send confirmation to client
|
383 |
await websocket.send_json({"type": "rnn_updated", "use_rnn": use_rnn})
|
384 |
|
385 |
+
# Add timeout checking function
|
386 |
+
async def check_timeout():
|
387 |
+
nonlocal timeout_warning_sent, timeout_task
|
388 |
+
|
389 |
+
while True:
|
390 |
+
try:
|
391 |
+
current_time = time.perf_counter()
|
392 |
+
time_since_activity = current_time - last_user_activity_time
|
393 |
+
|
394 |
+
# Send warning at 30 seconds
|
395 |
+
if time_since_activity >= WARNING_TIME and not timeout_warning_sent:
|
396 |
+
print(f"[{current_time:.3f}] Sending timeout warning to client {client_id}")
|
397 |
+
await websocket.send_json({
|
398 |
+
"type": "timeout_warning",
|
399 |
+
"timeout_in": CONNECTION_TIMEOUT - WARNING_TIME
|
400 |
+
})
|
401 |
+
timeout_warning_sent = True
|
402 |
+
|
403 |
+
# Close connection at 1 minute
|
404 |
+
if time_since_activity >= CONNECTION_TIMEOUT:
|
405 |
+
print(f"[{current_time:.3f}] Closing connection {client_id} due to timeout")
|
406 |
+
await websocket.close(code=1000, reason="User inactivity timeout")
|
407 |
+
return
|
408 |
+
|
409 |
+
await asyncio.sleep(1) # Check every second
|
410 |
+
|
411 |
+
except Exception as e:
|
412 |
+
print(f"[{time.perf_counter():.3f}] Error in timeout check for client {client_id}: {e}")
|
413 |
+
break
|
414 |
+
|
415 |
+
# Function to update user activity
|
416 |
+
def update_user_activity():
|
417 |
+
nonlocal last_user_activity_time, timeout_warning_sent
|
418 |
+
last_user_activity_time = time.perf_counter()
|
419 |
+
if timeout_warning_sent:
|
420 |
+
print(f"[{time.perf_counter():.3f}] User activity detected, resetting timeout warning for client {client_id}")
|
421 |
+
timeout_warning_sent = False
|
422 |
+
# Send activity reset notification to client
|
423 |
+
asyncio.create_task(websocket.send_json({"type": "activity_reset"}))
|
424 |
+
|
425 |
+
# Start timeout checking
|
426 |
+
timeout_task = asyncio.create_task(check_timeout())
|
427 |
+
|
428 |
async def process_input(data):
|
429 |
nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing
|
430 |
|
|
|
453 |
is_auto_input = data.get("is_auto_input", False)
|
454 |
if is_auto_input:
|
455 |
print (f'[{time.perf_counter():.3f}] Auto-input detected')
|
456 |
+
else:
|
457 |
+
# Update user activity for non-auto inputs
|
458 |
+
update_user_activity()
|
459 |
print(f'[{time.perf_counter():.3f}] Processing: x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}')
|
460 |
|
461 |
# Update the set based on the received data
|
|
|
597 |
# Handle reset command
|
598 |
if data.get("type") == "reset":
|
599 |
print(f"[{receive_time:.3f}] Received reset command")
|
600 |
+
update_user_activity() # Reset activity timer
|
601 |
await reset_simulation()
|
602 |
continue
|
603 |
|
604 |
# Handle sampling steps update
|
605 |
if data.get("type") == "update_sampling_steps":
|
606 |
print(f"[{receive_time:.3f}] Received request to update sampling steps")
|
607 |
+
update_user_activity() # Reset activity timer
|
608 |
await update_sampling_steps(data.get("steps", 32))
|
609 |
continue
|
610 |
|
611 |
# Handle USE_RNN update
|
612 |
if data.get("type") == "update_use_rnn":
|
613 |
print(f"[{receive_time:.3f}] Received request to update USE_RNN")
|
614 |
+
update_user_activity() # Reset activity timer
|
615 |
await update_use_rnn(data.get("use_rnn", False))
|
616 |
continue
|
617 |
|
618 |
# Handle settings request
|
619 |
if data.get("type") == "get_settings":
|
620 |
print(f"[{receive_time:.3f}] Received request for current settings")
|
621 |
+
update_user_activity() # Reset activity timer
|
622 |
await websocket.send_json({
|
623 |
"type": "settings",
|
624 |
"sampling_steps": client_settings["sampling_steps"],
|
|
|
653 |
traceback.print_exc()
|
654 |
|
655 |
finally:
|
656 |
+
# Clean up timeout task
|
657 |
+
if timeout_task and not timeout_task.done():
|
658 |
+
timeout_task.cancel()
|
659 |
+
try:
|
660 |
+
await timeout_task
|
661 |
+
except asyncio.CancelledError:
|
662 |
+
pass
|
663 |
+
|
664 |
# Print final FPS statistics when connection ends
|
665 |
if frame_num >= 0: # Only if we processed at least one frame
|
666 |
total_time = time.perf_counter() - connection_start_time
|
static/index.html
CHANGED
@@ -114,6 +114,12 @@
|
|
114 |
<input class="form-check-input" type="checkbox" role="switch" id="autoInputToggle" checked>
|
115 |
<label class="form-check-label" for="autoInputToggle" id="autoInputLabel">Auto Input</label>
|
116 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
</div>
|
118 |
</div>
|
119 |
|
@@ -186,6 +192,7 @@
|
|
186 |
console.log("WebSocket connection closed. Attempting to reconnect...");
|
187 |
isConnected = false;
|
188 |
stopAutoInput(); // Stop auto-input when connection is lost
|
|
|
189 |
clearInterval(heartbeatInterval);
|
190 |
scheduleReconnection();
|
191 |
};
|
@@ -219,6 +226,12 @@
|
|
219 |
console.log(`Received settings from server: SAMPLING_STEPS=${data.sampling_steps}, USE_RNN=${data.use_rnn}`);
|
220 |
document.getElementById('samplingSteps').value = data.sampling_steps;
|
221 |
document.getElementById('useRnnToggle').checked = data.use_rnn;
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
}
|
223 |
};
|
224 |
}
|
@@ -264,6 +277,11 @@
|
|
264 |
const AUTO_INPUT_INTERVAL = 500; // Send auto-input every 0.5 second once active
|
265 |
let autoInputEnabled = true; // Default to enabled
|
266 |
|
|
|
|
|
|
|
|
|
|
|
267 |
// Track currently pressed keys
|
268 |
const pressedKeys = new Set();
|
269 |
|
@@ -281,7 +299,7 @@
|
|
281 |
|
282 |
// Check if we should start auto-input mode
|
283 |
if (!autoInputActive && currentTime - lastUserInputTime >= INITIAL_AUTO_INPUT_DELAY) {
|
284 |
-
console.log("Starting auto-input mode (no user activity for
|
285 |
autoInputActive = true;
|
286 |
lastAutoInputTime = currentTime;
|
287 |
// Update UI to show auto-input is active
|
@@ -344,6 +362,59 @@
|
|
344 |
}
|
345 |
}
|
346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
function sendInputState(x, y, isLeftClick = false, isRightClick = false) {
|
348 |
const currentTime = Date.now();
|
349 |
if (isConnected && (isLeftClick || isRightClick || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
|
@@ -455,6 +526,7 @@
|
|
455 |
// Graceful disconnection
|
456 |
window.addEventListener('beforeunload', function (e) {
|
457 |
stopAutoInput(); // Clean up auto-input interval
|
|
|
458 |
if (isConnected) {
|
459 |
try {
|
460 |
//socket.send(JSON.stringify({ type: "disconnect" }));
|
|
|
114 |
<input class="form-check-input" type="checkbox" role="switch" id="autoInputToggle" checked>
|
115 |
<label class="form-check-label" for="autoInputToggle" id="autoInputLabel">Auto Input</label>
|
116 |
</div>
|
117 |
+
|
118 |
+
<div id="timeoutWarning" class="alert alert-warning" style="display: none; margin-top: 10px;">
|
119 |
+
<strong>Connection Timeout Warning:</strong>
|
120 |
+
No user activity detected. Connection will be dropped in <span id="timeoutCountdown">30</span> seconds.
|
121 |
+
<button type="button" class="btn btn-sm btn-primary ms-2" onclick="resetTimeout()">Stay Connected</button>
|
122 |
+
</div>
|
123 |
</div>
|
124 |
</div>
|
125 |
|
|
|
192 |
console.log("WebSocket connection closed. Attempting to reconnect...");
|
193 |
isConnected = false;
|
194 |
stopAutoInput(); // Stop auto-input when connection is lost
|
195 |
+
stopTimeoutCountdown(); // Stop timeout countdown when connection is lost
|
196 |
clearInterval(heartbeatInterval);
|
197 |
scheduleReconnection();
|
198 |
};
|
|
|
226 |
console.log(`Received settings from server: SAMPLING_STEPS=${data.sampling_steps}, USE_RNN=${data.use_rnn}`);
|
227 |
document.getElementById('samplingSteps').value = data.sampling_steps;
|
228 |
document.getElementById('useRnnToggle').checked = data.use_rnn;
|
229 |
+
} else if (data.type === "timeout_warning") {
|
230 |
+
console.log(`Received timeout warning: ${data.timeout_in} seconds remaining`);
|
231 |
+
startTimeoutCountdown();
|
232 |
+
} else if (data.type === "activity_reset") {
|
233 |
+
console.log("Server detected user activity, resetting timeout");
|
234 |
+
stopTimeoutCountdown();
|
235 |
}
|
236 |
};
|
237 |
}
|
|
|
277 |
const AUTO_INPUT_INTERVAL = 500; // Send auto-input every 0.5 second once active
|
278 |
let autoInputEnabled = true; // Default to enabled
|
279 |
|
280 |
+
// Timeout countdown mechanism
|
281 |
+
let timeoutCountdownInterval = null;
|
282 |
+
let timeoutCountdown = 30;
|
283 |
+
let timeoutWarningActive = false;
|
284 |
+
|
285 |
// Track currently pressed keys
|
286 |
const pressedKeys = new Set();
|
287 |
|
|
|
299 |
|
300 |
// Check if we should start auto-input mode
|
301 |
if (!autoInputActive && currentTime - lastUserInputTime >= INITIAL_AUTO_INPUT_DELAY) {
|
302 |
+
console.log("Starting auto-input mode (no user activity for 2 seconds)");
|
303 |
autoInputActive = true;
|
304 |
lastAutoInputTime = currentTime;
|
305 |
// Update UI to show auto-input is active
|
|
|
362 |
}
|
363 |
}
|
364 |
|
365 |
+
function startTimeoutCountdown() {
|
366 |
+
if (timeoutCountdownInterval) {
|
367 |
+
clearInterval(timeoutCountdownInterval);
|
368 |
+
}
|
369 |
+
|
370 |
+
timeoutCountdown = 30;
|
371 |
+
timeoutWarningActive = true;
|
372 |
+
|
373 |
+
// Show warning
|
374 |
+
const warning = document.getElementById('timeoutWarning');
|
375 |
+
if (warning) {
|
376 |
+
warning.style.display = 'block';
|
377 |
+
}
|
378 |
+
|
379 |
+
// Start countdown
|
380 |
+
timeoutCountdownInterval = setInterval(() => {
|
381 |
+
timeoutCountdown--;
|
382 |
+
const countdownElement = document.getElementById('timeoutCountdown');
|
383 |
+
if (countdownElement) {
|
384 |
+
countdownElement.textContent = timeoutCountdown;
|
385 |
+
}
|
386 |
+
|
387 |
+
if (timeoutCountdown <= 0) {
|
388 |
+
stopTimeoutCountdown();
|
389 |
+
console.log("Connection timeout countdown finished");
|
390 |
+
}
|
391 |
+
}, 1000);
|
392 |
+
}
|
393 |
+
|
394 |
+
function stopTimeoutCountdown() {
|
395 |
+
if (timeoutCountdownInterval) {
|
396 |
+
clearInterval(timeoutCountdownInterval);
|
397 |
+
timeoutCountdownInterval = null;
|
398 |
+
}
|
399 |
+
|
400 |
+
timeoutWarningActive = false;
|
401 |
+
timeoutCountdown = 30;
|
402 |
+
|
403 |
+
// Hide warning
|
404 |
+
const warning = document.getElementById('timeoutWarning');
|
405 |
+
if (warning) {
|
406 |
+
warning.style.display = 'none';
|
407 |
+
}
|
408 |
+
}
|
409 |
+
|
410 |
+
function resetTimeout() {
|
411 |
+
// Send a heartbeat to reset the server's timeout
|
412 |
+
if (socket && socket.readyState === WebSocket.OPEN) {
|
413 |
+
socket.send(JSON.stringify({ type: "heartbeat" }));
|
414 |
+
}
|
415 |
+
stopTimeoutCountdown();
|
416 |
+
}
|
417 |
+
|
418 |
function sendInputState(x, y, isLeftClick = false, isRightClick = false) {
|
419 |
const currentTime = Date.now();
|
420 |
if (isConnected && (isLeftClick || isRightClick || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
|
|
|
526 |
// Graceful disconnection
|
527 |
window.addEventListener('beforeunload', function (e) {
|
528 |
stopAutoInput(); // Clean up auto-input interval
|
529 |
+
stopTimeoutCountdown(); // Clean up timeout countdown
|
530 |
if (isConnected) {
|
531 |
try {
|
532 |
//socket.send(JSON.stringify({ type: "disconnect" }));
|