da03 commited on
Commit
a9363f4
·
1 Parent(s): 09a0506
Files changed (1) hide show
  1. main.py +4 -4
main.py CHANGED
@@ -131,7 +131,7 @@ def load_initial_images(width, height):
131
  initial_images = []
132
  if DEBUG_TEACHER_FORCING:
133
  # Load the previous 7 frames for image_81
134
- for i in range(74, 81): # Load images 74-80
135
  img = Image.open(f"record_100/image_{i}.png").resize((width, height))
136
  initial_images.append(np.array(img))
137
  else:
@@ -201,10 +201,10 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
201
  previous_actions.insert(0, ("N", unnorm_coords(x, y)))
202
  prev_x = 0
203
  prev_y = 0
204
- print ('here')
205
 
206
  if DEBUG_TEACHER_FORCING:
207
- print ('here2')
208
  # Use the predefined actions for image_81
209
  debug_actions = [
210
  'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
@@ -214,7 +214,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
214
  'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
215
  'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
216
  'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
217
- 'N + 0 9 2 7 : + 0 5 0 1'
218
  ]
219
  previous_actions = []
220
  for action in debug_actions[-8:]:
 
131
  initial_images = []
132
  if DEBUG_TEACHER_FORCING:
133
  # Load the previous 7 frames for image_81
134
+ for i in range(75, 82): # Load images 74-80
135
  img = Image.open(f"record_100/image_{i}.png").resize((width, height))
136
  initial_images.append(np.array(img))
137
  else:
 
201
  previous_actions.insert(0, ("N", unnorm_coords(x, y)))
202
  prev_x = 0
203
  prev_y = 0
204
+ #print ('here')
205
 
206
  if DEBUG_TEACHER_FORCING:
207
+ #print ('here2')
208
  # Use the predefined actions for image_81
209
  debug_actions = [
210
  'N + 0 8 5 3 : + 0 4 5 0', 'N + 0 8 7 1 : + 0 4 6 3',
 
214
  'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
215
  'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
216
  'L + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1',
217
+ 'N + 0 9 2 7 : + 0 5 0 1', 'N + 0 9 2 7 : + 0 5 0 1'
218
  ]
219
  previous_actions = []
220
  for action in debug_actions[-8:]: