xizaoqu commited on
Commit
4652b57
·
1 Parent(s): 0256e9a
Files changed (1) hide show
  1. algorithms/worldmem/df_video.py +5 -2
algorithms/worldmem/df_video.py CHANGED
@@ -804,9 +804,12 @@ class WorldMemMinecraft(DiffusionForcingBase):
804
  new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
805
  self_memory_c2w = new_c2w_mat[None, None].to(device)
806
  self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
807
- return first_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx
808
  else:
809
  last_frame = self_frames[-1].clone()
 
 
 
810
  last_pose_condition = self_poses[-1].clone()
811
  last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
812
  new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
@@ -900,7 +903,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
900
 
901
  xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
902
 
903
- return xs_pred[-1,0], self_frames, self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
904
 
905
 
906
  def reset(self):
 
804
  new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
805
  self_memory_c2w = new_c2w_mat[None, None].to(device)
806
  self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
807
+ return first_frame.cpu(), self_frames.cpu(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
808
  else:
809
  last_frame = self_frames[-1].clone()
810
+ self_poses = self_poses.to(device)
811
+ self_memory_c2w = self_memory_c2w.to(device)
812
+ self_frame_idx = self_frame_idx.to(device)
813
  last_pose_condition = self_poses[-1].clone()
814
  last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
815
  new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
 
903
 
904
  xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
905
 
906
+ return xs_pred[-1,0].cpu(), self_frames.cpu(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
907
 
908
 
909
  def reset(self):