xizaoqu commited on
Commit
0d5deae
·
1 Parent(s): 1e18469

update precision

Browse files
algorithms/worldmem/df_video.py CHANGED
@@ -791,22 +791,22 @@ class WorldMemMinecraft(DiffusionForcingBase):
791
  return
792
 
793
  @torch.no_grad()
794
- def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
795
  self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
796
 
797
  condition_similar_length = self.condition_similar_length
798
 
799
  if self_frames is None:
800
  first_frame = torch.from_numpy(first_frame)
801
- curr_actions = torch.from_numpy(curr_actions)
802
  first_pose = torch.from_numpy(first_pose)
803
  first_frame_encode = self.encode(first_frame[None, None].to(device))
804
  self_frames = first_frame_encode.cpu()
805
- self_actions = curr_actions[None, None].to(device)
806
  self_poses = first_pose[None, None].to(device)
807
  new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
808
  self_memory_c2w = new_c2w_mat[None, None].to(device)
809
- self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
810
  return first_frame.cpu().numpy(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
811
  else:
812
  self_frames = torch.from_numpy(self_frames)
@@ -814,9 +814,26 @@ class WorldMemMinecraft(DiffusionForcingBase):
814
  self_poses = torch.from_numpy(self_poses).to(device)
815
  self_memory_c2w = torch.from_numpy(self_memory_c2w).to(device)
816
  self_frame_idx = torch.from_numpy(self_frame_idx).to(device)
817
- curr_actions = curr_actions.to(device)
818
 
819
- last_frame = self_frames[-1].clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820
  last_pose_condition = self_poses[-1].clone()
821
  last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
822
  new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None], last_pose_condition)
@@ -829,88 +846,80 @@ class WorldMemMinecraft(DiffusionForcingBase):
829
  self_poses = torch.cat([self_poses, new_pose_condition[None]])
830
  new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
831
  self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None]])
832
- self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
833
 
834
- conditions = self_actions.clone()
835
- pose_conditions = self_poses.clone()
836
- c2w_mat = self_memory_c2w .clone()
837
- frame_idx = self_frame_idx.clone()
838
 
 
 
 
 
839
 
840
- curr_frame = 0
841
- horizon = 1
842
- batch_size = 1
843
- n_frames = curr_frame + horizon
844
- # context
845
- n_context_frames = context_frames_idx // self.frame_stack
846
- xs_pred = self_frames[:n_context_frames].clone()
847
- curr_frame += n_context_frames
848
 
849
- pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
 
850
 
851
- # generation on frame
852
- scheduling_matrix = self._generate_scheduling_matrix(horizon)
853
- chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
854
- chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
 
 
855
 
856
- xs_pred = torch.cat([xs_pred, chunk], 0)
 
 
 
 
 
 
 
 
857
 
858
- # sliding window: only input the last n_tokens frames
859
- start_frame = max(0, curr_frame + horizon - self.n_tokens)
 
 
 
860
 
861
- pbar.set_postfix(
862
- {
863
- "start": start_frame,
864
- "end": curr_frame + horizon,
865
- }
866
- )
867
 
868
- # Handle condition similarity logic
869
- if condition_similar_length:
870
- random_idx = self._generate_condition_indices(
871
- curr_frame, condition_similar_length, xs_pred, pose_conditions, frame_idx
872
- )
873
-
874
- # random_idx = np.unique(random_idx)[:, None]
875
- # condition_similar_length = len(random_idx)
876
- xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
877
-
878
- # Prepare input conditions and pose conditions
879
- input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
880
- start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
881
- image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
882
- )
883
 
884
- # Perform sampling for each step in the scheduling matrix
885
- for m in range(scheduling_matrix.shape[0] - 1):
886
- from_noise_levels, to_noise_levels = self._prepare_noise_levels(
887
- scheduling_matrix, m, curr_frame, batch_size, condition_similar_length
888
- )
889
 
890
- xs_pred[start_frame:] = self.diffusion_model.sample_step(
891
- xs_pred[start_frame:].to(input_condition.device),
892
- input_condition,
893
- input_pose_condition,
894
- from_noise_levels[start_frame:],
895
- to_noise_levels[start_frame:],
896
- current_frame=curr_frame,
897
- mode="validation",
898
- reference_length=condition_similar_length,
899
- frame_idx=frame_idx_list
900
- ).cpu()
901
 
902
 
903
- if condition_similar_length:
904
- xs_pred = xs_pred[:-condition_similar_length]
905
 
906
- curr_frame += horizon
907
- pbar.update(horizon)
908
 
909
  self_frames = torch.cat([self_frames, xs_pred[n_context_frames:]])
910
-
911
  xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
912
 
913
- return xs_pred[-1,0].cpu().numpy(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), \
914
  self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
915
 
916
 
 
791
  return
792
 
793
  @torch.no_grad()
794
+ def interactive(self, first_frame, new_actions, first_pose, device,
795
  self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
796
 
797
  condition_similar_length = self.condition_similar_length
798
 
799
  if self_frames is None:
800
  first_frame = torch.from_numpy(first_frame)
801
+ new_actions = torch.from_numpy(new_actions)
802
  first_pose = torch.from_numpy(first_pose)
803
  first_frame_encode = self.encode(first_frame[None, None].to(device))
804
  self_frames = first_frame_encode.cpu()
805
+ self_actions = new_actions[None, None].to(device)
806
  self_poses = first_pose[None, None].to(device)
807
  new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
808
  self_memory_c2w = new_c2w_mat[None, None].to(device)
809
+ self_frame_idx = torch.tensor([[0]]).to(device)
810
  return first_frame.cpu().numpy(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
811
  else:
812
  self_frames = torch.from_numpy(self_frames)
 
814
  self_poses = torch.from_numpy(self_poses).to(device)
815
  self_memory_c2w = torch.from_numpy(self_memory_c2w).to(device)
816
  self_frame_idx = torch.from_numpy(self_frame_idx).to(device)
817
+ new_actions = new_actions.to(device)
818
 
819
+ curr_frame = 0
820
+ horizon = 1
821
+ batch_size = 1
822
+ n_frames = curr_frame + horizon
823
+ # context
824
+ n_context_frames = len(self_frames)
825
+ xs_pred = self_frames[:n_context_frames].clone()
826
+ curr_frame += n_context_frames
827
+
828
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
829
+
830
+
831
+ for ai in range(len(new_actions)):
832
+ from time import time
833
+ start_time = time()
834
+
835
+ last_frame = xs_pred[-1].clone()
836
+ curr_actions = new_actions[ai]
837
  last_pose_condition = self_poses[-1].clone()
838
  last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
839
  new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None], last_pose_condition)
 
846
  self_poses = torch.cat([self_poses, new_pose_condition[None]])
847
  new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
848
  self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None]])
849
+ self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[self_frame_idx[-1,0]+1]]).to(device)])
850
 
851
+ conditions = self_actions.clone()
852
+ pose_conditions = self_poses.clone()
853
+ c2w_mat = self_memory_c2w .clone()
854
+ frame_idx = self_frame_idx.clone()
855
 
856
+ # generation on frame
857
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
858
+ chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
859
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
860
 
861
+ xs_pred = torch.cat([xs_pred, chunk], 0)
 
 
 
 
 
 
 
862
 
863
+ # sliding window: only input the last n_tokens frames
864
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
865
 
866
+ pbar.set_postfix(
867
+ {
868
+ "start": start_frame,
869
+ "end": curr_frame + horizon,
870
+ }
871
+ )
872
 
873
+ # Handle condition similarity logic
874
+ if condition_similar_length:
875
+ random_idx = self._generate_condition_indices(
876
+ curr_frame, condition_similar_length, xs_pred, pose_conditions, frame_idx
877
+ )
878
+
879
+ # random_idx = np.unique(random_idx)[:, None]
880
+ # condition_similar_length = len(random_idx)
881
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
882
 
883
+ # Prepare input conditions and pose conditions
884
+ input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
885
+ start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
886
+ image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
887
+ )
888
 
889
+ mid_time = time()
890
+ # Perform sampling for each step in the scheduling matrix
891
+ for m in range(scheduling_matrix.shape[0] - 1):
892
+ from_noise_levels, to_noise_levels = self._prepare_noise_levels(
893
+ scheduling_matrix, m, curr_frame, batch_size, condition_similar_length
894
+ )
895
 
896
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
897
+ xs_pred[start_frame:].to(input_condition.device),
898
+ input_condition,
899
+ input_pose_condition,
900
+ from_noise_levels[start_frame:],
901
+ to_noise_levels[start_frame:],
902
+ current_frame=curr_frame,
903
+ mode="validation",
904
+ reference_length=condition_similar_length,
905
+ frame_idx=frame_idx_list
906
+ ).cpu()
 
 
 
 
907
 
908
+ end_time = time()
 
 
 
 
909
 
910
+ print("time:", end_time - start_time, "mid time:", mid_time - start_time)
 
 
 
 
 
 
 
 
 
 
911
 
912
 
913
+ if condition_similar_length:
914
+ xs_pred = xs_pred[:-condition_similar_length]
915
 
916
+ curr_frame += horizon
917
+ pbar.update(horizon)
918
 
919
  self_frames = torch.cat([self_frames, xs_pred[n_context_frames:]])
 
920
  xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
921
 
922
+ return xs_pred.cpu().numpy(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), \
923
  self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
924
 
925
 
algorithms/worldmem/models/dit.py CHANGED
@@ -487,6 +487,8 @@ class DiT(nn.Module):
487
  t: (B, T,) tensor of diffusion timesteps
488
  """
489
 
 
 
490
  B, T, C, H, W = x.shape
491
 
492
  # add spatial embeddings
@@ -550,6 +552,8 @@ class DiT(nn.Module):
550
  # print("self.blocks[0].r_adaLN_modulation[1].weight:", self.blocks[0].r_adaLN_modulation[1].weight)
551
  # print("self.blocks[0].t_adaLN_modulation[1].weight:", self.blocks[0].t_adaLN_modulation[1].weight)
552
 
 
 
553
  return x
554
 
555
 
 
487
  t: (B, T,) tensor of diffusion timesteps
488
  """
489
 
490
+ from time import time
491
+ start = time()
492
  B, T, C, H, W = x.shape
493
 
494
  # add spatial embeddings
 
552
  # print("self.blocks[0].r_adaLN_modulation[1].weight:", self.blocks[0].r_adaLN_modulation[1].weight)
553
  # print("self.blocks[0].t_adaLN_modulation[1].weight:", self.blocks[0].t_adaLN_modulation[1].weight)
554
 
555
+ end_time = time()
556
+ print("in model time:", end_time - start)
557
  return x
558
 
559
 
app.py CHANGED
@@ -26,6 +26,8 @@ import spaces
26
  from algorithms.worldmem import WorldMemMinecraft
27
  from huggingface_hub import hf_hub_download
28
 
 
 
29
  ACTION_KEYS = [
30
  "inventory",
31
  "ESC",
@@ -142,6 +144,16 @@ def run_local(cfg: DictConfig):
142
  experiment = build_experiment(cfg, None, None)
143
  return experiment.exec_interactive(cfg.experiment.tasks[0])
144
 
 
 
 
 
 
 
 
 
 
 
145
  memory_frames = []
146
  memory_curr_frame = 0
147
  input_history = ""
@@ -175,12 +187,12 @@ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffus
175
  load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
176
  load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
177
  worldmem.to("cuda").eval()
178
-
179
 
180
  actions = np.zeros((1, 25), dtype=np.float32)
181
  poses = np.zeros((1, 5), dtype=np.float32)
182
 
183
- memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
184
 
185
  self_frames = None
186
  self_actions = None
@@ -190,12 +202,11 @@ self_frame_idx = None
190
 
191
 
192
  @spaces.GPU()
193
- def run_interactive(first_frame, action, first_pose, curr_frame, device, self_frames, self_actions,
194
  self_poses, self_memory_c2w, self_frame_idx):
195
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
196
  action,
197
  first_pose,
198
- curr_frame,
199
  device=device,
200
  self_frames=self_frames,
201
  self_actions=self_actions,
@@ -216,6 +227,7 @@ def generate(keys):
216
  # print("algo frame:", len(worldmem.frames))
217
  actions = parse_input_to_tensor(keys)
218
  global input_history
 
219
  global memory_curr_frame
220
  global self_frames
221
  global self_actions
@@ -223,26 +235,19 @@ def generate(keys):
223
  global self_memory_c2w
224
  global self_frame_idx
225
 
226
- for i in range(len(actions)):
227
- memory_curr_frame += 1
228
-
229
- new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
230
- actions[i],
231
- None,
232
- memory_curr_frame,
233
- device=device,
234
- self_frames=self_frames,
235
- self_actions=self_actions,
236
- self_poses=self_poses,
237
- self_memory_c2w=self_memory_c2w,
238
- self_frame_idx=self_frame_idx)
239
-
240
- # print("algo frame:", len(runner.algo.frames))
241
-
242
- memory_frames.append(new_frame)
243
-
244
- out_video = np.stack(memory_frames)
245
- out_video = out_video.transpose(0,2,3,1)
246
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
247
  out_video = (out_video * 255).astype(np.uint8)
248
 
@@ -268,15 +273,12 @@ def reset():
268
  self_poses = None
269
  self_memory_c2w = None
270
  self_frame_idx = None
271
- memory_frames = []
272
- memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE).numpy())
273
- memory_curr_frame = 0
274
  input_history = ""
275
 
276
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
277
  actions[0],
278
  poses[0],
279
- memory_curr_frame,
280
  device=device,
281
  self_frames=self_frames,
282
  self_actions=self_actions,
 
26
  from algorithms.worldmem import WorldMemMinecraft
27
  from huggingface_hub import hf_hub_download
28
 
29
+ torch.set_float32_matmul_precision("high")
30
+
31
  ACTION_KEYS = [
32
  "inventory",
33
  "ESC",
 
144
  experiment = build_experiment(cfg, None, None)
145
  return experiment.exec_interactive(cfg.experiment.tasks[0])
146
 
147
+ def enable_amp(model, precision="16-mixed"):
148
+ original_forward = model.forward
149
+
150
+ def amp_forward(*args, **kwargs):
151
+ with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
152
+ return original_forward(*args, **kwargs)
153
+
154
+ model.forward = amp_forward
155
+ return model
156
+
157
  memory_frames = []
158
  memory_curr_frame = 0
159
  input_history = ""
 
187
  load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
188
  load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
189
  worldmem.to("cuda").eval()
190
+ worldmem = enable_amp(worldmem, precision="16-mixed")
191
 
192
  actions = np.zeros((1, 25), dtype=np.float32)
193
  poses = np.zeros((1, 5), dtype=np.float32)
194
 
195
+ memory_frames = load_image_as_tensor(DEFAULT_IMAGE)[None].numpy()
196
 
197
  self_frames = None
198
  self_actions = None
 
202
 
203
 
204
  @spaces.GPU()
205
+ def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions,
206
  self_poses, self_memory_c2w, self_frame_idx):
207
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
208
  action,
209
  first_pose,
 
210
  device=device,
211
  self_frames=self_frames,
212
  self_actions=self_actions,
 
227
  # print("algo frame:", len(worldmem.frames))
228
  actions = parse_input_to_tensor(keys)
229
  global input_history
230
+ global memory_frames
231
  global memory_curr_frame
232
  global self_frames
233
  global self_actions
 
235
  global self_memory_c2w
236
  global self_frame_idx
237
 
238
+ new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
239
+ actions,
240
+ None,
241
+ device=device,
242
+ self_frames=self_frames,
243
+ self_actions=self_actions,
244
+ self_poses=self_poses,
245
+ self_memory_c2w=self_memory_c2w,
246
+ self_frame_idx=self_frame_idx)
247
+
248
+ memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
249
+
250
+ out_video = memory_frames.transpose(0,2,3,1)
 
 
 
 
 
 
 
251
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
252
  out_video = (out_video * 255).astype(np.uint8)
253
 
 
273
  self_poses = None
274
  self_memory_c2w = None
275
  self_frame_idx = None
276
+ memory_frames = load_image_as_tensor(DEFAULT_IMAGE).numpy()[None]
 
 
277
  input_history = ""
278
 
279
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
280
  actions[0],
281
  poses[0],
 
282
  device=device,
283
  self_frames=self_frames,
284
  self_actions=self_actions,