xizaoqu
commited on
Commit
·
4170d69
1
Parent(s):
c09e983
add huggingface_load
Browse files- app.py +22 -23
- configurations/huggingface.yaml +3 -3
- experiments/exp_base.py +11 -2
app.py
CHANGED
|
@@ -164,7 +164,7 @@ def save_video(frames, path="output.mp4", fps=10):
|
|
| 164 |
def run(cfg: DictConfig):
|
| 165 |
|
| 166 |
algo = run_local(cfg)
|
| 167 |
-
algo.to(
|
| 168 |
|
| 169 |
actions = torch.zeros((1, 25))
|
| 170 |
poses = torch.zeros((1, 5))
|
|
@@ -247,32 +247,31 @@ def run(cfg: DictConfig):
|
|
| 247 |
gr.Markdown(
|
| 248 |
"""
|
| 249 |
# WORLDMEM: Long-term Consistent World Generation with Memory
|
|
|
|
|
|
|
| 250 |
|
| 251 |
-
<div style="text-align: center;">
|
| 252 |
-
<!-- Public Website -->
|
| 253 |
-
<a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
| 254 |
-
|
| 255 |
-
</a>
|
| 256 |
|
| 257 |
-
<!-- GitHub Stars -->
|
| 258 |
-
<a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
|
| 259 |
-
|
| 260 |
-
</a>
|
| 261 |
|
| 262 |
-
<!-- Project Page -->
|
| 263 |
-
<a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
|
| 264 |
-
|
| 265 |
-
</a>
|
| 266 |
|
| 267 |
-
<!-- arXiv Paper -->
|
| 268 |
-
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
|
| 269 |
-
|
| 270 |
-
</a>
|
| 271 |
-
</div>
|
| 272 |
|
| 273 |
-
"""
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
with gr.Row(variant="panel"):
|
| 277 |
video_display = gr.Video(autoplay=True, loop=True)
|
| 278 |
image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
|
|
@@ -289,7 +288,7 @@ def run(cfg: DictConfig):
|
|
| 289 |
sampling_timesteps_state = gr.State(algo.sampling_timesteps)
|
| 290 |
|
| 291 |
example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
| 292 |
-
"DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "
|
| 293 |
|
| 294 |
def set_action(action):
|
| 295 |
return action
|
|
|
|
| 164 |
def run(cfg: DictConfig):
|
| 165 |
|
| 166 |
algo = run_local(cfg)
|
| 167 |
+
algo.to(device)
|
| 168 |
|
| 169 |
actions = torch.zeros((1, 25))
|
| 170 |
poses = torch.zeros((1, 5))
|
|
|
|
| 247 |
gr.Markdown(
|
| 248 |
"""
|
| 249 |
# WORLDMEM: Long-term Consistent World Generation with Memory
|
| 250 |
+
"""
|
| 251 |
+
)
|
| 252 |
|
| 253 |
+
# <div style="text-align: center;">
|
| 254 |
+
# <!-- Public Website -->
|
| 255 |
+
# <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
| 256 |
+
# <img src="https://img.shields.io/badge/public_website-8A2BE2">
|
| 257 |
+
# </a>
|
| 258 |
|
| 259 |
+
# <!-- GitHub Stars -->
|
| 260 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
|
| 261 |
+
# <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social">
|
| 262 |
+
# </a>
|
| 263 |
|
| 264 |
+
# <!-- Project Page -->
|
| 265 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
|
| 266 |
+
# <img src="https://img.shields.io/badge/project_page-blue">
|
| 267 |
+
# </a>
|
| 268 |
|
| 269 |
+
# <!-- arXiv Paper -->
|
| 270 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
|
| 271 |
+
# <img src="https://img.shields.io/badge/arXiv-paper-red">
|
| 272 |
+
# </a>
|
| 273 |
+
# </div>
|
| 274 |
|
|
|
|
|
|
|
|
|
|
| 275 |
with gr.Row(variant="panel"):
|
| 276 |
video_display = gr.Video(autoplay=True, loop=True)
|
| 277 |
image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
|
|
|
|
| 288 |
sampling_timesteps_state = gr.State(algo.sampling_timesteps)
|
| 289 |
|
| 290 |
example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
| 291 |
+
"DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAA1NNNNNNNNNSSUNNWW"]
|
| 292 |
|
| 293 |
def set_action(action):
|
| 294 |
return action
|
configurations/huggingface.yaml
CHANGED
|
@@ -30,9 +30,9 @@ experiment:
|
|
| 30 |
load_t_to_r: false
|
| 31 |
zero_init_gate: false
|
| 32 |
only_tune_refer: false
|
| 33 |
-
diffusion_path:
|
| 34 |
-
vae_path:
|
| 35 |
-
pose_predictor_path:
|
| 36 |
customized_load: true
|
| 37 |
|
| 38 |
algorithm:
|
|
|
|
| 30 |
load_t_to_r: false
|
| 31 |
zero_init_gate: false
|
| 32 |
only_tune_refer: false
|
| 33 |
+
diffusion_path: yslan/worldmem_checkpoints/diffusion_only.ckpt
|
| 34 |
+
vae_path: yslan/worldmem_checkpoints/vae_only.ckpt
|
| 35 |
+
pose_predictor_path: yslan/worldmem_checkpoints/pose_prediction_model_only.ckpt
|
| 36 |
customized_load: true
|
| 37 |
|
| 38 |
algorithm:
|
experiments/exp_base.py
CHANGED
|
@@ -26,7 +26,7 @@ from utils.print_utils import cyan
|
|
| 26 |
from utils.distributed_utils import is_rank_zero
|
| 27 |
from safetensors.torch import load_model
|
| 28 |
from pathlib import Path
|
| 29 |
-
|
| 30 |
|
| 31 |
torch.set_float32_matmul_precision("high")
|
| 32 |
|
|
@@ -38,7 +38,16 @@ def load_custom_checkpoint(algo, optimizer, checkpoint_path):
|
|
| 38 |
if not isinstance(checkpoint_path, Path):
|
| 39 |
checkpoint_path = Path(checkpoint_path)
|
| 40 |
|
| 41 |
-
if checkpoint_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
ckpt = torch.load(checkpoint_path, weights_only=True)
|
| 43 |
algo.load_state_dict(ckpt, strict=False)
|
| 44 |
elif checkpoint_path.suffix == ".ckpt":
|
|
|
|
| 26 |
from utils.distributed_utils import is_rank_zero
|
| 27 |
from safetensors.torch import load_model
|
| 28 |
from pathlib import Path
|
| 29 |
+
from huggingface_hub import hf_hub_download
|
| 30 |
|
| 31 |
torch.set_float32_matmul_precision("high")
|
| 32 |
|
|
|
|
| 38 |
if not isinstance(checkpoint_path, Path):
|
| 39 |
checkpoint_path = Path(checkpoint_path)
|
| 40 |
|
| 41 |
+
if "yslan" in str(checkpoint_path):
|
| 42 |
+
hf_ckpt = str(checkpoint_path).split('/')
|
| 43 |
+
repo_id = '/'.join(hf_ckpt[:2])
|
| 44 |
+
file_name = '/'.join(hf_ckpt[2:])
|
| 45 |
+
model_path = hf_hub_download(repo_id=repo_id,
|
| 46 |
+
filename=file_name)
|
| 47 |
+
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
| 48 |
+
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
| 49 |
+
|
| 50 |
+
elif checkpoint_path.suffix == ".pt":
|
| 51 |
ckpt = torch.load(checkpoint_path, weights_only=True)
|
| 52 |
algo.load_state_dict(ckpt, strict=False)
|
| 53 |
elif checkpoint_path.suffix == ".ckpt":
|