da03 commited on
Commit
98a4a00
·
1 Parent(s): 46f6899
Files changed (1) hide show
  1. main.py +29 -6
main.py CHANGED
@@ -131,10 +131,11 @@ def load_initial_images(width, height):
131
  initial_images = []
132
  if DEBUG_TEACHER_FORCING:
133
  # Load the previous 7 frames for image_81
134
- for i in range(75, 82): # Load images 74-80
135
  img = Image.open(f"record_100/image_{i}.png").resize((width, height))
136
  initial_images.append(np.array(img))
137
  else:
 
138
  for i in range(7):
139
  initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
140
  return initial_images
@@ -229,7 +230,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
229
  x, y = pos
230
  #norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
231
  #norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
232
- if False and DEBUG_TEACHER_FORCING:
233
  norm_x = x + (1920 - 512) / 2
234
  norm_y = y + (1080 - 512) / 2
235
  #if DEBUG:
@@ -306,6 +307,16 @@ async def websocket_endpoint(websocket: WebSocket):
306
  'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
307
  'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
308
  ]
 
 
 
 
 
 
 
 
 
 
309
  previous_actions = []
310
  for action in debug_actions[-8:]:
311
  x, y, action_type = parse_action_string(action)
@@ -316,7 +327,18 @@ async def websocket_endpoint(websocket: WebSocket):
316
  'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
317
  'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
318
  'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
319
- 'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1']
 
 
 
 
 
 
 
 
 
 
 
320
  #positions = positions[:4]
321
  try:
322
  while True:
@@ -340,12 +362,13 @@ async def websocket_endpoint(websocket: WebSocket):
340
  #mouse_position = position.split('~')
341
  #mouse_position = [int(item) for item in mouse_position]
342
  #mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
343
- if False and DEBUG_TEACHER_FORCING:
344
  position = positions[0]
345
  positions = positions[1:]
346
  x, y, action_type = parse_action_string(position)
347
  mouse_position = (x, y)
348
- previous_actions.append((action_type, mouse_position))
 
349
  #previous_actions = [(action_type, mouse_position)]
350
 
351
  # Log the start time
@@ -361,7 +384,7 @@ async def websocket_endpoint(websocket: WebSocket):
361
  if False and DEBUG_TEACHER_FORCING:
362
  img = Image.open(f"record_100/image_{82+len(previous_frames)}.png")
363
  previous_frames.append(img)
364
- else:
365
  previous_frames.append(next_frame_append)
366
 
367
  # Convert the numpy array to a base64 encoded image
 
131
  initial_images = []
132
  if DEBUG_TEACHER_FORCING:
133
  # Load the previous 7 frames for image_81
134
+ for i in range(222-7, 222): # Load images 74-80
135
  img = Image.open(f"record_100/image_{i}.png").resize((width, height))
136
  initial_images.append(np.array(img))
137
  else:
138
+ assert False
139
  for i in range(7):
140
  initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
141
  return initial_images
 
230
  x, y = pos
231
  #norm_x = int(round(x / 256 * 1024)) #x + (1920 - 256) / 2
232
  #norm_y = int(round(y / 256 * 640)) #y + (1080 - 256) / 2
233
+ if True and DEBUG_TEACHER_FORCING:
234
  norm_x = x + (1920 - 512) / 2
235
  norm_y = y + (1080 - 512) / 2
236
  #if DEBUG:
 
307
  'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
308
  'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
309
  ]
310
+ debug_actions = [
311
+ 'N + 1 1 6 5 : + 0 4 4 3', 'N + 1 1 7 0 : + 0 4 1 8',
312
+ 'N + 1 1 7 5 : + 0 3 9 4', 'N + 1 1 8 1 : + 0 3 7 0',
313
+ 'N + 1 1 8 4 : + 0 3 5 8', 'N + 1 1 8 9 : + 0 3 3 3',
314
+ 'N + 1 1 9 4 : + 0 3 0 9', 'N + 1 1 9 7 : + 0 2 9 7',
315
+ 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
316
+ 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
317
+ 'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
318
+ 'N + 1 1 9 7 : + 0 2 9 7'
319
+ ]
320
  previous_actions = []
321
  for action in debug_actions[-8:]:
322
  x, y, action_type = parse_action_string(action)
 
327
  'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
328
  'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
329
  'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
330
+ 'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1'
331
+ ]
332
+ positions = [
333
+ #'L + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
334
+ 'N + 1 1 9 7 : + 0 2 9 7', 'N + 1 1 9 7 : + 0 2 9 7',
335
+ 'N + 1 1 7 9 : + 0 3 0 3', 'N + 1 1 4 2 : + 0 3 1 4',
336
+ 'N + 1 1 0 6 : + 0 3 2 6', 'N + 1 0 6 9 : + 0 3 3 7',
337
+ 'N + 1 0 5 1 : + 0 3 4 3', 'N + 1 0 1 4 : + 0 3 5 4',
338
+ 'N + 0 9 7 8 : + 0 3 6 5', 'N + 0 9 4 2 : + 0 3 7 7',
339
+ 'N + 0 9 0 5 : + 0 3 8 8', 'N + 0 8 6 8 : + 0 4 0 0',
340
+ 'N + 0 8 3 2 : + 0 4 1 1'
341
+ ]
342
  #positions = positions[:4]
343
  try:
344
  while True:
 
362
  #mouse_position = position.split('~')
363
  #mouse_position = [int(item) for item in mouse_position]
364
  #mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
365
+ if True and DEBUG_TEACHER_FORCING:
366
  position = positions[0]
367
  positions = positions[1:]
368
  x, y, action_type = parse_action_string(position)
369
  mouse_position = (x, y)
370
+ if False:
371
+ previous_actions.append((action_type, mouse_position))
372
  #previous_actions = [(action_type, mouse_position)]
373
 
374
  # Log the start time
 
384
  if False and DEBUG_TEACHER_FORCING:
385
  img = Image.open(f"record_100/image_{82+len(previous_frames)}.png")
386
  previous_frames.append(img)
387
+ elif False:
388
  previous_frames.append(next_frame_append)
389
 
390
  # Convert the numpy array to a base64 encoded image