Spaces:
Runtime error
Runtime error
Commit
·
bfa85db
1
Parent(s):
7a83d85
Update main.py
Browse files
main.py
CHANGED
@@ -50,9 +50,10 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
50 |
while len(image_sequence) < 7:
|
51 |
image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
52 |
|
53 |
-
# Convert the image sequence to a tensor
|
54 |
-
image_sequence_tensor = torch.from_numpy(np.
|
55 |
-
image_sequence_tensor = image_sequence_tensor.unsqueeze(0).to(device)
|
|
|
56 |
|
57 |
# Prepare the prompt based on the previous actions
|
58 |
action_descriptions = [f"{pos[0]}:{pos[1]}" for _, pos in previous_actions[-7:]]
|
|
|
50 |
while len(image_sequence) < 7:
|
51 |
image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
|
52 |
|
53 |
+
# Convert the image sequence to a tensor and concatenate in the channel dimension
|
54 |
+
image_sequence_tensor = torch.from_numpy(np.concatenate(image_sequence, axis=2)).float() / 127.5 - 1
|
55 |
+
image_sequence_tensor = image_sequence_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
|
56 |
+
|
57 |
|
58 |
# Prepare the prompt based on the previous actions
|
59 |
action_descriptions = [f"{pos[0]}:{pos[1]}" for _, pos in previous_actions[-7:]]
|