xizaoqu
commited on
Commit
·
f07d258
1
Parent(s):
cef86dc
update
Browse files
app.py
CHANGED
@@ -132,26 +132,6 @@ def load_image_as_tensor(image_path: str) -> torch.Tensor:
|
|
132 |
])
|
133 |
return transform(image)
|
134 |
|
135 |
-
def run_local(cfg: DictConfig):
|
136 |
-
# delay some imports in case they are not needed in non-local envs for submission
|
137 |
-
from experiments import build_experiment
|
138 |
-
|
139 |
-
# Get yaml names
|
140 |
-
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
|
141 |
-
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
|
142 |
-
|
143 |
-
with open_dict(cfg):
|
144 |
-
if cfg_choice["experiment"] is not None:
|
145 |
-
cfg.experiment._name = cfg_choice["experiment"]
|
146 |
-
if cfg_choice["dataset"] is not None:
|
147 |
-
cfg.dataset._name = cfg_choice["dataset"]
|
148 |
-
if cfg_choice["algorithm"] is not None:
|
149 |
-
cfg.algorithm._name = cfg_choice["algorithm"]
|
150 |
-
|
151 |
-
# launch experiment
|
152 |
-
experiment = build_experiment(cfg, None, None)
|
153 |
-
return experiment.exec_interactive(cfg.experiment.tasks[0])
|
154 |
-
|
155 |
def enable_amp(model, precision="16-mixed"):
|
156 |
original_forward = model.forward
|
157 |
|
@@ -193,7 +173,7 @@ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffus
|
|
193 |
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
194 |
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
195 |
worldmem.to("cuda").eval()
|
196 |
-
worldmem = enable_amp(worldmem, precision="16-mixed")
|
197 |
|
198 |
actions = np.zeros((1, 25), dtype=np.float32)
|
199 |
poses = np.zeros((1, 5), dtype=np.float32)
|
@@ -555,7 +535,6 @@ with gr.Blocks(css=css) as demo:
|
|
555 |
)
|
556 |
|
557 |
|
558 |
-
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
559 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
560 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
561 |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
|
|
132 |
])
|
133 |
return transform(image)
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def enable_amp(model, precision="16-mixed"):
|
136 |
original_forward = model.forward
|
137 |
|
|
|
173 |
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
174 |
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
175 |
worldmem.to("cuda").eval()
|
176 |
+
# worldmem = enable_amp(worldmem, precision="16-mixed")
|
177 |
|
178 |
actions = np.zeros((1, 25), dtype=np.float32)
|
179 |
poses = np.zeros((1, 5), dtype=np.float32)
|
|
|
535 |
)
|
536 |
|
537 |
|
|
|
538 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
539 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
540 |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|