yuntian-deng commited on
Commit
bfa85db
·
1 Parent(s): 7a83d85

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -3
main.py CHANGED
@@ -50,9 +50,10 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
50
  while len(image_sequence) < 7:
51
  image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
52
 
53
- # Convert the image sequence to a tensor
54
- image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).permute(0, 3, 1, 2).float() / 127.5 - 1
55
- image_sequence_tensor = image_sequence_tensor.unsqueeze(0).to(device)
 
56
 
57
  # Prepare the prompt based on the previous actions
58
  action_descriptions = [f"{pos[0]}:{pos[1]}" for _, pos in previous_actions[-7:]]
 
50
  while len(image_sequence) < 7:
51
  image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
52
 
53
+ # Convert the image sequence to a tensor and concatenate in the channel dimension
54
+ image_sequence_tensor = torch.from_numpy(np.concatenate(image_sequence, axis=2)).float() / 127.5 - 1
55
+ image_sequence_tensor = image_sequence_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
56
+
57
 
58
  # Prepare the prompt based on the previous actions
59
  action_descriptions = [f"{pos[0]}:{pos[1]}" for _, pos in previous_actions[-7:]]