Jensin commited on
Commit
9c89ba7
·
verified ·
1 Parent(s): 22f9b0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -252
app.py CHANGED
@@ -21,121 +21,121 @@ import spaces
21
  from algorithms.worldmem import WorldMemMinecraft
22
  from huggingface_hub import hf_hub_download
23
  import tempfile
24
- import os
25
- import requests
26
- from huggingface_hub import model_info
27
-
28
- def is_huggingface_model(path: str) -> bool:
29
- hf_ckpt = str(path).split('/')
30
- repo_id = '/'.join(hf_ckpt[:2])
31
- try:
32
- model_info(repo_id)
33
- return True
34
- except:
35
- return False
36
 
37
  torch.set_float32_matmul_precision("high")
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def load_custom_checkpoint(algo, checkpoint_path):
40
- if is_huggingface_model(str(checkpoint_path)):
41
- hf_ckpt = str(checkpoint_path).split('/')
42
- repo_id = '/'.join(hf_ckpt[:2])
43
- file_name = '/'.join(hf_ckpt[2:])
44
- model_path = hf_hub_download(repo_id=repo_id, filename=file_name)
45
- ckpt = torch.load(model_path, map_location=torch.device('cpu'))
46
-
47
- filtered_state_dict = {}
48
- for k, v in ckpt['state_dict'].items():
49
- if "frame_timestep_embedder" in k:
50
- new_k = k.replace("frame_timestep_embedder", "timestamp_embedding")
51
- filtered_state_dict[new_k] = v
52
- else:
53
- filtered_state_dict[k] = v
54
-
55
- algo.load_state_dict(filtered_state_dict, strict=True)
56
- print("Load: ", model_path)
57
- else:
58
- ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
59
-
60
- filtered_state_dict = {}
61
- for k, v in ckpt['state_dict'].items():
62
- if "frame_timestep_embedder" in k:
63
- new_k = k.replace("frame_timestep_embedder", "timestamp_embedding")
64
- filtered_state_dict[new_k] = v
65
- else:
66
- filtered_state_dict[k] = v
67
-
68
- algo.load_state_dict(filtered_state_dict, strict=True)
69
- print("Load: ", checkpoint_path)
70
-
71
- def download_assets_if_needed():
72
- ASSETS_URL_BASE = "https://huggingface.co/spaces/yslan/worldmem/resolve/main/assets/examples"
73
- ASSETS_DIR = "assets/examples"
74
- ASSETS = ['case1.npz', 'case2.npz', 'case3.npz', 'case4.npz']
75
-
76
- if not os.path.exists(ASSETS_DIR):
77
- os.makedirs(ASSETS_DIR)
78
-
79
- for filename in ASSETS:
80
- filepath = os.path.join(ASSETS_DIR, filename)
81
- if not os.path.exists(filepath):
82
- print(f"Downloading {filename}...")
83
- url = f"{ASSETS_URL_BASE}/{filename}"
84
- response = requests.get(url)
85
- if response.status_code == 200:
86
- with open(filepath, "wb") as f:
87
- f.write(response.content)
88
- else:
89
- print(f"Failed to download {filename}: {response.status_code}")
90
 
91
  def parse_input_to_tensor(input_str):
 
 
 
 
 
 
 
 
 
92
  seq_len = len(input_str)
 
 
93
  action_tensor = torch.zeros((seq_len, 25))
 
 
94
  for i, char in enumerate(input_str):
95
- action, value = KEY_TO_ACTION.get(char.upper(), (None, None))
96
  if action and action in ACTION_KEYS:
97
  index = ACTION_KEYS.index(action)
98
- action_tensor[i, index] = value
 
99
  return action_tensor
100
 
101
- def load_image_as_tensor(image_path):
 
 
 
 
 
 
 
 
 
102
  if isinstance(image_path, str):
103
- image = Image.open(image_path).convert("RGB")
104
  else:
105
  image = image_path
106
  transform = transforms.Compose([
107
- transforms.ToTensor(),
108
  ])
109
  return transform(image)
110
 
111
  def enable_amp(model, precision="16-mixed"):
112
  original_forward = model.forward
 
113
  def amp_forward(*args, **kwargs):
114
  with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
115
  return original_forward(*args, **kwargs)
 
116
  model.forward = amp_forward
117
  return model
118
 
119
- download_assets_if_needed()
120
-
121
- ACTION_KEYS = [
122
- "inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4",
123
- "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward",
124
- "back", "left", "right", "cameraY", "cameraX", "jump", "sneak", "sprint",
125
- "swapHands", "attack", "use", "pickItem", "drop",
126
- ]
127
- KEY_TO_ACTION = {
128
- "Q": ("forward", 1), "E": ("back", 1), "W": ("cameraY", -1),
129
- "S": ("cameraY", 1), "A": ("cameraX", -1), "D": ("cameraX", 1),
130
- "U": ("drop", 1), "N": ("noop", 1), "1": ("hotbar.1", 1),
131
- }
132
- example_images = [
133
- ["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
134
- ["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
135
- ["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
136
- ["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
137
- ]
138
- video_frames = []
139
  input_history = ""
140
  ICE_PLAINS_IMAGE = "assets/ice_plains.png"
141
  DESERT_IMAGE = "assets/desert.png"
@@ -144,22 +144,21 @@ PLAINS_IMAGE = "assets/plans.png"
144
  PLACE_IMAGE = "assets/place.png"
145
  SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
146
  SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
 
147
  device = torch.device('cuda')
148
 
149
  def save_video(frames, path="output.mp4", fps=10):
150
- temp_path = path[:-4] + "_temp.mp4"
151
  h, w, _ = frames[0].shape
152
- out = cv2.VideoWriter(temp_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
153
  for frame in frames:
154
  out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
155
  out.release()
 
156
  ffmpeg_cmd = [
157
- "ffmpeg", "-y", "-i", temp_path,
158
- "-c:v", "libx264", "-crf", "23", "-preset", "medium",
159
- path
160
  ]
161
  subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
162
- os.remove(temp_path)
163
 
164
  cfg = OmegaConf.load("configurations/huggingface.yaml")
165
  worldmem = WorldMemMinecraft(cfg)
@@ -167,26 +166,31 @@ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffus
167
  load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
168
  load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
169
  worldmem.to("cuda").eval()
 
 
170
  actions = np.zeros((1, 25), dtype=np.float32)
171
  poses = np.zeros((1, 5), dtype=np.float32)
172
 
173
- def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
174
- memory_poses, memory_c2w, memory_frame_idx):
175
- return 5 * len(action) if memory_actions is not None else 5
 
 
176
 
177
  @spaces.GPU(duration=get_duration_single_image_to_long_video)
178
- def run_interactive(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
179
- memory_poses, memory_c2w, memory_frame_idx):
180
- new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = worldmem.interactive(first_frame,
181
  action,
182
  first_pose,
183
  device=device,
184
- memory_latent_frames=memory_latent_frames,
185
- memory_actions=memory_actions,
186
- memory_poses=memory_poses,
187
- memory_c2w=memory_c2w,
188
- memory_frame_idx=memory_frame_idx)
189
- return new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
 
190
 
191
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
192
  worldmem.sampling_timesteps = denoising_steps
@@ -201,137 +205,144 @@ def set_context_length(context_length, sampling_context_length_state):
201
  print("set context length to", worldmem.n_tokens)
202
  return sampling_context_length_state
203
 
204
- def set_memory_condition_length(memory_condition_length, sampling_memory_condition_length_state):
205
- worldmem.memory_condition_length = memory_condition_length
206
- sampling_memory_condition_length_state = memory_condition_length
207
- print("set memory length to", worldmem.memory_condition_length)
208
- return sampling_memory_condition_length_state
209
-
210
- def set_next_frame_length(next_frame_length, sampling_next_frame_length_state):
211
- worldmem.next_frame_length = next_frame_length
212
- sampling_next_frame_length_state = next_frame_length
213
- print("set next frame length to", worldmem.next_frame_length)
214
- return sampling_next_frame_length_state
215
 
216
- def generate(keys, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
217
  input_actions = parse_input_to_tensor(keys)
218
- if memory_latent_frames is None:
219
- new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
 
220
  actions[0],
221
  poses[0],
222
  device=device,
223
- memory_latent_frames=memory_latent_frames,
224
- memory_actions=memory_actions,
225
- memory_poses=memory_poses,
226
- memory_c2w=memory_c2w,
227
- memory_frame_idx=memory_frame_idx)
228
- new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
 
229
  input_actions,
230
  None,
231
  device=device,
232
- memory_latent_frames=memory_latent_frames,
233
- memory_actions=memory_actions,
234
- memory_poses=memory_poses,
235
- memory_c2w=memory_c2w,
236
- memory_frame_idx=memory_frame_idx)
237
- video_frames = np.concatenate([video_frames, new_frame[:,0]])
238
- out_video = video_frames.transpose(0,2,3,1).copy()
 
 
 
239
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
240
  out_video = (out_video * 255).astype(np.uint8)
 
241
  last_frame = out_video[-1].copy()
242
  border_thickness = 2
243
  out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
244
  out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
245
  out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
246
  out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
 
247
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
248
  save_video(out_video, temporal_video_path)
249
  input_history += keys
250
- return last_frame, temporal_video_path, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  def reset(selected_image):
253
- memory_latent_frames = None
254
- memory_poses = None
255
- memory_actions = None
256
- memory_c2w = None
257
- memory_frame_idx = None
258
- video_frames = load_image_as_tensor(selected_image).numpy()[None]
259
  input_history = ""
260
- new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
 
261
  actions[0],
262
  poses[0],
263
  device=device,
264
- memory_latent_frames=memory_latent_frames,
265
- memory_actions=memory_actions,
266
- memory_poses=memory_poses,
267
- memory_c2w=memory_c2w,
268
- memory_frame_idx=memory_frame_idx,
269
- )
270
- return input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
271
 
272
  def on_image_click(selected_image):
273
- input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(selected_image)
274
- return input_history, selected_image, selected_image, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
275
-
276
- # === NEW: Custom image upload handler ===
277
- def on_custom_image_upload(custom_image):
278
- if custom_image is None:
279
- return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
280
- image = Image.fromarray(custom_image.astype(np.uint8))
281
- input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(image)
282
- return (
283
- input_history, # log_output
284
- image, # selected_image
285
- image, # image_display
286
- video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
287
- )
288
 
289
- def set_memory(examples_case):
290
  if examples_case == '1':
291
  data_bundle = np.load("assets/examples/case1.npz")
292
  input_history = data_bundle['input_history'].item()
293
- video_frames = data_bundle['memory_frames']
294
- memory_latent_frames = data_bundle['self_frames']
295
- memory_actions = data_bundle['self_actions']
296
- memory_poses = data_bundle['self_poses']
297
- memory_c2w = data_bundle['self_memory_c2w']
298
- memory_frame_idx = data_bundle['self_frame_idx']
299
  elif examples_case == '2':
300
  data_bundle = np.load("assets/examples/case2.npz")
301
  input_history = data_bundle['input_history'].item()
302
- video_frames = data_bundle['memory_frames']
303
- memory_latent_frames = data_bundle['self_frames']
304
- memory_actions = data_bundle['self_actions']
305
- memory_poses = data_bundle['self_poses']
306
- memory_c2w = data_bundle['self_memory_c2w']
307
- memory_frame_idx = data_bundle['self_frame_idx']
308
  elif examples_case == '3':
309
  data_bundle = np.load("assets/examples/case3.npz")
310
  input_history = data_bundle['input_history'].item()
311
- video_frames = data_bundle['memory_frames']
312
- memory_latent_frames = data_bundle['self_frames']
313
- memory_actions = data_bundle['self_actions']
314
- memory_poses = data_bundle['self_poses']
315
- memory_c2w = data_bundle['self_memory_c2w']
316
- memory_frame_idx = data_bundle['self_frame_idx']
317
  elif examples_case == '4':
318
  data_bundle = np.load("assets/examples/case4.npz")
319
  input_history = data_bundle['input_history'].item()
320
- video_frames = data_bundle['memory_frames']
321
- memory_latent_frames = data_bundle['self_frames']
322
- memory_actions = data_bundle['self_actions']
323
- memory_poses = data_bundle['self_poses']
324
- memory_c2w = data_bundle['self_memory_c2w']
325
- memory_frame_idx = data_bundle['self_frame_idx']
326
-
327
- out_video = video_frames.transpose(0,2,3,1)
328
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
329
  out_video = (out_video * 255).astype(np.uint8)
330
 
331
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
332
  save_video(out_video, temporal_video_path)
333
 
334
- return input_history, out_video[-1], temporal_video_path, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
335
 
336
  css = """
337
  h1 {
@@ -344,27 +355,36 @@ with gr.Blocks(css=css) as demo:
344
  gr.Markdown(
345
  """
346
  # WORLDMEM: Long-term Consistent World Simulation with Memory
 
 
 
 
 
 
 
 
 
 
 
347
  """
348
- )
 
349
  gr.Markdown(
350
  """
351
  ## 🚀 How to Explore WorldMem
352
-
353
  Follow these simple steps to get started:
354
-
355
- 1. **Choose a scene** or **upload your own**.
356
  2. **Input your action sequence**.
357
  3. **Click "Generate"**.
358
-
359
  - You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
360
  - For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
361
- - ⭐️ If you like this project, please [give it a star on GitHub]()!
362
  - 💬 For questions or feedback, feel free to open an issue or email me at **[email protected]**.
363
-
364
  Happy exploring! 🌍
365
  """
366
  )
367
 
 
368
  example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
369
  "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
370
  "turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
@@ -374,15 +394,6 @@ with gr.Blocks(css=css) as demo:
374
 
375
  selected_image = gr.State(ICE_PLAINS_IMAGE)
376
 
377
- # --- NEW: Custom image upload UI ---
378
- with gr.Row():
379
- gr.Markdown("🖼️ Or upload your own scene (PNG/JPG, 256x256 recommended):")
380
- custom_image_upload = gr.Image(
381
- label="Upload a Custom Scene",
382
- type="numpy",
383
- tool=None
384
- )
385
-
386
  with gr.Row(variant="panel"):
387
  with gr.Column():
388
  gr.Markdown("🖼️ Start from this frame.")
@@ -399,7 +410,8 @@ with gr.Blocks(css=css) as demo:
399
  image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
400
  image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
401
  image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
402
- image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
 
403
 
404
  with gr.Row(variant="panel"):
405
  with gr.Column(scale=2):
@@ -409,7 +421,6 @@ with gr.Blocks(css=css) as demo:
409
  gr.Markdown(
410
  """
411
  ### 💡 Action Key Guide
412
-
413
  <pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
414
  W: Turn up S: Turn down A: Turn left D: Turn right
415
  Q: Go forward E: Go backward N: No-op U: Use item
@@ -435,7 +446,9 @@ with gr.Blocks(css=css) as demo:
435
  submit_button = gr.Button("🎬 Generate!", variant="primary")
436
  reset_btn = gr.Button("🔄 Reset")
437
 
 
438
  gr.Markdown("### ⚙️ Advanced Settings")
 
439
  slider_denoising_step = gr.Slider(
440
  minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
441
  label="Denoising Steps",
@@ -446,70 +459,60 @@ with gr.Blocks(css=css) as demo:
446
  label="Context Length",
447
  info="How many previous frames in temporal context window."
448
  )
449
- slider_memory_condition_length = gr.Slider(
450
- minimum=4, maximum=16, value=worldmem.memory_condition_length, step=1,
451
  label="Memory Length",
452
- info="How many previous frames in memory window. (Recommended: 1, multi-frame generation is not stable yet)"
453
- )
454
- slider_next_frame_length = gr.Slider(
455
- minimum=1, maximum=5, value=worldmem.next_frame_length, step=1,
456
- label="Next Frame Length",
457
- info="How many next frames to generate at once."
458
  )
 
 
459
  sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
460
  sampling_context_length_state = gr.State(worldmem.n_tokens)
461
- sampling_memory_condition_length_state = gr.State(worldmem.memory_condition_length)
462
- sampling_next_frame_length_state = gr.State(worldmem.next_frame_length)
463
- video_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
464
- memory_latent_frames = gr.State()
465
- memory_actions = gr.State()
466
- memory_poses = gr.State()
467
- memory_c2w = gr.State()
468
- memory_frame_idx = gr.State()
469
 
470
  def set_action(action):
471
  return action
 
 
472
 
473
  for button, action_key in zip(buttons, list(example_actions.keys())):
474
- button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
475
 
476
  gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
 
477
  example_case = gr.Textbox(label="Case", visible=False)
478
- image_output = gr.Image(visible=False)
479
 
480
  examples = gr.Examples(
481
  examples=example_images,
482
- inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_condition_length],
483
  cache_examples=False
484
  )
 
485
  example_case.change(
486
  fn=set_memory,
487
- inputs=[example_case],
488
- outputs=[log_output, image_display, video_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx]
489
  )
490
- submit_button.click(generate, inputs=[input_box, log_output, video_frames,
491
- memory_latent_frames, memory_actions, memory_poses,
492
- memory_c2w, memory_frame_idx],
493
- outputs=[image_display, video_display, log_output,
494
- video_frames, memory_latent_frames, memory_actions, memory_poses,
495
- memory_c2w, memory_frame_idx])
496
- reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
497
- image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
498
- image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
499
- image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
500
- image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
501
- image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
502
- image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
503
  slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
504
  slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
505
- slider_memory_condition_length.change(fn=set_memory_condition_length, inputs=[slider_memory_condition_length, sampling_memory_condition_length_state], outputs=sampling_memory_condition_length_state)
506
- slider_next_frame_length.change(fn=set_next_frame_length, inputs=[slider_next_frame_length, sampling_next_frame_length_state], outputs=sampling_next_frame_length_state)
507
-
508
- # --- NEW: Custom image upload triggers reset ---
509
- custom_image_upload.change(
510
- on_custom_image_upload,
511
- inputs=[custom_image_upload],
512
- outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx],
513
- )
514
 
515
- demo.launch(share=True)
 
21
  from algorithms.worldmem import WorldMemMinecraft
22
  from huggingface_hub import hf_hub_download
23
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  torch.set_float32_matmul_precision("high")
26
 
27
+ ACTION_KEYS = [
28
+ "inventory",
29
+ "ESC",
30
+ "hotbar.1",
31
+ "hotbar.2",
32
+ "hotbar.3",
33
+ "hotbar.4",
34
+ "hotbar.5",
35
+ "hotbar.6",
36
+ "hotbar.7",
37
+ "hotbar.8",
38
+ "hotbar.9",
39
+ "forward",
40
+ "back",
41
+ "left",
42
+ "right",
43
+ "cameraY",
44
+ "cameraX",
45
+ "jump",
46
+ "sneak",
47
+ "sprint",
48
+ "swapHands",
49
+ "attack",
50
+ "use",
51
+ "pickItem",
52
+ "drop",
53
+ ]
54
+
55
+ # Mapping of input keys to action names
56
+ KEY_TO_ACTION = {
57
+ "Q": ("forward", 1),
58
+ "E": ("back", 1),
59
+ "W": ("cameraY", -1),
60
+ "S": ("cameraY", 1),
61
+ "A": ("cameraX", -1),
62
+ "D": ("cameraX", 1),
63
+ "U": ("drop", 1),
64
+ "N": ("noop", 1),
65
+ "1": ("hotbar.1", 1),
66
+ }
67
+
68
+ example_images = [
69
+ ["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
70
+ ["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
71
+ ["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
72
+ ["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
73
+ ]
74
+
75
  def load_custom_checkpoint(algo, checkpoint_path):
76
+ hf_ckpt = str(checkpoint_path).split('/')
77
+ repo_id = '/'.join(hf_ckpt[:2])
78
+ file_name = '/'.join(hf_ckpt[2:])
79
+ model_path = hf_hub_download(repo_id=repo_id,
80
+ filename=file_name)
81
+ ckpt = torch.load(model_path, map_location=torch.device('cpu'))
82
+ algo.load_state_dict(ckpt['state_dict'], strict=False)
83
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def parse_input_to_tensor(input_str):
86
+ """
87
+ Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
88
+ of the corresponding action key.
89
+ Args:
90
+ input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS").
91
+ Returns:
92
+ torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action.
93
+ """
94
+ # Get the length of the input sequence
95
  seq_len = len(input_str)
96
+
97
+ # Initialize a zero tensor of shape (seq_len, 25)
98
  action_tensor = torch.zeros((seq_len, 25))
99
+
100
+ # Iterate through the input string and update the corresponding positions
101
  for i, char in enumerate(input_str):
102
+ action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity
103
  if action and action in ACTION_KEYS:
104
  index = ACTION_KEYS.index(action)
105
+ action_tensor[i, index] = value # Set the corresponding action index to 1
106
+
107
  return action_tensor
108
 
109
+ def load_image_as_tensor(image_path: str) -> torch.Tensor:
110
+ """
111
+ Load an image and convert it to a 0-1 normalized tensor.
112
+
113
+ Args:
114
+ image_path (str): Path to the image file.
115
+
116
+ Returns:
117
+ torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1].
118
+ """
119
  if isinstance(image_path, str):
120
+ image = Image.open(image_path).convert("RGB") # Ensure it's RGB
121
  else:
122
  image = image_path
123
  transform = transforms.Compose([
124
+ transforms.ToTensor(), # Converts to tensor and normalizes to [0,1]
125
  ])
126
  return transform(image)
127
 
128
  def enable_amp(model, precision="16-mixed"):
129
  original_forward = model.forward
130
+
131
  def amp_forward(*args, **kwargs):
132
  with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
133
  return original_forward(*args, **kwargs)
134
+
135
  model.forward = amp_forward
136
  return model
137
 
138
+ memory_frames = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  input_history = ""
140
  ICE_PLAINS_IMAGE = "assets/ice_plains.png"
141
  DESERT_IMAGE = "assets/desert.png"
 
144
  PLACE_IMAGE = "assets/place.png"
145
  SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
146
  SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
147
+
148
  device = torch.device('cuda')
149
 
150
  def save_video(frames, path="output.mp4", fps=10):
 
151
  h, w, _ = frames[0].shape
152
+ out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h))
153
  for frame in frames:
154
  out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
155
  out.release()
156
+
157
  ffmpeg_cmd = [
158
+ "ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path
 
 
159
  ]
160
  subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
161
+ return path
162
 
163
  cfg = OmegaConf.load("configurations/huggingface.yaml")
164
  worldmem = WorldMemMinecraft(cfg)
 
166
  load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
167
  load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
168
  worldmem.to("cuda").eval()
169
+ # worldmem = enable_amp(worldmem, precision="16-mixed")
170
+
171
  actions = np.zeros((1, 25), dtype=np.float32)
172
  poses = np.zeros((1, 5), dtype=np.float32)
173
 
174
+
175
+
176
+ def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
177
+ self_poses, self_memory_c2w, self_frame_idx):
178
+ return 5 * len(action) if self_actions is not None else 5
179
 
180
  @spaces.GPU(duration=get_duration_single_image_to_long_video)
181
+ def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions,
182
+ self_poses, self_memory_c2w, self_frame_idx):
183
+ new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
184
  action,
185
  first_pose,
186
  device=device,
187
+ self_frames=self_frames,
188
+ self_actions=self_actions,
189
+ self_poses=self_poses,
190
+ self_memory_c2w=self_memory_c2w,
191
+ self_frame_idx=self_frame_idx)
192
+
193
+ return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
194
 
195
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
196
  worldmem.sampling_timesteps = denoising_steps
 
205
  print("set context length to", worldmem.n_tokens)
206
  return sampling_context_length_state
207
 
208
+ def set_memory_length(memory_length, sampling_memory_length_state):
209
+ worldmem.condition_similar_length = memory_length
210
+ sampling_memory_length_state = memory_length
211
+ print("set memory length to", worldmem.condition_similar_length)
212
+ return sampling_memory_length_state
 
 
 
 
 
 
213
 
214
+ def generate(keys, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
215
  input_actions = parse_input_to_tensor(keys)
216
+
217
+ if self_frames is None:
218
+ new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
219
  actions[0],
220
  poses[0],
221
  device=device,
222
+ self_frames=self_frames,
223
+ self_actions=self_actions,
224
+ self_poses=self_poses,
225
+ self_memory_c2w=self_memory_c2w,
226
+ self_frame_idx=self_frame_idx)
227
+
228
+ new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
229
  input_actions,
230
  None,
231
  device=device,
232
+ self_frames=self_frames,
233
+ self_actions=self_actions,
234
+ self_poses=self_poses,
235
+ self_memory_c2w=self_memory_c2w,
236
+ self_frame_idx=self_frame_idx)
237
+
238
+ memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
239
+
240
+
241
+ out_video = memory_frames.transpose(0,2,3,1).copy()
242
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
243
  out_video = (out_video * 255).astype(np.uint8)
244
+
245
  last_frame = out_video[-1].copy()
246
  border_thickness = 2
247
  out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
248
  out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
249
  out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
250
  out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
251
+
252
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
253
  save_video(out_video, temporal_video_path)
254
  input_history += keys
255
+
256
+
257
+ # now = datetime.now()
258
+ # folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
259
+ # folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
260
+ # os.makedirs(folder_path, exist_ok=True)
261
+ # data_dict = {
262
+ # "input_history": input_history,
263
+ # "memory_frames": memory_frames,
264
+ # "self_frames": self_frames,
265
+ # "self_actions": self_actions,
266
+ # "self_poses": self_poses,
267
+ # "self_memory_c2w": self_memory_c2w,
268
+ # "self_frame_idx": self_frame_idx,
269
+ # }
270
+
271
+ # np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
272
+
273
+ return last_frame, temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
274
 
275
  def reset(selected_image):
276
+ self_frames = None
277
+ self_poses = None
278
+ self_actions = None
279
+ self_memory_c2w = None
280
+ self_frame_idx = None
281
+ memory_frames = load_image_as_tensor(selected_image).numpy()[None]
282
  input_history = ""
283
+
284
+ new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
285
  actions[0],
286
  poses[0],
287
  device=device,
288
+ self_frames=self_frames,
289
+ self_actions=self_actions,
290
+ self_poses=self_poses,
291
+ self_memory_c2w=self_memory_c2w,
292
+ self_frame_idx=self_frame_idx)
293
+
294
+ return input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
295
 
296
  def on_image_click(selected_image):
297
+ input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = reset(selected_image)
298
+ return input_history, selected_image, selected_image, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ def set_memory(examples_case, image_display, log_output, slider_denoising_step, slider_context_length, slider_memory_length):
301
  if examples_case == '1':
302
  data_bundle = np.load("assets/examples/case1.npz")
303
  input_history = data_bundle['input_history'].item()
304
+ memory_frames = data_bundle['memory_frames']
305
+ self_frames = data_bundle['self_frames']
306
+ self_actions = data_bundle['self_actions']
307
+ self_poses = data_bundle['self_poses']
308
+ self_memory_c2w = data_bundle['self_memory_c2w']
309
+ self_frame_idx = data_bundle['self_frame_idx']
310
  elif examples_case == '2':
311
  data_bundle = np.load("assets/examples/case2.npz")
312
  input_history = data_bundle['input_history'].item()
313
+ memory_frames = data_bundle['memory_frames']
314
+ self_frames = data_bundle['self_frames']
315
+ self_actions = data_bundle['self_actions']
316
+ self_poses = data_bundle['self_poses']
317
+ self_memory_c2w = data_bundle['self_memory_c2w']
318
+ self_frame_idx = data_bundle['self_frame_idx']
319
  elif examples_case == '3':
320
  data_bundle = np.load("assets/examples/case3.npz")
321
  input_history = data_bundle['input_history'].item()
322
+ memory_frames = data_bundle['memory_frames']
323
+ self_frames = data_bundle['self_frames']
324
+ self_actions = data_bundle['self_actions']
325
+ self_poses = data_bundle['self_poses']
326
+ self_memory_c2w = data_bundle['self_memory_c2w']
327
+ self_frame_idx = data_bundle['self_frame_idx']
328
  elif examples_case == '4':
329
  data_bundle = np.load("assets/examples/case4.npz")
330
  input_history = data_bundle['input_history'].item()
331
+ memory_frames = data_bundle['memory_frames']
332
+ self_frames = data_bundle['self_frames']
333
+ self_actions = data_bundle['self_actions']
334
+ self_poses = data_bundle['self_poses']
335
+ self_memory_c2w = data_bundle['self_memory_c2w']
336
+ self_frame_idx = data_bundle['self_frame_idx']
337
+
338
+ out_video = memory_frames.transpose(0,2,3,1)
339
  out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
340
  out_video = (out_video * 255).astype(np.uint8)
341
 
342
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
343
  save_video(out_video, temporal_video_path)
344
 
345
+ return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
346
 
347
  css = """
348
  h1 {
 
355
  gr.Markdown(
356
  """
357
  # WORLDMEM: Long-term Consistent World Simulation with Memory
358
+ <div style="text-align: center;">
359
+ <a style="display:inline-block; margin: 0 10px;" href="https://github.com/xizaoqu/WorldMem">
360
+ <img src="https://img.shields.io/badge/GitHub-Repository-black?logo=github"/>
361
+ </a>
362
+ <a style="display:inline-block; margin: 0 10px;" href="https://xizaoqu.github.io/worldmem/">
363
+ <img src="https://img.shields.io/badge/Project_Page-blue"/>
364
+ </a>
365
+ <a style="display:inline-block; margin: 0 10px;" href="https://arxiv.org/abs/2504.12369">
366
+ <img src="https://img.shields.io/badge/arXiv-Paper-red"/>
367
+ </a>
368
+ </div>
369
  """
370
+ )
371
+
372
  gr.Markdown(
373
  """
374
  ## 🚀 How to Explore WorldMem
 
375
  Follow these simple steps to get started:
376
+ 1. **Choose a scene**.
 
377
  2. **Input your action sequence**.
378
  3. **Click "Generate"**.
 
379
  - You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
380
  - For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
381
+ - ⭐�� If you like this project, please [give it a star on GitHub](https://github.com/xizaoqu/WorldMem)!
382
  - 💬 For questions or feedback, feel free to open an issue or email me at **[email protected]**.
 
383
  Happy exploring! 🌍
384
  """
385
  )
386
 
387
+
388
  example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
389
  "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
390
  "turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
 
394
 
395
  selected_image = gr.State(ICE_PLAINS_IMAGE)
396
 
 
 
 
 
 
 
 
 
 
397
  with gr.Row(variant="panel"):
398
  with gr.Column():
399
  gr.Markdown("🖼️ Start from this frame.")
 
410
  image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
411
  image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
412
  image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
413
+ image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
414
+
415
 
416
  with gr.Row(variant="panel"):
417
  with gr.Column(scale=2):
 
421
  gr.Markdown(
422
  """
423
  ### 💡 Action Key Guide
 
424
  <pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
425
  W: Turn up S: Turn down A: Turn left D: Turn right
426
  Q: Go forward E: Go backward N: No-op U: Use item
 
446
  submit_button = gr.Button("🎬 Generate!", variant="primary")
447
  reset_btn = gr.Button("🔄 Reset")
448
 
449
+
450
  gr.Markdown("### ⚙️ Advanced Settings")
451
+
452
  slider_denoising_step = gr.Slider(
453
  minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
454
  label="Denoising Steps",
 
459
  label="Context Length",
460
  info="How many previous frames in temporal context window."
461
  )
462
+ slider_memory_length = gr.Slider(
463
+ minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1,
464
  label="Memory Length",
465
+ info="How many previous frames in memory window."
 
 
 
 
 
466
  )
467
+
468
+
469
  sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
470
  sampling_context_length_state = gr.State(worldmem.n_tokens)
471
+ sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
472
+
473
+ memory_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
474
+ self_frames = gr.State()
475
+ self_actions = gr.State()
476
+ self_poses = gr.State()
477
+ self_memory_c2w = gr.State()
478
+ self_frame_idx = gr.State()
479
 
480
  def set_action(action):
481
  return action
482
+
483
+
484
 
485
  for button, action_key in zip(buttons, list(example_actions.keys())):
486
+ button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
487
 
488
  gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
489
+
490
  example_case = gr.Textbox(label="Case", visible=False)
491
+ image_output = gr.Image(visible=False)
492
 
493
  examples = gr.Examples(
494
  examples=example_images,
495
+ inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
496
  cache_examples=False
497
  )
498
+
499
  example_case.change(
500
  fn=set_memory,
501
+ inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
502
+ outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
503
  )
504
+
505
+ submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
506
+ reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
507
+ image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
508
+ image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
509
+ image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
510
+ image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
511
+ image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
512
+ image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
513
+
 
 
 
514
  slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
515
  slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
516
+ slider_memory_length.change(fn=set_memory_length, inputs=[slider_memory_length, sampling_memory_length_state], outputs=sampling_memory_length_state)
 
 
 
 
 
 
 
 
517
 
518
+ demo.launch()