Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -21,121 +21,121 @@ import spaces
|
|
21 |
from algorithms.worldmem import WorldMemMinecraft
|
22 |
from huggingface_hub import hf_hub_download
|
23 |
import tempfile
|
24 |
-
import os
|
25 |
-
import requests
|
26 |
-
from huggingface_hub import model_info
|
27 |
-
|
28 |
-
def is_huggingface_model(path: str) -> bool:
|
29 |
-
hf_ckpt = str(path).split('/')
|
30 |
-
repo_id = '/'.join(hf_ckpt[:2])
|
31 |
-
try:
|
32 |
-
model_info(repo_id)
|
33 |
-
return True
|
34 |
-
except:
|
35 |
-
return False
|
36 |
|
37 |
torch.set_float32_matmul_precision("high")
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def load_custom_checkpoint(algo, checkpoint_path):
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
for k, v in ckpt['state_dict'].items():
|
49 |
-
if "frame_timestep_embedder" in k:
|
50 |
-
new_k = k.replace("frame_timestep_embedder", "timestamp_embedding")
|
51 |
-
filtered_state_dict[new_k] = v
|
52 |
-
else:
|
53 |
-
filtered_state_dict[k] = v
|
54 |
-
|
55 |
-
algo.load_state_dict(filtered_state_dict, strict=True)
|
56 |
-
print("Load: ", model_path)
|
57 |
-
else:
|
58 |
-
ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
59 |
-
|
60 |
-
filtered_state_dict = {}
|
61 |
-
for k, v in ckpt['state_dict'].items():
|
62 |
-
if "frame_timestep_embedder" in k:
|
63 |
-
new_k = k.replace("frame_timestep_embedder", "timestamp_embedding")
|
64 |
-
filtered_state_dict[new_k] = v
|
65 |
-
else:
|
66 |
-
filtered_state_dict[k] = v
|
67 |
-
|
68 |
-
algo.load_state_dict(filtered_state_dict, strict=True)
|
69 |
-
print("Load: ", checkpoint_path)
|
70 |
-
|
71 |
-
def download_assets_if_needed():
|
72 |
-
ASSETS_URL_BASE = "https://huggingface.co/spaces/yslan/worldmem/resolve/main/assets/examples"
|
73 |
-
ASSETS_DIR = "assets/examples"
|
74 |
-
ASSETS = ['case1.npz', 'case2.npz', 'case3.npz', 'case4.npz']
|
75 |
-
|
76 |
-
if not os.path.exists(ASSETS_DIR):
|
77 |
-
os.makedirs(ASSETS_DIR)
|
78 |
-
|
79 |
-
for filename in ASSETS:
|
80 |
-
filepath = os.path.join(ASSETS_DIR, filename)
|
81 |
-
if not os.path.exists(filepath):
|
82 |
-
print(f"Downloading {filename}...")
|
83 |
-
url = f"{ASSETS_URL_BASE}/{filename}"
|
84 |
-
response = requests.get(url)
|
85 |
-
if response.status_code == 200:
|
86 |
-
with open(filepath, "wb") as f:
|
87 |
-
f.write(response.content)
|
88 |
-
else:
|
89 |
-
print(f"Failed to download {filename}: {response.status_code}")
|
90 |
|
91 |
def parse_input_to_tensor(input_str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
seq_len = len(input_str)
|
|
|
|
|
93 |
action_tensor = torch.zeros((seq_len, 25))
|
|
|
|
|
94 |
for i, char in enumerate(input_str):
|
95 |
-
action, value = KEY_TO_ACTION.get(char.upper()
|
96 |
if action and action in ACTION_KEYS:
|
97 |
index = ACTION_KEYS.index(action)
|
98 |
-
action_tensor[i, index] = value
|
|
|
99 |
return action_tensor
|
100 |
|
101 |
-
def load_image_as_tensor(image_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
if isinstance(image_path, str):
|
103 |
-
image = Image.open(image_path).convert("RGB")
|
104 |
else:
|
105 |
image = image_path
|
106 |
transform = transforms.Compose([
|
107 |
-
transforms.ToTensor(),
|
108 |
])
|
109 |
return transform(image)
|
110 |
|
111 |
def enable_amp(model, precision="16-mixed"):
|
112 |
original_forward = model.forward
|
|
|
113 |
def amp_forward(*args, **kwargs):
|
114 |
with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
|
115 |
return original_forward(*args, **kwargs)
|
|
|
116 |
model.forward = amp_forward
|
117 |
return model
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
ACTION_KEYS = [
|
122 |
-
"inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4",
|
123 |
-
"hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward",
|
124 |
-
"back", "left", "right", "cameraY", "cameraX", "jump", "sneak", "sprint",
|
125 |
-
"swapHands", "attack", "use", "pickItem", "drop",
|
126 |
-
]
|
127 |
-
KEY_TO_ACTION = {
|
128 |
-
"Q": ("forward", 1), "E": ("back", 1), "W": ("cameraY", -1),
|
129 |
-
"S": ("cameraY", 1), "A": ("cameraX", -1), "D": ("cameraX", 1),
|
130 |
-
"U": ("drop", 1), "N": ("noop", 1), "1": ("hotbar.1", 1),
|
131 |
-
}
|
132 |
-
example_images = [
|
133 |
-
["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
|
134 |
-
["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
|
135 |
-
["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
|
136 |
-
["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
|
137 |
-
]
|
138 |
-
video_frames = []
|
139 |
input_history = ""
|
140 |
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
141 |
DESERT_IMAGE = "assets/desert.png"
|
@@ -144,22 +144,21 @@ PLAINS_IMAGE = "assets/plans.png"
|
|
144 |
PLACE_IMAGE = "assets/place.png"
|
145 |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
146 |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
|
|
147 |
device = torch.device('cuda')
|
148 |
|
149 |
def save_video(frames, path="output.mp4", fps=10):
|
150 |
-
temp_path = path[:-4] + "_temp.mp4"
|
151 |
h, w, _ = frames[0].shape
|
152 |
-
out = cv2.VideoWriter(
|
153 |
for frame in frames:
|
154 |
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
155 |
out.release()
|
|
|
156 |
ffmpeg_cmd = [
|
157 |
-
"ffmpeg", "-y", "-i",
|
158 |
-
"-c:v", "libx264", "-crf", "23", "-preset", "medium",
|
159 |
-
path
|
160 |
]
|
161 |
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
162 |
-
|
163 |
|
164 |
cfg = OmegaConf.load("configurations/huggingface.yaml")
|
165 |
worldmem = WorldMemMinecraft(cfg)
|
@@ -167,26 +166,31 @@ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffus
|
|
167 |
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
168 |
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
169 |
worldmem.to("cuda").eval()
|
|
|
|
|
170 |
actions = np.zeros((1, 25), dtype=np.float32)
|
171 |
poses = np.zeros((1, 5), dtype=np.float32)
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
176 |
|
177 |
@spaces.GPU(duration=get_duration_single_image_to_long_video)
|
178 |
-
def run_interactive(first_frame, action, first_pose, device,
|
179 |
-
|
180 |
-
new_frame,
|
181 |
action,
|
182 |
first_pose,
|
183 |
device=device,
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
190 |
|
191 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
192 |
worldmem.sampling_timesteps = denoising_steps
|
@@ -201,137 +205,144 @@ def set_context_length(context_length, sampling_context_length_state):
|
|
201 |
print("set context length to", worldmem.n_tokens)
|
202 |
return sampling_context_length_state
|
203 |
|
204 |
-
def
|
205 |
-
worldmem.
|
206 |
-
|
207 |
-
print("set memory length to", worldmem.
|
208 |
-
return
|
209 |
-
|
210 |
-
def set_next_frame_length(next_frame_length, sampling_next_frame_length_state):
|
211 |
-
worldmem.next_frame_length = next_frame_length
|
212 |
-
sampling_next_frame_length_state = next_frame_length
|
213 |
-
print("set next frame length to", worldmem.next_frame_length)
|
214 |
-
return sampling_next_frame_length_state
|
215 |
|
216 |
-
def generate(keys, input_history,
|
217 |
input_actions = parse_input_to_tensor(keys)
|
218 |
-
|
219 |
-
|
|
|
220 |
actions[0],
|
221 |
poses[0],
|
222 |
device=device,
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
229 |
input_actions,
|
230 |
None,
|
231 |
device=device,
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
239 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
240 |
out_video = (out_video * 255).astype(np.uint8)
|
|
|
241 |
last_frame = out_video[-1].copy()
|
242 |
border_thickness = 2
|
243 |
out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
|
244 |
out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
|
245 |
out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
|
246 |
out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
|
|
|
247 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
248 |
save_video(out_video, temporal_video_path)
|
249 |
input_history += keys
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
def reset(selected_image):
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
input_history = ""
|
260 |
-
|
|
|
261 |
actions[0],
|
262 |
poses[0],
|
263 |
device=device,
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
return input_history,
|
271 |
|
272 |
def on_image_click(selected_image):
|
273 |
-
input_history,
|
274 |
-
return input_history, selected_image, selected_image,
|
275 |
-
|
276 |
-
# === NEW: Custom image upload handler ===
|
277 |
-
def on_custom_image_upload(custom_image):
|
278 |
-
if custom_image is None:
|
279 |
-
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
280 |
-
image = Image.fromarray(custom_image.astype(np.uint8))
|
281 |
-
input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(image)
|
282 |
-
return (
|
283 |
-
input_history, # log_output
|
284 |
-
image, # selected_image
|
285 |
-
image, # image_display
|
286 |
-
video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
287 |
-
)
|
288 |
|
289 |
-
def set_memory(examples_case):
|
290 |
if examples_case == '1':
|
291 |
data_bundle = np.load("assets/examples/case1.npz")
|
292 |
input_history = data_bundle['input_history'].item()
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
elif examples_case == '2':
|
300 |
data_bundle = np.load("assets/examples/case2.npz")
|
301 |
input_history = data_bundle['input_history'].item()
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
elif examples_case == '3':
|
309 |
data_bundle = np.load("assets/examples/case3.npz")
|
310 |
input_history = data_bundle['input_history'].item()
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
elif examples_case == '4':
|
318 |
data_bundle = np.load("assets/examples/case4.npz")
|
319 |
input_history = data_bundle['input_history'].item()
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
out_video =
|
328 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
329 |
out_video = (out_video * 255).astype(np.uint8)
|
330 |
|
331 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
332 |
save_video(out_video, temporal_video_path)
|
333 |
|
334 |
-
return input_history, out_video[-1], temporal_video_path,
|
335 |
|
336 |
css = """
|
337 |
h1 {
|
@@ -344,27 +355,36 @@ with gr.Blocks(css=css) as demo:
|
|
344 |
gr.Markdown(
|
345 |
"""
|
346 |
# WORLDMEM: Long-term Consistent World Simulation with Memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
"""
|
348 |
-
|
|
|
349 |
gr.Markdown(
|
350 |
"""
|
351 |
## 🚀 How to Explore WorldMem
|
352 |
-
|
353 |
Follow these simple steps to get started:
|
354 |
-
|
355 |
-
1. **Choose a scene** or **upload your own**.
|
356 |
2. **Input your action sequence**.
|
357 |
3. **Click "Generate"**.
|
358 |
-
|
359 |
- You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
|
360 |
- For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
|
361 |
-
-
|
362 |
- 💬 For questions or feedback, feel free to open an issue or email me at **[email protected]**.
|
363 |
-
|
364 |
Happy exploring! 🌍
|
365 |
"""
|
366 |
)
|
367 |
|
|
|
368 |
example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
369 |
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
370 |
"turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
@@ -374,15 +394,6 @@ with gr.Blocks(css=css) as demo:
|
|
374 |
|
375 |
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
376 |
|
377 |
-
# --- NEW: Custom image upload UI ---
|
378 |
-
with gr.Row():
|
379 |
-
gr.Markdown("🖼️ Or upload your own scene (PNG/JPG, 256x256 recommended):")
|
380 |
-
custom_image_upload = gr.Image(
|
381 |
-
label="Upload a Custom Scene",
|
382 |
-
type="numpy",
|
383 |
-
tool=None
|
384 |
-
)
|
385 |
-
|
386 |
with gr.Row(variant="panel"):
|
387 |
with gr.Column():
|
388 |
gr.Markdown("🖼️ Start from this frame.")
|
@@ -399,7 +410,8 @@ with gr.Blocks(css=css) as demo:
|
|
399 |
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
400 |
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
401 |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
402 |
-
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
|
|
403 |
|
404 |
with gr.Row(variant="panel"):
|
405 |
with gr.Column(scale=2):
|
@@ -409,7 +421,6 @@ with gr.Blocks(css=css) as demo:
|
|
409 |
gr.Markdown(
|
410 |
"""
|
411 |
### 💡 Action Key Guide
|
412 |
-
|
413 |
<pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
|
414 |
W: Turn up S: Turn down A: Turn left D: Turn right
|
415 |
Q: Go forward E: Go backward N: No-op U: Use item
|
@@ -435,7 +446,9 @@ with gr.Blocks(css=css) as demo:
|
|
435 |
submit_button = gr.Button("🎬 Generate!", variant="primary")
|
436 |
reset_btn = gr.Button("🔄 Reset")
|
437 |
|
|
|
438 |
gr.Markdown("### ⚙️ Advanced Settings")
|
|
|
439 |
slider_denoising_step = gr.Slider(
|
440 |
minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
|
441 |
label="Denoising Steps",
|
@@ -446,70 +459,60 @@ with gr.Blocks(css=css) as demo:
|
|
446 |
label="Context Length",
|
447 |
info="How many previous frames in temporal context window."
|
448 |
)
|
449 |
-
|
450 |
-
minimum=4, maximum=16, value=worldmem.
|
451 |
label="Memory Length",
|
452 |
-
info="How many previous frames in memory window.
|
453 |
-
)
|
454 |
-
slider_next_frame_length = gr.Slider(
|
455 |
-
minimum=1, maximum=5, value=worldmem.next_frame_length, step=1,
|
456 |
-
label="Next Frame Length",
|
457 |
-
info="How many next frames to generate at once."
|
458 |
)
|
|
|
|
|
459 |
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
460 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
|
470 |
def set_action(action):
|
471 |
return action
|
|
|
|
|
472 |
|
473 |
for button, action_key in zip(buttons, list(example_actions.keys())):
|
474 |
-
|
475 |
|
476 |
gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
|
|
|
477 |
example_case = gr.Textbox(label="Case", visible=False)
|
478 |
-
image_output = gr.Image(visible=False)
|
479 |
|
480 |
examples = gr.Examples(
|
481 |
examples=example_images,
|
482 |
-
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length,
|
483 |
cache_examples=False
|
484 |
)
|
|
|
485 |
example_case.change(
|
486 |
fn=set_memory,
|
487 |
-
inputs=[example_case],
|
488 |
-
outputs=[log_output, image_display, video_display,
|
489 |
)
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
501 |
-
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
502 |
-
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
503 |
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
504 |
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|
505 |
-
|
506 |
-
slider_next_frame_length.change(fn=set_next_frame_length, inputs=[slider_next_frame_length, sampling_next_frame_length_state], outputs=sampling_next_frame_length_state)
|
507 |
-
|
508 |
-
# --- NEW: Custom image upload triggers reset ---
|
509 |
-
custom_image_upload.change(
|
510 |
-
on_custom_image_upload,
|
511 |
-
inputs=[custom_image_upload],
|
512 |
-
outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx],
|
513 |
-
)
|
514 |
|
515 |
-
demo.launch(
|
|
|
21 |
from algorithms.worldmem import WorldMemMinecraft
|
22 |
from huggingface_hub import hf_hub_download
|
23 |
import tempfile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
torch.set_float32_matmul_precision("high")
|
26 |
|
27 |
+
ACTION_KEYS = [
|
28 |
+
"inventory",
|
29 |
+
"ESC",
|
30 |
+
"hotbar.1",
|
31 |
+
"hotbar.2",
|
32 |
+
"hotbar.3",
|
33 |
+
"hotbar.4",
|
34 |
+
"hotbar.5",
|
35 |
+
"hotbar.6",
|
36 |
+
"hotbar.7",
|
37 |
+
"hotbar.8",
|
38 |
+
"hotbar.9",
|
39 |
+
"forward",
|
40 |
+
"back",
|
41 |
+
"left",
|
42 |
+
"right",
|
43 |
+
"cameraY",
|
44 |
+
"cameraX",
|
45 |
+
"jump",
|
46 |
+
"sneak",
|
47 |
+
"sprint",
|
48 |
+
"swapHands",
|
49 |
+
"attack",
|
50 |
+
"use",
|
51 |
+
"pickItem",
|
52 |
+
"drop",
|
53 |
+
]
|
54 |
+
|
55 |
+
# Mapping of input keys to action names
|
56 |
+
KEY_TO_ACTION = {
|
57 |
+
"Q": ("forward", 1),
|
58 |
+
"E": ("back", 1),
|
59 |
+
"W": ("cameraY", -1),
|
60 |
+
"S": ("cameraY", 1),
|
61 |
+
"A": ("cameraX", -1),
|
62 |
+
"D": ("cameraX", 1),
|
63 |
+
"U": ("drop", 1),
|
64 |
+
"N": ("noop", 1),
|
65 |
+
"1": ("hotbar.1", 1),
|
66 |
+
}
|
67 |
+
|
68 |
+
example_images = [
|
69 |
+
["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
|
70 |
+
["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
|
71 |
+
["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
|
72 |
+
["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
|
73 |
+
]
|
74 |
+
|
75 |
def load_custom_checkpoint(algo, checkpoint_path):
|
76 |
+
hf_ckpt = str(checkpoint_path).split('/')
|
77 |
+
repo_id = '/'.join(hf_ckpt[:2])
|
78 |
+
file_name = '/'.join(hf_ckpt[2:])
|
79 |
+
model_path = hf_hub_download(repo_id=repo_id,
|
80 |
+
filename=file_name)
|
81 |
+
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
82 |
+
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
83 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def parse_input_to_tensor(input_str):
|
86 |
+
"""
|
87 |
+
Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
|
88 |
+
of the corresponding action key.
|
89 |
+
Args:
|
90 |
+
input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS").
|
91 |
+
Returns:
|
92 |
+
torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action.
|
93 |
+
"""
|
94 |
+
# Get the length of the input sequence
|
95 |
seq_len = len(input_str)
|
96 |
+
|
97 |
+
# Initialize a zero tensor of shape (seq_len, 25)
|
98 |
action_tensor = torch.zeros((seq_len, 25))
|
99 |
+
|
100 |
+
# Iterate through the input string and update the corresponding positions
|
101 |
for i, char in enumerate(input_str):
|
102 |
+
action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity
|
103 |
if action and action in ACTION_KEYS:
|
104 |
index = ACTION_KEYS.index(action)
|
105 |
+
action_tensor[i, index] = value # Set the corresponding action index to 1
|
106 |
+
|
107 |
return action_tensor
|
108 |
|
109 |
+
def load_image_as_tensor(image_path: str) -> torch.Tensor:
|
110 |
+
"""
|
111 |
+
Load an image and convert it to a 0-1 normalized tensor.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
image_path (str): Path to the image file.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1].
|
118 |
+
"""
|
119 |
if isinstance(image_path, str):
|
120 |
+
image = Image.open(image_path).convert("RGB") # Ensure it's RGB
|
121 |
else:
|
122 |
image = image_path
|
123 |
transform = transforms.Compose([
|
124 |
+
transforms.ToTensor(), # Converts to tensor and normalizes to [0,1]
|
125 |
])
|
126 |
return transform(image)
|
127 |
|
128 |
def enable_amp(model, precision="16-mixed"):
|
129 |
original_forward = model.forward
|
130 |
+
|
131 |
def amp_forward(*args, **kwargs):
|
132 |
with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
|
133 |
return original_forward(*args, **kwargs)
|
134 |
+
|
135 |
model.forward = amp_forward
|
136 |
return model
|
137 |
|
138 |
+
memory_frames = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
input_history = ""
|
140 |
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
141 |
DESERT_IMAGE = "assets/desert.png"
|
|
|
144 |
PLACE_IMAGE = "assets/place.png"
|
145 |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
146 |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
147 |
+
|
148 |
device = torch.device('cuda')
|
149 |
|
150 |
def save_video(frames, path="output.mp4", fps=10):
|
|
|
151 |
h, w, _ = frames[0].shape
|
152 |
+
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h))
|
153 |
for frame in frames:
|
154 |
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
155 |
out.release()
|
156 |
+
|
157 |
ffmpeg_cmd = [
|
158 |
+
"ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path
|
|
|
|
|
159 |
]
|
160 |
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
161 |
+
return path
|
162 |
|
163 |
cfg = OmegaConf.load("configurations/huggingface.yaml")
|
164 |
worldmem = WorldMemMinecraft(cfg)
|
|
|
166 |
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
167 |
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
168 |
worldmem.to("cuda").eval()
|
169 |
+
# worldmem = enable_amp(worldmem, precision="16-mixed")
|
170 |
+
|
171 |
actions = np.zeros((1, 25), dtype=np.float32)
|
172 |
poses = np.zeros((1, 5), dtype=np.float32)
|
173 |
|
174 |
+
|
175 |
+
|
176 |
+
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
|
177 |
+
self_poses, self_memory_c2w, self_frame_idx):
|
178 |
+
return 5 * len(action) if self_actions is not None else 5
|
179 |
|
180 |
@spaces.GPU(duration=get_duration_single_image_to_long_video)
|
181 |
+
def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions,
|
182 |
+
self_poses, self_memory_c2w, self_frame_idx):
|
183 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
|
184 |
action,
|
185 |
first_pose,
|
186 |
device=device,
|
187 |
+
self_frames=self_frames,
|
188 |
+
self_actions=self_actions,
|
189 |
+
self_poses=self_poses,
|
190 |
+
self_memory_c2w=self_memory_c2w,
|
191 |
+
self_frame_idx=self_frame_idx)
|
192 |
+
|
193 |
+
return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
194 |
|
195 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
196 |
worldmem.sampling_timesteps = denoising_steps
|
|
|
205 |
print("set context length to", worldmem.n_tokens)
|
206 |
return sampling_context_length_state
|
207 |
|
208 |
+
def set_memory_length(memory_length, sampling_memory_length_state):
|
209 |
+
worldmem.condition_similar_length = memory_length
|
210 |
+
sampling_memory_length_state = memory_length
|
211 |
+
print("set memory length to", worldmem.condition_similar_length)
|
212 |
+
return sampling_memory_length_state
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
def generate(keys, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
|
215 |
input_actions = parse_input_to_tensor(keys)
|
216 |
+
|
217 |
+
if self_frames is None:
|
218 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
219 |
actions[0],
|
220 |
poses[0],
|
221 |
device=device,
|
222 |
+
self_frames=self_frames,
|
223 |
+
self_actions=self_actions,
|
224 |
+
self_poses=self_poses,
|
225 |
+
self_memory_c2w=self_memory_c2w,
|
226 |
+
self_frame_idx=self_frame_idx)
|
227 |
+
|
228 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
229 |
input_actions,
|
230 |
None,
|
231 |
device=device,
|
232 |
+
self_frames=self_frames,
|
233 |
+
self_actions=self_actions,
|
234 |
+
self_poses=self_poses,
|
235 |
+
self_memory_c2w=self_memory_c2w,
|
236 |
+
self_frame_idx=self_frame_idx)
|
237 |
+
|
238 |
+
memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
|
239 |
+
|
240 |
+
|
241 |
+
out_video = memory_frames.transpose(0,2,3,1).copy()
|
242 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
243 |
out_video = (out_video * 255).astype(np.uint8)
|
244 |
+
|
245 |
last_frame = out_video[-1].copy()
|
246 |
border_thickness = 2
|
247 |
out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
|
248 |
out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
|
249 |
out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
|
250 |
out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
|
251 |
+
|
252 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
253 |
save_video(out_video, temporal_video_path)
|
254 |
input_history += keys
|
255 |
+
|
256 |
+
|
257 |
+
# now = datetime.now()
|
258 |
+
# folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
|
259 |
+
# folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
|
260 |
+
# os.makedirs(folder_path, exist_ok=True)
|
261 |
+
# data_dict = {
|
262 |
+
# "input_history": input_history,
|
263 |
+
# "memory_frames": memory_frames,
|
264 |
+
# "self_frames": self_frames,
|
265 |
+
# "self_actions": self_actions,
|
266 |
+
# "self_poses": self_poses,
|
267 |
+
# "self_memory_c2w": self_memory_c2w,
|
268 |
+
# "self_frame_idx": self_frame_idx,
|
269 |
+
# }
|
270 |
+
|
271 |
+
# np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
272 |
+
|
273 |
+
return last_frame, temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
274 |
|
275 |
def reset(selected_image):
|
276 |
+
self_frames = None
|
277 |
+
self_poses = None
|
278 |
+
self_actions = None
|
279 |
+
self_memory_c2w = None
|
280 |
+
self_frame_idx = None
|
281 |
+
memory_frames = load_image_as_tensor(selected_image).numpy()[None]
|
282 |
input_history = ""
|
283 |
+
|
284 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
285 |
actions[0],
|
286 |
poses[0],
|
287 |
device=device,
|
288 |
+
self_frames=self_frames,
|
289 |
+
self_actions=self_actions,
|
290 |
+
self_poses=self_poses,
|
291 |
+
self_memory_c2w=self_memory_c2w,
|
292 |
+
self_frame_idx=self_frame_idx)
|
293 |
+
|
294 |
+
return input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
295 |
|
296 |
def on_image_click(selected_image):
|
297 |
+
input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = reset(selected_image)
|
298 |
+
return input_history, selected_image, selected_image, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
+
def set_memory(examples_case, image_display, log_output, slider_denoising_step, slider_context_length, slider_memory_length):
|
301 |
if examples_case == '1':
|
302 |
data_bundle = np.load("assets/examples/case1.npz")
|
303 |
input_history = data_bundle['input_history'].item()
|
304 |
+
memory_frames = data_bundle['memory_frames']
|
305 |
+
self_frames = data_bundle['self_frames']
|
306 |
+
self_actions = data_bundle['self_actions']
|
307 |
+
self_poses = data_bundle['self_poses']
|
308 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
309 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
310 |
elif examples_case == '2':
|
311 |
data_bundle = np.load("assets/examples/case2.npz")
|
312 |
input_history = data_bundle['input_history'].item()
|
313 |
+
memory_frames = data_bundle['memory_frames']
|
314 |
+
self_frames = data_bundle['self_frames']
|
315 |
+
self_actions = data_bundle['self_actions']
|
316 |
+
self_poses = data_bundle['self_poses']
|
317 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
318 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
319 |
elif examples_case == '3':
|
320 |
data_bundle = np.load("assets/examples/case3.npz")
|
321 |
input_history = data_bundle['input_history'].item()
|
322 |
+
memory_frames = data_bundle['memory_frames']
|
323 |
+
self_frames = data_bundle['self_frames']
|
324 |
+
self_actions = data_bundle['self_actions']
|
325 |
+
self_poses = data_bundle['self_poses']
|
326 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
327 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
328 |
elif examples_case == '4':
|
329 |
data_bundle = np.load("assets/examples/case4.npz")
|
330 |
input_history = data_bundle['input_history'].item()
|
331 |
+
memory_frames = data_bundle['memory_frames']
|
332 |
+
self_frames = data_bundle['self_frames']
|
333 |
+
self_actions = data_bundle['self_actions']
|
334 |
+
self_poses = data_bundle['self_poses']
|
335 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
336 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
337 |
+
|
338 |
+
out_video = memory_frames.transpose(0,2,3,1)
|
339 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
340 |
out_video = (out_video * 255).astype(np.uint8)
|
341 |
|
342 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
343 |
save_video(out_video, temporal_video_path)
|
344 |
|
345 |
+
return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
346 |
|
347 |
css = """
|
348 |
h1 {
|
|
|
355 |
gr.Markdown(
|
356 |
"""
|
357 |
# WORLDMEM: Long-term Consistent World Simulation with Memory
|
358 |
+
<div style="text-align: center;">
|
359 |
+
<a style="display:inline-block; margin: 0 10px;" href="https://github.com/xizaoqu/WorldMem">
|
360 |
+
<img src="https://img.shields.io/badge/GitHub-Repository-black?logo=github"/>
|
361 |
+
</a>
|
362 |
+
<a style="display:inline-block; margin: 0 10px;" href="https://xizaoqu.github.io/worldmem/">
|
363 |
+
<img src="https://img.shields.io/badge/Project_Page-blue"/>
|
364 |
+
</a>
|
365 |
+
<a style="display:inline-block; margin: 0 10px;" href="https://arxiv.org/abs/2504.12369">
|
366 |
+
<img src="https://img.shields.io/badge/arXiv-Paper-red"/>
|
367 |
+
</a>
|
368 |
+
</div>
|
369 |
"""
|
370 |
+
)
|
371 |
+
|
372 |
gr.Markdown(
|
373 |
"""
|
374 |
## 🚀 How to Explore WorldMem
|
|
|
375 |
Follow these simple steps to get started:
|
376 |
+
1. **Choose a scene**.
|
|
|
377 |
2. **Input your action sequence**.
|
378 |
3. **Click "Generate"**.
|
|
|
379 |
- You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
|
380 |
- For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
|
381 |
+
- ⭐�� If you like this project, please [give it a star on GitHub](https://github.com/xizaoqu/WorldMem)!
|
382 |
- 💬 For questions or feedback, feel free to open an issue or email me at **[email protected]**.
|
|
|
383 |
Happy exploring! 🌍
|
384 |
"""
|
385 |
)
|
386 |
|
387 |
+
|
388 |
example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
389 |
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
390 |
"turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
|
|
394 |
|
395 |
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
with gr.Row(variant="panel"):
|
398 |
with gr.Column():
|
399 |
gr.Markdown("🖼️ Start from this frame.")
|
|
|
410 |
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
411 |
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
412 |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
413 |
+
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
414 |
+
|
415 |
|
416 |
with gr.Row(variant="panel"):
|
417 |
with gr.Column(scale=2):
|
|
|
421 |
gr.Markdown(
|
422 |
"""
|
423 |
### 💡 Action Key Guide
|
|
|
424 |
<pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
|
425 |
W: Turn up S: Turn down A: Turn left D: Turn right
|
426 |
Q: Go forward E: Go backward N: No-op U: Use item
|
|
|
446 |
submit_button = gr.Button("🎬 Generate!", variant="primary")
|
447 |
reset_btn = gr.Button("🔄 Reset")
|
448 |
|
449 |
+
|
450 |
gr.Markdown("### ⚙️ Advanced Settings")
|
451 |
+
|
452 |
slider_denoising_step = gr.Slider(
|
453 |
minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
|
454 |
label="Denoising Steps",
|
|
|
459 |
label="Context Length",
|
460 |
info="How many previous frames in temporal context window."
|
461 |
)
|
462 |
+
slider_memory_length = gr.Slider(
|
463 |
+
minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1,
|
464 |
label="Memory Length",
|
465 |
+
info="How many previous frames in memory window."
|
|
|
|
|
|
|
|
|
|
|
466 |
)
|
467 |
+
|
468 |
+
|
469 |
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
470 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
471 |
+
sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
|
472 |
+
|
473 |
+
memory_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
|
474 |
+
self_frames = gr.State()
|
475 |
+
self_actions = gr.State()
|
476 |
+
self_poses = gr.State()
|
477 |
+
self_memory_c2w = gr.State()
|
478 |
+
self_frame_idx = gr.State()
|
479 |
|
480 |
def set_action(action):
|
481 |
return action
|
482 |
+
|
483 |
+
|
484 |
|
485 |
for button, action_key in zip(buttons, list(example_actions.keys())):
|
486 |
+
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
487 |
|
488 |
gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
|
489 |
+
|
490 |
example_case = gr.Textbox(label="Case", visible=False)
|
491 |
+
image_output = gr.Image(visible=False)
|
492 |
|
493 |
examples = gr.Examples(
|
494 |
examples=example_images,
|
495 |
+
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
496 |
cache_examples=False
|
497 |
)
|
498 |
+
|
499 |
example_case.change(
|
500 |
fn=set_memory,
|
501 |
+
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
502 |
+
outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
|
503 |
)
|
504 |
+
|
505 |
+
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
506 |
+
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
507 |
+
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
508 |
+
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
509 |
+
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
510 |
+
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
511 |
+
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
512 |
+
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
513 |
+
|
|
|
|
|
|
|
514 |
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
515 |
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|
516 |
+
slider_memory_length.change(fn=set_memory_length, inputs=[slider_memory_length, sampling_memory_length_state], outputs=sampling_memory_length_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
|
518 |
+
demo.launch()
|