Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
021111c
1
Parent(s):
4406096
main.py
CHANGED
@@ -203,9 +203,11 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
203 |
# Prepare the image sequence for the model
|
204 |
assert len(initial_images) == 7
|
205 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
|
|
206 |
while len(image_sequence) < 7:
|
207 |
-
|
208 |
-
|
|
|
209 |
|
210 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
211 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
|
|
203 |
# Prepare the image sequence for the model
|
204 |
assert len(initial_images) == 7
|
205 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
206 |
+
i = 1
|
207 |
while len(image_sequence) < 7:
|
208 |
+
image_sequence.insert(0, initial_images[-i])
|
209 |
+
i += 1
|
210 |
+
#image_sequence.append(initial_images[len(image_sequence)])
|
211 |
|
212 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
213 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|