Spaces:
Runtime error
Runtime error
import gradio as gr | |
import time | |
import sys | |
import subprocess | |
import time | |
from pathlib import Path | |
import hydra | |
from omegaconf import DictConfig, OmegaConf | |
from omegaconf.omegaconf import open_dict | |
from utils.print_utils import cyan | |
from utils.ckpt_utils import download_latest_checkpoint, is_run_id | |
from utils.cluster_utils import submit_slurm_job | |
from utils.distributed_utils import is_rank_zero | |
import numpy as np | |
import torch | |
from datasets.video.minecraft_video_dataset import * | |
import torchvision.transforms as transforms | |
import cv2 | |
import subprocess | |
from PIL import Image | |
from datetime import datetime | |
ACTION_KEYS = [ | |
"inventory", | |
"ESC", | |
"hotbar.1", | |
"hotbar.2", | |
"hotbar.3", | |
"hotbar.4", | |
"hotbar.5", | |
"hotbar.6", | |
"hotbar.7", | |
"hotbar.8", | |
"hotbar.9", | |
"forward", | |
"back", | |
"left", | |
"right", | |
"cameraY", | |
"cameraX", | |
"jump", | |
"sneak", | |
"sprint", | |
"swapHands", | |
"attack", | |
"use", | |
"pickItem", | |
"drop", | |
] | |
# Mapping of input keys to action names | |
KEY_TO_ACTION = { | |
"Q": ("forward", 1), | |
"E": ("back", 1), | |
"W": ("cameraY", -1), | |
"S": ("cameraY", 1), | |
"A": ("cameraX", -1), | |
"D": ("cameraX", 1), | |
"U": ("drop", 1), | |
"N": ("noop", 1), | |
"1": ("hotbar.1", 1), | |
} | |
def parse_input_to_tensor(input_str): | |
""" | |
Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation | |
of the corresponding action key. | |
Args: | |
input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS"). | |
Returns: | |
torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action. | |
""" | |
# Get the length of the input sequence | |
seq_len = len(input_str) | |
# Initialize a zero tensor of shape (seq_len, 25) | |
action_tensor = torch.zeros((seq_len, 25)) | |
# Iterate through the input string and update the corresponding positions | |
for i, char in enumerate(input_str): | |
action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity | |
if action and action in ACTION_KEYS: | |
index = ACTION_KEYS.index(action) | |
action_tensor[i, index] = value # Set the corresponding action index to 1 | |
return action_tensor | |
def load_image_as_tensor(image_path: str) -> torch.Tensor: | |
""" | |
Load an image and convert it to a 0-1 normalized tensor. | |
Args: | |
image_path (str): Path to the image file. | |
Returns: | |
torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1]. | |
""" | |
if isinstance(image_path, str): | |
image = Image.open(image_path).convert("RGB") # Ensure it's RGB | |
else: | |
image = image_path | |
transform = transforms.Compose([ | |
transforms.ToTensor(), # Converts to tensor and normalizes to [0,1] | |
]) | |
return transform(image) | |
def run_local(cfg: DictConfig): | |
# delay some imports in case they are not needed in non-local envs for submission | |
from experiments import build_experiment | |
# Get yaml names | |
hydra_cfg = hydra.core.hydra_config.HydraConfig.get() | |
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) | |
with open_dict(cfg): | |
if cfg_choice["experiment"] is not None: | |
cfg.experiment._name = cfg_choice["experiment"] | |
if cfg_choice["dataset"] is not None: | |
cfg.dataset._name = cfg_choice["dataset"] | |
if cfg_choice["algorithm"] is not None: | |
cfg.algorithm._name = cfg_choice["algorithm"] | |
# launch experiment | |
experiment = build_experiment(cfg, None, cfg.checkpoint_path) | |
return experiment.exec_interactive(cfg.experiment.tasks[0]) | |
memory_frames = [] | |
memory_curr_frame = 0 | |
input_history = "" | |
ICE_PLAINS_IMAGE = "assets/ice_plains.png" | |
DESERT_IMAGE = "assets/desert.png" | |
SAVANNA_IMAGE = "assets/savanna.png" | |
PLAINS_IMAGE = "assets/plans.png" | |
PLACE_IMAGE = "assets/place.png" | |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png" | |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png" | |
DEFAULT_IMAGE = ICE_PLAINS_IMAGE | |
device = "cuda:0" | |
def save_video(frames, path="output.mp4", fps=10): | |
h, w, _ = frames[0].shape | |
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h)) | |
for frame in frames: | |
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
out.release() | |
ffmpeg_cmd = [ | |
"ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path | |
] | |
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
return path | |
def run(cfg: DictConfig): | |
algo = run_local(cfg) | |
algo.to("cuda:0") | |
actions = torch.zeros((1, 25)) | |
poses = torch.zeros((1, 5)) | |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE)) | |
_ = algo.interactive(memory_frames[0], | |
actions[0], | |
poses[0], | |
memory_curr_frame, | |
device="cuda:0") | |
def set_denoising_steps(denoising_steps, sampling_timesteps_state): | |
algo.sampling_timesteps = denoising_steps | |
algo.diffusion_model.sampling_timesteps = denoising_steps | |
sampling_timesteps_state = denoising_steps | |
print("set denoising steps to", algo.sampling_timesteps) | |
return sampling_timesteps_state | |
def update_image_and_log(keys): | |
actions = parse_input_to_tensor(keys) | |
global input_history | |
global memory_curr_frame | |
for i in range(len(actions)): | |
memory_curr_frame += 1 | |
new_frame = algo.interactive(memory_frames[0], | |
actions[i], | |
None, | |
memory_curr_frame, | |
device="cuda:0") | |
memory_frames.append(new_frame) | |
out_video = torch.stack(memory_frames) | |
out_video = out_video.permute(0,2,3,1).numpy() | |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0) | |
out_video = (out_video * 255).astype(np.uint8) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
os.makedirs("outputs_gradio", exist_ok=True) | |
filename = f"outputs_gradio/{timestamp}.mp4" | |
save_video(out_video, filename) | |
input_history += keys | |
return out_video[-1], filename, input_history | |
def reset(): | |
global memory_curr_frame | |
global input_history | |
global memory_frames | |
algo.reset() | |
memory_frames = [] | |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE)) | |
memory_curr_frame = 0 | |
input_history = "" | |
_ = algo.interactive(memory_frames[0], | |
actions[0], | |
poses[0], | |
memory_curr_frame, | |
device="cuda:0") | |
return input_history, DEFAULT_IMAGE | |
def on_image_click(SELECTED_IMAGE): | |
global DEFAULT_IMAGE | |
DEFAULT_IMAGE = SELECTED_IMAGE | |
reset() | |
return SELECTED_IMAGE | |
css = """ | |
h1 { | |
text-align: center; | |
display:block; | |
} | |
""" | |
# update_image_and_log("W") | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# WORLDMEM: Long-term Consistent World Generation with Memory | |
<div style="text-align: center;"> | |
<!-- Public Website --> | |
<a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/"> | |
<img src="https://img.shields.io/badge/public_website-8A2BE2"> | |
</a> | |
<!-- GitHub Stars --> | |
<a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything"> | |
<img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social"> | |
</a> | |
<!-- Project Page --> | |
<a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/"> | |
<img src="https://img.shields.io/badge/project_page-blue"> | |
</a> | |
<!-- arXiv Paper --> | |
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX"> | |
<img src="https://img.shields.io/badge/arXiv-paper-red"> | |
</a> | |
</div> | |
""" | |
) | |
with gr.Row(variant="panel"): | |
video_display = gr.Video(autoplay=True, loop=True) | |
image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame") | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=2): | |
input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1) | |
log_output = gr.Textbox(label="History Log", interactive=False) | |
with gr.Column(scale=1): | |
slider = gr.Slider(minimum=10, maximum=50, value=algo.sampling_timesteps, step=1, label="Denoising Steps") | |
submit_button = gr.Button("Generate") | |
reset_btn = gr.Button("Reset") | |
sampling_timesteps_state = gr.State(algo.sampling_timesteps) | |
example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD", | |
"DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEEAAAAAAAAAAAAAAAAAAAAAA"] | |
def set_action(action): | |
return action | |
gr.Markdown("### Action sequence examples.") | |
with gr.Row(): | |
buttons = [] | |
for action in example_actions[:2]: | |
with gr.Column(scale=len(action)): | |
buttons.append(gr.Button(action)) | |
with gr.Row(): | |
for action in example_actions[2:4]: | |
with gr.Column(scale=len(action)): | |
buttons.append(gr.Button(action)) | |
with gr.Row(): | |
for action in example_actions[4:5]: | |
with gr.Column(scale=len(action)): | |
buttons.append(gr.Button(action)) | |
for button, action in zip(buttons, example_actions): | |
button.click(set_action, inputs=[gr.State(value=action)], outputs=input_box) | |
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.") | |
with gr.Row(): | |
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains") | |
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert") | |
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna") | |
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains") | |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains") | |
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place") | |
gr.Markdown( | |
""" | |
## Instructions & Notes: | |
1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin. | |
2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel. | |
3. Click **"Reset"** to clear the current sequence and start fresh. | |
4. Action sequences can be composed using the following keys: | |
- W: turn up | |
- S: turn down | |
- A: turn left | |
- D: turn right | |
- Q: move forward | |
- E: move backward | |
- N: no-op (do nothing) | |
- 1: switch to hotbar 1 | |
- U: use item | |
5. Higher denoising steps produce more detailed results but take longer. **20 steps** is a good balance between quality and speed. | |
6. If you find this project interesting or useful, please consider giving it a โญ๏ธ on [GitHub]()! | |
7. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **[email protected]**. | |
""" | |
) | |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output]) | |
submit_button.click(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output]) | |
reset_btn.click(reset, outputs=[log_output, image_display]) | |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display) | |
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display) | |
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display) | |
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display) | |
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display) | |
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display) | |
slider.change(fn=set_denoising_steps, inputs=[slider, sampling_timesteps_state], outputs=sampling_timesteps_state) | |
# ๅ ่ฎธๅ ฌๅผ่ฎฟ้ฎ | |
demo.launch(share=True) | |
demo.launch(server_name="0.0.0.0", server_port=30066) | |
if __name__ == "__main__": | |
run() # pylint: disable=no-value-for-parameter | |