Jensin commited on
Commit
22f9b0f
·
verified ·
1 Parent(s): 8cc527c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -258
app.py CHANGED
@@ -21,123 +21,121 @@ import spaces
21
  from algorithms.worldmem import WorldMemMinecraft
22
  from huggingface_hub import hf_hub_download
23
  import tempfile
 
 
 
24
 
25
- torch.set_float32_matmul_precision("high")
26
-
27
- ACTION_KEYS = [
28
- "inventory",
29
- "ESC",
30
- "hotbar.1",
31
- "hotbar.2",
32
- "hotbar.3",
33
- "hotbar.4",
34
- "hotbar.5",
35
- "hotbar.6",
36
- "hotbar.7",
37
- "hotbar.8",
38
- "hotbar.9",
39
- "forward",
40
- "back",
41
- "left",
42
- "right",
43
- "cameraY",
44
- "cameraX",
45
- "jump",
46
- "sneak",
47
- "sprint",
48
- "swapHands",
49
- "attack",
50
- "use",
51
- "pickItem",
52
- "drop",
53
- ]
54
-
55
- # Mapping of input keys to action names
56
- KEY_TO_ACTION = {
57
- "Q": ("forward", 1),
58
- "E": ("back", 1),
59
- "W": ("cameraY", -1),
60
- "S": ("cameraY", 1),
61
- "A": ("cameraX", -1),
62
- "D": ("cameraX", 1),
63
- "U": ("drop", 1),
64
- "N": ("noop", 1),
65
- "1": ("hotbar.1", 1),
66
- }
67
 
68
- example_images = [
69
- ["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
70
- ["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
71
- ["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
72
- ["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
73
- ]
74
 
75
  def load_custom_checkpoint(algo, checkpoint_path):
76
- hf_ckpt = str(checkpoint_path).split('/')
77
- repo_id = '/'.join(hf_ckpt[:2])
78
- file_name = '/'.join(hf_ckpt[2:])
79
- model_path = hf_hub_download(repo_id=repo_id,
80
- filename=file_name)
81
- ckpt = torch.load(model_path, map_location=torch.device('cpu'))
82
- algo.load_state_dict(ckpt['state_dict'], strict=False)
83
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def parse_input_to_tensor(input_str):
86
- """
87
- Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
88
- of the corresponding action key.
89
-
90
- Args:
91
- input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS").
92
-
93
- Returns:
94
- torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action.
95
- """
96
- # Get the length of the input sequence
97
  seq_len = len(input_str)
98
-
99
- # Initialize a zero tensor of shape (seq_len, 25)
100
  action_tensor = torch.zeros((seq_len, 25))
101
-
102
- # Iterate through the input string and update the corresponding positions
103
  for i, char in enumerate(input_str):
104
- action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity
105
  if action and action in ACTION_KEYS:
106
  index = ACTION_KEYS.index(action)
107
- action_tensor[i, index] = value # Set the corresponding action index to 1
108
-
109
  return action_tensor
110
 
111
- def load_image_as_tensor(image_path: str) -> torch.Tensor:
112
- """
113
- Load an image and convert it to a 0-1 normalized tensor.
114
-
115
- Args:
116
- image_path (str): Path to the image file.
117
-
118
- Returns:
119
- torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1].
120
- """
121
  if isinstance(image_path, str):
122
- image = Image.open(image_path).convert("RGB") # Ensure it's RGB
123
  else:
124
  image = image_path
125
  transform = transforms.Compose([
126
- transforms.ToTensor(), # Converts to tensor and normalizes to [0,1]
127
  ])
128
  return transform(image)
129
 
130
  def enable_amp(model, precision="16-mixed"):
131
  original_forward = model.forward
132
-
133
  def amp_forward(*args, **kwargs):
134
  with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
135
  return original_forward(*args, **kwargs)
136
-
137
  model.forward = amp_forward
138
  return model
139
 
140
- memory_frames = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  input_history = ""
142
  ICE_PLAINS_IMAGE = "assets/ice_plains.png"
143
  DESERT_IMAGE = "assets/desert.png"
@@ -146,21 +144,22 @@ PLAINS_IMAGE = "assets/plans.png"
146
  PLACE_IMAGE = "assets/place.png"
147
  SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
148
  SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
149
-
150
  device = torch.device('cuda')
151
 
152
  def save_video(frames, path="output.mp4", fps=10):
 
153
  h, w, _ = frames[0].shape
154
- out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h))
155
  for frame in frames:
156
  out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
157
  out.release()
158
-
159
  ffmpeg_cmd = [
160
- "ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path
 
 
161
  ]
162
  subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
163
- return path
164
 
165
  cfg = OmegaConf.load("configurations/huggingface.yaml")
166
  worldmem = WorldMemMinecraft(cfg)
@@ -168,31 +167,26 @@ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffus
168
  load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
169
  load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
170
  worldmem.to("cuda").eval()
171
- # worldmem = enable_amp(worldmem, precision="16-mixed")
172
-
173
  actions = np.zeros((1, 25), dtype=np.float32)
174
  poses = np.zeros((1, 5), dtype=np.float32)
175
 
176
-
177
-
178
- def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
179
- self_poses, self_memory_c2w, self_frame_idx):
180
- return 5 * len(action) if self_actions is not None else 5
181
 
182
  @spaces.GPU(duration=get_duration_single_image_to_long_video)
183
- def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions,
184
- self_poses, self_memory_c2w, self_frame_idx):
185
- new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
186
  action,
187
  first_pose,
188
  device=device,
189
- self_frames=self_frames,
190
- self_actions=self_actions,
191
- self_poses=self_poses,
192
- self_memory_c2w=self_memory_c2w,
193
- self_frame_idx=self_frame_idx)
194
-
195
- return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
196
 
197
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
198
  worldmem.sampling_timesteps = denoising_steps
@@ -207,144 +201,137 @@ def set_context_length(context_length, sampling_context_length_state):
207
  print("set context length to", worldmem.n_tokens)
208
  return sampling_context_length_state
209
 
210
- def set_memory_length(memory_length, sampling_memory_length_state):
211
- worldmem.condition_similar_length = memory_length
212
- sampling_memory_length_state = memory_length
213
- print("set memory length to", worldmem.condition_similar_length)
214
- return sampling_memory_length_state
215
 
216
- def generate(keys, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
217
- input_actions = parse_input_to_tensor(keys)
 
 
 
218
 
219
- if self_frames is None:
220
- new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
 
 
221
  actions[0],
222
  poses[0],
223
  device=device,
224
- self_frames=self_frames,
225
- self_actions=self_actions,
226
- self_poses=self_poses,
227
- self_memory_c2w=self_memory_c2w,
228
- self_frame_idx=self_frame_idx)
229
-
230
- new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
231
  input_actions,
232
  None,
233
  device=device,
234
- self_frames=self_frames,
235
- self_actions=self_actions,
236
- self_poses=self_poses,
237
- self_memory_c2w=self_memory_c2w,
238
- self_frame_idx=self_frame_idx)
239
-
240
- memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
241
-
242
-
243
- out_video = memory_frames.transpose(0,2,3,1).copy()
244
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
245
  out_video = (out_video * 255).astype(np.uint8)
246
-
247
  last_frame = out_video[-1].copy()
248
  border_thickness = 2
249
  out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
250
  out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
251
  out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
252
  out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
253
-
254
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
255
  save_video(out_video, temporal_video_path)
256
  input_history += keys
257
-
258
-
259
- # now = datetime.now()
260
- # folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
261
- # folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
262
- # os.makedirs(folder_path, exist_ok=True)
263
- # data_dict = {
264
- # "input_history": input_history,
265
- # "memory_frames": memory_frames,
266
- # "self_frames": self_frames,
267
- # "self_actions": self_actions,
268
- # "self_poses": self_poses,
269
- # "self_memory_c2w": self_memory_c2w,
270
- # "self_frame_idx": self_frame_idx,
271
- # }
272
-
273
- # np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
274
-
275
- return last_frame, temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
276
 
277
  def reset(selected_image):
278
- self_frames = None
279
- self_poses = None
280
- self_actions = None
281
- self_memory_c2w = None
282
- self_frame_idx = None
283
- memory_frames = load_image_as_tensor(selected_image).numpy()[None]
284
  input_history = ""
285
-
286
- new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
287
  actions[0],
288
  poses[0],
289
  device=device,
290
- self_frames=self_frames,
291
- self_actions=self_actions,
292
- self_poses=self_poses,
293
- self_memory_c2w=self_memory_c2w,
294
- self_frame_idx=self_frame_idx)
295
-
296
- return input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
297
 
298
  def on_image_click(selected_image):
299
- input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = reset(selected_image)
300
- return input_history, selected_image, selected_image, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- def set_memory(examples_case, image_display, log_output, slider_denoising_step, slider_context_length, slider_memory_length):
303
  if examples_case == '1':
304
  data_bundle = np.load("assets/examples/case1.npz")
305
  input_history = data_bundle['input_history'].item()
306
- memory_frames = data_bundle['memory_frames']
307
- self_frames = data_bundle['self_frames']
308
- self_actions = data_bundle['self_actions']
309
- self_poses = data_bundle['self_poses']
310
- self_memory_c2w = data_bundle['self_memory_c2w']
311
- self_frame_idx = data_bundle['self_frame_idx']
312
  elif examples_case == '2':
313
  data_bundle = np.load("assets/examples/case2.npz")
314
  input_history = data_bundle['input_history'].item()
315
- memory_frames = data_bundle['memory_frames']
316
- self_frames = data_bundle['self_frames']
317
- self_actions = data_bundle['self_actions']
318
- self_poses = data_bundle['self_poses']
319
- self_memory_c2w = data_bundle['self_memory_c2w']
320
- self_frame_idx = data_bundle['self_frame_idx']
321
  elif examples_case == '3':
322
  data_bundle = np.load("assets/examples/case3.npz")
323
  input_history = data_bundle['input_history'].item()
324
- memory_frames = data_bundle['memory_frames']
325
- self_frames = data_bundle['self_frames']
326
- self_actions = data_bundle['self_actions']
327
- self_poses = data_bundle['self_poses']
328
- self_memory_c2w = data_bundle['self_memory_c2w']
329
- self_frame_idx = data_bundle['self_frame_idx']
330
  elif examples_case == '4':
331
  data_bundle = np.load("assets/examples/case4.npz")
332
  input_history = data_bundle['input_history'].item()
333
- memory_frames = data_bundle['memory_frames']
334
- self_frames = data_bundle['self_frames']
335
- self_actions = data_bundle['self_actions']
336
- self_poses = data_bundle['self_poses']
337
- self_memory_c2w = data_bundle['self_memory_c2w']
338
- self_frame_idx = data_bundle['self_frame_idx']
339
-
340
- out_video = memory_frames.transpose(0,2,3,1)
341
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
342
  out_video = (out_video * 255).astype(np.uint8)
343
 
344
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
345
  save_video(out_video, temporal_video_path)
346
 
347
- return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
348
 
349
  css = """
350
  h1 {
@@ -357,41 +344,27 @@ with gr.Blocks(css=css) as demo:
357
  gr.Markdown(
358
  """
359
  # WORLDMEM: Long-term Consistent World Simulation with Memory
360
-
361
- <div style="text-align: center;">
362
- <a style="display:inline-block; margin: 0 10px;" href="https://github.com/xizaoqu/WorldMem">
363
- <img src="https://img.shields.io/badge/GitHub-Repository-black?logo=github"/>
364
- </a>
365
- <a style="display:inline-block; margin: 0 10px;" href="https://xizaoqu.github.io/worldmem/">
366
- <img src="https://img.shields.io/badge/Project_Page-blue"/>
367
- </a>
368
- <a style="display:inline-block; margin: 0 10px;" href="https://arxiv.org/abs/2504.12369">
369
- <img src="https://img.shields.io/badge/arXiv-Paper-red"/>
370
- </a>
371
- </div>
372
  """
373
- )
374
-
375
  gr.Markdown(
376
  """
377
  ## 🚀 How to Explore WorldMem
378
 
379
  Follow these simple steps to get started:
380
 
381
- 1. **Choose a scene**.
382
  2. **Input your action sequence**.
383
  3. **Click "Generate"**.
384
 
385
  - You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
386
  - For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
387
- - ⭐️ If you like this project, please [give it a star on GitHub](https://github.com/xizaoqu/WorldMem)!
388
  - 💬 For questions or feedback, feel free to open an issue or email me at **[email protected]**.
389
 
390
  Happy exploring! 🌍
391
  """
392
  )
393
 
394
-
395
  example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
396
  "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
397
  "turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
@@ -401,6 +374,15 @@ with gr.Blocks(css=css) as demo:
401
 
402
  selected_image = gr.State(ICE_PLAINS_IMAGE)
403
 
 
 
 
 
 
 
 
 
 
404
  with gr.Row(variant="panel"):
405
  with gr.Column():
406
  gr.Markdown("🖼️ Start from this frame.")
@@ -417,8 +399,7 @@ with gr.Blocks(css=css) as demo:
417
  image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
418
  image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
419
  image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
420
- image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
421
-
422
 
423
  with gr.Row(variant="panel"):
424
  with gr.Column(scale=2):
@@ -454,9 +435,7 @@ with gr.Blocks(css=css) as demo:
454
  submit_button = gr.Button("🎬 Generate!", variant="primary")
455
  reset_btn = gr.Button("🔄 Reset")
456
 
457
-
458
  gr.Markdown("### ⚙️ Advanced Settings")
459
-
460
  slider_denoising_step = gr.Slider(
461
  minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
462
  label="Denoising Steps",
@@ -467,60 +446,70 @@ with gr.Blocks(css=css) as demo:
467
  label="Context Length",
468
  info="How many previous frames in temporal context window."
469
  )
470
- slider_memory_length = gr.Slider(
471
- minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1,
472
  label="Memory Length",
473
- info="How many previous frames in memory window."
 
 
 
 
 
474
  )
475
-
476
-
477
  sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
478
  sampling_context_length_state = gr.State(worldmem.n_tokens)
479
- sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
480
-
481
- memory_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
482
- self_frames = gr.State()
483
- self_actions = gr.State()
484
- self_poses = gr.State()
485
- self_memory_c2w = gr.State()
486
- self_frame_idx = gr.State()
487
 
488
  def set_action(action):
489
  return action
490
-
491
-
492
 
493
  for button, action_key in zip(buttons, list(example_actions.keys())):
494
- button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
495
 
496
  gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
497
-
498
  example_case = gr.Textbox(label="Case", visible=False)
499
- image_output = gr.Image(visible=False)
500
 
501
  examples = gr.Examples(
502
  examples=example_images,
503
- inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
504
  cache_examples=False
505
  )
506
-
507
  example_case.change(
508
  fn=set_memory,
509
- inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
510
- outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
511
  )
512
-
513
- submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
514
- reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
515
- image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
516
- image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
517
- image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
518
- image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
519
- image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
520
- image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
521
-
 
 
 
522
  slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
523
  slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
524
- slider_memory_length.change(fn=set_memory_length, inputs=[slider_memory_length, sampling_memory_length_state], outputs=sampling_memory_length_state)
 
 
 
 
 
 
 
 
525
 
526
- demo.launch()
 
21
  from algorithms.worldmem import WorldMemMinecraft
22
  from huggingface_hub import hf_hub_download
23
  import tempfile
24
+ import os
25
+ import requests
26
+ from huggingface_hub import model_info
27
 
28
+ def is_huggingface_model(path: str) -> bool:
29
+ hf_ckpt = str(path).split('/')
30
+ repo_id = '/'.join(hf_ckpt[:2])
31
+ try:
32
+ model_info(repo_id)
33
+ return True
34
+ except:
35
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ torch.set_float32_matmul_precision("high")
 
 
 
 
 
38
 
39
  def load_custom_checkpoint(algo, checkpoint_path):
40
+ if is_huggingface_model(str(checkpoint_path)):
41
+ hf_ckpt = str(checkpoint_path).split('/')
42
+ repo_id = '/'.join(hf_ckpt[:2])
43
+ file_name = '/'.join(hf_ckpt[2:])
44
+ model_path = hf_hub_download(repo_id=repo_id, filename=file_name)
45
+ ckpt = torch.load(model_path, map_location=torch.device('cpu'))
46
+
47
+ filtered_state_dict = {}
48
+ for k, v in ckpt['state_dict'].items():
49
+ if "frame_timestep_embedder" in k:
50
+ new_k = k.replace("frame_timestep_embedder", "timestamp_embedding")
51
+ filtered_state_dict[new_k] = v
52
+ else:
53
+ filtered_state_dict[k] = v
54
+
55
+ algo.load_state_dict(filtered_state_dict, strict=True)
56
+ print("Load: ", model_path)
57
+ else:
58
+ ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
59
+
60
+ filtered_state_dict = {}
61
+ for k, v in ckpt['state_dict'].items():
62
+ if "frame_timestep_embedder" in k:
63
+ new_k = k.replace("frame_timestep_embedder", "timestamp_embedding")
64
+ filtered_state_dict[new_k] = v
65
+ else:
66
+ filtered_state_dict[k] = v
67
+
68
+ algo.load_state_dict(filtered_state_dict, strict=True)
69
+ print("Load: ", checkpoint_path)
70
+
71
+ def download_assets_if_needed():
72
+ ASSETS_URL_BASE = "https://huggingface.co/spaces/yslan/worldmem/resolve/main/assets/examples"
73
+ ASSETS_DIR = "assets/examples"
74
+ ASSETS = ['case1.npz', 'case2.npz', 'case3.npz', 'case4.npz']
75
+
76
+ if not os.path.exists(ASSETS_DIR):
77
+ os.makedirs(ASSETS_DIR)
78
+
79
+ for filename in ASSETS:
80
+ filepath = os.path.join(ASSETS_DIR, filename)
81
+ if not os.path.exists(filepath):
82
+ print(f"Downloading {filename}...")
83
+ url = f"{ASSETS_URL_BASE}/{filename}"
84
+ response = requests.get(url)
85
+ if response.status_code == 200:
86
+ with open(filepath, "wb") as f:
87
+ f.write(response.content)
88
+ else:
89
+ print(f"Failed to download {filename}: {response.status_code}")
90
 
91
  def parse_input_to_tensor(input_str):
 
 
 
 
 
 
 
 
 
 
 
92
  seq_len = len(input_str)
 
 
93
  action_tensor = torch.zeros((seq_len, 25))
 
 
94
  for i, char in enumerate(input_str):
95
+ action, value = KEY_TO_ACTION.get(char.upper(), (None, None))
96
  if action and action in ACTION_KEYS:
97
  index = ACTION_KEYS.index(action)
98
+ action_tensor[i, index] = value
 
99
  return action_tensor
100
 
101
+ def load_image_as_tensor(image_path):
 
 
 
 
 
 
 
 
 
102
  if isinstance(image_path, str):
103
+ image = Image.open(image_path).convert("RGB")
104
  else:
105
  image = image_path
106
  transform = transforms.Compose([
107
+ transforms.ToTensor(),
108
  ])
109
  return transform(image)
110
 
111
  def enable_amp(model, precision="16-mixed"):
112
  original_forward = model.forward
 
113
  def amp_forward(*args, **kwargs):
114
  with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
115
  return original_forward(*args, **kwargs)
 
116
  model.forward = amp_forward
117
  return model
118
 
119
+ download_assets_if_needed()
120
+
121
+ ACTION_KEYS = [
122
+ "inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4",
123
+ "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward",
124
+ "back", "left", "right", "cameraY", "cameraX", "jump", "sneak", "sprint",
125
+ "swapHands", "attack", "use", "pickItem", "drop",
126
+ ]
127
+ KEY_TO_ACTION = {
128
+ "Q": ("forward", 1), "E": ("back", 1), "W": ("cameraY", -1),
129
+ "S": ("cameraY", 1), "A": ("cameraX", -1), "D": ("cameraX", 1),
130
+ "U": ("drop", 1), "N": ("noop", 1), "1": ("hotbar.1", 1),
131
+ }
132
+ example_images = [
133
+ ["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
134
+ ["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
135
+ ["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
136
+ ["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
137
+ ]
138
+ video_frames = []
139
  input_history = ""
140
  ICE_PLAINS_IMAGE = "assets/ice_plains.png"
141
  DESERT_IMAGE = "assets/desert.png"
 
144
  PLACE_IMAGE = "assets/place.png"
145
  SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
146
  SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
 
147
  device = torch.device('cuda')
148
 
149
  def save_video(frames, path="output.mp4", fps=10):
150
+ temp_path = path[:-4] + "_temp.mp4"
151
  h, w, _ = frames[0].shape
152
+ out = cv2.VideoWriter(temp_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
153
  for frame in frames:
154
  out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
155
  out.release()
 
156
  ffmpeg_cmd = [
157
+ "ffmpeg", "-y", "-i", temp_path,
158
+ "-c:v", "libx264", "-crf", "23", "-preset", "medium",
159
+ path
160
  ]
161
  subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
162
+ os.remove(temp_path)
163
 
164
  cfg = OmegaConf.load("configurations/huggingface.yaml")
165
  worldmem = WorldMemMinecraft(cfg)
 
167
  load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
168
  load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
169
  worldmem.to("cuda").eval()
 
 
170
  actions = np.zeros((1, 25), dtype=np.float32)
171
  poses = np.zeros((1, 5), dtype=np.float32)
172
 
173
+ def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
174
+ memory_poses, memory_c2w, memory_frame_idx):
175
+ return 5 * len(action) if memory_actions is not None else 5
 
 
176
 
177
  @spaces.GPU(duration=get_duration_single_image_to_long_video)
178
+ def run_interactive(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
179
+ memory_poses, memory_c2w, memory_frame_idx):
180
+ new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = worldmem.interactive(first_frame,
181
  action,
182
  first_pose,
183
  device=device,
184
+ memory_latent_frames=memory_latent_frames,
185
+ memory_actions=memory_actions,
186
+ memory_poses=memory_poses,
187
+ memory_c2w=memory_c2w,
188
+ memory_frame_idx=memory_frame_idx)
189
+ return new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
 
190
 
191
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
192
  worldmem.sampling_timesteps = denoising_steps
 
201
  print("set context length to", worldmem.n_tokens)
202
  return sampling_context_length_state
203
 
204
+ def set_memory_condition_length(memory_condition_length, sampling_memory_condition_length_state):
205
+ worldmem.memory_condition_length = memory_condition_length
206
+ sampling_memory_condition_length_state = memory_condition_length
207
+ print("set memory length to", worldmem.memory_condition_length)
208
+ return sampling_memory_condition_length_state
209
 
210
+ def set_next_frame_length(next_frame_length, sampling_next_frame_length_state):
211
+ worldmem.next_frame_length = next_frame_length
212
+ sampling_next_frame_length_state = next_frame_length
213
+ print("set next frame length to", worldmem.next_frame_length)
214
+ return sampling_next_frame_length_state
215
 
216
+ def generate(keys, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
217
+ input_actions = parse_input_to_tensor(keys)
218
+ if memory_latent_frames is None:
219
+ new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
220
  actions[0],
221
  poses[0],
222
  device=device,
223
+ memory_latent_frames=memory_latent_frames,
224
+ memory_actions=memory_actions,
225
+ memory_poses=memory_poses,
226
+ memory_c2w=memory_c2w,
227
+ memory_frame_idx=memory_frame_idx)
228
+ new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
 
229
  input_actions,
230
  None,
231
  device=device,
232
+ memory_latent_frames=memory_latent_frames,
233
+ memory_actions=memory_actions,
234
+ memory_poses=memory_poses,
235
+ memory_c2w=memory_c2w,
236
+ memory_frame_idx=memory_frame_idx)
237
+ video_frames = np.concatenate([video_frames, new_frame[:,0]])
238
+ out_video = video_frames.transpose(0,2,3,1).copy()
 
 
 
239
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
240
  out_video = (out_video * 255).astype(np.uint8)
 
241
  last_frame = out_video[-1].copy()
242
  border_thickness = 2
243
  out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
244
  out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
245
  out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
246
  out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
 
247
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
248
  save_video(out_video, temporal_video_path)
249
  input_history += keys
250
+ return last_frame, temporal_video_path, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  def reset(selected_image):
253
+ memory_latent_frames = None
254
+ memory_poses = None
255
+ memory_actions = None
256
+ memory_c2w = None
257
+ memory_frame_idx = None
258
+ video_frames = load_image_as_tensor(selected_image).numpy()[None]
259
  input_history = ""
260
+ new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
 
261
  actions[0],
262
  poses[0],
263
  device=device,
264
+ memory_latent_frames=memory_latent_frames,
265
+ memory_actions=memory_actions,
266
+ memory_poses=memory_poses,
267
+ memory_c2w=memory_c2w,
268
+ memory_frame_idx=memory_frame_idx,
269
+ )
270
+ return input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
271
 
272
  def on_image_click(selected_image):
273
+ input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(selected_image)
274
+ return input_history, selected_image, selected_image, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
275
+
276
+ # === NEW: Custom image upload handler ===
277
+ def on_custom_image_upload(custom_image):
278
+ if custom_image is None:
279
+ return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
280
+ image = Image.fromarray(custom_image.astype(np.uint8))
281
+ input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(image)
282
+ return (
283
+ input_history, # log_output
284
+ image, # selected_image
285
+ image, # image_display
286
+ video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
287
+ )
288
 
289
+ def set_memory(examples_case):
290
  if examples_case == '1':
291
  data_bundle = np.load("assets/examples/case1.npz")
292
  input_history = data_bundle['input_history'].item()
293
+ video_frames = data_bundle['memory_frames']
294
+ memory_latent_frames = data_bundle['self_frames']
295
+ memory_actions = data_bundle['self_actions']
296
+ memory_poses = data_bundle['self_poses']
297
+ memory_c2w = data_bundle['self_memory_c2w']
298
+ memory_frame_idx = data_bundle['self_frame_idx']
299
  elif examples_case == '2':
300
  data_bundle = np.load("assets/examples/case2.npz")
301
  input_history = data_bundle['input_history'].item()
302
+ video_frames = data_bundle['memory_frames']
303
+ memory_latent_frames = data_bundle['self_frames']
304
+ memory_actions = data_bundle['self_actions']
305
+ memory_poses = data_bundle['self_poses']
306
+ memory_c2w = data_bundle['self_memory_c2w']
307
+ memory_frame_idx = data_bundle['self_frame_idx']
308
  elif examples_case == '3':
309
  data_bundle = np.load("assets/examples/case3.npz")
310
  input_history = data_bundle['input_history'].item()
311
+ video_frames = data_bundle['memory_frames']
312
+ memory_latent_frames = data_bundle['self_frames']
313
+ memory_actions = data_bundle['self_actions']
314
+ memory_poses = data_bundle['self_poses']
315
+ memory_c2w = data_bundle['self_memory_c2w']
316
+ memory_frame_idx = data_bundle['self_frame_idx']
317
  elif examples_case == '4':
318
  data_bundle = np.load("assets/examples/case4.npz")
319
  input_history = data_bundle['input_history'].item()
320
+ video_frames = data_bundle['memory_frames']
321
+ memory_latent_frames = data_bundle['self_frames']
322
+ memory_actions = data_bundle['self_actions']
323
+ memory_poses = data_bundle['self_poses']
324
+ memory_c2w = data_bundle['self_memory_c2w']
325
+ memory_frame_idx = data_bundle['self_frame_idx']
326
+
327
+ out_video = video_frames.transpose(0,2,3,1)
328
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
329
  out_video = (out_video * 255).astype(np.uint8)
330
 
331
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
332
  save_video(out_video, temporal_video_path)
333
 
334
+ return input_history, out_video[-1], temporal_video_path, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
335
 
336
  css = """
337
  h1 {
 
344
  gr.Markdown(
345
  """
346
  # WORLDMEM: Long-term Consistent World Simulation with Memory
 
 
 
 
 
 
 
 
 
 
 
 
347
  """
348
+ )
 
349
  gr.Markdown(
350
  """
351
  ## 🚀 How to Explore WorldMem
352
 
353
  Follow these simple steps to get started:
354
 
355
+ 1. **Choose a scene** or **upload your own**.
356
  2. **Input your action sequence**.
357
  3. **Click "Generate"**.
358
 
359
  - You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
360
  - For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
361
+ - ⭐️ If you like this project, please [give it a star on GitHub]()!
362
  - 💬 For questions or feedback, feel free to open an issue or email me at **[email protected]**.
363
 
364
  Happy exploring! 🌍
365
  """
366
  )
367
 
 
368
  example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
369
  "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
370
  "turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
 
374
 
375
  selected_image = gr.State(ICE_PLAINS_IMAGE)
376
 
377
+ # --- NEW: Custom image upload UI ---
378
+ with gr.Row():
379
+ gr.Markdown("🖼️ Or upload your own scene (PNG/JPG, 256x256 recommended):")
380
+ custom_image_upload = gr.Image(
381
+ label="Upload a Custom Scene",
382
+ type="numpy",
383
+ tool=None
384
+ )
385
+
386
  with gr.Row(variant="panel"):
387
  with gr.Column():
388
  gr.Markdown("🖼️ Start from this frame.")
 
399
  image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
400
  image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
401
  image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
402
+ image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
 
403
 
404
  with gr.Row(variant="panel"):
405
  with gr.Column(scale=2):
 
435
  submit_button = gr.Button("🎬 Generate!", variant="primary")
436
  reset_btn = gr.Button("🔄 Reset")
437
 
 
438
  gr.Markdown("### ⚙️ Advanced Settings")
 
439
  slider_denoising_step = gr.Slider(
440
  minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
441
  label="Denoising Steps",
 
446
  label="Context Length",
447
  info="How many previous frames in temporal context window."
448
  )
449
+ slider_memory_condition_length = gr.Slider(
450
+ minimum=4, maximum=16, value=worldmem.memory_condition_length, step=1,
451
  label="Memory Length",
452
+ info="How many previous frames in memory window. (Recommended: 1, multi-frame generation is not stable yet)"
453
+ )
454
+ slider_next_frame_length = gr.Slider(
455
+ minimum=1, maximum=5, value=worldmem.next_frame_length, step=1,
456
+ label="Next Frame Length",
457
+ info="How many next frames to generate at once."
458
  )
 
 
459
  sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
460
  sampling_context_length_state = gr.State(worldmem.n_tokens)
461
+ sampling_memory_condition_length_state = gr.State(worldmem.memory_condition_length)
462
+ sampling_next_frame_length_state = gr.State(worldmem.next_frame_length)
463
+ video_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
464
+ memory_latent_frames = gr.State()
465
+ memory_actions = gr.State()
466
+ memory_poses = gr.State()
467
+ memory_c2w = gr.State()
468
+ memory_frame_idx = gr.State()
469
 
470
  def set_action(action):
471
  return action
 
 
472
 
473
  for button, action_key in zip(buttons, list(example_actions.keys())):
474
+ button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
475
 
476
  gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
 
477
  example_case = gr.Textbox(label="Case", visible=False)
478
+ image_output = gr.Image(visible=False)
479
 
480
  examples = gr.Examples(
481
  examples=example_images,
482
+ inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_condition_length],
483
  cache_examples=False
484
  )
 
485
  example_case.change(
486
  fn=set_memory,
487
+ inputs=[example_case],
488
+ outputs=[log_output, image_display, video_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx]
489
  )
490
+ submit_button.click(generate, inputs=[input_box, log_output, video_frames,
491
+ memory_latent_frames, memory_actions, memory_poses,
492
+ memory_c2w, memory_frame_idx],
493
+ outputs=[image_display, video_display, log_output,
494
+ video_frames, memory_latent_frames, memory_actions, memory_poses,
495
+ memory_c2w, memory_frame_idx])
496
+ reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
497
+ image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
498
+ image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
499
+ image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
500
+ image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
501
+ image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
502
+ image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
503
  slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
504
  slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
505
+ slider_memory_condition_length.change(fn=set_memory_condition_length, inputs=[slider_memory_condition_length, sampling_memory_condition_length_state], outputs=sampling_memory_condition_length_state)
506
+ slider_next_frame_length.change(fn=set_next_frame_length, inputs=[slider_next_frame_length, sampling_next_frame_length_state], outputs=sampling_next_frame_length_state)
507
+
508
+ # --- NEW: Custom image upload triggers reset ---
509
+ custom_image_upload.change(
510
+ on_custom_image_upload,
511
+ inputs=[custom_image_upload],
512
+ outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx],
513
+ )
514
 
515
+ demo.launch(share=True)