da03 commited on
Commit
021111c
·
1 Parent(s): 4406096
Files changed (1) hide show
  1. main.py +4 -2
main.py CHANGED
@@ -203,9 +203,11 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
203
  # Prepare the image sequence for the model
204
  assert len(initial_images) == 7
205
  image_sequence = previous_frames[-7:] # Take the last 7 frames
 
206
  while len(image_sequence) < 7:
207
- #image_sequence.insert(0, initial_images[len(image_sequence)])
208
- image_sequence.append(initial_images[len(image_sequence)])
 
209
 
210
  # Convert the image sequence to a tensor and concatenate in the channel dimension
211
  image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
 
203
  # Prepare the image sequence for the model
204
  assert len(initial_images) == 7
205
  image_sequence = previous_frames[-7:] # Take the last 7 frames
206
+ i = 1
207
  while len(image_sequence) < 7:
208
+ image_sequence.insert(0, initial_images[-i])
209
+ i += 1
210
+ #image_sequence.append(initial_images[len(image_sequence)])
211
 
212
  # Convert the image sequence to a tensor and concatenate in the channel dimension
213
  image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))