Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -21,123 +21,121 @@ import spaces
|
|
21 |
from algorithms.worldmem import WorldMemMinecraft
|
22 |
from huggingface_hub import hf_hub_download
|
23 |
import tempfile
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
"hotbar.4",
|
34 |
-
"hotbar.5",
|
35 |
-
"hotbar.6",
|
36 |
-
"hotbar.7",
|
37 |
-
"hotbar.8",
|
38 |
-
"hotbar.9",
|
39 |
-
"forward",
|
40 |
-
"back",
|
41 |
-
"left",
|
42 |
-
"right",
|
43 |
-
"cameraY",
|
44 |
-
"cameraX",
|
45 |
-
"jump",
|
46 |
-
"sneak",
|
47 |
-
"sprint",
|
48 |
-
"swapHands",
|
49 |
-
"attack",
|
50 |
-
"use",
|
51 |
-
"pickItem",
|
52 |
-
"drop",
|
53 |
-
]
|
54 |
-
|
55 |
-
# Mapping of input keys to action names
|
56 |
-
KEY_TO_ACTION = {
|
57 |
-
"Q": ("forward", 1),
|
58 |
-
"E": ("back", 1),
|
59 |
-
"W": ("cameraY", -1),
|
60 |
-
"S": ("cameraY", 1),
|
61 |
-
"A": ("cameraX", -1),
|
62 |
-
"D": ("cameraX", 1),
|
63 |
-
"U": ("drop", 1),
|
64 |
-
"N": ("noop", 1),
|
65 |
-
"1": ("hotbar.1", 1),
|
66 |
-
}
|
67 |
|
68 |
-
|
69 |
-
["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
|
70 |
-
["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
|
71 |
-
["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
|
72 |
-
["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
|
73 |
-
]
|
74 |
|
75 |
def load_custom_checkpoint(algo, checkpoint_path):
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def parse_input_to_tensor(input_str):
|
86 |
-
"""
|
87 |
-
Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
|
88 |
-
of the corresponding action key.
|
89 |
-
|
90 |
-
Args:
|
91 |
-
input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS").
|
92 |
-
|
93 |
-
Returns:
|
94 |
-
torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action.
|
95 |
-
"""
|
96 |
-
# Get the length of the input sequence
|
97 |
seq_len = len(input_str)
|
98 |
-
|
99 |
-
# Initialize a zero tensor of shape (seq_len, 25)
|
100 |
action_tensor = torch.zeros((seq_len, 25))
|
101 |
-
|
102 |
-
# Iterate through the input string and update the corresponding positions
|
103 |
for i, char in enumerate(input_str):
|
104 |
-
action, value = KEY_TO_ACTION.get(char.upper()
|
105 |
if action and action in ACTION_KEYS:
|
106 |
index = ACTION_KEYS.index(action)
|
107 |
-
action_tensor[i, index] = value
|
108 |
-
|
109 |
return action_tensor
|
110 |
|
111 |
-
def load_image_as_tensor(image_path
|
112 |
-
"""
|
113 |
-
Load an image and convert it to a 0-1 normalized tensor.
|
114 |
-
|
115 |
-
Args:
|
116 |
-
image_path (str): Path to the image file.
|
117 |
-
|
118 |
-
Returns:
|
119 |
-
torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1].
|
120 |
-
"""
|
121 |
if isinstance(image_path, str):
|
122 |
-
image = Image.open(image_path).convert("RGB")
|
123 |
else:
|
124 |
image = image_path
|
125 |
transform = transforms.Compose([
|
126 |
-
transforms.ToTensor(),
|
127 |
])
|
128 |
return transform(image)
|
129 |
|
130 |
def enable_amp(model, precision="16-mixed"):
|
131 |
original_forward = model.forward
|
132 |
-
|
133 |
def amp_forward(*args, **kwargs):
|
134 |
with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
|
135 |
return original_forward(*args, **kwargs)
|
136 |
-
|
137 |
model.forward = amp_forward
|
138 |
return model
|
139 |
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
input_history = ""
|
142 |
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
143 |
DESERT_IMAGE = "assets/desert.png"
|
@@ -146,21 +144,22 @@ PLAINS_IMAGE = "assets/plans.png"
|
|
146 |
PLACE_IMAGE = "assets/place.png"
|
147 |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
148 |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
149 |
-
|
150 |
device = torch.device('cuda')
|
151 |
|
152 |
def save_video(frames, path="output.mp4", fps=10):
|
|
|
153 |
h, w, _ = frames[0].shape
|
154 |
-
out = cv2.VideoWriter(
|
155 |
for frame in frames:
|
156 |
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
157 |
out.release()
|
158 |
-
|
159 |
ffmpeg_cmd = [
|
160 |
-
"ffmpeg", "-y", "-i",
|
|
|
|
|
161 |
]
|
162 |
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
163 |
-
|
164 |
|
165 |
cfg = OmegaConf.load("configurations/huggingface.yaml")
|
166 |
worldmem = WorldMemMinecraft(cfg)
|
@@ -168,31 +167,26 @@ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffus
|
|
168 |
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
169 |
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
170 |
worldmem.to("cuda").eval()
|
171 |
-
# worldmem = enable_amp(worldmem, precision="16-mixed")
|
172 |
-
|
173 |
actions = np.zeros((1, 25), dtype=np.float32)
|
174 |
poses = np.zeros((1, 5), dtype=np.float32)
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
self_poses, self_memory_c2w, self_frame_idx):
|
180 |
-
return 5 * len(action) if self_actions is not None else 5
|
181 |
|
182 |
@spaces.GPU(duration=get_duration_single_image_to_long_video)
|
183 |
-
def run_interactive(first_frame, action, first_pose, device,
|
184 |
-
|
185 |
-
new_frame,
|
186 |
action,
|
187 |
first_pose,
|
188 |
device=device,
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
196 |
|
197 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
198 |
worldmem.sampling_timesteps = denoising_steps
|
@@ -207,144 +201,137 @@ def set_context_length(context_length, sampling_context_length_state):
|
|
207 |
print("set context length to", worldmem.n_tokens)
|
208 |
return sampling_context_length_state
|
209 |
|
210 |
-
def
|
211 |
-
worldmem.
|
212 |
-
|
213 |
-
print("set memory length to", worldmem.
|
214 |
-
return
|
215 |
|
216 |
-
def
|
217 |
-
|
|
|
|
|
|
|
218 |
|
219 |
-
|
220 |
-
|
|
|
|
|
221 |
actions[0],
|
222 |
poses[0],
|
223 |
device=device,
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
231 |
input_actions,
|
232 |
None,
|
233 |
device=device,
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
out_video = memory_frames.transpose(0,2,3,1).copy()
|
244 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
245 |
out_video = (out_video * 255).astype(np.uint8)
|
246 |
-
|
247 |
last_frame = out_video[-1].copy()
|
248 |
border_thickness = 2
|
249 |
out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
|
250 |
out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
|
251 |
out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
|
252 |
out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
|
253 |
-
|
254 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
255 |
save_video(out_video, temporal_video_path)
|
256 |
input_history += keys
|
257 |
-
|
258 |
-
|
259 |
-
# now = datetime.now()
|
260 |
-
# folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
|
261 |
-
# folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
|
262 |
-
# os.makedirs(folder_path, exist_ok=True)
|
263 |
-
# data_dict = {
|
264 |
-
# "input_history": input_history,
|
265 |
-
# "memory_frames": memory_frames,
|
266 |
-
# "self_frames": self_frames,
|
267 |
-
# "self_actions": self_actions,
|
268 |
-
# "self_poses": self_poses,
|
269 |
-
# "self_memory_c2w": self_memory_c2w,
|
270 |
-
# "self_frame_idx": self_frame_idx,
|
271 |
-
# }
|
272 |
-
|
273 |
-
# np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
274 |
-
|
275 |
-
return last_frame, temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
276 |
|
277 |
def reset(selected_image):
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
input_history = ""
|
285 |
-
|
286 |
-
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
287 |
actions[0],
|
288 |
poses[0],
|
289 |
device=device,
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
return input_history,
|
297 |
|
298 |
def on_image_click(selected_image):
|
299 |
-
input_history,
|
300 |
-
return input_history, selected_image, selected_image,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
-
def set_memory(examples_case
|
303 |
if examples_case == '1':
|
304 |
data_bundle = np.load("assets/examples/case1.npz")
|
305 |
input_history = data_bundle['input_history'].item()
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
elif examples_case == '2':
|
313 |
data_bundle = np.load("assets/examples/case2.npz")
|
314 |
input_history = data_bundle['input_history'].item()
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
elif examples_case == '3':
|
322 |
data_bundle = np.load("assets/examples/case3.npz")
|
323 |
input_history = data_bundle['input_history'].item()
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
elif examples_case == '4':
|
331 |
data_bundle = np.load("assets/examples/case4.npz")
|
332 |
input_history = data_bundle['input_history'].item()
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
out_video =
|
341 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
342 |
out_video = (out_video * 255).astype(np.uint8)
|
343 |
|
344 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
345 |
save_video(out_video, temporal_video_path)
|
346 |
|
347 |
-
return input_history, out_video[-1], temporal_video_path,
|
348 |
|
349 |
css = """
|
350 |
h1 {
|
@@ -357,41 +344,27 @@ with gr.Blocks(css=css) as demo:
|
|
357 |
gr.Markdown(
|
358 |
"""
|
359 |
# WORLDMEM: Long-term Consistent World Simulation with Memory
|
360 |
-
|
361 |
-
<div style="text-align: center;">
|
362 |
-
<a style="display:inline-block; margin: 0 10px;" href="https://github.com/xizaoqu/WorldMem">
|
363 |
-
<img src="https://img.shields.io/badge/GitHub-Repository-black?logo=github"/>
|
364 |
-
</a>
|
365 |
-
<a style="display:inline-block; margin: 0 10px;" href="https://xizaoqu.github.io/worldmem/">
|
366 |
-
<img src="https://img.shields.io/badge/Project_Page-blue"/>
|
367 |
-
</a>
|
368 |
-
<a style="display:inline-block; margin: 0 10px;" href="https://arxiv.org/abs/2504.12369">
|
369 |
-
<img src="https://img.shields.io/badge/arXiv-Paper-red"/>
|
370 |
-
</a>
|
371 |
-
</div>
|
372 |
"""
|
373 |
-
|
374 |
-
|
375 |
gr.Markdown(
|
376 |
"""
|
377 |
## 🚀 How to Explore WorldMem
|
378 |
|
379 |
Follow these simple steps to get started:
|
380 |
|
381 |
-
1. **Choose a scene**.
|
382 |
2. **Input your action sequence**.
|
383 |
3. **Click "Generate"**.
|
384 |
|
385 |
- You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
|
386 |
- For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
|
387 |
-
- ⭐️ If you like this project, please [give it a star on GitHub](
|
388 |
- 💬 For questions or feedback, feel free to open an issue or email me at **[email protected]**.
|
389 |
|
390 |
Happy exploring! 🌍
|
391 |
"""
|
392 |
)
|
393 |
|
394 |
-
|
395 |
example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
396 |
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
397 |
"turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
@@ -401,6 +374,15 @@ with gr.Blocks(css=css) as demo:
|
|
401 |
|
402 |
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
with gr.Row(variant="panel"):
|
405 |
with gr.Column():
|
406 |
gr.Markdown("🖼️ Start from this frame.")
|
@@ -417,8 +399,7 @@ with gr.Blocks(css=css) as demo:
|
|
417 |
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
418 |
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
419 |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
420 |
-
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
421 |
-
|
422 |
|
423 |
with gr.Row(variant="panel"):
|
424 |
with gr.Column(scale=2):
|
@@ -454,9 +435,7 @@ with gr.Blocks(css=css) as demo:
|
|
454 |
submit_button = gr.Button("🎬 Generate!", variant="primary")
|
455 |
reset_btn = gr.Button("🔄 Reset")
|
456 |
|
457 |
-
|
458 |
gr.Markdown("### ⚙️ Advanced Settings")
|
459 |
-
|
460 |
slider_denoising_step = gr.Slider(
|
461 |
minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
|
462 |
label="Denoising Steps",
|
@@ -467,60 +446,70 @@ with gr.Blocks(css=css) as demo:
|
|
467 |
label="Context Length",
|
468 |
info="How many previous frames in temporal context window."
|
469 |
)
|
470 |
-
|
471 |
-
minimum=4, maximum=16, value=worldmem.
|
472 |
label="Memory Length",
|
473 |
-
info="How many previous frames in memory window."
|
|
|
|
|
|
|
|
|
|
|
474 |
)
|
475 |
-
|
476 |
-
|
477 |
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
478 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
|
488 |
def set_action(action):
|
489 |
return action
|
490 |
-
|
491 |
-
|
492 |
|
493 |
for button, action_key in zip(buttons, list(example_actions.keys())):
|
494 |
-
|
495 |
|
496 |
gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
|
497 |
-
|
498 |
example_case = gr.Textbox(label="Case", visible=False)
|
499 |
-
image_output = gr.Image(visible=False)
|
500 |
|
501 |
examples = gr.Examples(
|
502 |
examples=example_images,
|
503 |
-
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length,
|
504 |
cache_examples=False
|
505 |
)
|
506 |
-
|
507 |
example_case.change(
|
508 |
fn=set_memory,
|
509 |
-
inputs=[example_case
|
510 |
-
outputs=[log_output, image_display, video_display,
|
511 |
)
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
|
|
|
|
|
|
522 |
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
523 |
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|
524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
|
526 |
-
demo.launch()
|
|
|
21 |
from algorithms.worldmem import WorldMemMinecraft
|
22 |
from huggingface_hub import hf_hub_download
|
23 |
import tempfile
|
24 |
+
import os
|
25 |
+
import requests
|
26 |
+
from huggingface_hub import model_info
|
27 |
|
28 |
+
def is_huggingface_model(path: str) -> bool:
|
29 |
+
hf_ckpt = str(path).split('/')
|
30 |
+
repo_id = '/'.join(hf_ckpt[:2])
|
31 |
+
try:
|
32 |
+
model_info(repo_id)
|
33 |
+
return True
|
34 |
+
except:
|
35 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
torch.set_float32_matmul_precision("high")
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def load_custom_checkpoint(algo, checkpoint_path):
|
40 |
+
if is_huggingface_model(str(checkpoint_path)):
|
41 |
+
hf_ckpt = str(checkpoint_path).split('/')
|
42 |
+
repo_id = '/'.join(hf_ckpt[:2])
|
43 |
+
file_name = '/'.join(hf_ckpt[2:])
|
44 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=file_name)
|
45 |
+
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
46 |
+
|
47 |
+
filtered_state_dict = {}
|
48 |
+
for k, v in ckpt['state_dict'].items():
|
49 |
+
if "frame_timestep_embedder" in k:
|
50 |
+
new_k = k.replace("frame_timestep_embedder", "timestamp_embedding")
|
51 |
+
filtered_state_dict[new_k] = v
|
52 |
+
else:
|
53 |
+
filtered_state_dict[k] = v
|
54 |
+
|
55 |
+
algo.load_state_dict(filtered_state_dict, strict=True)
|
56 |
+
print("Load: ", model_path)
|
57 |
+
else:
|
58 |
+
ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
59 |
+
|
60 |
+
filtered_state_dict = {}
|
61 |
+
for k, v in ckpt['state_dict'].items():
|
62 |
+
if "frame_timestep_embedder" in k:
|
63 |
+
new_k = k.replace("frame_timestep_embedder", "timestamp_embedding")
|
64 |
+
filtered_state_dict[new_k] = v
|
65 |
+
else:
|
66 |
+
filtered_state_dict[k] = v
|
67 |
+
|
68 |
+
algo.load_state_dict(filtered_state_dict, strict=True)
|
69 |
+
print("Load: ", checkpoint_path)
|
70 |
+
|
71 |
+
def download_assets_if_needed():
|
72 |
+
ASSETS_URL_BASE = "https://huggingface.co/spaces/yslan/worldmem/resolve/main/assets/examples"
|
73 |
+
ASSETS_DIR = "assets/examples"
|
74 |
+
ASSETS = ['case1.npz', 'case2.npz', 'case3.npz', 'case4.npz']
|
75 |
+
|
76 |
+
if not os.path.exists(ASSETS_DIR):
|
77 |
+
os.makedirs(ASSETS_DIR)
|
78 |
+
|
79 |
+
for filename in ASSETS:
|
80 |
+
filepath = os.path.join(ASSETS_DIR, filename)
|
81 |
+
if not os.path.exists(filepath):
|
82 |
+
print(f"Downloading {filename}...")
|
83 |
+
url = f"{ASSETS_URL_BASE}/{filename}"
|
84 |
+
response = requests.get(url)
|
85 |
+
if response.status_code == 200:
|
86 |
+
with open(filepath, "wb") as f:
|
87 |
+
f.write(response.content)
|
88 |
+
else:
|
89 |
+
print(f"Failed to download {filename}: {response.status_code}")
|
90 |
|
91 |
def parse_input_to_tensor(input_str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
seq_len = len(input_str)
|
|
|
|
|
93 |
action_tensor = torch.zeros((seq_len, 25))
|
|
|
|
|
94 |
for i, char in enumerate(input_str):
|
95 |
+
action, value = KEY_TO_ACTION.get(char.upper(), (None, None))
|
96 |
if action and action in ACTION_KEYS:
|
97 |
index = ACTION_KEYS.index(action)
|
98 |
+
action_tensor[i, index] = value
|
|
|
99 |
return action_tensor
|
100 |
|
101 |
+
def load_image_as_tensor(image_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
if isinstance(image_path, str):
|
103 |
+
image = Image.open(image_path).convert("RGB")
|
104 |
else:
|
105 |
image = image_path
|
106 |
transform = transforms.Compose([
|
107 |
+
transforms.ToTensor(),
|
108 |
])
|
109 |
return transform(image)
|
110 |
|
111 |
def enable_amp(model, precision="16-mixed"):
|
112 |
original_forward = model.forward
|
|
|
113 |
def amp_forward(*args, **kwargs):
|
114 |
with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
|
115 |
return original_forward(*args, **kwargs)
|
|
|
116 |
model.forward = amp_forward
|
117 |
return model
|
118 |
|
119 |
+
download_assets_if_needed()
|
120 |
+
|
121 |
+
ACTION_KEYS = [
|
122 |
+
"inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4",
|
123 |
+
"hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward",
|
124 |
+
"back", "left", "right", "cameraY", "cameraX", "jump", "sneak", "sprint",
|
125 |
+
"swapHands", "attack", "use", "pickItem", "drop",
|
126 |
+
]
|
127 |
+
KEY_TO_ACTION = {
|
128 |
+
"Q": ("forward", 1), "E": ("back", 1), "W": ("cameraY", -1),
|
129 |
+
"S": ("cameraY", 1), "A": ("cameraX", -1), "D": ("cameraX", 1),
|
130 |
+
"U": ("drop", 1), "N": ("noop", 1), "1": ("hotbar.1", 1),
|
131 |
+
}
|
132 |
+
example_images = [
|
133 |
+
["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
|
134 |
+
["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
|
135 |
+
["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
|
136 |
+
["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
|
137 |
+
]
|
138 |
+
video_frames = []
|
139 |
input_history = ""
|
140 |
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
141 |
DESERT_IMAGE = "assets/desert.png"
|
|
|
144 |
PLACE_IMAGE = "assets/place.png"
|
145 |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
146 |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
|
|
147 |
device = torch.device('cuda')
|
148 |
|
149 |
def save_video(frames, path="output.mp4", fps=10):
|
150 |
+
temp_path = path[:-4] + "_temp.mp4"
|
151 |
h, w, _ = frames[0].shape
|
152 |
+
out = cv2.VideoWriter(temp_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
153 |
for frame in frames:
|
154 |
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
155 |
out.release()
|
|
|
156 |
ffmpeg_cmd = [
|
157 |
+
"ffmpeg", "-y", "-i", temp_path,
|
158 |
+
"-c:v", "libx264", "-crf", "23", "-preset", "medium",
|
159 |
+
path
|
160 |
]
|
161 |
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
162 |
+
os.remove(temp_path)
|
163 |
|
164 |
cfg = OmegaConf.load("configurations/huggingface.yaml")
|
165 |
worldmem = WorldMemMinecraft(cfg)
|
|
|
167 |
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
168 |
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
169 |
worldmem.to("cuda").eval()
|
|
|
|
|
170 |
actions = np.zeros((1, 25), dtype=np.float32)
|
171 |
poses = np.zeros((1, 5), dtype=np.float32)
|
172 |
|
173 |
+
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
|
174 |
+
memory_poses, memory_c2w, memory_frame_idx):
|
175 |
+
return 5 * len(action) if memory_actions is not None else 5
|
|
|
|
|
176 |
|
177 |
@spaces.GPU(duration=get_duration_single_image_to_long_video)
|
178 |
+
def run_interactive(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
|
179 |
+
memory_poses, memory_c2w, memory_frame_idx):
|
180 |
+
new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = worldmem.interactive(first_frame,
|
181 |
action,
|
182 |
first_pose,
|
183 |
device=device,
|
184 |
+
memory_latent_frames=memory_latent_frames,
|
185 |
+
memory_actions=memory_actions,
|
186 |
+
memory_poses=memory_poses,
|
187 |
+
memory_c2w=memory_c2w,
|
188 |
+
memory_frame_idx=memory_frame_idx)
|
189 |
+
return new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
|
|
190 |
|
191 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
192 |
worldmem.sampling_timesteps = denoising_steps
|
|
|
201 |
print("set context length to", worldmem.n_tokens)
|
202 |
return sampling_context_length_state
|
203 |
|
204 |
+
def set_memory_condition_length(memory_condition_length, sampling_memory_condition_length_state):
|
205 |
+
worldmem.memory_condition_length = memory_condition_length
|
206 |
+
sampling_memory_condition_length_state = memory_condition_length
|
207 |
+
print("set memory length to", worldmem.memory_condition_length)
|
208 |
+
return sampling_memory_condition_length_state
|
209 |
|
210 |
+
def set_next_frame_length(next_frame_length, sampling_next_frame_length_state):
|
211 |
+
worldmem.next_frame_length = next_frame_length
|
212 |
+
sampling_next_frame_length_state = next_frame_length
|
213 |
+
print("set next frame length to", worldmem.next_frame_length)
|
214 |
+
return sampling_next_frame_length_state
|
215 |
|
216 |
+
def generate(keys, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
|
217 |
+
input_actions = parse_input_to_tensor(keys)
|
218 |
+
if memory_latent_frames is None:
|
219 |
+
new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
|
220 |
actions[0],
|
221 |
poses[0],
|
222 |
device=device,
|
223 |
+
memory_latent_frames=memory_latent_frames,
|
224 |
+
memory_actions=memory_actions,
|
225 |
+
memory_poses=memory_poses,
|
226 |
+
memory_c2w=memory_c2w,
|
227 |
+
memory_frame_idx=memory_frame_idx)
|
228 |
+
new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
|
|
|
229 |
input_actions,
|
230 |
None,
|
231 |
device=device,
|
232 |
+
memory_latent_frames=memory_latent_frames,
|
233 |
+
memory_actions=memory_actions,
|
234 |
+
memory_poses=memory_poses,
|
235 |
+
memory_c2w=memory_c2w,
|
236 |
+
memory_frame_idx=memory_frame_idx)
|
237 |
+
video_frames = np.concatenate([video_frames, new_frame[:,0]])
|
238 |
+
out_video = video_frames.transpose(0,2,3,1).copy()
|
|
|
|
|
|
|
239 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
240 |
out_video = (out_video * 255).astype(np.uint8)
|
|
|
241 |
last_frame = out_video[-1].copy()
|
242 |
border_thickness = 2
|
243 |
out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
|
244 |
out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
|
245 |
out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
|
246 |
out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
|
|
|
247 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
248 |
save_video(out_video, temporal_video_path)
|
249 |
input_history += keys
|
250 |
+
return last_frame, temporal_video_path, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
def reset(selected_image):
|
253 |
+
memory_latent_frames = None
|
254 |
+
memory_poses = None
|
255 |
+
memory_actions = None
|
256 |
+
memory_c2w = None
|
257 |
+
memory_frame_idx = None
|
258 |
+
video_frames = load_image_as_tensor(selected_image).numpy()[None]
|
259 |
input_history = ""
|
260 |
+
new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
|
|
|
261 |
actions[0],
|
262 |
poses[0],
|
263 |
device=device,
|
264 |
+
memory_latent_frames=memory_latent_frames,
|
265 |
+
memory_actions=memory_actions,
|
266 |
+
memory_poses=memory_poses,
|
267 |
+
memory_c2w=memory_c2w,
|
268 |
+
memory_frame_idx=memory_frame_idx,
|
269 |
+
)
|
270 |
+
return input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
271 |
|
272 |
def on_image_click(selected_image):
|
273 |
+
input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(selected_image)
|
274 |
+
return input_history, selected_image, selected_image, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
275 |
+
|
276 |
+
# === NEW: Custom image upload handler ===
|
277 |
+
def on_custom_image_upload(custom_image):
|
278 |
+
if custom_image is None:
|
279 |
+
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
280 |
+
image = Image.fromarray(custom_image.astype(np.uint8))
|
281 |
+
input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(image)
|
282 |
+
return (
|
283 |
+
input_history, # log_output
|
284 |
+
image, # selected_image
|
285 |
+
image, # image_display
|
286 |
+
video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
287 |
+
)
|
288 |
|
289 |
+
def set_memory(examples_case):
|
290 |
if examples_case == '1':
|
291 |
data_bundle = np.load("assets/examples/case1.npz")
|
292 |
input_history = data_bundle['input_history'].item()
|
293 |
+
video_frames = data_bundle['memory_frames']
|
294 |
+
memory_latent_frames = data_bundle['self_frames']
|
295 |
+
memory_actions = data_bundle['self_actions']
|
296 |
+
memory_poses = data_bundle['self_poses']
|
297 |
+
memory_c2w = data_bundle['self_memory_c2w']
|
298 |
+
memory_frame_idx = data_bundle['self_frame_idx']
|
299 |
elif examples_case == '2':
|
300 |
data_bundle = np.load("assets/examples/case2.npz")
|
301 |
input_history = data_bundle['input_history'].item()
|
302 |
+
video_frames = data_bundle['memory_frames']
|
303 |
+
memory_latent_frames = data_bundle['self_frames']
|
304 |
+
memory_actions = data_bundle['self_actions']
|
305 |
+
memory_poses = data_bundle['self_poses']
|
306 |
+
memory_c2w = data_bundle['self_memory_c2w']
|
307 |
+
memory_frame_idx = data_bundle['self_frame_idx']
|
308 |
elif examples_case == '3':
|
309 |
data_bundle = np.load("assets/examples/case3.npz")
|
310 |
input_history = data_bundle['input_history'].item()
|
311 |
+
video_frames = data_bundle['memory_frames']
|
312 |
+
memory_latent_frames = data_bundle['self_frames']
|
313 |
+
memory_actions = data_bundle['self_actions']
|
314 |
+
memory_poses = data_bundle['self_poses']
|
315 |
+
memory_c2w = data_bundle['self_memory_c2w']
|
316 |
+
memory_frame_idx = data_bundle['self_frame_idx']
|
317 |
elif examples_case == '4':
|
318 |
data_bundle = np.load("assets/examples/case4.npz")
|
319 |
input_history = data_bundle['input_history'].item()
|
320 |
+
video_frames = data_bundle['memory_frames']
|
321 |
+
memory_latent_frames = data_bundle['self_frames']
|
322 |
+
memory_actions = data_bundle['self_actions']
|
323 |
+
memory_poses = data_bundle['self_poses']
|
324 |
+
memory_c2w = data_bundle['self_memory_c2w']
|
325 |
+
memory_frame_idx = data_bundle['self_frame_idx']
|
326 |
+
|
327 |
+
out_video = video_frames.transpose(0,2,3,1)
|
328 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
329 |
out_video = (out_video * 255).astype(np.uint8)
|
330 |
|
331 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
332 |
save_video(out_video, temporal_video_path)
|
333 |
|
334 |
+
return input_history, out_video[-1], temporal_video_path, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
335 |
|
336 |
css = """
|
337 |
h1 {
|
|
|
344 |
gr.Markdown(
|
345 |
"""
|
346 |
# WORLDMEM: Long-term Consistent World Simulation with Memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
"""
|
348 |
+
)
|
|
|
349 |
gr.Markdown(
|
350 |
"""
|
351 |
## 🚀 How to Explore WorldMem
|
352 |
|
353 |
Follow these simple steps to get started:
|
354 |
|
355 |
+
1. **Choose a scene** or **upload your own**.
|
356 |
2. **Input your action sequence**.
|
357 |
3. **Click "Generate"**.
|
358 |
|
359 |
- You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
|
360 |
- For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
|
361 |
+
- ⭐️ If you like this project, please [give it a star on GitHub]()!
|
362 |
- 💬 For questions or feedback, feel free to open an issue or email me at **[email protected]**.
|
363 |
|
364 |
Happy exploring! 🌍
|
365 |
"""
|
366 |
)
|
367 |
|
|
|
368 |
example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
369 |
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
370 |
"turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
|
|
374 |
|
375 |
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
376 |
|
377 |
+
# --- NEW: Custom image upload UI ---
|
378 |
+
with gr.Row():
|
379 |
+
gr.Markdown("🖼️ Or upload your own scene (PNG/JPG, 256x256 recommended):")
|
380 |
+
custom_image_upload = gr.Image(
|
381 |
+
label="Upload a Custom Scene",
|
382 |
+
type="numpy",
|
383 |
+
tool=None
|
384 |
+
)
|
385 |
+
|
386 |
with gr.Row(variant="panel"):
|
387 |
with gr.Column():
|
388 |
gr.Markdown("🖼️ Start from this frame.")
|
|
|
399 |
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
400 |
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
401 |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
402 |
+
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
|
|
403 |
|
404 |
with gr.Row(variant="panel"):
|
405 |
with gr.Column(scale=2):
|
|
|
435 |
submit_button = gr.Button("🎬 Generate!", variant="primary")
|
436 |
reset_btn = gr.Button("🔄 Reset")
|
437 |
|
|
|
438 |
gr.Markdown("### ⚙️ Advanced Settings")
|
|
|
439 |
slider_denoising_step = gr.Slider(
|
440 |
minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
|
441 |
label="Denoising Steps",
|
|
|
446 |
label="Context Length",
|
447 |
info="How many previous frames in temporal context window."
|
448 |
)
|
449 |
+
slider_memory_condition_length = gr.Slider(
|
450 |
+
minimum=4, maximum=16, value=worldmem.memory_condition_length, step=1,
|
451 |
label="Memory Length",
|
452 |
+
info="How many previous frames in memory window. (Recommended: 1, multi-frame generation is not stable yet)"
|
453 |
+
)
|
454 |
+
slider_next_frame_length = gr.Slider(
|
455 |
+
minimum=1, maximum=5, value=worldmem.next_frame_length, step=1,
|
456 |
+
label="Next Frame Length",
|
457 |
+
info="How many next frames to generate at once."
|
458 |
)
|
|
|
|
|
459 |
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
460 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
461 |
+
sampling_memory_condition_length_state = gr.State(worldmem.memory_condition_length)
|
462 |
+
sampling_next_frame_length_state = gr.State(worldmem.next_frame_length)
|
463 |
+
video_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
|
464 |
+
memory_latent_frames = gr.State()
|
465 |
+
memory_actions = gr.State()
|
466 |
+
memory_poses = gr.State()
|
467 |
+
memory_c2w = gr.State()
|
468 |
+
memory_frame_idx = gr.State()
|
469 |
|
470 |
def set_action(action):
|
471 |
return action
|
|
|
|
|
472 |
|
473 |
for button, action_key in zip(buttons, list(example_actions.keys())):
|
474 |
+
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
475 |
|
476 |
gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
|
|
|
477 |
example_case = gr.Textbox(label="Case", visible=False)
|
478 |
+
image_output = gr.Image(visible=False)
|
479 |
|
480 |
examples = gr.Examples(
|
481 |
examples=example_images,
|
482 |
+
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_condition_length],
|
483 |
cache_examples=False
|
484 |
)
|
|
|
485 |
example_case.change(
|
486 |
fn=set_memory,
|
487 |
+
inputs=[example_case],
|
488 |
+
outputs=[log_output, image_display, video_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx]
|
489 |
)
|
490 |
+
submit_button.click(generate, inputs=[input_box, log_output, video_frames,
|
491 |
+
memory_latent_frames, memory_actions, memory_poses,
|
492 |
+
memory_c2w, memory_frame_idx],
|
493 |
+
outputs=[image_display, video_display, log_output,
|
494 |
+
video_frames, memory_latent_frames, memory_actions, memory_poses,
|
495 |
+
memory_c2w, memory_frame_idx])
|
496 |
+
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
497 |
+
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
498 |
+
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
499 |
+
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
500 |
+
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
501 |
+
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
502 |
+
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
503 |
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
504 |
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|
505 |
+
slider_memory_condition_length.change(fn=set_memory_condition_length, inputs=[slider_memory_condition_length, sampling_memory_condition_length_state], outputs=sampling_memory_condition_length_state)
|
506 |
+
slider_next_frame_length.change(fn=set_next_frame_length, inputs=[slider_next_frame_length, sampling_next_frame_length_state], outputs=sampling_next_frame_length_state)
|
507 |
+
|
508 |
+
# --- NEW: Custom image upload triggers reset ---
|
509 |
+
custom_image_upload.change(
|
510 |
+
on_custom_image_upload,
|
511 |
+
inputs=[custom_image_upload],
|
512 |
+
outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx],
|
513 |
+
)
|
514 |
|
515 |
+
demo.launch(share=True)
|