Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
b2e55f9
1
Parent(s):
f024b1d
- image_0.png +0 -3
- image_1.png +0 -3
- image_10.png +0 -3
- image_2.png +0 -3
- image_3.png +0 -3
- image_4.png +0 -3
- image_5.png +0 -3
- image_6.png +0 -3
- image_7.png +0 -3
- image_8.png +0 -3
- image_9.png +0 -3
- main.py +145 -519
- static/index.html +70 -13
image_0.png
DELETED
Git LFS Details
|
image_1.png
DELETED
Git LFS Details
|
image_10.png
DELETED
Git LFS Details
|
image_2.png
DELETED
Git LFS Details
|
image_3.png
DELETED
Git LFS Details
|
image_4.png
DELETED
Git LFS Details
|
image_5.png
DELETED
Git LFS Details
|
image_6.png
DELETED
Git LFS Details
|
image_7.png
DELETED
Git LFS Details
|
image_8.png
DELETED
Git LFS Details
|
image_9.png
DELETED
Git LFS Details
|
main.py
CHANGED
@@ -11,137 +11,20 @@ from utils import initialize_model, sample_frame
|
|
11 |
import torch
|
12 |
import os
|
13 |
import time
|
|
|
|
|
14 |
|
15 |
-
DEBUG = False
|
16 |
-
DEBUG_TEACHER_FORCING = False
|
17 |
-
app = FastAPI()
|
18 |
-
|
19 |
-
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
20 |
-
app.mount("/static", StaticFiles(directory="static"), name="static")
|
21 |
-
|
22 |
-
# Add this at the top with other global variables
|
23 |
-
all_click_positions = [] # Store all historical click positions
|
24 |
-
|
25 |
-
def parse_action_string(action_str):
|
26 |
-
"""Convert formatted action string to x, y coordinates
|
27 |
-
Args:
|
28 |
-
action_str: String like 'N N N N N : N N N N N' or '+ 0 2 1 3 : + 0 3 8 3'
|
29 |
-
Returns:
|
30 |
-
tuple: (x, y) coordinates or None if action is padding
|
31 |
-
"""
|
32 |
-
action_type = action_str[0]
|
33 |
-
action_str = action_str[1:].strip()
|
34 |
-
if 'N' in action_str:
|
35 |
-
return (None, None, None)
|
36 |
-
|
37 |
-
# Split into x and y parts
|
38 |
-
action_str = action_str.replace(' ', '')
|
39 |
-
x_part, y_part = action_str.split(':')
|
40 |
-
|
41 |
-
# Parse x: remove sign, join digits, convert to int, apply sign
|
42 |
-
|
43 |
-
x = int(x_part)
|
44 |
-
|
45 |
-
# Parse y: remove sign, join digits, convert to int, apply sign
|
46 |
-
y = int(y_part)
|
47 |
-
|
48 |
-
return x, y, action_type
|
49 |
-
|
50 |
-
def create_position_and_click_map(pos,action_type, image_height=48, image_width=64, original_width=512, original_height=384):
|
51 |
-
"""Convert cursor position to a binary position map
|
52 |
-
Args:
|
53 |
-
x, y: Original cursor positions
|
54 |
-
image_size: Size of the output position map (square)
|
55 |
-
original_width: Original screen width (1024)
|
56 |
-
original_height: Original screen height (640)
|
57 |
-
Returns:
|
58 |
-
torch.Tensor: Binary position map of shape (1, image_size, image_size)
|
59 |
-
"""
|
60 |
-
x, y = pos
|
61 |
-
if x is None:
|
62 |
-
return torch.zeros((1, image_height, image_width)), torch.zeros((1, image_height, image_width)), None, None
|
63 |
-
# Scale the positions to new size
|
64 |
-
#x_scaled = int((x / original_width) * image_size)
|
65 |
-
#y_scaled = int((y / original_height) * image_size)
|
66 |
-
#screen_width, screen_height = 512, 384
|
67 |
-
#video_width, video_height = 512, 384
|
68 |
-
|
69 |
-
#x_scaled = x - (screen_width / 2 - video_width / 2)
|
70 |
-
#y_scaled = y - (screen_height / 2 - video_height / 2)
|
71 |
-
x_scaled = int(x / original_width * image_width)
|
72 |
-
y_scaled = int(y / original_height * image_height)
|
73 |
-
|
74 |
-
# Clamp values to ensure they're within bounds
|
75 |
-
x_scaled = max(0, min(x_scaled, image_width - 1))
|
76 |
-
y_scaled = max(0, min(y_scaled, image_height - 1))
|
77 |
-
|
78 |
-
# Create binary position map
|
79 |
-
pos_map = torch.zeros((1, image_height, image_width))
|
80 |
-
pos_map[0, y_scaled, x_scaled] = 1.0
|
81 |
-
|
82 |
-
leftclick_map = torch.zeros((1, image_height, image_width))
|
83 |
-
if action_type == 'L':
|
84 |
-
print ('left click', x_scaled, y_scaled)
|
85 |
-
#print ('skipped')
|
86 |
-
if True:
|
87 |
-
leftclick_map[0, y_scaled, x_scaled] = 1.0
|
88 |
-
|
89 |
-
|
90 |
-
return pos_map, leftclick_map, x_scaled, y_scaled
|
91 |
-
|
92 |
-
# Serve the index.html file at the root URL
|
93 |
-
@app.get("/")
|
94 |
-
async def get():
|
95 |
-
return HTMLResponse(open("static/index.html").read())
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
# Draw all historical click positions
|
107 |
-
for click_x, click_y in all_click_positions:
|
108 |
-
x_draw = click_x # Scale factor for display
|
109 |
-
y_draw = click_y
|
110 |
-
# Draw historical clicks as red circles
|
111 |
-
draw.ellipse([x_draw-4, y_draw-4, x_draw+4, y_draw+4], fill=(255, 0, 0))
|
112 |
-
|
113 |
-
# Draw current trace
|
114 |
-
prev_x, prev_y = None, None
|
115 |
-
for i, (action_type, position) in enumerate(previous_actions):
|
116 |
-
x, y = position
|
117 |
-
if x == 0 and y == 0:
|
118 |
-
continue
|
119 |
-
|
120 |
-
x_draw = x
|
121 |
-
y_draw = y
|
122 |
-
|
123 |
-
# Draw movement positions as blue dots
|
124 |
-
draw.ellipse([x_draw-2, y_draw-2, x_draw+2, y_draw+2], fill=(0, 0, 255))
|
125 |
-
|
126 |
-
# Draw connecting lines
|
127 |
-
if prev_x is not None:
|
128 |
-
draw.line([prev_x, prev_y, x_draw, y_draw], fill=(0, 255, 0), width=1)
|
129 |
-
prev_x, prev_y = x_draw, y_draw
|
130 |
-
|
131 |
-
# Draw current position
|
132 |
-
if x_scaled >= 0 and y_scaled >= 0:
|
133 |
-
x_current = x_scaled * 8
|
134 |
-
y_current = y_scaled * 8
|
135 |
-
#if not DEBUG_TEACHER_FORCING:
|
136 |
-
# x_current = x_current *8
|
137 |
-
# y_current = y_current *8
|
138 |
-
print ('x_current, y_current', x_current, y_current)
|
139 |
-
draw.ellipse([x_current-3, y_current-3, x_current+3, y_current+3], fill=(0, 255, 0))
|
140 |
-
else:
|
141 |
-
assert False
|
142 |
-
|
143 |
-
return np.array(pil_image)
|
144 |
|
|
|
145 |
# Initialize the model at the start of your application
|
146 |
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
|
147 |
model = initialize_model("standard_challenging_context32_nocond_all.yaml", "yuntian-deng/computer-model")
|
@@ -150,424 +33,167 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
150 |
model = model.to(device)
|
151 |
#model = torch.compile(model)
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
-
|
167 |
-
images = np.stack(images).astype(np.float32)
|
168 |
-
if target_range == (-1, 1):
|
169 |
-
return images / 127.5 - 1
|
170 |
-
elif target_range == (0, 1):
|
171 |
-
return images / 255.0
|
172 |
-
else:
|
173 |
-
raise ValueError(f"Unsupported target range: {target_range}")
|
174 |
-
|
175 |
-
def normalize_image(image, target_range=(-1, 1)):
|
176 |
-
image = image.astype(np.float32)
|
177 |
-
if target_range == (-1, 1):
|
178 |
-
return image / 127.5 - 1
|
179 |
-
elif target_range == (0, 1):
|
180 |
-
return image / 255.0
|
181 |
-
else:
|
182 |
-
raise ValueError(f"Unsupported target range: {target_range}")
|
183 |
|
184 |
-
|
185 |
-
|
186 |
-
return ((image + 1) * 127.5).clip(0, 255).astype(np.uint8)
|
187 |
-
elif source_range == (0, 1):
|
188 |
-
return (image * 255).clip(0, 255).astype(np.uint8)
|
189 |
-
else:
|
190 |
-
raise ValueError(f"Unsupported source range: {source_range}")
|
191 |
-
|
192 |
-
def format_action(action_str, is_padding=False, is_leftclick=False):
|
193 |
-
if is_padding:
|
194 |
-
return "N N N N N N : N N N N N"
|
195 |
-
|
196 |
-
# Split the x~y coordinates
|
197 |
-
x, y = map(int, action_str.split('~'))
|
198 |
-
prefix = 'N'
|
199 |
-
if is_leftclick:
|
200 |
-
prefix = 'L'
|
201 |
-
# Convert numbers to padded strings and add spaces between digits
|
202 |
-
x_str = f"{abs(x):04d}"
|
203 |
-
y_str = f"{abs(y):04d}"
|
204 |
-
x_spaced = ' '.join(x_str)
|
205 |
-
y_spaced = ' '.join(y_str)
|
206 |
-
|
207 |
-
# Format with sign and proper spacing
|
208 |
-
return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
|
209 |
|
210 |
-
|
211 |
-
padding_image = torch.zeros((height//8, width//8, 4)).to(device)
|
212 |
-
data_mean = -0.54
|
213 |
-
data_std = 6.78
|
214 |
-
data_min = -27.681446075439453
|
215 |
-
data_max = 30.854148864746094
|
216 |
-
padding_image = (padding_image - data_mean) / data_std
|
217 |
|
218 |
-
def
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
-
if False:
|
256 |
-
prompt = 'N + 0 4 1 6 : + 0 3 2 0 L + 0 2 0 0 : + 0 1 7 6 N + 0 3 8 4 : + 0 0 4 8 N + 0 3 6 0 : + 0 2 5 6 N + 0 3 6 8 : + 0 0 1 6 N + 0 0 3 2 : + 0 1 0 4 L + 0 2 8 0 : + 0 0 4 0 L + 0 5 0 4 : + 0 0 7 2'
|
257 |
-
previous_actions = [('move', (416, 320)), ('left_click', (200, 176)), ('move', (384, 48)), ('move', (360, 256)), ('move', (368, 16)), ('move', (32, 104)), ('left_click', (280, 40)), ('left_click', (504, 72))]
|
258 |
-
prompt = 'N + 0 3 4 4 : + 0 3 2 0 N + 0 4 8 0 : + 0 1 2 8 N + 0 4 4 8 : + 0 3 6 0 N + 0 4 4 8 : + 0 0 6 4 N + 0 4 6 4 : + 0 3 3 6 N + 0 0 2 4 : + 0 1 3 6 N + 0 1 2 8 : + 0 2 8 0 N + 0 4 4 0 : + 0 0 4 8'
|
259 |
-
previous_actions = [('move', (344, 320)), ('move', (480, 128)), ('move', (448, 360)), ('move', (448, 64)), ('move', (464, 336)), ('move', (24, 136)), ('move', (128, 280)), ('move', (440, 48))]
|
260 |
-
prompt = 'N + 0 4 7 2 : + 0 1 6 0 N + 0 3 0 4 : + 0 2 7 2 N + 0 0 0 0 : + 0 1 7 6 N + 0 2 0 0 : + 0 0 3 2 N + 0 1 6 8 : + 0 0 5 6 L + 0 4 3 2 : + 0 0 4 0 L + 0 2 0 8 : + 0 2 7 2 L + 0 1 8 4 : + 0 0 0 8'
|
261 |
-
previous_actions = [('move', (472, 160)), ('move', (304, 272)), ('move', (0, 176)), ('move', (200, 32)), ('left_click', (168, 56)), ('left_click', (432, 40)), ('left_click', (208, 272)), ('left_click', (184, 8))]
|
262 |
-
prompt = 'N + 0 0 1 6 : + 0 3 2 8 N + 0 3 0 4 : + 0 0 9 6 N + 0 2 4 0 : + 0 1 9 2 N + 0 1 5 2 : + 0 0 5 6 L + 0 2 8 8 : + 0 1 7 6 L + 0 0 5 6 : + 0 3 7 6 N + 0 1 3 6 : + 0 3 6 0 N + 0 1 1 2 : + 0 0 4 8'
|
263 |
-
previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))]
|
264 |
-
prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4'
|
265 |
-
previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))]
|
266 |
-
for action_type, pos in previous_actions[-33:]:
|
267 |
-
#print ('here3', action_type, pos)
|
268 |
-
if action_type == 'move':
|
269 |
-
action_type = 'N'
|
270 |
-
if action_type == 'left_click':
|
271 |
-
action_type = 'L'
|
272 |
-
if action_type == "N":
|
273 |
-
x, y = pos
|
274 |
-
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
|
275 |
-
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
|
276 |
-
#norm_x = x + (1920 - 512) / 2
|
277 |
-
#norm_y = y + (1080 - 512) / 2
|
278 |
-
norm_x = x
|
279 |
-
norm_y = y
|
280 |
-
if False and DEBUG_TEACHER_FORCING:
|
281 |
-
norm_x = x
|
282 |
-
norm_y = y
|
283 |
-
#action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
|
284 |
-
#action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
|
285 |
-
action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0))
|
286 |
-
prev_x = norm_x
|
287 |
-
prev_y = norm_y
|
288 |
-
elif action_type == "L":
|
289 |
-
x, y = pos
|
290 |
-
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
|
291 |
-
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
|
292 |
-
#norm_x = x + (1920 - 512) / 2
|
293 |
-
#norm_y = y + (1080 - 512) / 2
|
294 |
-
norm_x = x
|
295 |
-
norm_y = y
|
296 |
-
if False and DEBUG_TEACHER_FORCING:
|
297 |
-
norm_x = x #+ (1920 - 512) / 2
|
298 |
-
norm_y = y #+ (1080 - 512) / 2
|
299 |
-
#if DEBUG:
|
300 |
-
# norm_x = x
|
301 |
-
# norm_y = y
|
302 |
-
#action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
|
303 |
-
#action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}', x==0 and y==0))
|
304 |
-
action_descriptions.append(format_action(f'{norm_x:.0f}~{norm_y:.0f}', x==0 and y==0, True))
|
305 |
-
elif action_type == "right_click":
|
306 |
-
assert False
|
307 |
-
action_descriptions.append("right_click")
|
308 |
-
else:
|
309 |
-
assert False
|
310 |
-
|
311 |
-
prompt = " ".join(action_descriptions[-33:])
|
312 |
-
print(prompt)
|
313 |
-
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
314 |
-
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
315 |
-
#pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
316 |
-
leftclick_maps = []
|
317 |
-
pos_maps = []
|
318 |
-
for j in range(1, 34):
|
319 |
-
print ('fsfs', action_descriptions[-j])
|
320 |
-
x, y, action_type = parse_action_string(action_descriptions[-j])
|
321 |
-
pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
|
322 |
-
leftclick_maps.append(leftclick_map_j)
|
323 |
-
pos_maps.append(pos_map_j)
|
324 |
-
if j == 1:
|
325 |
-
x_scaled = x_scaled_j
|
326 |
-
y_scaled = y_scaled_j
|
327 |
-
if action_type == 'L':
|
328 |
-
all_click_positions.append((x, y))
|
329 |
-
|
330 |
-
#prompt = ''
|
331 |
-
#prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
|
332 |
-
print(prompt)
|
333 |
-
#prompt = prompt.replace('L', 'N')
|
334 |
-
#print ('changing L to N')
|
335 |
-
|
336 |
-
# Generate the next frame
|
337 |
-
new_frame, new_frame_feedback = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
|
338 |
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
344 |
-
|
345 |
-
# Draw the trace of previous actions
|
346 |
-
new_frame_with_trace = draw_trace(new_frame_denormalized, previous_actions, x_scaled, y_scaled)
|
347 |
-
|
348 |
-
# Track click positions
|
349 |
-
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
350 |
-
|
351 |
-
|
352 |
-
return new_frame_with_trace, new_frame_denormalized, new_frame_feedback
|
353 |
|
354 |
# WebSocket endpoint for continuous user interaction
|
355 |
@app.websocket("/ws")
|
356 |
-
async def websocket_endpoint(websocket: WebSocket):
|
357 |
-
#global all_click_positions # Add this line
|
358 |
-
#all_click_positions = [] # Reset at the start of each connection
|
359 |
-
|
360 |
client_id = id(websocket) # Use a unique identifier for each connection
|
361 |
print(f"New WebSocket connection: {client_id}")
|
362 |
await websocket.accept()
|
363 |
-
previous_frames = []
|
364 |
-
previous_actions = []
|
365 |
-
positions = ['815~335', '787~342', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368']
|
366 |
-
#positions = ['815~335', '787~342', '749~345', '703~346', '703~346', '654~347', '654~347', '604~349', '555~353', '555~353', '509~357', '509~357', '468~362', '431~368', '431~368']
|
367 |
-
positions = ['307~375']
|
368 |
-
positions = ['815~335']
|
369 |
-
#positions = ['787~342']
|
370 |
-
positions = ['300~800']
|
371 |
|
372 |
-
if DEBUG_TEACHER_FORCING:
|
373 |
-
#print ('here2')
|
374 |
-
# Use the predefined actions for image_81
|
375 |
-
debug_actions = [
|
376 |
-
'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
|
377 |
-
'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
|
378 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
379 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
380 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
381 |
-
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
382 |
-
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
383 |
-
'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
|
384 |
-
]
|
385 |
-
debug_actions = [
|
386 |
-
'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8',
|
387 |
-
'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0',
|
388 |
-
'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3',
|
389 |
-
'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7',
|
390 |
-
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
391 |
-
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
392 |
-
'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
393 |
-
'N + 1 1 9 7 : + 0 2 9 7'
|
394 |
-
]
|
395 |
-
debug_actions = [
|
396 |
-
'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8',
|
397 |
-
'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0',
|
398 |
-
'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3',
|
399 |
-
'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7',
|
400 |
-
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
401 |
-
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
402 |
-
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
403 |
-
'N + 1 1 9 7 : + 0 2 9 7'
|
404 |
-
]
|
405 |
-
debug_actions = ['N + 0 0 4 0 : + 0 2 0 4', 'N + 0 1 3 8 : + 0 1 9 0',
|
406 |
-
'N + 0 2 7 4 : + 0 3 8 3', 'N + 0 5 0 1 : + 0 1 7 3',
|
407 |
-
'L + 0 4 7 3 : + 0 0 8 7', 'N + 0 1 0 9 : + 0 3 4 4',
|
408 |
-
'N + 0 0 5 2 : + 0 1 9 4', 'N + 0 3 6 5 : + 0 2 3 2',
|
409 |
-
'N + 0 3 8 9 : + 0 2 4 5', 'N + 0 0 2 0 : + 0 0 5 9',
|
410 |
-
'N + 0 4 7 3 : + 0 1 5 7', 'L + 0 1 9 1 : + 0 0 8 7',
|
411 |
-
'L + 0 1 9 1 : + 0 0 8 7', 'N + 0 3 4 3 : + 0 2 6 3', ]
|
412 |
-
#'N + 0 2 0 5 : + 0 1 3 3']
|
413 |
-
previous_actions = []
|
414 |
-
for action in debug_actions[-8:]:
|
415 |
-
#action = action.replace('1 1', '0 4')
|
416 |
-
x, y, action_type = parse_action_string(action)
|
417 |
-
previous_actions.append((action_type, (x, y)))
|
418 |
-
positions = [
|
419 |
-
'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 1 8 : + 0 4 9 2',
|
420 |
-
'N + 0 9 0 8 : + 0 4 8 3', 'N + 0 8 9 8 : + 0 4 7 4',
|
421 |
-
'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
|
422 |
-
'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
|
423 |
-
'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
|
424 |
-
'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1'
|
425 |
-
]
|
426 |
-
positions = [
|
427 |
-
#'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
428 |
-
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
429 |
-
'N + 1 1 7 9 : + 0 3 0 3', 'N + 1 1 4 2 : + 0 3 1 4',
|
430 |
-
'N + 1 1 0 6 : + 0 3 2 6', 'N + 1 0 6 9 : + 0 3 3 7',
|
431 |
-
'N + 1 0 5 1 : + 0 3 4 3', 'N + 1 0 1 4 : + 0 3 5 4',
|
432 |
-
'N + 0 9 7 8 : + 0 3 6 5', 'N + 0 9 4 2 : + 0 3 7 7',
|
433 |
-
'N + 0 9 0 5 : + 0 3 8 8', 'N + 0 8 6 8 : + 0 4 0 0',
|
434 |
-
'N + 0 8 3 2 : + 0 4 1 1'
|
435 |
-
]
|
436 |
-
positions = ['L + 0 1 9 1 : + 0 0 8 7',
|
437 |
-
'L + 0 1 9 1 : + 0 0 8 7', 'N + 0 3 4 3 : + 0 2 6 3',
|
438 |
-
'N + 0 2 0 5 : + 0 1 3 3', 'N + 0 0 7 6 : + 0 3 4 5',
|
439 |
-
'N + 0 3 1 8 : + 0 3 3 3', 'N + 0 2 5 4 : + 0 2 9 0',
|
440 |
-
'N + 0 1 0 6 : + 0 1 6 4', 'N + 0 0 7 4 : + 0 2 8 4',
|
441 |
-
'N + 0 0 2 4 : + 0 0 4 1', 'N + 0 1 5 0 : + 0 3 8 3',
|
442 |
-
'N + 0 4 0 5 : + 0 1 6 8', 'N + 0 0 5 4 : + 0 3 2 4',
|
443 |
-
'N + 0 2 9 0 : + 0 1 4 1', 'N + 0 4 0 2 : + 0 0 0 9',
|
444 |
-
'N + 0 3 0 7 : + 0 3 3 2', 'N + 0 2 2 0 : + 0 3 7 1',
|
445 |
-
'N + 0 0 8 2 : + 0 1 5 1']
|
446 |
-
positions = positions[3:]
|
447 |
-
#positions = positions[:4]
|
448 |
-
#position = positions[0]
|
449 |
-
#positions = positions[1:]
|
450 |
-
#x, y, action_type = parse_action_string(position)
|
451 |
-
#mouse_position = (x, y)
|
452 |
-
|
453 |
-
#previous_actions.append((action_type, mouse_position))
|
454 |
-
|
455 |
-
if not DEBUG_TEACHER_FORCING:
|
456 |
-
previous_actions = []
|
457 |
-
|
458 |
-
for t in range(15): # Generate 15 actions
|
459 |
-
# Random movement
|
460 |
-
x = np.random.randint(0, 64)
|
461 |
-
y = np.random.randint(0, 48)
|
462 |
-
#x = max(0, min(63, x + dx))
|
463 |
-
#y = max(0, min(47, y + dy))
|
464 |
-
|
465 |
-
# Random click with 20% probability
|
466 |
-
if np.random.random() < 0.2:
|
467 |
-
action_type = 'L'
|
468 |
-
else:
|
469 |
-
action_type = 'N'
|
470 |
-
|
471 |
-
# Format action string
|
472 |
-
previous_actions.append((action_type, (x*8, y*8)))
|
473 |
try:
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
while True:
|
479 |
try:
|
480 |
# Receive user input with a timeout
|
481 |
#data = await asyncio.wait_for(websocket.receive_json(), timeout=90000.0)
|
482 |
data = await websocket.receive_json()
|
483 |
-
|
484 |
-
|
485 |
if data.get("type") == "heartbeat":
|
486 |
await websocket.send_json({"type": "heartbeat_response"})
|
487 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
position = positions[0]
|
503 |
-
#positions = positions[1:]
|
504 |
-
#mouse_position = position.split('~')
|
505 |
-
#mouse_position = [int(item) for item in mouse_position]
|
506 |
-
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
507 |
-
if DEBUG_TEACHER_FORCING:
|
508 |
-
position = positions[0]
|
509 |
-
positions = positions[1:]
|
510 |
-
x, y, action_type = parse_action_string(position)
|
511 |
-
mouse_position = (x, y)
|
512 |
-
previous_actions.append((action_type, mouse_position))
|
513 |
-
if True:
|
514 |
-
previous_actions.append((action_type, mouse_position))
|
515 |
-
#previous_actions = [(action_type, mouse_position)]
|
516 |
-
#if not DEBUG_TEACHER_FORCING:
|
517 |
-
# x, y = mouse_position
|
518 |
-
# x = x//8 * 8
|
519 |
-
# y = y // 8 * 8
|
520 |
-
# assert x % 8 == 0
|
521 |
-
# assert y % 8 == 0
|
522 |
-
# mouse_position = (x, y)
|
523 |
-
# #mouse_position = (x//8, y//8)
|
524 |
-
# previous_actions.append((action_type, mouse_position))
|
525 |
-
# Log the start time
|
526 |
-
start_time = time.time()
|
527 |
-
|
528 |
-
# Predict the next frame based on the previous frames and actions
|
529 |
-
#if DEBUG_TEACHER_FORCING:
|
530 |
-
# print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
|
531 |
-
print ('previous_actions', previous_actions)
|
532 |
-
next_frame, next_frame_append, next_frame_feedback = predict_next_frame(previous_frames, previous_actions)
|
533 |
-
feedback = True
|
534 |
-
if feedback:
|
535 |
-
next_frame_feedback = torch.einsum('chw->hwc', next_frame_feedback)
|
536 |
-
print (f'appending feedback of shape {next_frame_feedback.shape}')
|
537 |
-
previous_frames.append(next_frame_feedback)
|
538 |
-
else:
|
539 |
-
#previous_frames = []
|
540 |
-
previous_actions = []
|
541 |
-
processing_time = time.time() - start_time
|
542 |
-
print(f"Frame processing time: {processing_time:.2f} seconds")
|
543 |
-
frame_times.append(processing_time)
|
544 |
-
frames_since_update += 1
|
545 |
-
print (f"Average frame processing time: {np.mean(frame_times):.2f} seconds")
|
546 |
-
fps = 1 / np.mean(frame_times)
|
547 |
-
print (f"FPS: {fps:.2f}")
|
548 |
-
|
549 |
-
#previous_actions = []
|
550 |
-
# Load and append the corresponding ground truth image instead of model output
|
551 |
-
#print ('here4', len(previous_frames))
|
552 |
-
#if DEBUG_TEACHER_FORCING:
|
553 |
-
# img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
|
554 |
-
# previous_frames.append(np.array(img))
|
555 |
-
#else:
|
556 |
-
# assert False
|
557 |
-
# previous_frames.append(next_frame_append)
|
558 |
-
# pass
|
559 |
-
#previous_frames = []
|
560 |
-
#previous_actions = []
|
561 |
-
|
562 |
-
# Convert the numpy array to a base64 encoded image
|
563 |
-
img = Image.fromarray(next_frame)
|
564 |
buffered = io.BytesIO()
|
565 |
img.save(buffered, format="PNG")
|
566 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
567 |
|
568 |
-
# Log the processing time
|
569 |
-
|
570 |
-
|
571 |
# Send the generated frame back to the client
|
572 |
await websocket.send_json({"image": img_str})
|
573 |
|
|
|
11 |
import torch
|
12 |
import os
|
13 |
import time
|
14 |
+
from typing import Any, Dict
|
15 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion, DDIMSampler
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
SCREEN_WIDTH = 512
|
19 |
+
SCREEN_HEIGHT = 384
|
20 |
+
NUM_SAMPLING_STEPS = 8
|
21 |
+
DATA_NORMALIZATION = {
|
22 |
+
'mean': -0.54,
|
23 |
+
'std': 6.78,
|
24 |
+
}
|
25 |
+
LATENT_DIMS = (1, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
# Initialize the model at the start of your application
|
29 |
#model = initialize_model("config_csllm.yaml", "yuntian-deng/computer-model")
|
30 |
model = initialize_model("standard_challenging_context32_nocond_all.yaml", "yuntian-deng/computer-model")
|
|
|
33 |
model = model.to(device)
|
34 |
#model = torch.compile(model)
|
35 |
|
36 |
+
padding_image = torch.zeros(1, SCREEN_HEIGHT // 8, SCREEN_WIDTH // 8, 4)
|
37 |
+
padding_image = (padding_image - DATA_NORMALIZATION['mean']) / DATA_NORMALIZATION['std']
|
38 |
+
padding_image = padding_image.to(device)
|
39 |
+
|
40 |
+
# Valid keyboard inputs
|
41 |
+
KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(',
|
42 |
+
')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7',
|
43 |
+
'8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
|
44 |
+
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o',
|
45 |
+
'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~',
|
46 |
+
'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace',
|
47 |
+
'browserback', 'browserfavorites', 'browserforward', 'browserhome',
|
48 |
+
'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear',
|
49 |
+
'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete',
|
50 |
+
'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10',
|
51 |
+
'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20',
|
52 |
+
'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
|
53 |
+
'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja',
|
54 |
+
'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail',
|
55 |
+
'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack',
|
56 |
+
'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6',
|
57 |
+
'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn',
|
58 |
+
'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn',
|
59 |
+
'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator',
|
60 |
+
'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab',
|
61 |
+
'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen',
|
62 |
+
'command', 'option', 'optionleft', 'optionright']
|
63 |
+
INVALID_KEYS = ['f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20',
|
64 |
+
'f21', 'f22', 'f23', 'f24', 'select', 'separator', 'execute']
|
65 |
+
VALID_KEYS = [key for key in KEYS if key not in INVALID_KEYS]
|
66 |
+
itos = VALID_KEYS
|
67 |
+
stoi = {key: i for i, key in enumerate(itos)}
|
68 |
|
69 |
+
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
72 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
# Add this at the top with other global variables
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
def prepare_model_inputs(
|
77 |
+
previous_frame: torch.Tensor,
|
78 |
+
hidden_states: Any,
|
79 |
+
x: int,
|
80 |
+
y: int,
|
81 |
+
right_click: bool,
|
82 |
+
left_click: bool,
|
83 |
+
keys_down: List[str],
|
84 |
+
stoi: Dict[str, int],
|
85 |
+
itos: List[str],
|
86 |
+
time_step: int
|
87 |
+
) -> Dict[str, torch.Tensor]:
|
88 |
+
"""Prepare inputs for the model."""
|
89 |
+
inputs = {
|
90 |
+
'image_features': previous_frame.to(device),
|
91 |
+
'is_padding': torch.BoolTensor([time_step == 0]).to(device),
|
92 |
+
'x': torch.LongTensor([x if x is not None else 0]).unsqueeze(0).to(device),
|
93 |
+
'y': torch.LongTensor([y if y is not None else 0]).unsqueeze(0).to(device),
|
94 |
+
'is_leftclick': torch.BoolTensor([left_click]).unsqueeze(0).to(device),
|
95 |
+
'is_rightclick': torch.BoolTensor([right_click]).unsqueeze(0).to(device),
|
96 |
+
'key_events': torch.zeros(len(itos), dtype=torch.long).to(device)
|
97 |
+
}
|
98 |
+
for key in keys_down:
|
99 |
+
inputs['key_events'][stoi[key]] = 1
|
100 |
+
|
101 |
+
if hidden_states is not None:
|
102 |
+
inputs['hidden_states'] = hidden_states
|
103 |
+
|
104 |
+
return inputs
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def process_frame(
|
108 |
+
model: LatentDiffusion,
|
109 |
+
inputs: Dict[str, torch.Tensor]
|
110 |
+
) -> Tuple[torch.Tensor, np.ndarray, Any, Dict[str, float]]:
|
111 |
+
"""Process a single frame through the model."""
|
112 |
+
timing = {}
|
113 |
+
|
114 |
+
# Temporal encoding
|
115 |
+
start = time.perf_counter()
|
116 |
+
output_from_rnn, hidden_states = model.temporal_encoder.forward_step(inputs)
|
117 |
+
timing['temporal_encoder'] = time.perf_counter() - start
|
118 |
+
|
119 |
+
# UNet sampling
|
120 |
+
start = time.perf_counter()
|
121 |
+
sampler = DDIMSampler(model)
|
122 |
+
sample_latent, _ = sampler.sample(
|
123 |
+
S=NUM_SAMPLING_STEPS,
|
124 |
+
conditioning={'c_concat': output_from_rnn},
|
125 |
+
batch_size=1,
|
126 |
+
shape=LATENT_DIMS,
|
127 |
+
verbose=False
|
128 |
+
)
|
129 |
+
timing['unet'] = time.perf_counter() - start
|
130 |
+
|
131 |
+
# Decoding
|
132 |
+
start = time.perf_counter()
|
133 |
+
sample = sample_latent * DATA_NORMALIZATION['std'] + DATA_NORMALIZATION['mean']
|
134 |
+
sample = model.decode_first_stage(sample)
|
135 |
+
sample = sample.squeeze(0).clamp(-1, 1)
|
136 |
+
timing['decode'] = time.perf_counter() - start
|
137 |
+
|
138 |
+
# Convert to image
|
139 |
+
sample_img = ((sample[:3].transpose(0,1).transpose(1,2).cpu().float().numpy() + 1) * 127.5).astype(np.uint8)
|
140 |
+
|
141 |
+
timing['total'] = sum(timing.values())
|
142 |
+
|
143 |
+
return sample_latent, sample_img, hidden_states, timing
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
# Serve the index.html file at the root URL
|
147 |
+
@app.get("/")
|
148 |
+
async def get():
|
149 |
+
return HTMLResponse(open("static/index.html").read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# WebSocket endpoint for continuous user interaction
|
152 |
@app.websocket("/ws")
|
153 |
+
async def websocket_endpoint(websocket: WebSocket):
|
|
|
|
|
|
|
154 |
client_id = id(websocket) # Use a unique identifier for each connection
|
155 |
print(f"New WebSocket connection: {client_id}")
|
156 |
await websocket.accept()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
try:
|
159 |
+
previous_frame = padding_image
|
160 |
+
hidden_states = None
|
161 |
+
keys_down = set() # Initialize as an empty set
|
162 |
+
frame_num = -1
|
163 |
while True:
|
164 |
try:
|
165 |
# Receive user input with a timeout
|
166 |
#data = await asyncio.wait_for(websocket.receive_json(), timeout=90000.0)
|
167 |
data = await websocket.receive_json()
|
|
|
|
|
168 |
if data.get("type") == "heartbeat":
|
169 |
await websocket.send_json({"type": "heartbeat_response"})
|
170 |
continue
|
171 |
+
frame_num += 1
|
172 |
+
start_frame = time.perf_counter()
|
173 |
+
x = data.get("x")
|
174 |
+
y = data.get("y")
|
175 |
+
is_left_click = data.get("is_left_click")
|
176 |
+
is_right_click = data.get("is_right_click")
|
177 |
+
keys_down_list = data.get("keys_down", []) # Get as list
|
178 |
+
keys_up_list = data.get("keys_up", [])
|
179 |
|
180 |
+
# Update the set based on the received data
|
181 |
+
for key in keys_down_list:
|
182 |
+
keys_down.add(key)
|
183 |
+
for key in keys_up_list:
|
184 |
+
if key in keys_down: # Check if key exists to avoid KeyError
|
185 |
+
keys_down.remove(key)
|
186 |
+
|
187 |
+
inputs = prepare_model_inputs(previous_frame, hidden_states, x, y, is_right_click, is_left_click, list(keys_down), stoi, itos, frame_num)
|
188 |
+
|
189 |
+
previous_frame, sample_img, hidden_states, timing_info = process_frame(model, inputs)
|
190 |
+
timing_info['full_frame'] = time.perf_counter() - start_frame
|
191 |
+
|
192 |
+
img = Image.fromarray(sample_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
buffered = io.BytesIO()
|
194 |
img.save(buffered, format="PNG")
|
195 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
196 |
|
|
|
|
|
|
|
197 |
# Send the generated frame back to the client
|
198 |
await websocket.send_json({"image": img_str})
|
199 |
|
static/index.html
CHANGED
@@ -45,10 +45,11 @@
|
|
45 |
if (data.type === "heartbeat_response") {
|
46 |
console.log("Heartbeat response received");
|
47 |
} else if (data.image) {
|
48 |
-
|
49 |
img.onload = function() {
|
50 |
-
ctx.
|
51 |
-
|
|
|
52 |
};
|
53 |
img.src = 'data:image/png;base64,' + data.image;
|
54 |
}
|
@@ -83,19 +84,30 @@
|
|
83 |
let lastSentPosition = null;
|
84 |
let lastSentTime = 0;
|
85 |
const SEND_INTERVAL = 50; // Send updates every 50ms
|
|
|
|
|
|
|
86 |
|
87 |
-
function
|
88 |
const currentTime = Date.now();
|
89 |
-
if (isConnected &&
|
90 |
try {
|
91 |
socket.send(JSON.stringify({
|
92 |
-
"
|
93 |
-
"
|
|
|
|
|
|
|
|
|
94 |
}));
|
95 |
lastSentPosition = { x, y };
|
96 |
lastSentTime = currentTime;
|
|
|
|
|
|
|
|
|
97 |
} catch (error) {
|
98 |
-
console.error("Error sending
|
99 |
}
|
100 |
}
|
101 |
}
|
@@ -113,7 +125,7 @@
|
|
113 |
ctx.lineTo(x, y);
|
114 |
ctx.stroke();
|
115 |
|
116 |
-
|
117 |
});
|
118 |
|
119 |
canvas.addEventListener("click", function (event) {
|
@@ -122,14 +134,59 @@
|
|
122 |
let x = event.clientX - rect.left;
|
123 |
let y = event.clientY - rect.top;
|
124 |
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
try {
|
127 |
socket.send(JSON.stringify({
|
128 |
-
"
|
129 |
-
"
|
|
|
|
|
|
|
|
|
130 |
}));
|
131 |
} catch (error) {
|
132 |
-
console.error("Error sending
|
133 |
}
|
134 |
});
|
135 |
|
|
|
45 |
if (data.type === "heartbeat_response") {
|
46 |
console.log("Heartbeat response received");
|
47 |
} else if (data.image) {
|
48 |
+
let img = new Image();
|
49 |
img.onload = function() {
|
50 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
51 |
+
ctx.drawImage(img, 0, 0);
|
52 |
+
//isProcessing = false; // Reset the processing flag when we get a response
|
53 |
};
|
54 |
img.src = 'data:image/png;base64,' + data.image;
|
55 |
}
|
|
|
84 |
let lastSentPosition = null;
|
85 |
let lastSentTime = 0;
|
86 |
const SEND_INTERVAL = 50; // Send updates every 50ms
|
87 |
+
|
88 |
+
// Track currently pressed keys
|
89 |
+
const pressedKeys = new Set();
|
90 |
|
91 |
+
function sendInputState(x, y, isLeftClick = false, isRightClick = false) {
|
92 |
const currentTime = Date.now();
|
93 |
+
if (isConnected && (isLeftClick || isRightClick || !lastSentPosition || currentTime - lastSentTime >= SEND_INTERVAL)) {
|
94 |
try {
|
95 |
socket.send(JSON.stringify({
|
96 |
+
"x": x,
|
97 |
+
"y": y,
|
98 |
+
"is_left_click": isLeftClick,
|
99 |
+
"is_right_click": isRightClick,
|
100 |
+
"keys_down": Array.from(pressedKeys),
|
101 |
+
"keys_up": [],
|
102 |
}));
|
103 |
lastSentPosition = { x, y };
|
104 |
lastSentTime = currentTime;
|
105 |
+
|
106 |
+
//if (isLeftClick || isRightClick) {
|
107 |
+
// isProcessing = true; // Block further inputs until response
|
108 |
+
//}
|
109 |
} catch (error) {
|
110 |
+
console.error("Error sending input state:", error);
|
111 |
}
|
112 |
}
|
113 |
}
|
|
|
125 |
ctx.lineTo(x, y);
|
126 |
ctx.stroke();
|
127 |
|
128 |
+
sendInputState(x, y);
|
129 |
});
|
130 |
|
131 |
canvas.addEventListener("click", function (event) {
|
|
|
134 |
let x = event.clientX - rect.left;
|
135 |
let y = event.clientY - rect.top;
|
136 |
|
137 |
+
sendInputState(x, y, true, false);
|
138 |
+
});
|
139 |
+
|
140 |
+
// Handle right clicks
|
141 |
+
canvas.addEventListener("contextmenu", function (event) {
|
142 |
+
event.preventDefault(); // Prevent default context menu
|
143 |
+
if (!isConnected || isProcessing) return;
|
144 |
+
|
145 |
+
let rect = canvas.getBoundingClientRect();
|
146 |
+
let x = event.clientX - rect.left;
|
147 |
+
let y = event.clientY - rect.top;
|
148 |
+
|
149 |
+
sendInputState(x, y, false, true);
|
150 |
+
});
|
151 |
+
|
152 |
+
// Track keyboard events
|
153 |
+
document.addEventListener("keydown", function (event) {
|
154 |
+
if (!isConnected || isProcessing) return;
|
155 |
+
|
156 |
+
// Add the key to our set of pressed keys
|
157 |
+
pressedKeys.add(event.key);
|
158 |
+
|
159 |
+
// Get the current mouse position
|
160 |
+
let rect = canvas.getBoundingClientRect();
|
161 |
+
let x = lastSentPosition ? lastSentPosition.x : canvas.width / 2;
|
162 |
+
let y = lastSentPosition ? lastSentPosition.y : canvas.height / 2;
|
163 |
+
|
164 |
+
sendInputState(x, y);
|
165 |
+
});
|
166 |
+
|
167 |
+
document.addEventListener("keyup", function (event) {
|
168 |
+
if (!isConnected) return;
|
169 |
+
|
170 |
+
// Remove the key from our set of pressed keys
|
171 |
+
pressedKeys.delete(event.key);
|
172 |
+
|
173 |
+
// Get the current mouse position
|
174 |
+
let rect = canvas.getBoundingClientRect();
|
175 |
+
let x = lastSentPosition ? lastSentPosition.x : canvas.width / 2;
|
176 |
+
let y = lastSentPosition ? lastSentPosition.y : canvas.height / 2;
|
177 |
+
|
178 |
+
// For key up events, we send the key in the keys_up array
|
179 |
try {
|
180 |
socket.send(JSON.stringify({
|
181 |
+
"x": x,
|
182 |
+
"y": y,
|
183 |
+
"is_left_click": false,
|
184 |
+
"is_right_click": false,
|
185 |
+
"keys_down": Array.from(pressedKeys),
|
186 |
+
"keys_up": [event.key],
|
187 |
}));
|
188 |
} catch (error) {
|
189 |
+
console.error("Error sending key up event:", error);
|
190 |
}
|
191 |
});
|
192 |
|