|
import gradio as gr |
|
import time |
|
|
|
import sys |
|
import subprocess |
|
import time |
|
from pathlib import Path |
|
|
|
import hydra |
|
from omegaconf import DictConfig, OmegaConf |
|
from omegaconf.omegaconf import open_dict |
|
|
|
from utils.print_utils import cyan |
|
from utils.ckpt_utils import download_latest_checkpoint, is_run_id |
|
from utils.cluster_utils import submit_slurm_job |
|
from utils.distributed_utils import is_rank_zero |
|
import numpy as np |
|
import torch |
|
from datasets.video.minecraft_video_dataset import * |
|
import torchvision.transforms as transforms |
|
import cv2 |
|
import subprocess |
|
from PIL import Image |
|
from datetime import datetime |
|
import spaces |
|
from algorithms.worldmem import WorldMemMinecraft |
|
from huggingface_hub import hf_hub_download |
|
import tempfile |
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
ACTION_KEYS = [ |
|
"inventory", |
|
"ESC", |
|
"hotbar.1", |
|
"hotbar.2", |
|
"hotbar.3", |
|
"hotbar.4", |
|
"hotbar.5", |
|
"hotbar.6", |
|
"hotbar.7", |
|
"hotbar.8", |
|
"hotbar.9", |
|
"forward", |
|
"back", |
|
"left", |
|
"right", |
|
"cameraY", |
|
"cameraX", |
|
"jump", |
|
"sneak", |
|
"sprint", |
|
"swapHands", |
|
"attack", |
|
"use", |
|
"pickItem", |
|
"drop", |
|
] |
|
|
|
|
|
KEY_TO_ACTION = { |
|
"Q": ("forward", 1), |
|
"E": ("back", 1), |
|
"W": ("cameraY", -1), |
|
"S": ("cameraY", 1), |
|
"A": ("cameraX", -1), |
|
"D": ("cameraX", 1), |
|
"U": ("drop", 1), |
|
"N": ("noop", 1), |
|
"1": ("hotbar.1", 1), |
|
} |
|
|
|
def load_custom_checkpoint(algo, checkpoint_path): |
|
hf_ckpt = str(checkpoint_path).split('/') |
|
repo_id = '/'.join(hf_ckpt[:2]) |
|
file_name = '/'.join(hf_ckpt[2:]) |
|
model_path = hf_hub_download(repo_id=repo_id, |
|
filename=file_name) |
|
ckpt = torch.load(model_path, map_location=torch.device('cpu')) |
|
algo.load_state_dict(ckpt['state_dict'], strict=False) |
|
|
|
|
|
def parse_input_to_tensor(input_str): |
|
""" |
|
Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation |
|
of the corresponding action key. |
|
|
|
Args: |
|
input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS"). |
|
|
|
Returns: |
|
torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action. |
|
""" |
|
|
|
seq_len = len(input_str) |
|
|
|
|
|
action_tensor = torch.zeros((seq_len, 25)) |
|
|
|
|
|
for i, char in enumerate(input_str): |
|
action, value = KEY_TO_ACTION.get(char.upper()) |
|
if action and action in ACTION_KEYS: |
|
index = ACTION_KEYS.index(action) |
|
action_tensor[i, index] = value |
|
|
|
return action_tensor |
|
|
|
def load_image_as_tensor(image_path: str) -> torch.Tensor: |
|
""" |
|
Load an image and convert it to a 0-1 normalized tensor. |
|
|
|
Args: |
|
image_path (str): Path to the image file. |
|
|
|
Returns: |
|
torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1]. |
|
""" |
|
if isinstance(image_path, str): |
|
image = Image.open(image_path).convert("RGB") |
|
else: |
|
image = image_path |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
]) |
|
return transform(image) |
|
|
|
def run_local(cfg: DictConfig): |
|
|
|
from experiments import build_experiment |
|
|
|
|
|
hydra_cfg = hydra.core.hydra_config.HydraConfig.get() |
|
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) |
|
|
|
with open_dict(cfg): |
|
if cfg_choice["experiment"] is not None: |
|
cfg.experiment._name = cfg_choice["experiment"] |
|
if cfg_choice["dataset"] is not None: |
|
cfg.dataset._name = cfg_choice["dataset"] |
|
if cfg_choice["algorithm"] is not None: |
|
cfg.algorithm._name = cfg_choice["algorithm"] |
|
|
|
|
|
experiment = build_experiment(cfg, None, None) |
|
return experiment.exec_interactive(cfg.experiment.tasks[0]) |
|
|
|
def enable_amp(model, precision="16-mixed"): |
|
original_forward = model.forward |
|
|
|
def amp_forward(*args, **kwargs): |
|
with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16): |
|
return original_forward(*args, **kwargs) |
|
|
|
model.forward = amp_forward |
|
return model |
|
|
|
memory_frames = [] |
|
memory_curr_frame = 0 |
|
input_history = "" |
|
ICE_PLAINS_IMAGE = "assets/ice_plains.png" |
|
DESERT_IMAGE = "assets/desert.png" |
|
SAVANNA_IMAGE = "assets/savanna.png" |
|
PLAINS_IMAGE = "assets/plans.png" |
|
PLACE_IMAGE = "assets/place.png" |
|
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png" |
|
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png" |
|
|
|
DEFAULT_IMAGE = ICE_PLAINS_IMAGE |
|
device = torch.device('cuda') |
|
|
|
def save_video(frames, path="output.mp4", fps=10): |
|
h, w, _ = frames[0].shape |
|
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h)) |
|
for frame in frames: |
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) |
|
out.release() |
|
|
|
ffmpeg_cmd = [ |
|
"ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path |
|
] |
|
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
|
return path |
|
|
|
cfg = OmegaConf.load("configurations/huggingface.yaml") |
|
worldmem = WorldMemMinecraft(cfg) |
|
load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffusion_path) |
|
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path) |
|
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path) |
|
worldmem.to("cuda").eval() |
|
worldmem = enable_amp(worldmem, precision="16-mixed") |
|
|
|
actions = np.zeros((1, 25), dtype=np.float32) |
|
poses = np.zeros((1, 5), dtype=np.float32) |
|
|
|
memory_frames = load_image_as_tensor(DEFAULT_IMAGE)[None].numpy() |
|
|
|
self_frames = None |
|
self_actions = None |
|
self_poses = None |
|
self_memory_c2w = None |
|
self_frame_idx = None |
|
|
|
|
|
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions, |
|
self_poses, self_memory_c2w, self_frame_idx): |
|
return 5 * len(action) if self_actions is not None else 5 |
|
|
|
@spaces.GPU(duration=get_duration_single_image_to_long_video) |
|
def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions, |
|
self_poses, self_memory_c2w, self_frame_idx): |
|
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame, |
|
action, |
|
first_pose, |
|
device=device, |
|
self_frames=self_frames, |
|
self_actions=self_actions, |
|
self_poses=self_poses, |
|
self_memory_c2w=self_memory_c2w, |
|
self_frame_idx=self_frame_idx) |
|
|
|
return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx |
|
|
|
def set_denoising_steps(denoising_steps, sampling_timesteps_state): |
|
worldmem.sampling_timesteps = denoising_steps |
|
worldmem.diffusion_model.sampling_timesteps = denoising_steps |
|
sampling_timesteps_state = denoising_steps |
|
print("set denoising steps to", worldmem.sampling_timesteps) |
|
return sampling_timesteps_state |
|
|
|
def set_context_length(context_length, sampling_context_length_state): |
|
worldmem.n_tokens = context_length |
|
sampling_context_length_state = context_length |
|
print("set context length to", worldmem.n_tokens) |
|
return sampling_context_length_state |
|
|
|
def set_memory_length(memory_length, sampling_memory_length_state): |
|
worldmem.condition_similar_length = memory_length |
|
sampling_memory_length_state = memory_length |
|
print("set memory length to", worldmem.condition_similar_length) |
|
return sampling_memory_length_state |
|
|
|
def generate(keys): |
|
|
|
input_actions = parse_input_to_tensor(keys) |
|
global input_history |
|
global memory_frames |
|
global memory_curr_frame |
|
global self_frames |
|
global self_actions |
|
global self_poses |
|
global self_memory_c2w |
|
global self_frame_idx |
|
|
|
if self_frames is None: |
|
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0], |
|
actions[0], |
|
poses[0], |
|
device=device, |
|
self_frames=self_frames, |
|
self_actions=self_actions, |
|
self_poses=self_poses, |
|
self_memory_c2w=self_memory_c2w, |
|
self_frame_idx=self_frame_idx) |
|
|
|
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0], |
|
input_actions, |
|
None, |
|
device=device, |
|
self_frames=self_frames, |
|
self_actions=self_actions, |
|
self_poses=self_poses, |
|
self_memory_c2w=self_memory_c2w, |
|
self_frame_idx=self_frame_idx) |
|
|
|
memory_frames = np.concatenate([memory_frames, new_frame[:,0]]) |
|
|
|
out_video = memory_frames.transpose(0,2,3,1) |
|
out_video = np.clip(out_video, a_min=0.0, a_max=1.0) |
|
out_video = (out_video * 255).astype(np.uint8) |
|
|
|
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name |
|
save_video(out_video, temporal_video_path) |
|
|
|
input_history += keys |
|
return out_video[-1], temporal_video_path, input_history |
|
|
|
def reset(): |
|
global memory_curr_frame |
|
global input_history |
|
global memory_frames |
|
global self_frames |
|
global self_actions |
|
global self_poses |
|
global self_memory_c2w |
|
global self_frame_idx |
|
|
|
self_frames = None |
|
self_poses = None |
|
self_actions = None |
|
self_memory_c2w = None |
|
self_frame_idx = None |
|
memory_frames = load_image_as_tensor(DEFAULT_IMAGE).numpy()[None] |
|
input_history = "" |
|
|
|
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0], |
|
actions[0], |
|
poses[0], |
|
device=device, |
|
self_frames=self_frames, |
|
self_actions=self_actions, |
|
self_poses=self_poses, |
|
self_memory_c2w=self_memory_c2w, |
|
self_frame_idx=self_frame_idx) |
|
|
|
return input_history, DEFAULT_IMAGE |
|
|
|
def on_image_click(SELECTED_IMAGE): |
|
global DEFAULT_IMAGE |
|
DEFAULT_IMAGE = SELECTED_IMAGE |
|
reset() |
|
return SELECTED_IMAGE |
|
|
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# WORLDMEM: Long-term Consistent World Generation with Memory |
|
""" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
example_actions = ["AAAAAAAAAAAADDDDDDDDDDDD", "AAAAAAAAAAAAAAAAAAAAAAAA", "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD", |
|
"DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS", "SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"] |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
video_display = gr.Video(autoplay=True, loop=True) |
|
image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame") |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=2): |
|
input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1) |
|
log_output = gr.Textbox(label="History Log", interactive=False) |
|
gr.Markdown("### Action sequence examples.") |
|
with gr.Row(): |
|
buttons = [] |
|
for action in example_actions[:2]: |
|
with gr.Column(scale=len(action)): |
|
buttons.append(gr.Button(action)) |
|
with gr.Row(): |
|
for action in example_actions[2:4]: |
|
with gr.Column(scale=len(action)): |
|
buttons.append(gr.Button(action)) |
|
with gr.Row(): |
|
for action in example_actions[4:6]: |
|
with gr.Column(scale=len(action)): |
|
buttons.append(gr.Button(action)) |
|
|
|
with gr.Column(scale=1): |
|
slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps") |
|
slider_context_length = gr.Slider(minimum=2, maximum=10, value=worldmem.n_tokens, step=1, label="Context Length") |
|
slider_memory_length = gr.Slider(minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1, label="Memory Length") |
|
submit_button = gr.Button("Generate") |
|
reset_btn = gr.Button("Reset") |
|
|
|
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps) |
|
sampling_context_length_state = gr.State(worldmem.n_tokens) |
|
sampling_memory_length_state = gr.State(worldmem.condition_similar_length) |
|
|
|
|
|
def set_action(action): |
|
return action |
|
|
|
|
|
|
|
|
|
for button, action in zip(buttons, example_actions): |
|
button.click(set_action, inputs=[gr.State(value=action)], outputs=input_box) |
|
|
|
|
|
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.") |
|
|
|
with gr.Row(): |
|
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains") |
|
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert") |
|
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna") |
|
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains") |
|
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains") |
|
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place") |
|
|
|
gr.Markdown( |
|
""" |
|
## Instructions & Notes: |
|
|
|
1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin. |
|
2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel. |
|
3. Click **"Reset"** to clear the current sequence and start fresh. |
|
4. Action sequences can be composed using the following keys: |
|
- W: turn up |
|
- S: turn down |
|
- A: turn left |
|
- D: turn right |
|
- Q: move forward |
|
- E: move backward |
|
- N: no-op (do nothing) |
|
- U: use item |
|
5. Higher denoising steps produce more detailed results but take longer. 20 steps is a good balance between quality and speed. The same applies to context and memory length. |
|
6. For faster performance, we recommend running the demo locally (~1s/frame on H100 vs ~5s on Spaces). |
|
7. If you find this project interesting or useful, please consider giving it a ⭐️ on [GitHub]()! |
|
8. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **[email protected]**. |
|
""" |
|
) |
|
|
|
submit_button.click(generate, inputs=[input_box], outputs=[image_display, video_display, log_output]) |
|
reset_btn.click(reset, outputs=[log_output, image_display]) |
|
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display) |
|
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display) |
|
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display) |
|
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display) |
|
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display) |
|
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display) |
|
|
|
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state) |
|
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state) |
|
slider_memory_length.change(fn=set_memory_length, inputs=[slider_memory_length, sampling_memory_length_state], outputs=sampling_memory_length_state) |
|
|
|
demo.launch() |
|
|