da03 commited on
Commit
6466ec0
·
1 Parent(s): 11a5f81
Files changed (1) hide show
  1. main.py +10 -4
main.py CHANGED
@@ -87,11 +87,15 @@ def prepare_model_inputs(
87
  time_step: int
88
  ) -> Dict[str, torch.Tensor]:
89
  """Prepare inputs for the model."""
 
 
 
 
90
  inputs = {
91
  'image_features': previous_frame.to(device),
92
  'is_padding': torch.BoolTensor([time_step == 0]).to(device),
93
- 'x': torch.LongTensor([x if x is not None else 0]).unsqueeze(0).to(device),
94
- 'y': torch.LongTensor([y if y is not None else 0]).unsqueeze(0).to(device),
95
  'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
96
  'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
97
  'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
@@ -182,7 +186,6 @@ async def websocket_endpoint(websocket: WebSocket):
182
  try:
183
  process_start_time = time.perf_counter()
184
  print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {len(input_queue)}")
185
- is_processing = True
186
  frame_num += 1
187
  frame_count += 1 # Increment total frame counter
188
 
@@ -234,7 +237,7 @@ async def websocket_endpoint(websocket: WebSocket):
234
  process_next_input()
235
 
236
  def process_next_input():
237
- nonlocal input_queue
238
 
239
  current_time = time.perf_counter()
240
  if not input_queue:
@@ -245,6 +248,9 @@ async def websocket_endpoint(websocket: WebSocket):
245
  print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
246
  return
247
 
 
 
 
248
  print(f"[{current_time:.3f}] Processing next input. Queue size: {len(input_queue)}")
249
 
250
  # Find the most recent interesting input (click or key event)
 
87
  time_step: int
88
  ) -> Dict[str, torch.Tensor]:
89
  """Prepare inputs for the model."""
90
+ # Clamp coordinates to valid ranges
91
+ x = min(max(0, x), SCREEN_WIDTH - 1) if x is not None else 0
92
+ y = min(max(0, y), SCREEN_HEIGHT - 1) if y is not None else 0
93
+
94
  inputs = {
95
  'image_features': previous_frame.to(device),
96
  'is_padding': torch.BoolTensor([time_step == 0]).to(device),
97
+ 'x': torch.LongTensor([x]).unsqueeze(0).to(device),
98
+ 'y': torch.LongTensor([y]).unsqueeze(0).to(device),
99
  'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
100
  'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
101
  'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
 
186
  try:
187
  process_start_time = time.perf_counter()
188
  print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {len(input_queue)}")
 
189
  frame_num += 1
190
  frame_count += 1 # Increment total frame counter
191
 
 
237
  process_next_input()
238
 
239
  def process_next_input():
240
+ nonlocal input_queue, is_processing
241
 
242
  current_time = time.perf_counter()
243
  if not input_queue:
 
248
  print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
249
  return
250
 
251
+ # Set is_processing to True BEFORE creating the task
252
+ is_processing = True
253
+
254
  print(f"[{current_time:.3f}] Processing next input. Queue size: {len(input_queue)}")
255
 
256
  # Find the most recent interesting input (click or key event)