Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
98a4a00
1
Parent(s):
46f6899
main.py
CHANGED
@@ -131,10 +131,11 @@ def load_initial_images(width, height):
|
|
131 |
initial_images = []
|
132 |
if DEBUG_TEACHER_FORCING:
|
133 |
# Load the previous 7 frames for image_81
|
134 |
-
for i in range(
|
135 |
img = Image.open(f"record_100/image_{i}.png").resize((width, height))
|
136 |
initial_images.append(np.array(img))
|
137 |
else:
|
|
|
138 |
for i in range(7):
|
139 |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
140 |
return initial_images
|
@@ -229,7 +230,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
229 |
x, y = pos
|
230 |
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
|
231 |
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
|
232 |
-
if
|
233 |
norm_x = x + (1920 - 512) / 2
|
234 |
norm_y = y + (1080 - 512) / 2
|
235 |
#if DEBUG:
|
@@ -306,6 +307,16 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
306 |
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
307 |
'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
|
308 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
previous_actions = []
|
310 |
for action in debug_actions[-8:]:
|
311 |
x, y, action_type = parse_action_string(action)
|
@@ -316,7 +327,18 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
316 |
'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
|
317 |
'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
|
318 |
'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
|
319 |
-
'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
#positions = positions[:4]
|
321 |
try:
|
322 |
while True:
|
@@ -340,12 +362,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
340 |
#mouse_position = position.split('~')
|
341 |
#mouse_position = [int(item) for item in mouse_position]
|
342 |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
343 |
-
if
|
344 |
position = positions[0]
|
345 |
positions = positions[1:]
|
346 |
x, y, action_type = parse_action_string(position)
|
347 |
mouse_position = (x, y)
|
348 |
-
|
|
|
349 |
#previous_actions = [(action_type, mouse_position)]
|
350 |
|
351 |
# Log the start time
|
@@ -361,7 +384,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
361 |
if False and DEBUG_TEACHER_FORCING:
|
362 |
img = Image.open(f"record_100/image_{82+len(previous_frames)}.png")
|
363 |
previous_frames.append(img)
|
364 |
-
|
365 |
previous_frames.append(next_frame_append)
|
366 |
|
367 |
# Convert the numpy array to a base64 encoded image
|
|
|
131 |
initial_images = []
|
132 |
if DEBUG_TEACHER_FORCING:
|
133 |
# Load the previous 7 frames for image_81
|
134 |
+
for i in range(222-7, 222): # Load images 74-80
|
135 |
img = Image.open(f"record_100/image_{i}.png").resize((width, height))
|
136 |
initial_images.append(np.array(img))
|
137 |
else:
|
138 |
+
assert False
|
139 |
for i in range(7):
|
140 |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
141 |
return initial_images
|
|
|
230 |
x, y = pos
|
231 |
#norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
|
232 |
#norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
|
233 |
+
if True and DEBUG_TEACHER_FORCING:
|
234 |
norm_x = x + (1920 - 512) / 2
|
235 |
norm_y = y + (1080 - 512) / 2
|
236 |
#if DEBUG:
|
|
|
307 |
'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
|
308 |
'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
|
309 |
]
|
310 |
+
debug_actions = [
|
311 |
+
'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8',
|
312 |
+
'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0',
|
313 |
+
'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3',
|
314 |
+
'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7',
|
315 |
+
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
316 |
+
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
317 |
+
'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
318 |
+
'N + 1 1 9 7 : + 0 2 9 7'
|
319 |
+
]
|
320 |
previous_actions = []
|
321 |
for action in debug_actions[-8:]:
|
322 |
x, y, action_type = parse_action_string(action)
|
|
|
327 |
'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
|
328 |
'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
|
329 |
'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
|
330 |
+
'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1'
|
331 |
+
]
|
332 |
+
positions = [
|
333 |
+
#'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
334 |
+
'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
|
335 |
+
'N + 1 1 7 9 : + 0 3 0 3', 'N + 1 1 4 2 : + 0 3 1 4',
|
336 |
+
'N + 1 1 0 6 : + 0 3 2 6', 'N + 1 0 6 9 : + 0 3 3 7',
|
337 |
+
'N + 1 0 5 1 : + 0 3 4 3', 'N + 1 0 1 4 : + 0 3 5 4',
|
338 |
+
'N + 0 9 7 8 : + 0 3 6 5', 'N + 0 9 4 2 : + 0 3 7 7',
|
339 |
+
'N + 0 9 0 5 : + 0 3 8 8', 'N + 0 8 6 8 : + 0 4 0 0',
|
340 |
+
'N + 0 8 3 2 : + 0 4 1 1'
|
341 |
+
]
|
342 |
#positions = positions[:4]
|
343 |
try:
|
344 |
while True:
|
|
|
362 |
#mouse_position = position.split('~')
|
363 |
#mouse_position = [int(item) for item in mouse_position]
|
364 |
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
365 |
+
if True and DEBUG_TEACHER_FORCING:
|
366 |
position = positions[0]
|
367 |
positions = positions[1:]
|
368 |
x, y, action_type = parse_action_string(position)
|
369 |
mouse_position = (x, y)
|
370 |
+
if False:
|
371 |
+
previous_actions.append((action_type, mouse_position))
|
372 |
#previous_actions = [(action_type, mouse_position)]
|
373 |
|
374 |
# Log the start time
|
|
|
384 |
if False and DEBUG_TEACHER_FORCING:
|
385 |
img = Image.open(f"record_100/image_{82+len(previous_frames)}.png")
|
386 |
previous_frames.append(img)
|
387 |
+
elif False:
|
388 |
previous_frames.append(next_frame_append)
|
389 |
|
390 |
# Convert the numpy array to a base64 encoded image
|