da03 commited on
Commit
6454884
·
1 Parent(s): e77b83d
Files changed (1) hide show
  1. main.py +62 -9
main.py CHANGED
@@ -172,14 +172,15 @@ async def websocket_endpoint(websocket: WebSocket):
172
  connection_start_time = time.perf_counter()
173
  frame_count = 0
174
 
175
- while True:
 
 
 
 
 
 
176
  try:
177
- # Receive user input with a timeout
178
- #data = await asyncio.wait_for(websocket.receive_json(), timeout=90000.0)
179
- data = await websocket.receive_json()
180
- if data.get("type") == "heartbeat":
181
- await websocket.send_json({"type": "heartbeat_response"})
182
- continue
183
  frame_num += 1
184
  frame_count += 1 # Increment total frame counter
185
  start_frame = time.perf_counter()
@@ -194,13 +195,15 @@ async def websocket_endpoint(websocket: WebSocket):
194
  is_right_click = data.get("is_right_click")
195
  keys_down_list = data.get("keys_down", []) # Get as list
196
  keys_up_list = data.get("keys_up", [])
197
- print (f'x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}')
 
198
  # Update the set based on the received data
199
  for key in keys_down_list:
200
  keys_down.add(key)
201
  for key in keys_up_list:
202
  if key in keys_down: # Check if key exists to avoid KeyError
203
  keys_down.remove(key)
 
204
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
205
  previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
206
  timing_info['full_frame'] = time.perf_counter() - start_frame
@@ -218,10 +221,60 @@ async def websocket_endpoint(websocket: WebSocket):
218
 
219
  # Send the generated frame back to the client
220
  await websocket.send_json({"image": img_str})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  except asyncio.TimeoutError:
223
  print("WebSocket connection timed out")
224
- #break # Exit the loop on timeout
225
 
226
  except WebSocketDisconnect:
227
  print("WebSocket disconnected")
 
172
  connection_start_time = time.perf_counter()
173
  frame_count = 0
174
 
175
+ # Input queue management
176
+ input_queue = []
177
+ is_processing = False
178
+
179
+ async def process_input(data):
180
+ nonlocal previous_frame, hidden_states, keys_down, frame_num, frame_count, is_processing
181
+
182
  try:
183
+ is_processing = True
 
 
 
 
 
184
  frame_num += 1
185
  frame_count += 1 # Increment total frame counter
186
  start_frame = time.perf_counter()
 
195
  is_right_click = data.get("is_right_click")
196
  keys_down_list = data.get("keys_down", []) # Get as list
197
  keys_up_list = data.get("keys_up", [])
198
+ print(f'x: {x}, y: {y}, is_left_click: {is_left_click}, is_right_click: {is_right_click}, keys_down_list: {keys_down_list}, keys_up_list: {keys_up_list}')
199
+
200
  # Update the set based on the received data
201
  for key in keys_down_list:
202
  keys_down.add(key)
203
  for key in keys_up_list:
204
  if key in keys_down: # Check if key exists to avoid KeyError
205
  keys_down.remove(key)
206
+
207
  inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
208
  previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
209
  timing_info['full_frame'] = time.perf_counter() - start_frame
 
221
 
222
  # Send the generated frame back to the client
223
  await websocket.send_json({"image": img_str})
224
+ finally:
225
+ is_processing = False
226
+ # Check if we have more inputs to process after this one
227
+ process_next_input()
228
+
229
+ def process_next_input():
230
+ nonlocal input_queue
231
+
232
+ if not input_queue or is_processing:
233
+ return
234
+
235
+ # Find the most recent interesting input (click or key event)
236
+ interesting_indices = [i for i, data in enumerate(input_queue)
237
+ if data.get("is_left_click") or
238
+ data.get("is_right_click") or
239
+ data.get("keys_down") or
240
+ data.get("keys_up")]
241
+
242
+ if interesting_indices:
243
+ # There are interesting events - take the most recent one
244
+ idx = interesting_indices[-1]
245
+ next_input = input_queue[idx]
246
+
247
+ # Clear all inputs up to and including this one
248
+ input_queue = input_queue[idx+1:]
249
+
250
+ print(f"Processing interesting input (skipped {idx} events)")
251
+ else:
252
+ # No interesting events - just take the most recent movement
253
+ next_input = input_queue[-1]
254
+ input_queue = []
255
+ print(f"Processing latest movement (skipped {len(input_queue)} events)")
256
+
257
+ # Process the selected input asynchronously
258
+ asyncio.create_task(process_input(next_input))
259
+
260
+ while True:
261
+ try:
262
+ # Receive user input
263
+ data = await websocket.receive_json()
264
+
265
+ if data.get("type") == "heartbeat":
266
+ await websocket.send_json({"type": "heartbeat_response"})
267
+ continue
268
+
269
+ # Add the input to our queue
270
+ input_queue.append(data)
271
+
272
+ # If we're not currently processing, start processing this input
273
+ if not is_processing:
274
+ process_next_input()
275
 
276
  except asyncio.TimeoutError:
277
  print("WebSocket connection timed out")
 
278
 
279
  except WebSocketDisconnect:
280
  print("WebSocket disconnected")