xizaoqu commited on
Commit
4170d69
·
1 Parent(s): c09e983

add huggingface_load

Browse files
app.py CHANGED
@@ -164,7 +164,7 @@ def save_video(frames, path="output.mp4", fps=10):
164
  def run(cfg: DictConfig):
165
 
166
  algo = run_local(cfg)
167
- algo.to("cuda:0")
168
 
169
  actions = torch.zeros((1, 25))
170
  poses = torch.zeros((1, 5))
@@ -247,32 +247,31 @@ def run(cfg: DictConfig):
247
  gr.Markdown(
248
  """
249
  # WORLDMEM: Long-term Consistent World Generation with Memory
 
 
250
 
251
- <div style="text-align: center;">
252
- <!-- Public Website -->
253
- <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
254
- <img src="https://img.shields.io/badge/public_website-8A2BE2">
255
- </a>
256
 
257
- <!-- GitHub Stars -->
258
- <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
259
- <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social">
260
- </a>
261
 
262
- <!-- Project Page -->
263
- <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
264
- <img src="https://img.shields.io/badge/project_page-blue">
265
- </a>
266
 
267
- <!-- arXiv Paper -->
268
- <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
269
- <img src="https://img.shields.io/badge/arXiv-paper-red">
270
- </a>
271
- </div>
272
 
273
- """
274
- )
275
-
276
  with gr.Row(variant="panel"):
277
  video_display = gr.Video(autoplay=True, loop=True)
278
  image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
@@ -289,7 +288,7 @@ def run(cfg: DictConfig):
289
  sampling_timesteps_state = gr.State(algo.sampling_timesteps)
290
 
291
  example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
292
- "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEEAAAAAAAAAAAAAAAAAAAAAA"]
293
 
294
  def set_action(action):
295
  return action
 
164
  def run(cfg: DictConfig):
165
 
166
  algo = run_local(cfg)
167
+ algo.to(device)
168
 
169
  actions = torch.zeros((1, 25))
170
  poses = torch.zeros((1, 5))
 
247
  gr.Markdown(
248
  """
249
  # WORLDMEM: Long-term Consistent World Generation with Memory
250
+ """
251
+ )
252
 
253
+ # <div style="text-align: center;">
254
+ # <!-- Public Website -->
255
+ # <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
256
+ # <img src="https://img.shields.io/badge/public_website-8A2BE2">
257
+ # </a>
258
 
259
+ # <!-- GitHub Stars -->
260
+ # <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
261
+ # <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social">
262
+ # </a>
263
 
264
+ # <!-- Project Page -->
265
+ # <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
266
+ # <img src="https://img.shields.io/badge/project_page-blue">
267
+ # </a>
268
 
269
+ # <!-- arXiv Paper -->
270
+ # <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
271
+ # <img src="https://img.shields.io/badge/arXiv-paper-red">
272
+ # </a>
273
+ # </div>
274
 
 
 
 
275
  with gr.Row(variant="panel"):
276
  video_display = gr.Video(autoplay=True, loop=True)
277
  image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
 
288
  sampling_timesteps_state = gr.State(algo.sampling_timesteps)
289
 
290
  example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
291
+ "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAA1NNNNNNNNNSSUNNWW"]
292
 
293
  def set_action(action):
294
  return action
configurations/huggingface.yaml CHANGED
@@ -30,9 +30,9 @@ experiment:
30
  load_t_to_r: false
31
  zero_init_gate: false
32
  only_tune_refer: false
33
- diffusion_path: checkpoints/diffusion_only.ckpt
34
- vae_path: checkpoints/vae_only.ckpt
35
- pose_predictor_path: checkpoints/pose_prediction_model_only.ckpt
36
  customized_load: true
37
 
38
  algorithm:
 
30
  load_t_to_r: false
31
  zero_init_gate: false
32
  only_tune_refer: false
33
+ diffusion_path: yslan/worldmem_checkpoints/diffusion_only.ckpt
34
+ vae_path: yslan/worldmem_checkpoints/vae_only.ckpt
35
+ pose_predictor_path: yslan/worldmem_checkpoints/pose_prediction_model_only.ckpt
36
  customized_load: true
37
 
38
  algorithm:
experiments/exp_base.py CHANGED
@@ -26,7 +26,7 @@ from utils.print_utils import cyan
26
  from utils.distributed_utils import is_rank_zero
27
  from safetensors.torch import load_model
28
  from pathlib import Path
29
-
30
 
31
  torch.set_float32_matmul_precision("high")
32
 
@@ -38,7 +38,16 @@ def load_custom_checkpoint(algo, optimizer, checkpoint_path):
38
  if not isinstance(checkpoint_path, Path):
39
  checkpoint_path = Path(checkpoint_path)
40
 
41
- if checkpoint_path.suffix == ".pt":
 
 
 
 
 
 
 
 
 
42
  ckpt = torch.load(checkpoint_path, weights_only=True)
43
  algo.load_state_dict(ckpt, strict=False)
44
  elif checkpoint_path.suffix == ".ckpt":
 
26
  from utils.distributed_utils import is_rank_zero
27
  from safetensors.torch import load_model
28
  from pathlib import Path
29
+ from huggingface_hub import hf_hub_download
30
 
31
  torch.set_float32_matmul_precision("high")
32
 
 
38
  if not isinstance(checkpoint_path, Path):
39
  checkpoint_path = Path(checkpoint_path)
40
 
41
+ if "yslan" in str(checkpoint_path):
42
+ hf_ckpt = str(checkpoint_path).split('/')
43
+ repo_id = '/'.join(hf_ckpt[:2])
44
+ file_name = '/'.join(hf_ckpt[2:])
45
+ model_path = hf_hub_download(repo_id=repo_id,
46
+ filename=file_name)
47
+ ckpt = torch.load(model_path, map_location=torch.device('cpu'))
48
+ algo.load_state_dict(ckpt['state_dict'], strict=False)
49
+
50
+ elif checkpoint_path.suffix == ".pt":
51
  ckpt = torch.load(checkpoint_path, weights_only=True)
52
  algo.load_state_dict(ckpt, strict=False)
53
  elif checkpoint_path.suffix == ".ckpt":