xizaoqu
commited on
Commit
·
4652b57
1
Parent(s):
0256e9a
update
Browse files
algorithms/worldmem/df_video.py
CHANGED
@@ -804,9 +804,12 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
804 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
805 |
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
806 |
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
807 |
-
return first_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx
|
808 |
else:
|
809 |
last_frame = self_frames[-1].clone()
|
|
|
|
|
|
|
810 |
last_pose_condition = self_poses[-1].clone()
|
811 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
812 |
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
|
@@ -900,7 +903,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
900 |
|
901 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
902 |
|
903 |
-
return xs_pred[-1,0], self_frames, self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
|
904 |
|
905 |
|
906 |
def reset(self):
|
|
|
804 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
805 |
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
806 |
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
807 |
+
return first_frame.cpu(), self_frames.cpu(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
|
808 |
else:
|
809 |
last_frame = self_frames[-1].clone()
|
810 |
+
self_poses = self_poses.to(device)
|
811 |
+
self_memory_c2w = self_memory_c2w.to(device)
|
812 |
+
self_frame_idx = self_frame_idx.to(device)
|
813 |
last_pose_condition = self_poses[-1].clone()
|
814 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
815 |
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
|
|
|
903 |
|
904 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
905 |
|
906 |
+
return xs_pred[-1,0].cpu(), self_frames.cpu(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
|
907 |
|
908 |
|
909 |
def reset(self):
|