Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
6466ec0
1
Parent(s):
11a5f81
main.py
CHANGED
|
@@ -87,11 +87,15 @@ def prepare_model_inputs(
|
|
| 87 |
time_step: int
|
| 88 |
) -> Dict[str, torch.Tensor]:
|
| 89 |
"""Prepare inputs for the model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
inputs = {
|
| 91 |
'image_features': previous_frame.to(device),
|
| 92 |
'is_padding': torch.BoolTensor([time_step == 0]).to(device),
|
| 93 |
-
'x': torch.LongTensor([x
|
| 94 |
-
'y': torch.LongTensor([y
|
| 95 |
'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
|
| 96 |
'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
|
| 97 |
'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
|
|
@@ -182,7 +186,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 182 |
try:
|
| 183 |
process_start_time = time.perf_counter()
|
| 184 |
print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {len(input_queue)}")
|
| 185 |
-
is_processing = True
|
| 186 |
frame_num += 1
|
| 187 |
frame_count += 1 # Increment total frame counter
|
| 188 |
|
|
@@ -234,7 +237,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 234 |
process_next_input()
|
| 235 |
|
| 236 |
def process_next_input():
|
| 237 |
-
nonlocal input_queue
|
| 238 |
|
| 239 |
current_time = time.perf_counter()
|
| 240 |
if not input_queue:
|
|
@@ -245,6 +248,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 245 |
print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
|
| 246 |
return
|
| 247 |
|
|
|
|
|
|
|
|
|
|
| 248 |
print(f"[{current_time:.3f}] Processing next input. Queue size: {len(input_queue)}")
|
| 249 |
|
| 250 |
# Find the most recent interesting input (click or key event)
|
|
|
|
| 87 |
time_step: int
|
| 88 |
) -> Dict[str, torch.Tensor]:
|
| 89 |
"""Prepare inputs for the model."""
|
| 90 |
+
# Clamp coordinates to valid ranges
|
| 91 |
+
x = min(max(0, x), SCREEN_WIDTH - 1) if x is not None else 0
|
| 92 |
+
y = min(max(0, y), SCREEN_HEIGHT - 1) if y is not None else 0
|
| 93 |
+
|
| 94 |
inputs = {
|
| 95 |
'image_features': previous_frame.to(device),
|
| 96 |
'is_padding': torch.BoolTensor([time_step == 0]).to(device),
|
| 97 |
+
'x': torch.LongTensor([x]).unsqueeze(0).to(device),
|
| 98 |
+
'y': torch.LongTensor([y]).unsqueeze(0).to(device),
|
| 99 |
'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
|
| 100 |
'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
|
| 101 |
'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
|
|
|
|
| 186 |
try:
|
| 187 |
process_start_time = time.perf_counter()
|
| 188 |
print(f"[{process_start_time:.3f}] Starting to process input. Queue size before: {len(input_queue)}")
|
|
|
|
| 189 |
frame_num += 1
|
| 190 |
frame_count += 1 # Increment total frame counter
|
| 191 |
|
|
|
|
| 237 |
process_next_input()
|
| 238 |
|
| 239 |
def process_next_input():
|
| 240 |
+
nonlocal input_queue, is_processing
|
| 241 |
|
| 242 |
current_time = time.perf_counter()
|
| 243 |
if not input_queue:
|
|
|
|
| 248 |
print(f"[{current_time:.3f}] Already processing an input. Will check again later.")
|
| 249 |
return
|
| 250 |
|
| 251 |
+
# Set is_processing to True BEFORE creating the task
|
| 252 |
+
is_processing = True
|
| 253 |
+
|
| 254 |
print(f"[{current_time:.3f}] Processing next input. Queue size: {len(input_queue)}")
|
| 255 |
|
| 256 |
# Find the most recent interesting input (click or key event)
|