xizaoqu
commited on
Commit
·
f373311
1
Parent(s):
eda3a61
update
Browse files- algorithms/worldmem/df_video.py +29 -26
- app.py +23 -6
algorithms/worldmem/df_video.py
CHANGED
@@ -354,10 +354,10 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
354 |
|
355 |
self.is_interactive = cfg.get("is_interactive", False)
|
356 |
if self.is_interactive:
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
|
362 |
super().__init__(cfg)
|
363 |
|
@@ -791,21 +791,23 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
791 |
return
|
792 |
|
793 |
@torch.no_grad()
|
794 |
-
def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device
|
|
|
|
|
795 |
condition_similar_length = self.condition_similar_length
|
796 |
|
797 |
-
if
|
798 |
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
799 |
-
|
800 |
self.actions = curr_actions[None, None].to(device)
|
801 |
-
|
802 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
803 |
-
|
804 |
-
|
805 |
-
return first_frame
|
806 |
else:
|
807 |
-
last_frame =
|
808 |
-
last_pose_condition =
|
809 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
810 |
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
|
811 |
|
@@ -814,15 +816,15 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
814 |
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
815 |
new_pose_condition[:,3:] %= 360
|
816 |
self.actions = torch.cat([self.actions, curr_actions[None, None].to(device)])
|
817 |
-
|
818 |
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
819 |
-
|
820 |
-
|
821 |
|
822 |
conditions = self.actions.clone()
|
823 |
-
pose_conditions =
|
824 |
-
c2w_mat =
|
825 |
-
frame_idx =
|
826 |
|
827 |
|
828 |
curr_frame = 0
|
@@ -831,7 +833,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
831 |
n_frames = curr_frame + horizon
|
832 |
# context
|
833 |
n_context_frames = context_frames_idx // self.frame_stack
|
834 |
-
xs_pred =
|
835 |
curr_frame += n_context_frames
|
836 |
|
837 |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
@@ -894,14 +896,15 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
894 |
curr_frame += horizon
|
895 |
pbar.update(horizon)
|
896 |
|
897 |
-
|
898 |
|
899 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
900 |
-
|
|
|
901 |
|
902 |
|
903 |
def reset(self):
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
|
|
354 |
|
355 |
self.is_interactive = cfg.get("is_interactive", False)
|
356 |
if self.is_interactive:
|
357 |
+
self_frames = None
|
358 |
+
self_poses = None
|
359 |
+
self_memory_c2w = None
|
360 |
+
self_frame_idx = None
|
361 |
|
362 |
super().__init__(cfg)
|
363 |
|
|
|
791 |
return
|
792 |
|
793 |
@torch.no_grad()
|
794 |
+
def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
|
795 |
+
self_frames, self_poses, self_memory_c2w, self_frame_idx):
|
796 |
+
|
797 |
condition_similar_length = self.condition_similar_length
|
798 |
|
799 |
+
if self_frames is None:
|
800 |
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
801 |
+
self_frames = first_frame_encode.cpu()
|
802 |
self.actions = curr_actions[None, None].to(device)
|
803 |
+
self_poses = first_pose[None, None].to(device)
|
804 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
805 |
+
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
806 |
+
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
807 |
+
return first_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx
|
808 |
else:
|
809 |
+
last_frame = self_frames[-1].clone()
|
810 |
+
last_pose_condition = self_poses[-1].clone()
|
811 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
812 |
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
|
813 |
|
|
|
816 |
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
817 |
new_pose_condition[:,3:] %= 360
|
818 |
self.actions = torch.cat([self.actions, curr_actions[None, None].to(device)])
|
819 |
+
self_poses = torch.cat([self_poses, new_pose_condition[None].to(device)])
|
820 |
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
821 |
+
self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None].to(device)])
|
822 |
+
self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
|
823 |
|
824 |
conditions = self.actions.clone()
|
825 |
+
pose_conditions = self_poses.clone()
|
826 |
+
c2w_mat = self_memory_c2w .clone()
|
827 |
+
frame_idx = self_frame_idx.clone()
|
828 |
|
829 |
|
830 |
curr_frame = 0
|
|
|
833 |
n_frames = curr_frame + horizon
|
834 |
# context
|
835 |
n_context_frames = context_frames_idx // self.frame_stack
|
836 |
+
xs_pred = self_frames[:n_context_frames].clone()
|
837 |
curr_frame += n_context_frames
|
838 |
|
839 |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
|
|
896 |
curr_frame += horizon
|
897 |
pbar.update(horizon)
|
898 |
|
899 |
+
self_frames = torch.cat([self_frames, xs_pred[n_context_frames:]])
|
900 |
|
901 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
902 |
+
|
903 |
+
return xs_pred[-1,0], self_frames, self_poses, self_memory_c2w, self_frame_idx
|
904 |
|
905 |
|
906 |
def reset(self):
|
907 |
+
self_frames = None
|
908 |
+
self_poses = None
|
909 |
+
self_memory_c2w = None
|
910 |
+
self_frame_idx = None
|
app.py
CHANGED
@@ -182,15 +182,28 @@ poses = torch.zeros((1, 5))
|
|
182 |
|
183 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
@spaces.GPU()
|
186 |
def run_interactive(first_frame, action, first_pose, curr_frame, device):
|
187 |
-
global
|
188 |
-
|
|
|
|
|
|
|
|
|
189 |
action,
|
190 |
first_pose,
|
191 |
curr_frame,
|
192 |
-
device=device
|
193 |
-
|
|
|
|
|
|
|
194 |
return new_frame
|
195 |
|
196 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
@@ -201,7 +214,7 @@ def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
|
201 |
return sampling_timesteps_state
|
202 |
|
203 |
def generate(keys):
|
204 |
-
print("algo frame:", len(worldmem.frames))
|
205 |
actions = parse_input_to_tensor(keys)
|
206 |
global input_history
|
207 |
global memory_curr_frame
|
@@ -236,7 +249,11 @@ def reset():
|
|
236 |
global input_history
|
237 |
global memory_frames
|
238 |
|
239 |
-
worldmem.reset()
|
|
|
|
|
|
|
|
|
240 |
memory_frames = []
|
241 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
242 |
memory_curr_frame = 0
|
|
|
182 |
|
183 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
184 |
|
185 |
+
self_frames = None
|
186 |
+
self_poses = None
|
187 |
+
self_memory_c2w = None
|
188 |
+
self_frame_idx = None
|
189 |
+
|
190 |
+
|
191 |
@spaces.GPU()
|
192 |
def run_interactive(first_frame, action, first_pose, curr_frame, device):
|
193 |
+
global self_frames
|
194 |
+
global self_poses
|
195 |
+
global self_memory_c2w
|
196 |
+
global self_frame_idx
|
197 |
+
|
198 |
+
new_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
|
199 |
action,
|
200 |
first_pose,
|
201 |
curr_frame,
|
202 |
+
device=device,
|
203 |
+
self_frames=self_frames,
|
204 |
+
self_poses=self_poses,
|
205 |
+
self_memory_c2w=self_memory_c2w,
|
206 |
+
self_frame_idx=self_frame_idx)
|
207 |
return new_frame
|
208 |
|
209 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
|
|
214 |
return sampling_timesteps_state
|
215 |
|
216 |
def generate(keys):
|
217 |
+
# print("algo frame:", len(worldmem.frames))
|
218 |
actions = parse_input_to_tensor(keys)
|
219 |
global input_history
|
220 |
global memory_curr_frame
|
|
|
249 |
global input_history
|
250 |
global memory_frames
|
251 |
|
252 |
+
# worldmem.reset()
|
253 |
+
self_frames = None
|
254 |
+
self_poses = None
|
255 |
+
self_memory_c2w = None
|
256 |
+
self_frame_idx = None
|
257 |
memory_frames = []
|
258 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
259 |
memory_curr_frame = 0
|