xizaoqu
commited on
Commit
ยท
cef86dc
1
Parent(s):
db555c7
update
Browse files
app.py
CHANGED
@@ -71,10 +71,10 @@ KEY_TO_ACTION = {
|
|
71 |
}
|
72 |
|
73 |
example_images = [
|
74 |
-
["1", "assets/ice_plains.png", "turn
|
75 |
-
["2", "assets/place.png", "put item
|
76 |
-
["3", "assets/rain_sunflower_plains.png", "turn right
|
77 |
-
["4", "assets/desert.png", "turn 360 degree
|
78 |
]
|
79 |
|
80 |
def load_custom_checkpoint(algo, checkpoint_path):
|
@@ -264,10 +264,18 @@ def generate(keys, input_history, memory_frames, self_frames, self_actions, self
|
|
264 |
|
265 |
memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
|
266 |
|
267 |
-
|
|
|
268 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
269 |
out_video = (out_video * 255).astype(np.uint8)
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
272 |
save_video(out_video, temporal_video_path)
|
273 |
input_history += keys
|
@@ -289,7 +297,7 @@ def generate(keys, input_history, memory_frames, self_frames, self_actions, self
|
|
289 |
|
290 |
# np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
291 |
|
292 |
-
return
|
293 |
|
294 |
def reset(selected_image):
|
295 |
self_frames = None
|
@@ -381,6 +389,24 @@ with gr.Blocks(css=css) as demo:
|
|
381 |
"""
|
382 |
)
|
383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
# <div style="text-align: center;">
|
385 |
# <!-- Public Website -->
|
386 |
# <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
@@ -403,25 +429,50 @@ with gr.Blocks(css=css) as demo:
|
|
403 |
# </a>
|
404 |
# </div>
|
405 |
|
406 |
-
example_actions = {"turn left
|
407 |
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
408 |
-
"turn right
|
409 |
-
"turn right
|
410 |
-
"turn right
|
411 |
-
"put item
|
412 |
|
413 |
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
414 |
|
415 |
with gr.Row(variant="panel"):
|
416 |
-
|
417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
|
419 |
|
420 |
with gr.Row(variant="panel"):
|
421 |
with gr.Column(scale=2):
|
422 |
-
|
423 |
-
|
424 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
with gr.Row():
|
426 |
buttons = []
|
427 |
for action_key in list(example_actions.keys())[:2]:
|
@@ -437,11 +488,28 @@ with gr.Blocks(css=css) as demo:
|
|
437 |
buttons.append(gr.Button(action_key))
|
438 |
|
439 |
with gr.Column(scale=1):
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
446 |
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
447 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
@@ -457,24 +525,12 @@ with gr.Blocks(css=css) as demo:
|
|
457 |
def set_action(action):
|
458 |
return action
|
459 |
|
460 |
-
# gr.Markdown("### Action sequence examples.")
|
461 |
|
462 |
|
463 |
for button, action_key in zip(buttons, list(example_actions.keys())):
|
464 |
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
465 |
|
466 |
-
|
467 |
-
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
|
468 |
-
|
469 |
-
with gr.Row():
|
470 |
-
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
|
471 |
-
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
|
472 |
-
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
473 |
-
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
474 |
-
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
475 |
-
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
476 |
-
|
477 |
-
gr.Markdown("### Click the examples below for a quick review, and continue generating based on them.")
|
478 |
|
479 |
example_case = gr.Textbox(label="Case", visible=False)
|
480 |
image_output = gr.Image(visible=False)
|
@@ -499,29 +555,6 @@ with gr.Blocks(css=css) as demo:
|
|
499 |
)
|
500 |
|
501 |
|
502 |
-
|
503 |
-
gr.Markdown(
|
504 |
-
"""
|
505 |
-
## Instructions & Notes:
|
506 |
-
|
507 |
-
1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin.
|
508 |
-
2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel.
|
509 |
-
3. Click **"Reset"** to clear the current sequence and start fresh.
|
510 |
-
4. Action sequences can be composed using the following keys:
|
511 |
-
- W: turn up
|
512 |
-
- S: turn down
|
513 |
-
- A: turn left
|
514 |
-
- D: turn right
|
515 |
-
- Q: move forward
|
516 |
-
- E: move backward
|
517 |
-
- N: no-op (do nothing)
|
518 |
-
- U: use item
|
519 |
-
5. Higher denoising steps produce more detailed results but take longer. 20 steps is a good balance between quality and speed. The same applies to context and memory length.
|
520 |
-
6. For faster performance, we recommend running the demo locally (~1s/frame on H100 vs ~5s on Spaces).
|
521 |
-
7. If you find this project interesting or useful, please consider giving it a โญ๏ธ on [GitHub]()!
|
522 |
-
8. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **[email protected]**.
|
523 |
-
"""
|
524 |
-
)
|
525 |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
526 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
527 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
|
|
71 |
}
|
72 |
|
73 |
example_images = [
|
74 |
+
["1", "assets/ice_plains.png", "turn rightgo backwardโlook upโturn leftโlook downโturn rightโgo forwardโturn left", 20, 3, 8],
|
75 |
+
["2", "assets/place.png", "put itemโgo backwardโput itemโgo backwardโgo around", 20, 3, 8],
|
76 |
+
["3", "assets/rain_sunflower_plains.png", "turn rightโlook upโturn rightโlook downโturn leftโgo backwardโturn left", 20, 3, 8],
|
77 |
+
["4", "assets/desert.png", "turn 360 degreeโturn rightโgo forwardโturn left", 20, 3, 8],
|
78 |
]
|
79 |
|
80 |
def load_custom_checkpoint(algo, checkpoint_path):
|
|
|
264 |
|
265 |
memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
|
266 |
|
267 |
+
|
268 |
+
out_video = memory_frames.transpose(0,2,3,1).copy()
|
269 |
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
270 |
out_video = (out_video * 255).astype(np.uint8)
|
271 |
|
272 |
+
last_frame = out_video[-1].copy()
|
273 |
+
border_thickness = 2
|
274 |
+
out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
|
275 |
+
out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
|
276 |
+
out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
|
277 |
+
out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
|
278 |
+
|
279 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
280 |
save_video(out_video, temporal_video_path)
|
281 |
input_history += keys
|
|
|
297 |
|
298 |
# np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
299 |
|
300 |
+
return last_frame, temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
301 |
|
302 |
def reset(selected_image):
|
303 |
self_frames = None
|
|
|
389 |
"""
|
390 |
)
|
391 |
|
392 |
+
gr.Markdown(
|
393 |
+
"""
|
394 |
+
## ๐ How to Explore WorldMem
|
395 |
+
|
396 |
+
Follow these simple steps to get started:
|
397 |
+
|
398 |
+
1. **Choose a scene**.
|
399 |
+
2. **Input your action sequence**.
|
400 |
+
3. **Click "Generate"**.
|
401 |
+
|
402 |
+
- You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
|
403 |
+
- For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
|
404 |
+
- โญ๏ธ If you like this project, please [give it a star on GitHub]()!
|
405 |
+
- ๐ฌ For questions or feedback, feel free to open an issue or email me at **[email protected]**.
|
406 |
+
|
407 |
+
Happy exploring! ๐
|
408 |
+
"""
|
409 |
+
)
|
410 |
# <div style="text-align: center;">
|
411 |
# <!-- Public Website -->
|
412 |
# <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
|
|
429 |
# </a>
|
430 |
# </div>
|
431 |
|
432 |
+
example_actions = {"turn leftโturn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
433 |
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
434 |
+
"turn rightโgo backwardโlook upโturn leftโlook down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
435 |
+
"turn rightโgo forwardโturn right": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
436 |
+
"turn rightโlook upโturn rightโlook down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
|
437 |
+
"put itemโgo backwardโput itemโgo backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
|
438 |
|
439 |
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
440 |
|
441 |
with gr.Row(variant="panel"):
|
442 |
+
with gr.Column():
|
443 |
+
gr.Markdown("๐ผ๏ธ Start from this frame.")
|
444 |
+
image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
|
445 |
+
with gr.Column():
|
446 |
+
gr.Markdown("๐๏ธ Generated videos. New contents are marked in red box.")
|
447 |
+
video_display = gr.Video(autoplay=True, loop=True)
|
448 |
+
|
449 |
+
gr.Markdown("### ๐๏ธ Choose a scene and start generation.")
|
450 |
+
|
451 |
+
with gr.Row():
|
452 |
+
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
|
453 |
+
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
|
454 |
+
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
455 |
+
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
456 |
+
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
457 |
+
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
458 |
|
459 |
|
460 |
with gr.Row(variant="panel"):
|
461 |
with gr.Column(scale=2):
|
462 |
+
gr.Markdown("### ๐น๏ธ Input action sequences for interaction.")
|
463 |
+
input_box = gr.Textbox(label="Action Sequences", placeholder="Enter action sequences here, e.g. (AAAAAAAAAAAADDDDDDDDDDDD)", lines=1, max_lines=1)
|
464 |
+
log_output = gr.Textbox(label="History Sequences", interactive=False)
|
465 |
+
gr.Markdown(
|
466 |
+
"""
|
467 |
+
### ๐ก Action Key Guide
|
468 |
+
|
469 |
+
<pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
|
470 |
+
W: Turn up S: Turn down A: Turn left D: Turn right
|
471 |
+
Q: Go forward E: Go backward N: No-op U: Use item
|
472 |
+
</pre>
|
473 |
+
"""
|
474 |
+
)
|
475 |
+
gr.Markdown("### ๐ Click to quickly set action sequence examples.")
|
476 |
with gr.Row():
|
477 |
buttons = []
|
478 |
for action_key in list(example_actions.keys())[:2]:
|
|
|
488 |
buttons.append(gr.Button(action_key))
|
489 |
|
490 |
with gr.Column(scale=1):
|
491 |
+
submit_button = gr.Button("๐ฌ Generate!", variant="primary")
|
492 |
+
reset_btn = gr.Button("๐ Reset")
|
493 |
+
gr.Markdown("<div style='flex-grow:1; height: 100px'></div>")
|
494 |
+
|
495 |
+
gr.Markdown("### โ๏ธ Advanced Settings")
|
496 |
+
|
497 |
+
slider_denoising_step = gr.Slider(
|
498 |
+
minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
|
499 |
+
label="Denoising Steps",
|
500 |
+
info="Higher values yield better quality but slower speed"
|
501 |
+
)
|
502 |
+
slider_context_length = gr.Slider(
|
503 |
+
minimum=2, maximum=10, value=worldmem.n_tokens, step=1,
|
504 |
+
label="Context Length",
|
505 |
+
info="How many previous frames in temporal context window."
|
506 |
+
)
|
507 |
+
slider_memory_length = gr.Slider(
|
508 |
+
minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1,
|
509 |
+
label="Memory Length",
|
510 |
+
info="How many previous frames in memory window."
|
511 |
+
)
|
512 |
+
|
513 |
|
514 |
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
515 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
|
|
525 |
def set_action(action):
|
526 |
return action
|
527 |
|
|
|
528 |
|
529 |
|
530 |
for button, action_key in zip(buttons, list(example_actions.keys())):
|
531 |
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
532 |
|
533 |
+
gr.Markdown("### ๐ Click to review generated examples, and continue generation based on them.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
|
535 |
example_case = gr.Textbox(label="Case", visible=False)
|
536 |
image_output = gr.Image(visible=False)
|
|
|
555 |
)
|
556 |
|
557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
559 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
560 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|