yuntian-deng commited on
Commit
33a9da7
·
1 Parent(s): bfa85db

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +2 -2
main.py CHANGED
@@ -51,8 +51,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
51
  image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
52
 
53
  # Convert the image sequence to a tensor and concatenate in the channel dimension
54
- image_sequence_tensor = torch.from_numpy(np.concatenate(image_sequence, axis=2)).float() / 127.5 - 1
55
- image_sequence_tensor = image_sequence_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
56
 
57
 
58
  # Prepare the prompt based on the previous actions
 
51
  image_sequence.insert(0, np.zeros((height, width, 3), dtype=np.uint8))
52
 
53
  # Convert the image sequence to a tensor and concatenate in the channel dimension
54
+ image_sequence_tensor = torch.from_numpy(np.stack(image_sequence)).float() / 127.5 - 1
55
+ image_sequence_tensor = image_sequence_tensor.to(device)
56
 
57
 
58
  # Prepare the prompt based on the previous actions