xizaoqu commited on
Commit
cef86dc
ยท
1 Parent(s): db555c7
Files changed (1) hide show
  1. app.py +90 -57
app.py CHANGED
@@ -71,10 +71,10 @@ KEY_TO_ACTION = {
71
  }
72
 
73
  example_images = [
74
- ["1", "assets/ice_plains.png", "turn right+go backward+look up+turn left+look down+turn right+go forward+turn left", 20, 3, 8],
75
- ["2", "assets/place.png", "put item+go backward+put item+go backward+go around", 20, 3, 8],
76
- ["3", "assets/rain_sunflower_plains.png", "turn right+look up+turn right+look down+turn left+go backward+turn left", 20, 3, 8],
77
- ["4", "assets/desert.png", "turn 360 degree+turn right+go forward+turn left", 20, 3, 8],
78
  ]
79
 
80
  def load_custom_checkpoint(algo, checkpoint_path):
@@ -264,10 +264,18 @@ def generate(keys, input_history, memory_frames, self_frames, self_actions, self
264
 
265
  memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
266
 
267
- out_video = memory_frames.transpose(0,2,3,1)
 
268
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
269
  out_video = (out_video * 255).astype(np.uint8)
270
 
 
 
 
 
 
 
 
271
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
272
  save_video(out_video, temporal_video_path)
273
  input_history += keys
@@ -289,7 +297,7 @@ def generate(keys, input_history, memory_frames, self_frames, self_actions, self
289
 
290
  # np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
291
 
292
- return out_video[-1], temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
293
 
294
  def reset(selected_image):
295
  self_frames = None
@@ -381,6 +389,24 @@ with gr.Blocks(css=css) as demo:
381
  """
382
  )
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # <div style="text-align: center;">
385
  # <!-- Public Website -->
386
  # <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
@@ -403,25 +429,50 @@ with gr.Blocks(css=css) as demo:
403
  # </a>
404
  # </div>
405
 
406
- example_actions = {"turn left + turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
407
  "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
408
- "turn right+go backward+look up+turn left+look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
409
- "turn right+go forward+turn right": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
410
- "turn right+look up+turn right+look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
411
- "put item+go backward+put item+go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
412
 
413
  selected_image = gr.State(ICE_PLAINS_IMAGE)
414
 
415
  with gr.Row(variant="panel"):
416
- video_display = gr.Video(autoplay=True, loop=True)
417
- image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
 
420
  with gr.Row(variant="panel"):
421
  with gr.Column(scale=2):
422
- input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1)
423
- log_output = gr.Textbox(label="History Log", interactive=False)
424
- gr.Markdown("### Action sequence examples.")
 
 
 
 
 
 
 
 
 
 
 
425
  with gr.Row():
426
  buttons = []
427
  for action_key in list(example_actions.keys())[:2]:
@@ -437,11 +488,28 @@ with gr.Blocks(css=css) as demo:
437
  buttons.append(gr.Button(action_key))
438
 
439
  with gr.Column(scale=1):
440
- slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps")
441
- slider_context_length = gr.Slider(minimum=2, maximum=10, value=worldmem.n_tokens, step=1, label="Context Length")
442
- slider_memory_length = gr.Slider(minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1, label="Memory Length")
443
- submit_button = gr.Button("Generate")
444
- reset_btn = gr.Button("Reset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
447
  sampling_context_length_state = gr.State(worldmem.n_tokens)
@@ -457,24 +525,12 @@ with gr.Blocks(css=css) as demo:
457
  def set_action(action):
458
  return action
459
 
460
- # gr.Markdown("### Action sequence examples.")
461
 
462
 
463
  for button, action_key in zip(buttons, list(example_actions.keys())):
464
  button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
465
 
466
-
467
- gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
468
-
469
- with gr.Row():
470
- image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
471
- image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
472
- image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
473
- image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
474
- image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
475
- image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
476
-
477
- gr.Markdown("### Click the examples below for a quick review, and continue generating based on them.")
478
 
479
  example_case = gr.Textbox(label="Case", visible=False)
480
  image_output = gr.Image(visible=False)
@@ -499,29 +555,6 @@ with gr.Blocks(css=css) as demo:
499
  )
500
 
501
 
502
-
503
- gr.Markdown(
504
- """
505
- ## Instructions & Notes:
506
-
507
- 1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin.
508
- 2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel.
509
- 3. Click **"Reset"** to clear the current sequence and start fresh.
510
- 4. Action sequences can be composed using the following keys:
511
- - W: turn up
512
- - S: turn down
513
- - A: turn left
514
- - D: turn right
515
- - Q: move forward
516
- - E: move backward
517
- - N: no-op (do nothing)
518
- - U: use item
519
- 5. Higher denoising steps produce more detailed results but take longer. 20 steps is a good balance between quality and speed. The same applies to context and memory length.
520
- 6. For faster performance, we recommend running the demo locally (~1s/frame on H100 vs ~5s on Spaces).
521
- 7. If you find this project interesting or useful, please consider giving it a โญ๏ธ on [GitHub]()!
522
- 8. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **[email protected]**.
523
- """
524
- )
525
  # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
526
  submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
527
  reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
 
71
  }
72
 
73
  example_images = [
74
+ ["1", "assets/ice_plains.png", "turn rightgo backwardโ†’look upโ†’turn leftโ†’look downโ†’turn rightโ†’go forwardโ†’turn left", 20, 3, 8],
75
+ ["2", "assets/place.png", "put itemโ†’go backwardโ†’put itemโ†’go backwardโ†’go around", 20, 3, 8],
76
+ ["3", "assets/rain_sunflower_plains.png", "turn rightโ†’look upโ†’turn rightโ†’look downโ†’turn leftโ†’go backwardโ†’turn left", 20, 3, 8],
77
+ ["4", "assets/desert.png", "turn 360 degreeโ†’turn rightโ†’go forwardโ†’turn left", 20, 3, 8],
78
  ]
79
 
80
  def load_custom_checkpoint(algo, checkpoint_path):
 
264
 
265
  memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
266
 
267
+
268
+ out_video = memory_frames.transpose(0,2,3,1).copy()
269
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
270
  out_video = (out_video * 255).astype(np.uint8)
271
 
272
+ last_frame = out_video[-1].copy()
273
+ border_thickness = 2
274
+ out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
275
+ out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
276
+ out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
277
+ out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
278
+
279
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
280
  save_video(out_video, temporal_video_path)
281
  input_history += keys
 
297
 
298
  # np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
299
 
300
+ return last_frame, temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
301
 
302
  def reset(selected_image):
303
  self_frames = None
 
389
  """
390
  )
391
 
392
+ gr.Markdown(
393
+ """
394
+ ## ๐Ÿš€ How to Explore WorldMem
395
+
396
+ Follow these simple steps to get started:
397
+
398
+ 1. **Choose a scene**.
399
+ 2. **Input your action sequence**.
400
+ 3. **Click "Generate"**.
401
+
402
+ - You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
403
+ - For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
404
+ - โญ๏ธ If you like this project, please [give it a star on GitHub]()!
405
+ - ๐Ÿ’ฌ For questions or feedback, feel free to open an issue or email me at **[email protected]**.
406
+
407
+ Happy exploring! ๐ŸŒ
408
+ """
409
+ )
410
  # <div style="text-align: center;">
411
  # <!-- Public Website -->
412
  # <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
 
429
  # </a>
430
  # </div>
431
 
432
+ example_actions = {"turn leftโ†’turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
433
  "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
434
+ "turn rightโ†’go backwardโ†’look upโ†’turn leftโ†’look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
435
+ "turn rightโ†’go forwardโ†’turn right": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
436
+ "turn rightโ†’look upโ†’turn rightโ†’look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
437
+ "put itemโ†’go backwardโ†’put itemโ†’go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
438
 
439
  selected_image = gr.State(ICE_PLAINS_IMAGE)
440
 
441
  with gr.Row(variant="panel"):
442
+ with gr.Column():
443
+ gr.Markdown("๐Ÿ–ผ๏ธ Start from this frame.")
444
+ image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
445
+ with gr.Column():
446
+ gr.Markdown("๐ŸŽž๏ธ Generated videos. New contents are marked in red box.")
447
+ video_display = gr.Video(autoplay=True, loop=True)
448
+
449
+ gr.Markdown("### ๐Ÿž๏ธ Choose a scene and start generation.")
450
+
451
+ with gr.Row():
452
+ image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
453
+ image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
454
+ image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
455
+ image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
456
+ image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
457
+ image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
458
 
459
 
460
  with gr.Row(variant="panel"):
461
  with gr.Column(scale=2):
462
+ gr.Markdown("### ๐Ÿ•น๏ธ Input action sequences for interaction.")
463
+ input_box = gr.Textbox(label="Action Sequences", placeholder="Enter action sequences here, e.g. (AAAAAAAAAAAADDDDDDDDDDDD)", lines=1, max_lines=1)
464
+ log_output = gr.Textbox(label="History Sequences", interactive=False)
465
+ gr.Markdown(
466
+ """
467
+ ### ๐Ÿ’ก Action Key Guide
468
+
469
+ <pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
470
+ W: Turn up S: Turn down A: Turn left D: Turn right
471
+ Q: Go forward E: Go backward N: No-op U: Use item
472
+ </pre>
473
+ """
474
+ )
475
+ gr.Markdown("### ๐Ÿ‘‡ Click to quickly set action sequence examples.")
476
  with gr.Row():
477
  buttons = []
478
  for action_key in list(example_actions.keys())[:2]:
 
488
  buttons.append(gr.Button(action_key))
489
 
490
  with gr.Column(scale=1):
491
+ submit_button = gr.Button("๐ŸŽฌ Generate!", variant="primary")
492
+ reset_btn = gr.Button("๐Ÿ”„ Reset")
493
+ gr.Markdown("<div style='flex-grow:1; height: 100px'></div>")
494
+
495
+ gr.Markdown("### โš™๏ธ Advanced Settings")
496
+
497
+ slider_denoising_step = gr.Slider(
498
+ minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
499
+ label="Denoising Steps",
500
+ info="Higher values yield better quality but slower speed"
501
+ )
502
+ slider_context_length = gr.Slider(
503
+ minimum=2, maximum=10, value=worldmem.n_tokens, step=1,
504
+ label="Context Length",
505
+ info="How many previous frames in temporal context window."
506
+ )
507
+ slider_memory_length = gr.Slider(
508
+ minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1,
509
+ label="Memory Length",
510
+ info="How many previous frames in memory window."
511
+ )
512
+
513
 
514
  sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
515
  sampling_context_length_state = gr.State(worldmem.n_tokens)
 
525
  def set_action(action):
526
  return action
527
 
 
528
 
529
 
530
  for button, action_key in zip(buttons, list(example_actions.keys())):
531
  button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
532
 
533
+ gr.Markdown("### ๐Ÿ‘‡ Click to review generated examples, and continue generation based on them.")
 
 
 
 
 
 
 
 
 
 
 
534
 
535
  example_case = gr.Textbox(label="Case", visible=False)
536
  image_output = gr.Image(visible=False)
 
555
  )
556
 
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
559
  submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
560
  reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])