Spaces:
Runtime error
Runtime error
Commit
·
33a9da7
1
Parent(s):
bfa85db
Update main.py
Browse files
main.py
CHANGED
@@ -51,8 +51,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
51 |
image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
52 |
|
53 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
54 |
-
image_sequence_tensor = torch.from_numpy(np.
|
55 |
-
image_sequence_tensor = image_sequence_tensor.
|
56 |
|
57 |
|
58 |
# Prepare the prompt based on the previous actions
|
|
|
51 |
image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
52 |
|
53 |
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
54 |
+
image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).float() / 127.5 - 1
|
55 |
+
image_sequence_tensor = image_sequence_tensor.to(device)
|
56 |
|
57 |
|
58 |
# Prepare the prompt based on the previous actions
|