xizaoqu
commited on
Commit
·
4170d69
1
Parent(s):
c09e983
add huggingface_load
Browse files- app.py +22 -23
- configurations/huggingface.yaml +3 -3
- experiments/exp_base.py +11 -2
app.py
CHANGED
@@ -164,7 +164,7 @@ def save_video(frames, path="output.mp4", fps=10):
|
|
164 |
def run(cfg: DictConfig):
|
165 |
|
166 |
algo = run_local(cfg)
|
167 |
-
algo.to(
|
168 |
|
169 |
actions = torch.zeros((1, 25))
|
170 |
poses = torch.zeros((1, 5))
|
@@ -247,32 +247,31 @@ def run(cfg: DictConfig):
|
|
247 |
gr.Markdown(
|
248 |
"""
|
249 |
# WORLDMEM: Long-term Consistent World Generation with Memory
|
|
|
|
|
250 |
|
251 |
-
<div style="text-align: center;">
|
252 |
-
<!-- Public Website -->
|
253 |
-
<a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
254 |
-
|
255 |
-
</a>
|
256 |
|
257 |
-
<!-- GitHub Stars -->
|
258 |
-
<a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
|
259 |
-
|
260 |
-
</a>
|
261 |
|
262 |
-
<!-- Project Page -->
|
263 |
-
<a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
|
264 |
-
|
265 |
-
</a>
|
266 |
|
267 |
-
<!-- arXiv Paper -->
|
268 |
-
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
|
269 |
-
|
270 |
-
</a>
|
271 |
-
</div>
|
272 |
|
273 |
-
"""
|
274 |
-
)
|
275 |
-
|
276 |
with gr.Row(variant="panel"):
|
277 |
video_display = gr.Video(autoplay=True, loop=True)
|
278 |
image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
|
@@ -289,7 +288,7 @@ def run(cfg: DictConfig):
|
|
289 |
sampling_timesteps_state = gr.State(algo.sampling_timesteps)
|
290 |
|
291 |
example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
292 |
-
"DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "
|
293 |
|
294 |
def set_action(action):
|
295 |
return action
|
|
|
164 |
def run(cfg: DictConfig):
|
165 |
|
166 |
algo = run_local(cfg)
|
167 |
+
algo.to(device)
|
168 |
|
169 |
actions = torch.zeros((1, 25))
|
170 |
poses = torch.zeros((1, 5))
|
|
|
247 |
gr.Markdown(
|
248 |
"""
|
249 |
# WORLDMEM: Long-term Consistent World Generation with Memory
|
250 |
+
"""
|
251 |
+
)
|
252 |
|
253 |
+
# <div style="text-align: center;">
|
254 |
+
# <!-- Public Website -->
|
255 |
+
# <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
256 |
+
# <img src="https://img.shields.io/badge/public_website-8A2BE2">
|
257 |
+
# </a>
|
258 |
|
259 |
+
# <!-- GitHub Stars -->
|
260 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
|
261 |
+
# <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social">
|
262 |
+
# </a>
|
263 |
|
264 |
+
# <!-- Project Page -->
|
265 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
|
266 |
+
# <img src="https://img.shields.io/badge/project_page-blue">
|
267 |
+
# </a>
|
268 |
|
269 |
+
# <!-- arXiv Paper -->
|
270 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
|
271 |
+
# <img src="https://img.shields.io/badge/arXiv-paper-red">
|
272 |
+
# </a>
|
273 |
+
# </div>
|
274 |
|
|
|
|
|
|
|
275 |
with gr.Row(variant="panel"):
|
276 |
video_display = gr.Video(autoplay=True, loop=True)
|
277 |
image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
|
|
|
288 |
sampling_timesteps_state = gr.State(algo.sampling_timesteps)
|
289 |
|
290 |
example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
291 |
+
"DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAA1NNNNNNNNNSSUNNWW"]
|
292 |
|
293 |
def set_action(action):
|
294 |
return action
|
configurations/huggingface.yaml
CHANGED
@@ -30,9 +30,9 @@ experiment:
|
|
30 |
load_t_to_r: false
|
31 |
zero_init_gate: false
|
32 |
only_tune_refer: false
|
33 |
-
diffusion_path:
|
34 |
-
vae_path:
|
35 |
-
pose_predictor_path:
|
36 |
customized_load: true
|
37 |
|
38 |
algorithm:
|
|
|
30 |
load_t_to_r: false
|
31 |
zero_init_gate: false
|
32 |
only_tune_refer: false
|
33 |
+
diffusion_path: yslan/worldmem_checkpoints/diffusion_only.ckpt
|
34 |
+
vae_path: yslan/worldmem_checkpoints/vae_only.ckpt
|
35 |
+
pose_predictor_path: yslan/worldmem_checkpoints/pose_prediction_model_only.ckpt
|
36 |
customized_load: true
|
37 |
|
38 |
algorithm:
|
experiments/exp_base.py
CHANGED
@@ -26,7 +26,7 @@ from utils.print_utils import cyan
|
|
26 |
from utils.distributed_utils import is_rank_zero
|
27 |
from safetensors.torch import load_model
|
28 |
from pathlib import Path
|
29 |
-
|
30 |
|
31 |
torch.set_float32_matmul_precision("high")
|
32 |
|
@@ -38,7 +38,16 @@ def load_custom_checkpoint(algo, optimizer, checkpoint_path):
|
|
38 |
if not isinstance(checkpoint_path, Path):
|
39 |
checkpoint_path = Path(checkpoint_path)
|
40 |
|
41 |
-
if checkpoint_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
ckpt = torch.load(checkpoint_path, weights_only=True)
|
43 |
algo.load_state_dict(ckpt, strict=False)
|
44 |
elif checkpoint_path.suffix == ".ckpt":
|
|
|
26 |
from utils.distributed_utils import is_rank_zero
|
27 |
from safetensors.torch import load_model
|
28 |
from pathlib import Path
|
29 |
+
from huggingface_hub import hf_hub_download
|
30 |
|
31 |
torch.set_float32_matmul_precision("high")
|
32 |
|
|
|
38 |
if not isinstance(checkpoint_path, Path):
|
39 |
checkpoint_path = Path(checkpoint_path)
|
40 |
|
41 |
+
if "yslan" in str(checkpoint_path):
|
42 |
+
hf_ckpt = str(checkpoint_path).split('/')
|
43 |
+
repo_id = '/'.join(hf_ckpt[:2])
|
44 |
+
file_name = '/'.join(hf_ckpt[2:])
|
45 |
+
model_path = hf_hub_download(repo_id=repo_id,
|
46 |
+
filename=file_name)
|
47 |
+
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
48 |
+
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
49 |
+
|
50 |
+
elif checkpoint_path.suffix == ".pt":
|
51 |
ckpt = torch.load(checkpoint_path, weights_only=True)
|
52 |
algo.load_state_dict(ckpt, strict=False)
|
53 |
elif checkpoint_path.suffix == ".ckpt":
|