Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
2163e7f
1
Parent(s):
5670558
main.py
CHANGED
@@ -201,9 +201,11 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
201 |
initial_images = load_initial_images(width, height)
|
202 |
|
203 |
# Prepare the image sequence for the model
|
|
|
204 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
205 |
while len(image_sequence) < 7:
|
206 |
-
image_sequence.insert(0, initial_images[len(image_sequence)])
|
|
|
207 |
|
208 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
209 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
@@ -219,6 +221,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
219 |
|
220 |
# Process initial actions if there are not enough previous actions
|
221 |
while len(previous_actions) < 8:
|
|
|
222 |
x, y = map(int, initial_actions.pop(0).split(':'))
|
223 |
previous_actions.insert(0, ("N", unnorm_coords(x, y)))
|
224 |
prev_x = 0
|
|
|
201 |
initial_images = load_initial_images(width, height)
|
202 |
|
203 |
# Prepare the image sequence for the model
|
204 |
+
assert len(initial_images) == 7
|
205 |
image_sequence = previous_frames[-7:] # Take the last 7 frames
|
206 |
while len(image_sequence) < 7:
|
207 |
+
#image_sequence.insert(0, initial_images[len(image_sequence)])
|
208 |
+
image_sequence.append(initial_images[len(image_sequence)])
|
209 |
|
210 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
211 |
image_sequence_tensor = torch.from_numpy(normalize_images(image_sequence, target_range=(-1, 1)))
|
|
|
221 |
|
222 |
# Process initial actions if there are not enough previous actions
|
223 |
while len(previous_actions) < 8:
|
224 |
+
assert False
|
225 |
x, y = map(int, initial_actions.pop(0).split(':'))
|
226 |
previous_actions.insert(0, ("N", unnorm_coords(x, y)))
|
227 |
prev_x = 0
|