xizaoqu commited on
Commit
f07d258
·
1 Parent(s): cef86dc
Files changed (1) hide show
  1. app.py +1 -22
app.py CHANGED
@@ -132,26 +132,6 @@ def load_image_as_tensor(image_path: str) -> torch.Tensor:
132
  ])
133
  return transform(image)
134
 
135
- def run_local(cfg: DictConfig):
136
- # delay some imports in case they are not needed in non-local envs for submission
137
- from experiments import build_experiment
138
-
139
- # Get yaml names
140
- hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
141
- cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
142
-
143
- with open_dict(cfg):
144
- if cfg_choice["experiment"] is not None:
145
- cfg.experiment._name = cfg_choice["experiment"]
146
- if cfg_choice["dataset"] is not None:
147
- cfg.dataset._name = cfg_choice["dataset"]
148
- if cfg_choice["algorithm"] is not None:
149
- cfg.algorithm._name = cfg_choice["algorithm"]
150
-
151
- # launch experiment
152
- experiment = build_experiment(cfg, None, None)
153
- return experiment.exec_interactive(cfg.experiment.tasks[0])
154
-
155
  def enable_amp(model, precision="16-mixed"):
156
  original_forward = model.forward
157
 
@@ -193,7 +173,7 @@ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffus
193
  load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
194
  load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
195
  worldmem.to("cuda").eval()
196
- worldmem = enable_amp(worldmem, precision="16-mixed")
197
 
198
  actions = np.zeros((1, 25), dtype=np.float32)
199
  poses = np.zeros((1, 5), dtype=np.float32)
@@ -555,7 +535,6 @@ with gr.Blocks(css=css) as demo:
555
  )
556
 
557
 
558
- # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
559
  submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
560
  reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
561
  image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
 
132
  ])
133
  return transform(image)
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def enable_amp(model, precision="16-mixed"):
136
  original_forward = model.forward
137
 
 
173
  load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
174
  load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
175
  worldmem.to("cuda").eval()
176
+ # worldmem = enable_amp(worldmem, precision="16-mixed")
177
 
178
  actions = np.zeros((1, 25), dtype=np.float32)
179
  poses = np.zeros((1, 5), dtype=np.float32)
 
535
  )
536
 
537
 
 
538
  submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
539
  reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
540
  image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])