da03 commited on
Commit
313bb52
·
1 Parent(s): 94b146f
Files changed (1) hide show
  1. main.py +33 -20
main.py CHANGED
@@ -203,23 +203,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
203
  prev_y = 0
204
  #print ('here')
205
 
206
- if DEBUG_TEACHER_FORCING:
207
- #print ('here2')
208
- # Use the predefined actions for image_81
209
- debug_actions = [
210
- 'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
211
- 'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
212
- 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
213
- 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
214
- 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
215
- 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
216
- 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
217
- 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1'
218
- ]
219
- previous_actions = []
220
- for action in debug_actions[-8:]:
221
- x, y, action_type = parse_action_string(action)
222
- previous_actions.append((action_type, (x, y)))
223
 
224
  for action_type, pos in previous_actions: #[-8:]:
225
  print ('here3', action_type, pos)
@@ -302,6 +286,31 @@ async def websocket_endpoint(websocket: WebSocket):
302
  positions = ['815~335']
303
  #positions = ['787~342']
304
  positions = ['300~800']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  #positions = positions[:4]
306
  try:
307
  while True:
@@ -325,9 +334,13 @@ async def websocket_endpoint(websocket: WebSocket):
325
  #mouse_position = position.split('~')
326
  #mouse_position = [int(item) for item in mouse_position]
327
  #mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
328
-
 
 
 
 
329
  previous_actions.append((action_type, mouse_position))
330
- previous_actions = [(action_type, mouse_position)]
331
 
332
  # Log the start time
333
  start_time = time.time()
@@ -336,7 +349,7 @@ async def websocket_endpoint(websocket: WebSocket):
336
  next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
337
  # Load and append the corresponding ground truth image instead of model output
338
  #img = Image.open(f"image_{len(previous_frames)%7}.png")
339
- #previous_frames.append(next_frame_append)
340
 
341
  # Convert the numpy array to a base64 encoded image
342
  img = Image.fromarray(next_frame)
 
203
  prev_y = 0
204
  #print ('here')
205
 
206
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  for action_type, pos in previous_actions: #[-8:]:
209
  print ('here3', action_type, pos)
 
286
  positions = ['815~335']
287
  #positions = ['787~342']
288
  positions = ['300~800']
289
+
290
+ if DEBUG_TEACHER_FORCING:
291
+ #print ('here2')
292
+ # Use the predefined actions for image_81
293
+ debug_actions = [
294
+ 'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
295
+ 'N + 0 8 9 0 : + 0 4 7 5', 'N + 0 9 0 8 : + 0 4 8 8',
296
+ 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
297
+ 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
298
+ 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
299
+ 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
300
+ 'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
301
+ 'N + 0 9 2 7 : + 0 5 0 1', #'N + 0 9 2 7 : + 0 5 0 1'
302
+ ]
303
+ previous_actions = []
304
+ for action in debug_actions[-8:]:
305
+ x, y, action_type = parse_action_string(action)
306
+ previous_actions.append((action_type, (x, y)))
307
+ positions = [
308
+ 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 1 8 : + 0 4 9 2',
309
+ 'N + 0 9 0 8 : + 0 4 8 3', 'N + 0 8 9 8 : + 0 4 7 4',
310
+ 'N + 0 8 8 9 : + 0 4 6 5', 'N + 0 8 8 0 : + 0 4 5 6',
311
+ 'N + 0 8 7 0 : + 0 4 4 7', 'N + 0 8 6 0 : + 0 4 3 8',
312
+ 'N + 0 8 5 1 : + 0 4 2 9', 'N + 0 8 4 2 : + 0 4 2 0',
313
+ 'N + 0 8 3 2 : + 0 4 1 1', 'N + 0 8 3 2 : + 0 4 1 1']
314
  #positions = positions[:4]
315
  try:
316
  while True:
 
334
  #mouse_position = position.split('~')
335
  #mouse_position = [int(item) for item in mouse_position]
336
  #mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
337
+ if DEBUG_TEACHER_FORCING:
338
+ position = positions[0]
339
+ positions = positions[1:]
340
+ x, y, action_type = parse_action_string(position)
341
+ mouse_position = (x, y)
342
  previous_actions.append((action_type, mouse_position))
343
+ #previous_actions = [(action_type, mouse_position)]
344
 
345
  # Log the start time
346
  start_time = time.time()
 
349
  next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
350
  # Load and append the corresponding ground truth image instead of model output
351
  #img = Image.open(f"image_{len(previous_frames)%7}.png")
352
+ previous_frames.append(next_frame_append)
353
 
354
  # Convert the numpy array to a base64 encoded image
355
  img = Image.fromarray(next_frame)