xizaoqu commited on
Commit
594fef7
·
1 Parent(s): 0d5deae

update precision

Browse files
algorithms/worldmem/df_video.py CHANGED
@@ -829,8 +829,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
829
 
830
 
831
  for ai in range(len(new_actions)):
832
- from time import time
833
- start_time = time()
834
 
835
  last_frame = xs_pred[-1].clone()
836
  curr_actions = new_actions[ai]
@@ -886,7 +884,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
886
  image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
887
  )
888
 
889
- mid_time = time()
890
  # Perform sampling for each step in the scheduling matrix
891
  for m in range(scheduling_matrix.shape[0] - 1):
892
  from_noise_levels, to_noise_levels = self._prepare_noise_levels(
@@ -905,10 +902,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
905
  frame_idx=frame_idx_list
906
  ).cpu()
907
 
908
- end_time = time()
909
-
910
- print("time:", end_time - start_time, "mid time:", mid_time - start_time)
911
-
912
 
913
  if condition_similar_length:
914
  xs_pred = xs_pred[:-condition_similar_length]
 
829
 
830
 
831
  for ai in range(len(new_actions)):
 
 
832
 
833
  last_frame = xs_pred[-1].clone()
834
  curr_actions = new_actions[ai]
 
884
  image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
885
  )
886
 
 
887
  # Perform sampling for each step in the scheduling matrix
888
  for m in range(scheduling_matrix.shape[0] - 1):
889
  from_noise_levels, to_noise_levels = self._prepare_noise_levels(
 
902
  frame_idx=frame_idx_list
903
  ).cpu()
904
 
 
 
 
 
905
 
906
  if condition_similar_length:
907
  xs_pred = xs_pred[:-condition_similar_length]
algorithms/worldmem/models/dit.py CHANGED
@@ -487,8 +487,6 @@ class DiT(nn.Module):
487
  t: (B, T,) tensor of diffusion timesteps
488
  """
489
 
490
- from time import time
491
- start = time()
492
  B, T, C, H, W = x.shape
493
 
494
  # add spatial embeddings
@@ -552,8 +550,6 @@ class DiT(nn.Module):
552
  # print("self.blocks[0].r_adaLN_modulation[1].weight:", self.blocks[0].r_adaLN_modulation[1].weight)
553
  # print("self.blocks[0].t_adaLN_modulation[1].weight:", self.blocks[0].t_adaLN_modulation[1].weight)
554
 
555
- end_time = time()
556
- print("in model time:", end_time - start)
557
  return x
558
 
559
 
 
487
  t: (B, T,) tensor of diffusion timesteps
488
  """
489
 
 
 
490
  B, T, C, H, W = x.shape
491
 
492
  # add spatial embeddings
 
550
  # print("self.blocks[0].r_adaLN_modulation[1].weight:", self.blocks[0].r_adaLN_modulation[1].weight)
551
  # print("self.blocks[0].t_adaLN_modulation[1].weight:", self.blocks[0].t_adaLN_modulation[1].weight)
552
 
 
 
553
  return x
554
 
555
 
app.py CHANGED
@@ -201,7 +201,11 @@ self_memory_c2w = None
201
  self_frame_idx = None
202
 
203
 
204
- @spaces.GPU()
 
 
 
 
205
  def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions,
206
  self_poses, self_memory_c2w, self_frame_idx):
207
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
@@ -271,6 +275,7 @@ def reset():
271
 
272
  self_frames = None
273
  self_poses = None
 
274
  self_memory_c2w = None
275
  self_frame_idx = None
276
  memory_frames = load_image_as_tensor(DEFAULT_IMAGE).numpy()[None]
 
201
  self_frame_idx = None
202
 
203
 
204
+ def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
205
+ self_poses, self_memory_c2w, self_frame_idx):
206
+ return 5 * len(action) is self_actions is not None else 5
207
+
208
+ @spaces.GPU(duration=get_duration_single_image_to_long_video)
209
  def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions,
210
  self_poses, self_memory_c2w, self_frame_idx):
211
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
 
275
 
276
  self_frames = None
277
  self_poses = None
278
+ self_actions = None
279
  self_memory_c2w = None
280
  self_frame_idx = None
281
  memory_frames = load_image_as_tensor(DEFAULT_IMAGE).numpy()[None]