da03 commited on
Commit
b2e55f9
·
1 Parent(s): f024b1d
Files changed (13) hide show
  1. image_0.png +0 -3
  2. image_1.png +0 -3
  3. image_10.png +0 -3
  4. image_2.png +0 -3
  5. image_3.png +0 -3
  6. image_4.png +0 -3
  7. image_5.png +0 -3
  8. image_6.png +0 -3
  9. image_7.png +0 -3
  10. image_8.png +0 -3
  11. image_9.png +0 -3
  12. main.py +145 -519
  13. static/index.html +70 -13
image_0.png DELETED

Git LFS Details

  • SHA256: 74950de39fe0fc18f237a36259ce70db6abc1485b7fd6d9fcaa2b08324973ff7
  • Pointer size: 130 Bytes
  • Size of remote file: 88.6 kB
image_1.png DELETED

Git LFS Details

  • SHA256: fff09b88ea7f151e6748522d16dfcca724f4960fe9177f57872ed772499e7dd7
  • Pointer size: 130 Bytes
  • Size of remote file: 89.2 kB
image_10.png DELETED

Git LFS Details

  • SHA256: fbec95664684da5021d6167d6ddfcfb7b88b037c0dec94be80025afb66b9e59e
  • Pointer size: 130 Bytes
  • Size of remote file: 89.4 kB
image_2.png DELETED

Git LFS Details

  • SHA256: 93528f81b0fdec9a8e1a0a33e203a61a9fdd0c164b0957f3b21dca6ad668e07c
  • Pointer size: 130 Bytes
  • Size of remote file: 89.3 kB
image_3.png DELETED

Git LFS Details

  • SHA256: f7db6bccfdc33309252e323e0be99cd649bc62e1e7d975065bd62c9aa68dba8c
  • Pointer size: 130 Bytes
  • Size of remote file: 89.4 kB
image_4.png DELETED

Git LFS Details

  • SHA256: 383b17d3282d4305c731cfdb64fa716577390735c5ae770fe7fbb790064d01f1
  • Pointer size: 130 Bytes
  • Size of remote file: 89.2 kB
image_5.png DELETED

Git LFS Details

  • SHA256: 09c2d538e1bd6ac5a98837b918f660477679aadc01746d8f17bd1b2ced0e3e62
  • Pointer size: 130 Bytes
  • Size of remote file: 89.3 kB
image_6.png DELETED

Git LFS Details

  • SHA256: 8890954f638df02991be520b0fd3e0544dadf724a6822395a15e1096c683fe07
  • Pointer size: 130 Bytes
  • Size of remote file: 89.3 kB
image_7.png DELETED

Git LFS Details

  • SHA256: 3f8c0a7d8e94def6357c11ad4118355adf7adc7f0155397458f5c470a0ca5183
  • Pointer size: 130 Bytes
  • Size of remote file: 89.2 kB
image_8.png DELETED

Git LFS Details

  • SHA256: d068653c7064717d20ab79756493a863289b667b41cbe96579c3bc2f0709bea1
  • Pointer size: 130 Bytes
  • Size of remote file: 89.3 kB
image_9.png DELETED

Git LFS Details

  • SHA256: 97bfdc7df961ccccbeea5e33c9e4be9245ad98d6b03d534ad33376f5bdd01b63
  • Pointer size: 130 Bytes
  • Size of remote file: 89.4 kB
main.py CHANGED
@@ -11,137 +11,20 @@ from utils import initialize_model, sample_frame
11
  import torch
12
  import os
13
  import time
 
 
14
 
15
- DEBUG = False
16
- DEBUG_TEACHER_FORCING = False
17
- app = FastAPI()
18
-
19
- # Mount the static directory to serve HTML, JavaScript, and CSS files
20
- app.mount("/static", StaticFiles(directory="static"), name="static")
21
-
22
- # Add this at the top with other global variables
23
- all_click_positions = [] # Store all historical click positions
24
-
25
- def parse_action_string(action_str):
26
- """Convert formatted action string to x, y coordinates
27
- Args:
28
- action_str: String like 'N N N N N : N N N N N' or '+ 0 2 1 3 : + 0 3 8 3'
29
- Returns:
30
- tuple: (x, y) coordinates or None if action is padding
31
- """
32
- action_type = action_str[0]
33
- action_str = action_str[1:].strip()
34
- if 'N' in action_str:
35
- return (None, None, None)
36
-
37
- # Split into x and y parts
38
- action_str = action_str.replace(' ', '')
39
- x_part, y_part = action_str.split(':')
40
-
41
- # Parse x: remove sign, join digits, convert to int, apply sign
42
-
43
- x = int(x_part)
44
-
45
- # Parse y: remove sign, join digits, convert to int, apply sign
46
- y = int(y_part)
47
-
48
- return x, y, action_type
49
-
50
- def create_position_and_click_map(pos,action_type, image_height=48, image_width=64, original_width=512, original_height=384):
51
- """Convert cursor position to a binary position map
52
- Args:
53
- x, y: Original cursor positions
54
- image_size: Size of the output position map (square)
55
- original_width: Original screen width (1024)
56
- original_height: Original screen height (640)
57
- Returns:
58
- torch.Tensor: Binary position map of shape (1, image_size, image_size)
59
- """
60
- x, y = pos
61
- if x is None:
62
- return torch.zeros((1, image_height, image_width)), torch.zeros((1, image_height, image_width)), None, None
63
- # Scale the positions to new size
64
- #x_scaled = int((x / original_width) * image_size)
65
- #y_scaled = int((y / original_height) * image_size)
66
- #screen_width, screen_height = 512, 384
67
- #video_width, video_height = 512, 384
68
-
69
- #x_scaled = x - (screen_width / 2 - video_width / 2)
70
- #y_scaled = y - (screen_height / 2 - video_height / 2)
71
- x_scaled = int(x / original_width * image_width)
72
- y_scaled = int(y / original_height * image_height)
73
-
74
- # Clamp values to ensure they're within bounds
75
- x_scaled = max(0, min(x_scaled, image_width - 1))
76
- y_scaled = max(0, min(y_scaled, image_height - 1))
77
-
78
- # Create binary position map
79
- pos_map = torch.zeros((1, image_height, image_width))
80
- pos_map[0, y_scaled, x_scaled] = 1.0
81
-
82
- leftclick_map = torch.zeros((1, image_height, image_width))
83
- if action_type == 'L':
84
- print ('left click', x_scaled, y_scaled)
85
- #print ('skipped')
86
- if True:
87
- leftclick_map[0, y_scaled, x_scaled] = 1.0
88
-
89
-
90
- return pos_map, leftclick_map, x_scaled, y_scaled
91
-
92
- # Serve the index.html file at the root URL
93
- @app.get("/")
94
- async def get():
95
- return HTMLResponse(open("static/index.html").read())
96
 
97
- def generate_random_image(width: int, height: int) -> np.ndarray:
98
- return np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
99
-
100
- def draw_trace(image: np.ndarray, previous_actions: List[Tuple[str, List[int]]], x_scaled=-1, y_scaled=-1) -> np.ndarray:
101
- if True:
102
- return image
103
- pil_image = Image.fromarray(image)
104
- draw = ImageDraw.Draw(pil_image)
105
-
106
- # Draw all historical click positions
107
- for click_x, click_y in all_click_positions:
108
- x_draw = click_x # Scale factor for display
109
- y_draw = click_y
110
- # Draw historical clicks as red circles
111
- draw.ellipse([x_draw-4, y_draw-4, x_draw+4, y_draw+4], fill=(255, 0, 0))
112
-
113
- # Draw current trace
114
- prev_x, prev_y = None, None
115
- for i, (action_type, position) in enumerate(previous_actions):
116
- x, y = position
117
- if x == 0 and y == 0:
118
- continue
119
-
120
- x_draw = x
121
- y_draw = y
122
-
123
- # Draw movement positions as blue dots
124
- draw.ellipse([x_draw-2, y_draw-2, x_draw+2, y_draw+2], fill=(0, 0, 255))
125
-
126
- # Draw connecting lines
127
- if prev_x is not None:
128
- draw.line([prev_x, prev_y, x_draw, y_draw], fill=(0, 255, 0), width=1)
129
- prev_x, prev_y = x_draw, y_draw
130
-
131
- # Draw current position
132
- if x_scaled >= 0 and y_scaled >= 0:
133
- x_current = x_scaled * 8
134
- y_current = y_scaled * 8
135
- #if not DEBUG_TEACHER_FORCING:
136
- # x_current = x_current *8
137
- # y_current = y_current *8
138
- print ('x_current, y_current', x_current, y_current)
139
- draw.ellipse([x_current-3, y_current-3, x_current+3, y_current+3], fill=(0, 255, 0))
140
- else:
141
- assert False
142
-
143
- return np.array(pil_image)
144
 
 
145
  # Initialize the model at the start of your application
146
  #model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
147
  model = initialize_model("standard_challenging_context32_nocond_all.yaml", "yuntian-deng/computer-model")
@@ -150,424 +33,167 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
150
  model = model.to(device)
151
  #model = torch.compile(model)
152
 
153
- def load_initial_images(width, height):
154
- initial_images = []
155
- if DEBUG_TEACHER_FORCING:
156
- # Load the previous 7 frames for image_81
157
- for i in range(117-7, 117): # Load images 74-80
158
- img = Image.open(f"record_10003/image_{i}.png")#.resize((width, height))
159
- initial_images.append(np.array(img))
160
- else:
161
- #assert False
162
- for i in range(32):
163
- initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
164
- return initial_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- def normalize_images(images, target_range=(-1, 1)):
167
- images = np.stack(images).astype(np.float32)
168
- if target_range == (-1, 1):
169
- return images / 127.5 - 1
170
- elif target_range == (0, 1):
171
- return images / 255.0
172
- else:
173
- raise ValueError(f"Unsupported target range: {target_range}")
174
-
175
- def normalize_image(image, target_range=(-1, 1)):
176
- image = image.astype(np.float32)
177
- if target_range == (-1, 1):
178
- return image / 127.5 - 1
179
- elif target_range == (0, 1):
180
- return image / 255.0
181
- else:
182
- raise ValueError(f"Unsupported target range: {target_range}")
183
 
184
- def denormalize_image(image, source_range=(-1, 1)):
185
- if source_range == (-1, 1):
186
- return ((image + 1) * 127.5).clip(0, 255).astype(np.uint8)
187
- elif source_range == (0, 1):
188
- return (image * 255).clip(0, 255).astype(np.uint8)
189
- else:
190
- raise ValueError(f"Unsupported source range: {source_range}")
191
-
192
- def format_action(action_str, is_padding=False, is_leftclick=False):
193
- if is_padding:
194
- return "N N N N N N : N N N N N"
195
-
196
- # Split the x~y coordinates
197
- x, y = map(int, action_str.split('~'))
198
- prefix = 'N'
199
- if is_leftclick:
200
- prefix = 'L'
201
- # Convert numbers to padded strings and add spaces between digits
202
- x_str = f"{abs(x):04d}"
203
- y_str = f"{abs(y):04d}"
204
- x_spaced = ' '.join(x_str)
205
- y_spaced = ' '.join(y_str)
206
-
207
- # Format with sign and proper spacing
208
- return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
209
 
210
- width, height = 512, 384
211
- padding_image = torch.zeros((height//8, width//8, 4)).to(device)
212
- data_mean = -0.54
213
- data_std = 6.78
214
- data_min = -27.681446075439453
215
- data_max = 30.854148864746094
216
- padding_image = (padding_image - data_mean) / data_std
217
 
218
- def predict_next_frame(previous_frames, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
219
- all_click_positions = []
220
- #initial_images = load_initial_images(width, height)
221
- #print ('length of previous_frames', len(previous_frames))
222
- # Prepare the image sequence for the model
223
- #assert len(initial_images) == 32
224
- image_sequence = previous_frames[-32:] # Take the last 7 frames
225
- #i = 1
226
- while len(image_sequence) < 32:
227
- image_sequence.insert(0, padding_image)
228
- #i += 1
229
- #image_sequence.append(initial_images[len(image_sequence)])
230
-
231
- # Convert the image sequence to a tensor and concatenate in the channel dimension
232
- #image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence_list, target_range=(-1, 1)))
233
- #image_sequence_tensor = image_sequence_tensor.to(device)
234
- image_sequence_tensor = torch.cat(image_sequence, dim=-1)
235
-
236
- #image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
237
-
238
- # Prepare the prompt based on the previous actions
239
- action_descriptions = []
240
- #initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
241
- initial_actions = ['0:0'] * 32
242
- #initial_actions = ['N N N N N : N N N N N'] * 7
243
- def unnorm_coords(x, y):
244
- return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
245
-
246
- # Process initial actions if there are not enough previous actions
247
- while len(previous_actions) < 33:
248
- #assert False
249
- x, y = map(int, initial_actions.pop(0).split(':'))
250
- previous_actions.insert(0, ("N", unnorm_coords(x, y)))
251
- prev_x = 0
252
- prev_y = 0
253
- #print ('here')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- if False:
256
- prompt = 'N + 0 4 1 6 : + 0 3 2 0 L + 0 2 0 0 : + 0 1 7 6 N + 0 3 8 4 : + 0 0 4 8 N + 0 3 6 0 : + 0 2 5 6 N + 0 3 6 8 : + 0 0 1 6 N + 0 0 3 2 : + 0 1 0 4 L + 0 2 8 0 : + 0 0 4 0 L + 0 5 0 4 : + 0 0 7 2'
257
- previous_actions = [('move', (416, 320)), ('left_click', (200, 176)), ('move', (384, 48)), ('move', (360, 256)), ('move', (368, 16)), ('move', (32, 104)), ('left_click', (280, 40)), ('left_click', (504, 72))]
258
- prompt = 'N + 0 3 4 4 : + 0 3 2 0 N + 0 4 8 0 : + 0 1 2 8 N + 0 4 4 8 : + 0 3 6 0 N + 0 4 4 8 : + 0 0 6 4 N + 0 4 6 4 : + 0 3 3 6 N + 0 0 2 4 : + 0 1 3 6 N + 0 1 2 8 : + 0 2 8 0 N + 0 4 4 0 : + 0 0 4 8'
259
- previous_actions = [('move', (344, 320)), ('move', (480, 128)), ('move', (448, 360)), ('move', (448, 64)), ('move', (464, 336)), ('move', (24, 136)), ('move', (128, 280)), ('move', (440, 48))]
260
- prompt = 'N + 0 4 7 2 : + 0 1 6 0 N + 0 3 0 4 : + 0 2 7 2 N + 0 0 0 0 : + 0 1 7 6 N + 0 2 0 0 : + 0 0 3 2 N + 0 1 6 8 : + 0 0 5 6 L + 0 4 3 2 : + 0 0 4 0 L + 0 2 0 8 : + 0 2 7 2 L + 0 1 8 4 : + 0 0 0 8'
261
- previous_actions = [('move', (472, 160)), ('move', (304, 272)), ('move', (0, 176)), ('move', (200, 32)), ('left_click', (168, 56)), ('left_click', (432, 40)), ('left_click', (208, 272)), ('left_click', (184, 8))]
262
- prompt = 'N + 0 0 1 6 : + 0 3 2 8 N + 0 3 0 4 : + 0 0 9 6 N + 0 2 4 0 : + 0 1 9 2 N + 0 1 5 2 : + 0 0 5 6 L + 0 2 8 8 : + 0 1 7 6 L + 0 0 5 6 : + 0 3 7 6 N + 0 1 3 6 : + 0 3 6 0 N + 0 1 1 2 : + 0 0 4 8'
263
- previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))]
264
- prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4'
265
- previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))]
266
- for action_type, pos in previous_actions[-33:]:
267
- #print ('here3', action_type, pos)
268
- if action_type == 'move':
269
- action_type = 'N'
270
- if action_type == 'left_click':
271
- action_type = 'L'
272
- if action_type == "N":
273
- x, y = pos
274
- #norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
275
- #norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
276
- #norm_x = x + (1920 - 512) / 2
277
- #norm_y = y + (1080 - 512) / 2
278
- norm_x = x
279
- norm_y = y
280
- if False and DEBUG_TEACHER_FORCING:
281
- norm_x = x
282
- norm_y = y
283
- #action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
284
- #action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
285
- action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0))
286
- prev_x = norm_x
287
- prev_y = norm_y
288
- elif action_type == "L":
289
- x, y = pos
290
- #norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
291
- #norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
292
- #norm_x = x + (1920 - 512) / 2
293
- #norm_y = y + (1080 - 512) / 2
294
- norm_x = x
295
- norm_y = y
296
- if False and DEBUG_TEACHER_FORCING:
297
- norm_x = x #+ (1920 - 512) / 2
298
- norm_y = y #+ (1080 - 512) / 2
299
- #if DEBUG:
300
- # norm_x = x
301
- # norm_y = y
302
- #action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
303
- #action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
304
- action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0, True))
305
- elif action_type == "right_click":
306
- assert False
307
- action_descriptions.append("right_click")
308
- else:
309
- assert False
310
-
311
- prompt = " ".join(action_descriptions[-33:])
312
- print(prompt)
313
- #prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
314
- #x, y, action_type = parse_action_string(action_descriptions[-1])
315
- #pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
316
- leftclick_maps = []
317
- pos_maps = []
318
- for j in range(1, 34):
319
- print ('fsfs', action_descriptions[-j])
320
- x, y, action_type = parse_action_string(action_descriptions[-j])
321
- pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
322
- leftclick_maps.append(leftclick_map_j)
323
- pos_maps.append(pos_map_j)
324
- if j == 1:
325
- x_scaled = x_scaled_j
326
- y_scaled = y_scaled_j
327
- if action_type == 'L':
328
- all_click_positions.append((x, y))
329
-
330
- #prompt = ''
331
- #prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
332
- print(prompt)
333
- #prompt = prompt.replace('L', 'N')
334
- #print ('changing L to N')
335
-
336
- # Generate the next frame
337
- new_frame, new_frame_feedback = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
338
 
339
- # Convert the generated frame to the correct format
340
- new_frame = new_frame.transpose(1, 2, 0)
341
- print (new_frame.max(), new_frame.min())
342
- #new_frame = new_frame * data_std + data_mean
343
- new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
344
-
345
- # Draw the trace of previous actions
346
- new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
347
-
348
- # Track click positions
349
- #x, y, action_type = parse_action_string(action_descriptions[-1])
350
-
351
-
352
- return new_frame_with_trace, new_frame_denormalized, new_frame_feedback
353
 
354
  # WebSocket endpoint for continuous user interaction
355
  @app.websocket("/ws")
356
- async def websocket_endpoint(websocket: WebSocket):
357
- #global all_click_positions # Add this line
358
- #all_click_positions = [] # Reset at the start of each connection
359
-
360
  client_id = id(websocket) # Use a unique identifier for each connection
361
  print(f"New WebSocket connection: {client_id}")
362
  await websocket.accept()
363
- previous_frames = []
364
- previous_actions = []
365
- positions = ['815~335', '787~342', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368']
366
- #positions = ['815~335', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368']
367
- positions = ['307~375']
368
- positions = ['815~335']
369
- #positions = ['787~342']
370
- positions = ['300~800']
371
 
372
- if DEBUG_TEACHER_FORCING:
373
- #print ('here2')
374
- # Use the predefined actions for image_81
375
- debug_actions = [
376
- 'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
377
- 'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
378
- 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
379
- 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
380
- 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
381
- 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
382
- 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
383
- 'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
384
- ]
385
- debug_actions = [
386
- 'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8',
387
- 'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0',
388
- 'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3',
389
- 'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7',
390
- 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
391
- 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
392
- 'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
393
- 'N + 1 1 9 7 : + 0 2 9 7'
394
- ]
395
- debug_actions = [
396
- 'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8',
397
- 'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0',
398
- 'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3',
399
- 'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7',
400
- 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
401
- 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
402
- 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
403
- 'N + 1 1 9 7 : + 0 2 9 7'
404
- ]
405
- debug_actions = ['N + 0 0 4 0 : + 0 2 0 4', 'N + 0 1 3 8 : + 0 1 9 0',
406
- 'N + 0 2 7 4 : + 0 3 8 3', 'N + 0 5 0 1 : + 0 1 7 3',
407
- 'L + 0 4 7 3 : + 0 0 8 7', 'N + 0 1 0 9 : + 0 3 4 4',
408
- 'N + 0 0 5 2 : + 0 1 9 4', 'N + 0 3 6 5 : + 0 2 3 2',
409
- 'N + 0 3 8 9 : + 0 2 4 5', 'N + 0 0 2 0 : + 0 0 5 9',
410
- 'N + 0 4 7 3 : + 0 1 5 7', 'L + 0 1 9 1 : + 0 0 8 7',
411
- 'L + 0 1 9 1 : + 0 0 8 7', 'N + 0 3 4 3 : + 0 2 6 3', ]
412
- #'N + 0 2 0 5 : + 0 1 3 3']
413
- previous_actions = []
414
- for action in debug_actions[-8:]:
415
- #action = action.replace('1 1', '0 4')
416
- x, y, action_type = parse_action_string(action)
417
- previous_actions.append((action_type, (x, y)))
418
- positions = [
419
- 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 1 8 : + 0 4 9 2',
420
- 'N + 0 9 0 8 : + 0 4 8 3', 'N + 0 8 9 8 : + 0 4 7 4',
421
- 'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
422
- 'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
423
- 'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
424
- 'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1'
425
- ]
426
- positions = [
427
- #'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
428
- 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
429
- 'N + 1 1 7 9 : + 0 3 0 3', 'N + 1 1 4 2 : + 0 3 1 4',
430
- 'N + 1 1 0 6 : + 0 3 2 6', 'N + 1 0 6 9 : + 0 3 3 7',
431
- 'N + 1 0 5 1 : + 0 3 4 3', 'N + 1 0 1 4 : + 0 3 5 4',
432
- 'N + 0 9 7 8 : + 0 3 6 5', 'N + 0 9 4 2 : + 0 3 7 7',
433
- 'N + 0 9 0 5 : + 0 3 8 8', 'N + 0 8 6 8 : + 0 4 0 0',
434
- 'N + 0 8 3 2 : + 0 4 1 1'
435
- ]
436
- positions = ['L + 0 1 9 1 : + 0 0 8 7',
437
- 'L + 0 1 9 1 : + 0 0 8 7', 'N + 0 3 4 3 : + 0 2 6 3',
438
- 'N + 0 2 0 5 : + 0 1 3 3', 'N + 0 0 7 6 : + 0 3 4 5',
439
- 'N + 0 3 1 8 : + 0 3 3 3', 'N + 0 2 5 4 : + 0 2 9 0',
440
- 'N + 0 1 0 6 : + 0 1 6 4', 'N + 0 0 7 4 : + 0 2 8 4',
441
- 'N + 0 0 2 4 : + 0 0 4 1', 'N + 0 1 5 0 : + 0 3 8 3',
442
- 'N + 0 4 0 5 : + 0 1 6 8', 'N + 0 0 5 4 : + 0 3 2 4',
443
- 'N + 0 2 9 0 : + 0 1 4 1', 'N + 0 4 0 2 : + 0 0 0 9',
444
- 'N + 0 3 0 7 : + 0 3 3 2', 'N + 0 2 2 0 : + 0 3 7 1',
445
- 'N + 0 0 8 2 : + 0 1 5 1']
446
- positions = positions[3:]
447
- #positions = positions[:4]
448
- #position = positions[0]
449
- #positions = positions[1:]
450
- #x, y, action_type = parse_action_string(position)
451
- #mouse_position = (x, y)
452
-
453
- #previous_actions.append((action_type, mouse_position))
454
-
455
- if not DEBUG_TEACHER_FORCING:
456
- previous_actions = []
457
-
458
- for t in range(15): # Generate 15 actions
459
- # Random movement
460
- x = np.random.randint(0, 64)
461
- y = np.random.randint(0, 48)
462
- #x = max(0, min(63, x + dx))
463
- #y = max(0, min(47, y + dy))
464
-
465
- # Random click with 20% probability
466
- if np.random.random() < 0.2:
467
- action_type = 'L'
468
- else:
469
- action_type = 'N'
470
-
471
- # Format action string
472
- previous_actions.append((action_type, (x*8, y*8)))
473
  try:
474
- previous_actions = []
475
- previous_frames = []
476
- frames_since_update = 0
477
- frame_times = []
478
  while True:
479
  try:
480
  # Receive user input with a timeout
481
  #data = await asyncio.wait_for(websocket.receive_json(), timeout=90000.0)
482
  data = await websocket.receive_json()
483
-
484
-
485
  if data.get("type") == "heartbeat":
486
  await websocket.send_json({"type": "heartbeat_response"})
487
  continue
 
 
 
 
 
 
 
 
488
 
489
- action_type = data.get("action_type")
490
- mouse_position = data.get("mouse_position")
491
- #if np.random.random() < 0.9:
492
- # print ('setting left click')
493
- # action_type = 'left_click'
494
- #else:
495
- # print ('not setting left click')
496
- #action_type = 'move'
497
- #print ('setting normal move')
498
-
499
-
500
- # Store the actions
501
- if False and DEBUG:
502
- position = positions[0]
503
- #positions = positions[1:]
504
- #mouse_position = position.split('~')
505
- #mouse_position = [int(item) for item in mouse_position]
506
- #mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
507
- if DEBUG_TEACHER_FORCING:
508
- position = positions[0]
509
- positions = positions[1:]
510
- x, y, action_type = parse_action_string(position)
511
- mouse_position = (x, y)
512
- previous_actions.append((action_type, mouse_position))
513
- if True:
514
- previous_actions.append((action_type, mouse_position))
515
- #previous_actions = [(action_type, mouse_position)]
516
- #if not DEBUG_TEACHER_FORCING:
517
- # x, y = mouse_position
518
- # x = x//8 * 8
519
- # y = y // 8 * 8
520
- # assert x % 8 == 0
521
- # assert y % 8 == 0
522
- # mouse_position = (x, y)
523
- # #mouse_position = (x//8, y//8)
524
- # previous_actions.append((action_type, mouse_position))
525
- # Log the start time
526
- start_time = time.time()
527
-
528
- # Predict the next frame based on the previous frames and actions
529
- #if DEBUG_TEACHER_FORCING:
530
- # print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
531
- print ('previous_actions', previous_actions)
532
- next_frame, next_frame_append, next_frame_feedback = predict_next_frame(previous_frames, previous_actions)
533
- feedback = True
534
- if feedback:
535
- next_frame_feedback = torch.einsum('chw->hwc', next_frame_feedback)
536
- print (f'appending feedback of shape {next_frame_feedback.shape}')
537
- previous_frames.append(next_frame_feedback)
538
- else:
539
- #previous_frames = []
540
- previous_actions = []
541
- processing_time = time.time() - start_time
542
- print(f"Frame processing time: {processing_time:.2f} seconds")
543
- frame_times.append(processing_time)
544
- frames_since_update += 1
545
- print (f"Average frame processing time: {np.mean(frame_times):.2f} seconds")
546
- fps = 1 / np.mean(frame_times)
547
- print (f"FPS: {fps:.2f}")
548
-
549
- #previous_actions = []
550
- # Load and append the corresponding ground truth image instead of model output
551
- #print ('here4', len(previous_frames))
552
- #if DEBUG_TEACHER_FORCING:
553
- # img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
554
- # previous_frames.append(np.array(img))
555
- #else:
556
- # assert False
557
- # previous_frames.append(next_frame_append)
558
- # pass
559
- #previous_frames = []
560
- #previous_actions = []
561
-
562
- # Convert the numpy array to a base64 encoded image
563
- img = Image.fromarray(next_frame)
564
  buffered = io.BytesIO()
565
  img.save(buffered, format="PNG")
566
  img_str = base64.b64encode(buffered.getvalue()).decode()
567
 
568
- # Log the processing time
569
-
570
-
571
  # Send the generated frame back to the client
572
  await websocket.send_json({"image": img_str})
573
 
 
11
  import torch
12
  import os
13
  import time
14
+ from typing import Any, Dict
15
+ from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ SCREEN_WIDTH = 512
19
+ SCREEN_HEIGHT = 384
20
+ NUM_SAMPLING_STEPS = 8
21
+ DATA_NORMALIZATION = {
22
+ 'mean': -0.54,
23
+ 'std': 6.78,
24
+ }
25
+ LATENT_DIMS = (1, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8, 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
  # Initialize the model at the start of your application
29
  #model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
30
  model = initialize_model("standard_challenging_context32_nocond_all.yaml", "yuntian-deng/computer-model")
 
33
  model = model.to(device)
34
  #model = torch.compile(model)
35
 
36
+ padding_image = torch.zeros(1, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8, 4)
37
+ padding_image = (padding_image - DATA_NORMALIZATION['mean']) / DATA_NORMALIZATION['std']
38
+ padding_image = padding_image.to(device)
39
+
40
+ # Valid keyboard inputs
41
+ KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
42
+ ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
43
+ '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
44
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
45
+ 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
46
+ 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
47
+ 'browserback', 'browserfavorites', 'browserforward', 'browserhome',
48
+ 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
49
+ 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
50
+ 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
51
+ 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
52
+ 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
53
+ 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
54
+ 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
55
+ 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
56
+ 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
57
+ 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
58
+ 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
59
+ 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
60
+ 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
61
+ 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
62
+ 'command', 'option', 'optionleft', 'optionright']
63
+ INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
64
+ 'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
65
+ VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
66
+ itos = VALID_KEYS
67
+ stoi = {key: i for i, key in enumerate(itos)}
68
 
69
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # Mount the static directory to serve HTML, JavaScript, and CSS files
72
+ app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # Add this at the top with other global variables
 
 
 
 
 
 
75
 
76
+ def prepare_model_inputs(
77
+ previous_frame: torch.Tensor,
78
+ hidden_states: Any,
79
+ x: int,
80
+ y: int,
81
+ right_click: bool,
82
+ left_click: bool,
83
+ keys_down: List[str],
84
+ stoi: Dict[str, int],
85
+ itos: List[str],
86
+ time_step: int
87
+ ) -> Dict[str, torch.Tensor]:
88
+ """Prepare inputs for the model."""
89
+ inputs = {
90
+ 'image_features': previous_frame.to(device),
91
+ 'is_padding': torch.BoolTensor([time_step == 0]).to(device),
92
+ 'x': torch.LongTensor([x if x is not None else 0]).unsqueeze(0).to(device),
93
+ 'y': torch.LongTensor([y if y is not None else 0]).unsqueeze(0).to(device),
94
+ 'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
95
+ 'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
96
+ 'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
97
+ }
98
+ for key in keys_down:
99
+ inputs['key_events'][stoi[key]] = 1
100
+
101
+ if hidden_states is not None:
102
+ inputs['hidden_states'] = hidden_states
103
+
104
+ return inputs
105
+
106
+ @torch.no_grad()
107
+ def process_frame(
108
+ model: LatentDiffusion,
109
+ inputs: Dict[str, torch.Tensor]
110
+ ) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
111
+ """Process a single frame through the model."""
112
+ timing = {}
113
+
114
+ # Temporal encoding
115
+ start = time.perf_counter()
116
+ output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
117
+ timing['temporal_encoder'] = time.perf_counter() - start
118
+
119
+ # UNet sampling
120
+ start = time.perf_counter()
121
+ sampler = DDIMSampler(model)
122
+ sample_latent, _ = sampler.sample(
123
+ S=NUM_SAMPLING_STEPS,
124
+ conditioning={'c_concat': output_from_rnn},
125
+ batch_size=1,
126
+ shape=LATENT_DIMS,
127
+ verbose=False
128
+ )
129
+ timing['unet'] = time.perf_counter() - start
130
+
131
+ # Decoding
132
+ start = time.perf_counter()
133
+ sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
134
+ sample = model.decode_first_stage(sample)
135
+ sample = sample.squeeze(0).clamp(-1, 1)
136
+ timing['decode'] = time.perf_counter() - start
137
+
138
+ # Convert to image
139
+ sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
140
+
141
+ timing['total'] = sum(timing.values())
142
+
143
+ return sample_latent, sample_img, hidden_states, timing
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # Serve the index.html file at the root URL
147
+ @app.get("/")
148
+ async def get():
149
+ return HTMLResponse(open("static/index.html").read())
 
 
 
 
 
 
 
 
 
 
150
 
151
  # WebSocket endpoint for continuous user interaction
152
  @app.websocket("/ws")
153
+ async def websocket_endpoint(websocket: WebSocket):
 
 
 
154
  client_id = id(websocket) # Use a unique identifier for each connection
155
  print(f"New WebSocket connection: {client_id}")
156
  await websocket.accept()
 
 
 
 
 
 
 
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  try:
159
+ previous_frame = padding_image
160
+ hidden_states = None
161
+ keys_down = set() # Initialize as an empty set
162
+ frame_num = -1
163
  while True:
164
  try:
165
  # Receive user input with a timeout
166
  #data = await asyncio.wait_for(websocket.receive_json(), timeout=90000.0)
167
  data = await websocket.receive_json()
 
 
168
  if data.get("type") == "heartbeat":
169
  await websocket.send_json({"type": "heartbeat_response"})
170
  continue
171
+ frame_num += 1
172
+ start_frame = time.perf_counter()
173
+ x = data.get("x")
174
+ y = data.get("y")
175
+ is_left_click = data.get("is_left_click")
176
+ is_right_click = data.get("is_right_click")
177
+ keys_down_list = data.get("keys_down", []) # Get as list
178
+ keys_up_list = data.get("keys_up", [])
179
 
180
+ # Update the set based on the received data
181
+ for key in keys_down_list:
182
+ keys_down.add(key)
183
+ for key in keys_up_list:
184
+ if key in keys_down: # Check if key exists to avoid KeyError
185
+ keys_down.remove(key)
186
+
187
+ inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
188
+
189
+ previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
190
+ timing_info['full_frame'] = time.perf_counter() - start_frame
191
+
192
+ img = Image.fromarray(sample_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  buffered = io.BytesIO()
194
  img.save(buffered, format="PNG")
195
  img_str = base64.b64encode(buffered.getvalue()).decode()
196
 
 
 
 
197
  # Send the generated frame back to the client
198
  await websocket.send_json({"image": img_str})
199
 
static/index.html CHANGED
@@ -45,10 +45,11 @@
45
  if (data.type === "heartbeat_response") {
46
  console.log("Heartbeat response received");
47
  } else if (data.image) {
48
- const img = new Image();
49
  img.onload = function() {
50
- ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
51
- isProcessing = false; // Allow new inputs after drawing the image
 
52
  };
53
  img.src = 'data:image/png;base64,' + data.image;
54
  }
@@ -83,19 +84,30 @@
83
  let lastSentPosition = null;
84
  let lastSentTime = 0;
85
  const SEND_INTERVAL = 50; // Send updates every 50ms
 
 
 
86
 
87
- function sendMousePosition(x, y, forceUpdate = false) {
88
  const currentTime = Date.now();
89
- if (isConnected && !isProcessing && (forceUpdate || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
90
  try {
91
  socket.send(JSON.stringify({
92
- "action_type": "move",
93
- "mouse_position": [x, y]
 
 
 
 
94
  }));
95
  lastSentPosition = { x, y };
96
  lastSentTime = currentTime;
 
 
 
 
97
  } catch (error) {
98
- console.error("Error sending mouse position:", error);
99
  }
100
  }
101
  }
@@ -113,7 +125,7 @@
113
  ctx.lineTo(x, y);
114
  ctx.stroke();
115
 
116
- sendMousePosition(x, y);
117
  });
118
 
119
  canvas.addEventListener("click", function (event) {
@@ -122,14 +134,59 @@
122
  let x = event.clientX - rect.left;
123
  let y = event.clientY - rect.top;
124
 
125
- isProcessing = false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  try {
127
  socket.send(JSON.stringify({
128
- "action_type": "left_click",
129
- "mouse_position": [x, y]
 
 
 
 
130
  }));
131
  } catch (error) {
132
- console.error("Error sending click action:", error);
133
  }
134
  });
135
 
 
45
  if (data.type === "heartbeat_response") {
46
  console.log("Heartbeat response received");
47
  } else if (data.image) {
48
+ let img = new Image();
49
  img.onload = function() {
50
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
51
+ ctx.drawImage(img, 0, 0);
52
+ //isProcessing = false; // Reset the processing flag when we get a response
53
  };
54
  img.src = 'data:image/png;base64,' + data.image;
55
  }
 
84
  let lastSentPosition = null;
85
  let lastSentTime = 0;
86
  const SEND_INTERVAL = 50; // Send updates every 50ms
87
+
88
+ // Track currently pressed keys
89
+ const pressedKeys = new Set();
90
 
91
+ function sendInputState(x, y, isLeftClick = false, isRightClick = false) {
92
  const currentTime = Date.now();
93
+ if (isConnected && (isLeftClick || isRightClick || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
94
  try {
95
  socket.send(JSON.stringify({
96
+ "x": x,
97
+ "y": y,
98
+ "is_left_click": isLeftClick,
99
+ "is_right_click": isRightClick,
100
+ "keys_down": Array.from(pressedKeys),
101
+ "keys_up": [],
102
  }));
103
  lastSentPosition = { x, y };
104
  lastSentTime = currentTime;
105
+
106
+ //if (isLeftClick || isRightClick) {
107
+ // isProcessing = true; // Block further inputs until response
108
+ //}
109
  } catch (error) {
110
+ console.error("Error sending input state:", error);
111
  }
112
  }
113
  }
 
125
  ctx.lineTo(x, y);
126
  ctx.stroke();
127
 
128
+ sendInputState(x, y);
129
  });
130
 
131
  canvas.addEventListener("click", function (event) {
 
134
  let x = event.clientX - rect.left;
135
  let y = event.clientY - rect.top;
136
 
137
+ sendInputState(x, y, true, false);
138
+ });
139
+
140
+ // Handle right clicks
141
+ canvas.addEventListener("contextmenu", function (event) {
142
+ event.preventDefault(); // Prevent default context menu
143
+ if (!isConnected || isProcessing) return;
144
+
145
+ let rect = canvas.getBoundingClientRect();
146
+ let x = event.clientX - rect.left;
147
+ let y = event.clientY - rect.top;
148
+
149
+ sendInputState(x, y, false, true);
150
+ });
151
+
152
+ // Track keyboard events
153
+ document.addEventListener("keydown", function (event) {
154
+ if (!isConnected || isProcessing) return;
155
+
156
+ // Add the key to our set of pressed keys
157
+ pressedKeys.add(event.key);
158
+
159
+ // Get the current mouse position
160
+ let rect = canvas.getBoundingClientRect();
161
+ let x = lastSentPosition ? lastSentPosition.x : canvas.width / 2;
162
+ let y = lastSentPosition ? lastSentPosition.y : canvas.height / 2;
163
+
164
+ sendInputState(x, y);
165
+ });
166
+
167
+ document.addEventListener("keyup", function (event) {
168
+ if (!isConnected) return;
169
+
170
+ // Remove the key from our set of pressed keys
171
+ pressedKeys.delete(event.key);
172
+
173
+ // Get the current mouse position
174
+ let rect = canvas.getBoundingClientRect();
175
+ let x = lastSentPosition ? lastSentPosition.x : canvas.width / 2;
176
+ let y = lastSentPosition ? lastSentPosition.y : canvas.height / 2;
177
+
178
+ // For key up events, we send the key in the keys_up array
179
  try {
180
  socket.send(JSON.stringify({
181
+ "x": x,
182
+ "y": y,
183
+ "is_left_click": false,
184
+ "is_right_click": false,
185
+ "keys_down": Array.from(pressedKeys),
186
+ "keys_up": [event.key],
187
  }));
188
  } catch (error) {
189
+ console.error("Error sending key up event:", error);
190
  }
191
  });
192