da03 commited on
Commit
2163e7f
·
1 Parent(s): 5670558
Files changed (1) hide show
  1. main.py +4 -1
main.py CHANGED
@@ -201,9 +201,11 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
201
  initial_images = load_initial_images(width, height)
202
 
203
  # Prepare the image sequence for the model
 
204
  image_sequence = previous_frames[-7:] # Take the last 7 frames
205
  while len(image_sequence) < 7:
206
- image_sequence.insert(0, initial_images[len(image_sequence)])
 
207
 
208
  # Convert the image sequence to a tensor and concatenate in the channel dimension
209
  image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
@@ -219,6 +221,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
219
 
220
  # Process initial actions if there are not enough previous actions
221
  while len(previous_actions) < 8:
 
222
  x, y = map(int, initial_actions.pop(0).split(':'))
223
  previous_actions.insert(0, ("N", unnorm_coords(x, y)))
224
  prev_x = 0
 
201
  initial_images = load_initial_images(width, height)
202
 
203
  # Prepare the image sequence for the model
204
+ assert len(initial_images) == 7
205
  image_sequence = previous_frames[-7:] # Take the last 7 frames
206
  while len(image_sequence) < 7:
207
+ #image_sequence.insert(0, initial_images[len(image_sequence)])
208
+ image_sequence.append(initial_images[len(image_sequence)])
209
 
210
  # Convert the image sequence to a tensor and concatenate in the channel dimension
211
  image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
 
221
 
222
  # Process initial actions if there are not enough previous actions
223
  while len(previous_actions) < 8:
224
+ assert False
225
  x, y = map(int, initial_actions.pop(0).split(':'))
226
  previous_actions.insert(0, ("N", unnorm_coords(x, y)))
227
  prev_x = 0