Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
1a5d2e7
1
Parent(s):
8f6c968
main.py
CHANGED
@@ -13,7 +13,7 @@ import os
|
|
13 |
import time
|
14 |
|
15 |
DEBUG = False
|
16 |
-
DEBUG_TEACHER_FORCING =
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
@@ -156,7 +156,7 @@ def load_initial_images(width, height):
|
|
156 |
initial_images.append(np.array(img))
|
157 |
else:
|
158 |
#assert False
|
159 |
-
for i in range(
|
160 |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
161 |
return initial_images
|
162 |
|
@@ -202,10 +202,10 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
202 |
print ('length of previous_frames', len(previous_frames))
|
203 |
|
204 |
# Prepare the image sequence for the model
|
205 |
-
assert len(initial_images) ==
|
206 |
-
image_sequence = previous_frames[-
|
207 |
i = 1
|
208 |
-
while len(image_sequence) <
|
209 |
image_sequence.insert(0, initial_images[-i])
|
210 |
i += 1
|
211 |
#image_sequence.append(initial_images[len(image_sequence)])
|
@@ -213,18 +213,23 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
213 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
214 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
215 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
# Prepare the prompt based on the previous actions
|
218 |
action_descriptions = []
|
219 |
#initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
220 |
-
initial_actions = ['0:0'] *
|
221 |
#initial_actions = ['N N N N N : N N N N N'] * 7
|
222 |
def unnorm_coords(x, y):
|
223 |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
224 |
|
225 |
# Process initial actions if there are not enough previous actions
|
226 |
-
while len(previous_actions) <
|
227 |
-
assert False
|
228 |
x, y = map(int, initial_actions.pop(0).split(':'))
|
229 |
previous_actions.insert(0, ("N", unnorm_coords(x, y)))
|
230 |
prev_x = 0
|
@@ -242,7 +247,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
242 |
previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))]
|
243 |
prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4'
|
244 |
previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))]
|
245 |
-
for action_type, pos in previous_actions[-
|
246 |
#print ('here3', action_type, pos)
|
247 |
if action_type == 'move':
|
248 |
action_type = 'N'
|
@@ -287,14 +292,14 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
287 |
else:
|
288 |
assert False
|
289 |
|
290 |
-
prompt = " ".join(action_descriptions[-
|
291 |
print(prompt)
|
292 |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
293 |
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
294 |
#pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
295 |
leftclick_maps = []
|
296 |
pos_maps = []
|
297 |
-
for j in range(1,
|
298 |
print ('fsfs', action_descriptions[-j])
|
299 |
x, y, action_type = parse_action_string(action_descriptions[-j])
|
300 |
pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
|
@@ -318,6 +323,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
318 |
# Convert the generated frame to the correct format
|
319 |
new_frame = new_frame.transpose(1, 2, 0)
|
320 |
print (new_frame.max(), new_frame.min())
|
|
|
321 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
322 |
|
323 |
# Draw the trace of previous actions
|
@@ -429,6 +435,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
429 |
#mouse_position = (x, y)
|
430 |
|
431 |
#previous_actions.append((action_type, mouse_position))
|
|
|
432 |
if not DEBUG_TEACHER_FORCING:
|
433 |
previous_actions = []
|
434 |
|
@@ -448,6 +455,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
448 |
# Format action string
|
449 |
previous_actions.append((action_type, (x*8, y*8)))
|
450 |
try:
|
|
|
|
|
451 |
while True:
|
452 |
try:
|
453 |
# Receive user input with a timeout
|
@@ -483,36 +492,37 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
483 |
x, y, action_type = parse_action_string(position)
|
484 |
mouse_position = (x, y)
|
485 |
previous_actions.append((action_type, mouse_position))
|
486 |
-
if
|
487 |
previous_actions.append((action_type, mouse_position))
|
488 |
#previous_actions = [(action_type, mouse_position)]
|
489 |
-
if not DEBUG_TEACHER_FORCING:
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
# Log the start time
|
499 |
start_time = time.time()
|
500 |
|
501 |
# Predict the next frame based on the previous frames and actions
|
502 |
-
if DEBUG_TEACHER_FORCING:
|
503 |
-
|
504 |
-
|
505 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
506 |
# Load and append the corresponding ground truth image instead of model output
|
507 |
-
print ('here4', len(previous_frames))
|
508 |
-
if DEBUG_TEACHER_FORCING:
|
509 |
-
|
510 |
-
|
511 |
-
else:
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
|
|
516 |
|
517 |
# Convert the numpy array to a base64 encoded image
|
518 |
img = Image.fromarray(next_frame)
|
|
|
13 |
import time
|
14 |
|
15 |
DEBUG = False
|
16 |
+
DEBUG_TEACHER_FORCING = False
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
|
|
156 |
initial_images.append(np.array(img))
|
157 |
else:
|
158 |
#assert False
|
159 |
+
for i in range(32):
|
160 |
initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
|
161 |
return initial_images
|
162 |
|
|
|
202 |
print ('length of previous_frames', len(previous_frames))
|
203 |
|
204 |
# Prepare the image sequence for the model
|
205 |
+
assert len(initial_images) == 32
|
206 |
+
image_sequence = previous_frames[-32:] # Take the last 7 frames
|
207 |
i = 1
|
208 |
+
while len(image_sequence) < 32:
|
209 |
image_sequence.insert(0, initial_images[-i])
|
210 |
i += 1
|
211 |
#image_sequence.append(initial_images[len(image_sequence)])
|
|
|
213 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
214 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
215 |
image_sequence_tensor = image_sequence_tensor.to(device)
|
216 |
+
data_mean = -0.54
|
217 |
+
data_std = 6.78
|
218 |
+
data_min = -27.681446075439453
|
219 |
+
data_max = 30.854148864746094
|
220 |
+
image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
|
221 |
|
222 |
# Prepare the prompt based on the previous actions
|
223 |
action_descriptions = []
|
224 |
#initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
225 |
+
initial_actions = ['0:0'] * 32
|
226 |
#initial_actions = ['N N N N N : N N N N N'] * 7
|
227 |
def unnorm_coords(x, y):
|
228 |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
229 |
|
230 |
# Process initial actions if there are not enough previous actions
|
231 |
+
while len(previous_actions) < 33:
|
232 |
+
#assert False
|
233 |
x, y = map(int, initial_actions.pop(0).split(':'))
|
234 |
previous_actions.insert(0, ("N", unnorm_coords(x, y)))
|
235 |
prev_x = 0
|
|
|
247 |
previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))]
|
248 |
prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4'
|
249 |
previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))]
|
250 |
+
for action_type, pos in previous_actions[-33:]:
|
251 |
#print ('here3', action_type, pos)
|
252 |
if action_type == 'move':
|
253 |
action_type = 'N'
|
|
|
292 |
else:
|
293 |
assert False
|
294 |
|
295 |
+
prompt = " ".join(action_descriptions[-33:])
|
296 |
print(prompt)
|
297 |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
298 |
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
299 |
#pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
300 |
leftclick_maps = []
|
301 |
pos_maps = []
|
302 |
+
for j in range(1, 34):
|
303 |
print ('fsfs', action_descriptions[-j])
|
304 |
x, y, action_type = parse_action_string(action_descriptions[-j])
|
305 |
pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
|
|
|
323 |
# Convert the generated frame to the correct format
|
324 |
new_frame = new_frame.transpose(1, 2, 0)
|
325 |
print (new_frame.max(), new_frame.min())
|
326 |
+
new_frame = new_frame * data_std + data_mean
|
327 |
new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
|
328 |
|
329 |
# Draw the trace of previous actions
|
|
|
435 |
#mouse_position = (x, y)
|
436 |
|
437 |
#previous_actions.append((action_type, mouse_position))
|
438 |
+
|
439 |
if not DEBUG_TEACHER_FORCING:
|
440 |
previous_actions = []
|
441 |
|
|
|
455 |
# Format action string
|
456 |
previous_actions.append((action_type, (x*8, y*8)))
|
457 |
try:
|
458 |
+
previous_actions = []
|
459 |
+
previous_frames = []
|
460 |
while True:
|
461 |
try:
|
462 |
# Receive user input with a timeout
|
|
|
492 |
x, y, action_type = parse_action_string(position)
|
493 |
mouse_position = (x, y)
|
494 |
previous_actions.append((action_type, mouse_position))
|
495 |
+
if True:
|
496 |
previous_actions.append((action_type, mouse_position))
|
497 |
#previous_actions = [(action_type, mouse_position)]
|
498 |
+
#if not DEBUG_TEACHER_FORCING:
|
499 |
+
# x, y = mouse_position
|
500 |
+
# x = x//8 * 8
|
501 |
+
# y = y // 8 * 8
|
502 |
+
# assert x % 8 == 0
|
503 |
+
# assert y % 8 == 0
|
504 |
+
# mouse_position = (x, y)
|
505 |
+
# #mouse_position = (x//8, y//8)
|
506 |
+
# previous_actions.append((action_type, mouse_position))
|
507 |
# Log the start time
|
508 |
start_time = time.time()
|
509 |
|
510 |
# Predict the next frame based on the previous frames and actions
|
511 |
+
#if DEBUG_TEACHER_FORCING:
|
512 |
+
# print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
|
513 |
+
print ('previous_actions', previous_actions)
|
514 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
515 |
# Load and append the corresponding ground truth image instead of model output
|
516 |
+
#print ('here4', len(previous_frames))
|
517 |
+
#if DEBUG_TEACHER_FORCING:
|
518 |
+
# img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
|
519 |
+
# previous_frames.append(np.array(img))
|
520 |
+
#else:
|
521 |
+
# assert False
|
522 |
+
# previous_frames.append(next_frame_append)
|
523 |
+
# pass
|
524 |
+
previous_frames = []
|
525 |
+
previous_actions = []
|
526 |
|
527 |
# Convert the numpy array to a base64 encoded image
|
528 |
img = Image.fromarray(next_frame)
|