xizaoqu commited on
Commit
f373311
·
1 Parent(s): eda3a61
Files changed (2) hide show
  1. algorithms/worldmem/df_video.py +29 -26
  2. app.py +23 -6
algorithms/worldmem/df_video.py CHANGED
@@ -354,10 +354,10 @@ class WorldMemMinecraft(DiffusionForcingBase):
354
 
355
  self.is_interactive = cfg.get("is_interactive", False)
356
  if self.is_interactive:
357
- self.frames = None
358
- self.poses = None
359
- self.memory_c2w = None
360
- self.frame_idx = None
361
 
362
  super().__init__(cfg)
363
 
@@ -791,21 +791,23 @@ class WorldMemMinecraft(DiffusionForcingBase):
791
  return
792
 
793
  @torch.no_grad()
794
- def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device):
 
 
795
  condition_similar_length = self.condition_similar_length
796
 
797
- if self.frames is None:
798
  first_frame_encode = self.encode(first_frame[None, None].to(device))
799
- self.frames = first_frame_encode.cpu()
800
  self.actions = curr_actions[None, None].to(device)
801
- self.poses = first_pose[None, None].to(device)
802
  new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
803
- self.memory_c2w = new_c2w_mat[None, None].to(device)
804
- self.frame_idx = torch.tensor([[context_frames_idx]]).to(device)
805
- return first_frame
806
  else:
807
- last_frame = self.frames[-1].clone()
808
- last_pose_condition = self.poses[-1].clone()
809
  last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
810
  new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
811
 
@@ -814,15 +816,15 @@ class WorldMemMinecraft(DiffusionForcingBase):
814
  new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
815
  new_pose_condition[:,3:] %= 360
816
  self.actions = torch.cat([self.actions, curr_actions[None, None].to(device)])
817
- self.poses = torch.cat([self.poses, new_pose_condition[None].to(device)])
818
  new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
819
- self.memory_c2w = torch.cat([self.memory_c2w, new_c2w_mat[None].to(device)])
820
- self.frame_idx = torch.cat([self.frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
821
 
822
  conditions = self.actions.clone()
823
- pose_conditions = self.poses.clone()
824
- c2w_mat = self.memory_c2w .clone()
825
- frame_idx = self.frame_idx.clone()
826
 
827
 
828
  curr_frame = 0
@@ -831,7 +833,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
831
  n_frames = curr_frame + horizon
832
  # context
833
  n_context_frames = context_frames_idx // self.frame_stack
834
- xs_pred = self.frames[:n_context_frames].clone()
835
  curr_frame += n_context_frames
836
 
837
  pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
@@ -894,14 +896,15 @@ class WorldMemMinecraft(DiffusionForcingBase):
894
  curr_frame += horizon
895
  pbar.update(horizon)
896
 
897
- self.frames = torch.cat([self.frames, xs_pred[n_context_frames:]])
898
 
899
  xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
900
- return xs_pred[-1,0]
 
901
 
902
 
903
  def reset(self):
904
- self.frames = None
905
- self.poses = None
906
- self.memory_c2w = None
907
- self.frame_idx = None
 
354
 
355
  self.is_interactive = cfg.get("is_interactive", False)
356
  if self.is_interactive:
357
+ self_frames = None
358
+ self_poses = None
359
+ self_memory_c2w = None
360
+ self_frame_idx = None
361
 
362
  super().__init__(cfg)
363
 
 
791
  return
792
 
793
  @torch.no_grad()
794
+ def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
795
+ self_frames, self_poses, self_memory_c2w, self_frame_idx):
796
+
797
  condition_similar_length = self.condition_similar_length
798
 
799
+ if self_frames is None:
800
  first_frame_encode = self.encode(first_frame[None, None].to(device))
801
+ self_frames = first_frame_encode.cpu()
802
  self.actions = curr_actions[None, None].to(device)
803
+ self_poses = first_pose[None, None].to(device)
804
  new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
805
+ self_memory_c2w = new_c2w_mat[None, None].to(device)
806
+ self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
807
+ return first_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx
808
  else:
809
+ last_frame = self_frames[-1].clone()
810
+ last_pose_condition = self_poses[-1].clone()
811
  last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
812
  new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
813
 
 
816
  new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
817
  new_pose_condition[:,3:] %= 360
818
  self.actions = torch.cat([self.actions, curr_actions[None, None].to(device)])
819
+ self_poses = torch.cat([self_poses, new_pose_condition[None].to(device)])
820
  new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
821
+ self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None].to(device)])
822
+ self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
823
 
824
  conditions = self.actions.clone()
825
+ pose_conditions = self_poses.clone()
826
+ c2w_mat = self_memory_c2w .clone()
827
+ frame_idx = self_frame_idx.clone()
828
 
829
 
830
  curr_frame = 0
 
833
  n_frames = curr_frame + horizon
834
  # context
835
  n_context_frames = context_frames_idx // self.frame_stack
836
+ xs_pred = self_frames[:n_context_frames].clone()
837
  curr_frame += n_context_frames
838
 
839
  pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
 
896
  curr_frame += horizon
897
  pbar.update(horizon)
898
 
899
+ self_frames = torch.cat([self_frames, xs_pred[n_context_frames:]])
900
 
901
  xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
902
+
903
+ return xs_pred[-1,0], self_frames, self_poses, self_memory_c2w, self_frame_idx
904
 
905
 
906
  def reset(self):
907
+ self_frames = None
908
+ self_poses = None
909
+ self_memory_c2w = None
910
+ self_frame_idx = None
app.py CHANGED
@@ -182,15 +182,28 @@ poses = torch.zeros((1, 5))
182
 
183
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
184
 
 
 
 
 
 
 
185
  @spaces.GPU()
186
  def run_interactive(first_frame, action, first_pose, curr_frame, device):
187
- global worldmem
188
- new_frame = worldmem.interactive(first_frame,
 
 
 
 
189
  action,
190
  first_pose,
191
  curr_frame,
192
- device=device)
193
- print("algo frame:", len(worldmem.frames))
 
 
 
194
  return new_frame
195
 
196
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
@@ -201,7 +214,7 @@ def set_denoising_steps(denoising_steps, sampling_timesteps_state):
201
  return sampling_timesteps_state
202
 
203
  def generate(keys):
204
- print("algo frame:", len(worldmem.frames))
205
  actions = parse_input_to_tensor(keys)
206
  global input_history
207
  global memory_curr_frame
@@ -236,7 +249,11 @@ def reset():
236
  global input_history
237
  global memory_frames
238
 
239
- worldmem.reset()
 
 
 
 
240
  memory_frames = []
241
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
242
  memory_curr_frame = 0
 
182
 
183
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
184
 
185
+ self_frames = None
186
+ self_poses = None
187
+ self_memory_c2w = None
188
+ self_frame_idx = None
189
+
190
+
191
  @spaces.GPU()
192
  def run_interactive(first_frame, action, first_pose, curr_frame, device):
193
+ global self_frames
194
+ global self_poses
195
+ global self_memory_c2w
196
+ global self_frame_idx
197
+
198
+ new_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
199
  action,
200
  first_pose,
201
  curr_frame,
202
+ device=device,
203
+ self_frames=self_frames,
204
+ self_poses=self_poses,
205
+ self_memory_c2w=self_memory_c2w,
206
+ self_frame_idx=self_frame_idx)
207
  return new_frame
208
 
209
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
 
214
  return sampling_timesteps_state
215
 
216
  def generate(keys):
217
+ # print("algo frame:", len(worldmem.frames))
218
  actions = parse_input_to_tensor(keys)
219
  global input_history
220
  global memory_curr_frame
 
249
  global input_history
250
  global memory_frames
251
 
252
+ # worldmem.reset()
253
+ self_frames = None
254
+ self_poses = None
255
+ self_memory_c2w = None
256
+ self_frame_idx = None
257
  memory_frames = []
258
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
259
  memory_curr_frame = 0