da03 commited on
Commit
1a5d2e7
·
1 Parent(s): 8f6c968
Files changed (1) hide show
  1. main.py +43 -33
main.py CHANGED
@@ -13,7 +13,7 @@ import os
13
  import time
14
 
15
  DEBUG = False
16
- DEBUG_TEACHER_FORCING = True
17
  app = FastAPI()
18
 
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
@@ -156,7 +156,7 @@ def load_initial_images(width, height):
156
  initial_images.append(np.array(img))
157
  else:
158
  #assert False
159
- for i in range(7):
160
  initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
161
  return initial_images
162
 
@@ -202,10 +202,10 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
202
  print ('length of previous_frames', len(previous_frames))
203
 
204
  # Prepare the image sequence for the model
205
- assert len(initial_images) == 7
206
- image_sequence = previous_frames[-7:] # Take the last 7 frames
207
  i = 1
208
- while len(image_sequence) < 7:
209
  image_sequence.insert(0, initial_images[-i])
210
  i += 1
211
  #image_sequence.append(initial_images[len(image_sequence)])
@@ -213,18 +213,23 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
213
  # Convert the image sequence to a tensor and concatenate in the channel dimension
214
  image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
215
  image_sequence_tensor = image_sequence_tensor.to(device)
 
 
 
 
 
216
 
217
  # Prepare the prompt based on the previous actions
218
  action_descriptions = []
219
  #initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
220
- initial_actions = ['0:0'] * 7
221
  #initial_actions = ['N N N N N : N N N N N'] * 7
222
  def unnorm_coords(x, y):
223
  return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
224
 
225
  # Process initial actions if there are not enough previous actions
226
- while len(previous_actions) < 8:
227
- assert False
228
  x, y = map(int, initial_actions.pop(0).split(':'))
229
  previous_actions.insert(0, ("N", unnorm_coords(x, y)))
230
  prev_x = 0
@@ -242,7 +247,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
242
  previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))]
243
  prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4'
244
  previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))]
245
- for action_type, pos in previous_actions[-8:]:
246
  #print ('here3', action_type, pos)
247
  if action_type == 'move':
248
  action_type = 'N'
@@ -287,14 +292,14 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
287
  else:
288
  assert False
289
 
290
- prompt = " ".join(action_descriptions[-8:])
291
  print(prompt)
292
  #prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
293
  #x, y, action_type = parse_action_string(action_descriptions[-1])
294
  #pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
295
  leftclick_maps = []
296
  pos_maps = []
297
- for j in range(1, 9):
298
  print ('fsfs', action_descriptions[-j])
299
  x, y, action_type = parse_action_string(action_descriptions[-j])
300
  pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
@@ -318,6 +323,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
318
  # Convert the generated frame to the correct format
319
  new_frame = new_frame.transpose(1, 2, 0)
320
  print (new_frame.max(), new_frame.min())
 
321
  new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
322
 
323
  # Draw the trace of previous actions
@@ -429,6 +435,7 @@ async def websocket_endpoint(websocket: WebSocket):
429
  #mouse_position = (x, y)
430
 
431
  #previous_actions.append((action_type, mouse_position))
 
432
  if not DEBUG_TEACHER_FORCING:
433
  previous_actions = []
434
 
@@ -448,6 +455,8 @@ async def websocket_endpoint(websocket: WebSocket):
448
  # Format action string
449
  previous_actions.append((action_type, (x*8, y*8)))
450
  try:
 
 
451
  while True:
452
  try:
453
  # Receive user input with a timeout
@@ -483,36 +492,37 @@ async def websocket_endpoint(websocket: WebSocket):
483
  x, y, action_type = parse_action_string(position)
484
  mouse_position = (x, y)
485
  previous_actions.append((action_type, mouse_position))
486
- if False:
487
  previous_actions.append((action_type, mouse_position))
488
  #previous_actions = [(action_type, mouse_position)]
489
- if not DEBUG_TEACHER_FORCING:
490
- x, y = mouse_position
491
- x = x//8 * 8
492
- y = y // 8 * 8
493
- assert x % 8 == 0
494
- assert y % 8 == 0
495
- mouse_position = (x, y)
496
- #mouse_position = (x//8, y//8)
497
- previous_actions.append((action_type, mouse_position))
498
  # Log the start time
499
  start_time = time.time()
500
 
501
  # Predict the next frame based on the previous frames and actions
502
- if DEBUG_TEACHER_FORCING:
503
- print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
504
-
505
  next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
506
  # Load and append the corresponding ground truth image instead of model output
507
- print ('here4', len(previous_frames))
508
- if DEBUG_TEACHER_FORCING:
509
- img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
510
- previous_frames.append(np.array(img))
511
- else:
512
- #assert False
513
- #previous_frames.append(next_frame_append)
514
- pass
515
- #previous_frames = []
 
516
 
517
  # Convert the numpy array to a base64 encoded image
518
  img = Image.fromarray(next_frame)
 
13
  import time
14
 
15
  DEBUG = False
16
+ DEBUG_TEACHER_FORCING = False
17
  app = FastAPI()
18
 
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
 
156
  initial_images.append(np.array(img))
157
  else:
158
  #assert False
159
+ for i in range(32):
160
  initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
161
  return initial_images
162
 
 
202
  print ('length of previous_frames', len(previous_frames))
203
 
204
  # Prepare the image sequence for the model
205
+ assert len(initial_images) == 32
206
+ image_sequence = previous_frames[-32:] # Take the last 7 frames
207
  i = 1
208
+ while len(image_sequence) < 32:
209
  image_sequence.insert(0, initial_images[-i])
210
  i += 1
211
  #image_sequence.append(initial_images[len(image_sequence)])
 
213
  # Convert the image sequence to a tensor and concatenate in the channel dimension
214
  image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
215
  image_sequence_tensor = image_sequence_tensor.to(device)
216
+ data_mean = -0.54
217
+ data_std = 6.78
218
+ data_min = -27.681446075439453
219
+ data_max = 30.854148864746094
220
+ image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
221
 
222
  # Prepare the prompt based on the previous actions
223
  action_descriptions = []
224
  #initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
225
+ initial_actions = ['0:0'] * 32
226
  #initial_actions = ['N N N N N : N N N N N'] * 7
227
  def unnorm_coords(x, y):
228
  return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
229
 
230
  # Process initial actions if there are not enough previous actions
231
+ while len(previous_actions) < 33:
232
+ #assert False
233
  x, y = map(int, initial_actions.pop(0).split(':'))
234
  previous_actions.insert(0, ("N", unnorm_coords(x, y)))
235
  prev_x = 0
 
247
  previous_actions = [('move', (16, 328)), ('move', (304, 96)), ('move', (240, 192)), ('move', (152, 56)), ('left_click', (288, 176)), ('left_click', (56, 376)), ('move', (136, 360)), ('move', (112, 48))]
248
  prompt = 'L + 0 0 5 6 : + 0 1 2 8 N + 0 4 0 0 : + 0 0 6 4 N + 0 5 0 4 : + 0 1 2 8 N + 0 4 2 4 : + 0 1 2 0 N + 0 3 2 0 : + 0 1 0 4 N + 0 2 8 0 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4 N + 0 2 7 2 : + 0 1 0 4'
249
  previous_actions = [('left_click', (56, 128)), ('left_click', (400, 64)), ('move', (504, 128)), ('move', (424, 120)), ('left_click', (320, 104)), ('left_click', (280, 104)), ('move', (272, 104)), ('move', (272, 104))]
250
+ for action_type, pos in previous_actions[-33:]:
251
  #print ('here3', action_type, pos)
252
  if action_type == 'move':
253
  action_type = 'N'
 
292
  else:
293
  assert False
294
 
295
+ prompt = " ".join(action_descriptions[-33:])
296
  print(prompt)
297
  #prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
298
  #x, y, action_type = parse_action_string(action_descriptions[-1])
299
  #pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
300
  leftclick_maps = []
301
  pos_maps = []
302
+ for j in range(1, 34):
303
  print ('fsfs', action_descriptions[-j])
304
  x, y, action_type = parse_action_string(action_descriptions[-j])
305
  pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
 
323
  # Convert the generated frame to the correct format
324
  new_frame = new_frame.transpose(1, 2, 0)
325
  print (new_frame.max(), new_frame.min())
326
+ new_frame = new_frame * data_std + data_mean
327
  new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
328
 
329
  # Draw the trace of previous actions
 
435
  #mouse_position = (x, y)
436
 
437
  #previous_actions.append((action_type, mouse_position))
438
+
439
  if not DEBUG_TEACHER_FORCING:
440
  previous_actions = []
441
 
 
455
  # Format action string
456
  previous_actions.append((action_type, (x*8, y*8)))
457
  try:
458
+ previous_actions = []
459
+ previous_frames = []
460
  while True:
461
  try:
462
  # Receive user input with a timeout
 
492
  x, y, action_type = parse_action_string(position)
493
  mouse_position = (x, y)
494
  previous_actions.append((action_type, mouse_position))
495
+ if True:
496
  previous_actions.append((action_type, mouse_position))
497
  #previous_actions = [(action_type, mouse_position)]
498
+ #if not DEBUG_TEACHER_FORCING:
499
+ # x, y = mouse_position
500
+ # x = x//8 * 8
501
+ # y = y // 8 * 8
502
+ # assert x % 8 == 0
503
+ # assert y % 8 == 0
504
+ # mouse_position = (x, y)
505
+ # #mouse_position = (x//8, y//8)
506
+ # previous_actions.append((action_type, mouse_position))
507
  # Log the start time
508
  start_time = time.time()
509
 
510
  # Predict the next frame based on the previous frames and actions
511
+ #if DEBUG_TEACHER_FORCING:
512
+ # print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
513
+ print ('previous_actions', previous_actions)
514
  next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
515
  # Load and append the corresponding ground truth image instead of model output
516
+ #print ('here4', len(previous_frames))
517
+ #if DEBUG_TEACHER_FORCING:
518
+ # img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
519
+ # previous_frames.append(np.array(img))
520
+ #else:
521
+ # assert False
522
+ # previous_frames.append(next_frame_append)
523
+ # pass
524
+ previous_frames = []
525
+ previous_actions = []
526
 
527
  # Convert the numpy array to a base64 encoded image
528
  img = Image.fromarray(next_frame)