Sm0kyWu commited on
Commit
bc59ff5
·
verified ·
1 Parent(s): 054fa18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +642 -640
app.py CHANGED
@@ -1,641 +1,643 @@
1
- import gradio as gr
2
- import spaces
3
-
4
- import os
5
-
6
- import shutil
7
- os.environ['SPCONV_ALGO'] = 'native'
8
- from typing import *
9
- import torch
10
- import numpy as np
11
- import imageio
12
- from easydict import EasyDict as edict
13
- from PIL import Image
14
- from Amodal3R.pipelines import Amodal3RImageTo3DPipeline
15
- from Amodal3R.representations import Gaussian, MeshExtractResult
16
- from Amodal3R.utils import render_utils, postprocessing_utils
17
- from segment_anything import sam_model_registry, SamPredictor
18
- from huggingface_hub import hf_hub_download
19
- import cv2
20
-
21
-
22
- MAX_SEED = np.iinfo(np.int32).max
23
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
24
- os.makedirs(TMP_DIR, exist_ok=True)
25
-
26
- def start_session(req: gr.Request):
27
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
28
- os.makedirs(user_dir, exist_ok=True)
29
-
30
- def end_session(req: gr.Request):
31
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
32
- shutil.rmtree(user_dir)
33
-
34
- def change_message():
35
- return "Please wait for a few seconds after uploading the image."
36
-
37
- def reset_image(predictor, img):
38
- img = np.array(img)
39
- predictor.set_image(img)
40
- original_img = img.copy()
41
- return predictor, original_img, "The models are ready.", [], [], [], original_img
42
-
43
- def button_clickable(selected_points):
44
- if len(selected_points) > 0:
45
- return gr.Button.update(interactive=True)
46
- else:
47
- return gr.Button.update(interactive=False)
48
-
49
- def run_sam(img, predictor, selected_points):
50
- if len(selected_points) == 0:
51
- return np.zeros(img.shape[:2], dtype=np.uint8)
52
- input_points = [p for p in selected_points]
53
- input_labels = [1 for _ in range(len(selected_points))]
54
- masks, _, _ = predictor.predict(
55
- point_coords=np.array(input_points),
56
- point_labels=np.array(input_labels),
57
- multimask_output=False,
58
- )
59
- best_mask = masks[0].astype(np.uint8)
60
- # dilate
61
- if len(selected_points) > 1:
62
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
63
- best_mask = cv2.dilate(best_mask, kernel, iterations=1)
64
- best_mask = cv2.erode(best_mask, kernel, iterations=1)
65
- return best_mask
66
-
67
-
68
- @spaces.GPU
69
- def image_to_3d(
70
- image: np.ndarray,
71
- mask: np.ndarray,
72
- seed: int,
73
- ss_guidance_strength: float,
74
- ss_sampling_steps: int,
75
- slat_guidance_strength: float,
76
- slat_sampling_steps: int,
77
- erode_kernel_size: int,
78
- req: gr.Request,
79
- ) -> Tuple[dict, str]:
80
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
81
- outputs = pipeline.run_multi_image(
82
- [image],
83
- [mask],
84
- seed=seed,
85
- formats=["gaussian", "mesh"],
86
- sparse_structure_sampler_params={
87
- "steps": ss_sampling_steps,
88
- "cfg_strength": ss_guidance_strength,
89
- },
90
- slat_sampler_params={
91
- "steps": slat_sampling_steps,
92
- "cfg_strength": slat_guidance_strength,
93
- },
94
- mode="stochastic",
95
- erode_kernel_size=erode_kernel_size,
96
- )
97
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120, bg_color=(1,1,1))['color']
98
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
99
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
100
- video_path = os.path.join(user_dir, 'sample.mp4')
101
- imageio.mimsave(video_path, video, fps=15)
102
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
103
- torch.cuda.empty_cache()
104
- return state, video_path
105
-
106
-
107
- @spaces.GPU(duration=90)
108
- def extract_glb(
109
- state: dict,
110
- mesh_simplify: float,
111
- texture_size: int,
112
- req: gr.Request,
113
- ) -> tuple:
114
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
115
- gs, mesh = unpack_state(state)
116
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
117
- glb_path = os.path.join(user_dir, 'sample.glb')
118
- glb.export(glb_path)
119
- torch.cuda.empty_cache()
120
- return glb_path, glb_path
121
-
122
-
123
- @spaces.GPU
124
- def extract_gaussian(state: dict, req: gr.Request) -> tuple:
125
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
126
- gs, _ = unpack_state(state)
127
- gaussian_path = os.path.join(user_dir, 'sample.ply')
128
- gs.save_ply(gaussian_path)
129
- torch.cuda.empty_cache()
130
- return gaussian_path, gaussian_path
131
-
132
-
133
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
134
- return {
135
- 'gaussian': {
136
- **gs.init_params,
137
- '_xyz': gs._xyz.cpu().numpy(),
138
- '_features_dc': gs._features_dc.cpu().numpy(),
139
- '_scaling': gs._scaling.cpu().numpy(),
140
- '_rotation': gs._rotation.cpu().numpy(),
141
- '_opacity': gs._opacity.cpu().numpy(),
142
- },
143
- 'mesh': {
144
- 'vertices': mesh.vertices.cpu().numpy(),
145
- 'faces': mesh.faces.cpu().numpy(),
146
- },
147
- }
148
-
149
-
150
- def unpack_state(state: dict) -> tuple:
151
- gs = Gaussian(
152
- aabb=state['gaussian']['aabb'],
153
- sh_degree=state['gaussian']['sh_degree'],
154
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
155
- scaling_bias=state['gaussian']['scaling_bias'],
156
- opacity_bias=state['gaussian']['opacity_bias'],
157
- scaling_activation=state['gaussian']['scaling_activation'],
158
- )
159
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
160
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
161
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
162
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
163
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
164
-
165
- mesh = edict(
166
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
167
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
168
- )
169
-
170
- return gs, mesh
171
-
172
- def get_sam_predictor():
173
- sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
174
- model_type = "vit_h"
175
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
176
- sam_predictor = SamPredictor(sam)
177
- return sam_predictor
178
-
179
-
180
- def draw_points_on_image(image, point):
181
- image_with_points = image.copy()
182
- x, y = point
183
- color = (255, 0, 0)
184
- cv2.circle(image_with_points, (int(x), int(y)), radius=10, color=color, thickness=-1)
185
- return image_with_points
186
-
187
-
188
- def see_point(image, x, y):
189
- updated_image = draw_points_on_image(image, [x,y])
190
- return updated_image
191
-
192
- def add_point(x, y, visible_points):
193
- if [x, y] not in visible_points:
194
- visible_points.append([x, y])
195
- return visible_points
196
-
197
- def delete_point(visible_points):
198
- visible_points.pop()
199
- return visible_points
200
-
201
-
202
- def clear_all_points(image):
203
- updated_image = image.copy()
204
- return updated_image
205
-
206
- def see_visible_points(image, visible_points):
207
- updated_image = image.copy()
208
- for p in visible_points:
209
- cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
210
- return updated_image
211
-
212
- def see_occlusion_points(image, occlusion_points):
213
- updated_image = image.copy()
214
- for p in occlusion_points:
215
- cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(0, 255, 0), thickness=-1)
216
- return updated_image
217
-
218
- def update_all_points(points):
219
- text = f"Points: {points}"
220
- dropdown_choices = [f"({p[0]}, {p[1]})" for p in points]
221
- return text, gr.Dropdown(show_label=False, choices=dropdown_choices, value=None, interactive=True)
222
-
223
- def delete_selected(image, visible_points, occlusion_points, occlusion_mask_list, selected_value, point_type):
224
- if point_type == "visibility":
225
- try:
226
- selected_index = [f"({p[0]}, {p[1]})" for p in visible_points].index(selected_value)
227
- except ValueError:
228
- selected_index = None
229
- if selected_index is not None and 0 <= selected_index < len(visible_points):
230
- visible_points.pop(selected_index)
231
- else:
232
- try:
233
- selected_index = [f"({p[0]}, {p[1]})" for p in occlusion_points].index(selected_value)
234
- except ValueError:
235
- selected_index = None
236
- if selected_index is not None and 0 <= selected_index < len(occlusion_points):
237
- occlusion_points.pop(selected_index)
238
- occlusion_mask_list.pop(selected_index)
239
- updated_image = image.copy()
240
- updated_image = see_visible_points(updated_image, visible_points)
241
- updated_image = see_occlusion_points(updated_image, occlusion_points)
242
- if point_type == "visibility":
243
- updated_text, dropdown = update_all_points(visible_points)
244
- else:
245
- updated_text, dropdown = update_all_points(occlusion_points)
246
- return updated_image, visible_points, occlusion_points, updated_text, dropdown
247
-
248
- def add_current_mask(visibility_mask, visibilty_mask_list, point_type):
249
- if point_type == "visibility":
250
- if len(visibilty_mask_list) > 0:
251
- if np.array_equal(visibility_mask, visibilty_mask_list[-1]):
252
- return visibilty_mask_list
253
- visibilty_mask_list.append(visibility_mask)
254
- return visibilty_mask_list
255
- else: # the occlusion mask will be automatically added, so do nothing here
256
- return visibilty_mask_list
257
-
258
- def apply_mask_overlay(image, mask, color=(255, 0, 0)):
259
- img_arr = image
260
- overlay = img_arr.copy()
261
- gray_color = np.array([200, 200, 200], dtype=np.uint8)
262
- non_mask = mask == 0
263
- overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
264
- contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
265
- cv2.drawContours(overlay, contours, -1, color, 2)
266
- return overlay
267
-
268
- def vis_mask(image, mask_list):
269
- updated_image = image.copy()
270
- combined_mask = np.zeros_like(updated_image[:, :, 0])
271
- for mask in mask_list:
272
- combined_mask = cv2.bitwise_or(combined_mask, mask)
273
- updated_image = apply_mask_overlay(updated_image, combined_mask)
274
- return updated_image
275
-
276
- def segment_and_overlay(image, points, sam_predictor, mask_list, point_type):
277
- if point_type == "visibility":
278
- visible_mask = run_sam(image, sam_predictor, points)
279
- for mask in mask_list:
280
- visible_mask = cv2.bitwise_or(visible_mask, mask)
281
- overlaid = apply_mask_overlay(image, visible_mask * 255)
282
- return overlaid, visible_mask, mask_list
283
- else:
284
- combined_occlusion_mask = np.zeros_like(image[:, :, 0])
285
- mask_list = []
286
- if len(points) != 0:
287
- for point in points:
288
- mask = run_sam(image, sam_predictor, [point])
289
- mask_list.append(mask)
290
- combined_occlusion_mask = cv2.bitwise_or(combined_occlusion_mask, mask)
291
- overlaid = apply_mask_overlay(image, combined_occlusion_mask * 255, color=(0, 255, 0))
292
- return overlaid, combined_occlusion_mask, mask_list
293
-
294
- def delete_mask(visibility_mask_list, occlusion_mask_list, occlusion_points_state, point_type):
295
- if point_type == "visibility":
296
- if len(visibility_mask_list) > 0:
297
- visibility_mask_list.pop()
298
- else:
299
- if len(occlusion_mask_list) > 0:
300
- occlusion_mask_list.pop()
301
- occlusion_points_state.pop()
302
- return visibility_mask_list, occlusion_mask_list, occlusion_points_state
303
-
304
- def check_combined_mask(image, visibility_mask, visibility_mask_list, occlusion_mask_list, scale=0.68):
305
- if visibility_mask.sum() == 0:
306
- return np.zeros_like(image), np.zeros_like(image[:, :, 0])
307
- updated_image = image.copy()
308
- combined_mask = np.zeros_like(updated_image[:, :, 0])
309
- occluded_mask = np.zeros_like(updated_image[:, :, 0])
310
- binary_visibility_masks = [(m > 0).astype(np.uint8) for m in visibility_mask_list]
311
- combined_mask = np.zeros_like(binary_visibility_masks[0]) if binary_visibility_masks else (visibility_mask > 0).astype(np.uint8)
312
- for m in binary_visibility_masks:
313
- combined_mask = cv2.bitwise_or(combined_mask, m)
314
-
315
- if len(binary_visibility_masks) > 1:
316
- kernel = np.ones((5, 5), np.uint8)
317
- combined_mask = cv2.dilate(combined_mask, kernel, iterations=1)
318
-
319
- binary_occlusion_masks = [(m > 0).astype(np.uint8) for m in occlusion_mask_list]
320
- occluded_mask = np.zeros_like(binary_occlusion_masks[0]) if binary_occlusion_masks else np.zeros_like(combined_mask)
321
- for m in binary_occlusion_masks:
322
- occluded_mask = cv2.bitwise_or(occluded_mask, m)
323
-
324
- kernel_small = np.ones((3, 3), np.uint8)
325
- if len(binary_occlusion_masks) > 0:
326
- dilated = cv2.dilate(combined_mask, kernel_small, iterations=1)
327
- boundary_mask = dilated - combined_mask
328
- occluded_mask = cv2.bitwise_or(occluded_mask, boundary_mask)
329
- occluded_mask = (occluded_mask > 0).astype(np.uint8)
330
- occluded_mask = cv2.dilate(occluded_mask, kernel_small, iterations=1)
331
- occluded_mask = (occluded_mask > 0).astype(np.uint8)
332
- else:
333
- occluded_mask = 1 - combined_mask
334
-
335
- combined_mask[occluded_mask == 1] = 0
336
-
337
- occluded_mask = (1-occluded_mask) * 255
338
-
339
- masked_img = updated_image * combined_mask[:, :, None]
340
- occluded_mask[combined_mask == 1] = 127
341
-
342
- x, y, w, h = cv2.boundingRect(combined_mask.astype(np.uint8))
343
-
344
- ori_h, ori_w = masked_img.shape[:2]
345
- target_size = 512
346
- scale_factor = target_size / max(w, h)
347
- final_scale = scale_factor * scale
348
- new_w = int(round(ori_w * final_scale))
349
- new_h = int(round(ori_h * final_scale))
350
-
351
- resized_occluded_mask = cv2.resize(occluded_mask.astype(np.uint8), (new_w, new_h), interpolation=cv2.INTER_NEAREST)
352
- resized_img = cv2.resize(masked_img, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
353
-
354
- final_img = np.zeros((target_size, target_size, 3), dtype=updated_image.dtype)
355
- final_occluded_mask = np.ones((target_size, target_size), dtype=np.uint8) * 255
356
-
357
- new_x = int(round(x * final_scale))
358
- new_y = int(round(y * final_scale))
359
- new_w_box = int(round(w * final_scale))
360
- new_h_box = int(round(h * final_scale))
361
-
362
- new_cx = new_x + new_w_box // 2
363
- new_cy = new_y + new_h_box // 2
364
-
365
- final_cx, final_cy = target_size // 2, target_size // 2
366
- x_offset = final_cx - new_cx
367
- y_offset = final_cy - new_cy
368
-
369
- final_x_start = max(0, x_offset)
370
- final_y_start = max(0, y_offset)
371
- final_x_end = min(target_size, x_offset + new_w)
372
- final_y_end = min(target_size, y_offset + new_h)
373
-
374
- img_x_start = max(0, -x_offset)
375
- img_y_start = max(0, -y_offset)
376
- img_x_end = min(new_w, target_size - x_offset)
377
- img_y_end = min(new_h, target_size - y_offset)
378
-
379
- final_img[final_y_start:final_y_end, final_x_start:final_x_end] = resized_img[img_y_start:img_y_end, img_x_start:img_x_end]
380
- final_occluded_mask[final_y_start:final_y_end, final_x_start:final_x_end] = resized_occluded_mask[img_y_start:img_y_end, img_x_start:img_x_end]
381
-
382
- return final_img, final_occluded_mask
383
-
384
-
385
- def get_point(img, point_type, visible_points_state, occlusion_points_state, evt: gr.SelectData):
386
- updated_img = np.array(img).copy()
387
- if point_type == "visibility":
388
- visible_points_state = add_point(evt.index[0], evt.index[1], visible_points_state)
389
- else:
390
- occlusion_points_state = add_point(evt.index[0], evt.index[1], occlusion_points_state)
391
- updated_img = see_visible_points(updated_img, visible_points_state)
392
- updated_img = see_occlusion_points(updated_img, occlusion_points_state)
393
- return updated_img, visible_points_state, occlusion_points_state
394
-
395
-
396
- def change_point_type(point_type, visible_points_state, occlusion_points_state):
397
- if point_type == "visibility":
398
- text = f"Points: {visible_points_state}"
399
- dropdown_choices = [f"({p[0]}, {p[1]})" for p in visible_points_state]
400
- else:
401
- text = f"Points: {occlusion_points_state}"
402
- dropdown_choices = [f"({p[0]}, {p[1]})" for p in occlusion_points_state]
403
- return text, gr.Dropdown(show_label=False, choices=dropdown_choices, value=None, interactive=True)
404
-
405
-
406
- def get_seed(randomize_seed: bool, seed: int) -> int:
407
- """
408
- Get the random seed.
409
- """
410
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
411
-
412
-
413
- with gr.Blocks(delete_cache=(600, 600)) as demo:
414
- gr.Markdown("""
415
- ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
416
- """)
417
-
418
- predictor = gr.State(value=get_sam_predictor())
419
- visible_points_state = gr.State(value=[])
420
- occlusion_points_state = gr.State(value=[])
421
- occlusion_mask = gr.State(value=None)
422
- occlusion_mask_list = gr.State(value=[])
423
- original_image = gr.State(value=None)
424
- visibility_mask = gr.State(value=None)
425
- visibility_mask_list = gr.State(value=[])
426
-
427
- occluded_mask = gr.State(value=None)
428
- output_buf = gr.State()
429
-
430
-
431
- with gr.Row():
432
- with gr.Column():
433
- gr.Markdown("""
434
- ### Step 1 - Generate Visibility and Occlusion Mask.
435
- * Please click "Load Example Image" first when using the provided example images (bottom).
436
- * Please wait for a few seconds after uploading the image. Segment Anything is getting ready.
437
- * **Click to add the point prompts** to indicate the target object (multiple points supported) and occluders (one point for an occluder for better usability).
438
- * "Add mask", current mask will be saved if the input needs to be added sequentially.
439
- * The scale of target object can be adjusted for better reconstruction, we suggest 0.4 to 0.7 for most cases.
440
- """)
441
- with gr.Row():
442
- input_image = gr.Image(interactive=True, type='pil', label='Input Occlusion Image', show_label=True, sources="upload", height=300)
443
- input_with_prompt = gr.Image(type="numpy", label='Input with Prompt', interactive=False, height=300)
444
- with gr.Row():
445
- apply_example_btn = gr.Button("Load Example Image")
446
- message = gr.Markdown("Please wait a few seconds after uploading the image.", label="Message")
447
- with gr.Row():
448
- point_type = gr.Radio(["visibility", "occlusion"], label="Point Prompt Type", value="visibility")
449
- with gr.Row():
450
- with gr.Column():
451
- points_text = gr.Textbox(show_label=False, interactive=False)
452
- with gr.Column():
453
- points_dropdown = gr.Dropdown(show_label=False, choices=[], value=None, interactive=True)
454
- delete_button = gr.Button("Delete Selected Point")
455
- with gr.Row():
456
- with gr.Column():
457
- render_mask = gr.Image(label='Render Mask', interactive=False, height=300)
458
- with gr.Row():
459
- add_mask = gr.Button("Add Mask")
460
- undo_mask = gr.Button("Undo Last Mask")
461
- with gr.Column():
462
- vis_input = gr.Image(label='Visible Input', interactive=False, height=300)
463
- with gr.Row():
464
- zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.68, step=0.1)
465
- with gr.Row():
466
- check_visible_input = gr.Button("Generate Occluded Input")
467
-
468
- with gr.Column():
469
- gr.Markdown("""
470
- ### Step 2 - 3D Amodal Reconstruction. (Thanks to [TRELLIS](https://huggingface.co/spaces/JeffreyXiang/TRELLIS) for the 3D rendering component!)
471
- * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
472
- * The boundary of the segmentation may not be accurate, so here we provide the option to erode the visible area (try 0, 3 or 5).
473
- * If the reconstructed 3D asset is satisfactory, interactive GLB file can be extracted (may look dull due to the absence of light source) and downloaded.
474
- """)
475
- with gr.Row():
476
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
477
- with gr.Row():
478
- with gr.Accordion(label="Generation Settings", open=False):
479
- with gr.Row():
480
- with gr.Column():
481
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
482
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
483
- with gr.Column():
484
- erode_kernel_size = gr.Slider(0, 5, label="Erode Kernel Size", value=3, step=1)
485
- gr.Markdown("Stage 1: Sparse Structure Generation")
486
- with gr.Row():
487
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
488
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
489
- gr.Markdown("Stage 2: Structured Latent Generation")
490
- with gr.Row():
491
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
492
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
493
- with gr.Row():
494
- generate_btn = gr.Button("Amodal 3D Reconstruction")
495
- with gr.Row():
496
- model_output = gr.Model3D(label="Extracted GLB", pan_speed=0.5, height=300, clear_color=(0.9,0.9,0.9,1))
497
- with gr.Row():
498
- with gr.Accordion(label="GLB Extraction Settings", open=False):
499
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
500
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
501
- with gr.Row():
502
- extract_glb_btn = gr.Button("Extract GLB")
503
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
504
-
505
- with gr.Row():
506
- examples = gr.Examples(
507
- examples=[
508
- f'assets/example_image/{image}'
509
- for image in os.listdir("assets/example_image")
510
- ],
511
- inputs=[input_image],
512
- fn=lambda x: x,
513
- outputs=[input_image],
514
- run_on_click=True,
515
- examples_per_page=12,
516
- )
517
-
518
-
519
- # # Handlers
520
- demo.load(start_session)
521
- demo.unload(end_session)
522
-
523
- input_image.upload(
524
- change_message,
525
- [],
526
- [message]
527
- ).then(
528
- reset_image,
529
- [predictor, input_image],
530
- [predictor, original_image, message, visible_points_state, occlusion_points_state, occlusion_mask_list, input_with_prompt],
531
- )
532
-
533
- apply_example_btn.click(
534
- change_message,
535
- [],
536
- [message]
537
- ).then(
538
- reset_image,
539
- inputs=[predictor, input_image],
540
- outputs=[predictor, original_image, message, visible_points_state, occlusion_points_state, occlusion_mask_list, input_with_prompt]
541
- )
542
- input_image.select(
543
- get_point,
544
- inputs=[input_image, point_type, visible_points_state, occlusion_points_state],
545
- outputs=[input_with_prompt, visible_points_state, occlusion_points_state]
546
- )
547
-
548
- point_type.change(
549
- change_point_type,
550
- inputs=[point_type, visible_points_state, occlusion_points_state],
551
- outputs=[points_text, points_dropdown]
552
- )
553
-
554
- visible_points_state.change(
555
- update_all_points,
556
- inputs=[visible_points_state],
557
- outputs=[points_text, points_dropdown]
558
- ).then(
559
- segment_and_overlay,
560
- inputs=[original_image, visible_points_state, predictor, visibility_mask_list, point_type],
561
- outputs=[render_mask, visibility_mask, visibility_mask_list]
562
- ).then(
563
- check_combined_mask,
564
- inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale],
565
- outputs=[vis_input, occluded_mask]
566
- )
567
-
568
- occlusion_points_state.change(
569
- update_all_points,
570
- inputs=[occlusion_points_state],
571
- outputs=[points_text, points_dropdown]
572
- ).then(
573
- segment_and_overlay,
574
- inputs=[original_image, occlusion_points_state, predictor, occlusion_mask_list, point_type],
575
- outputs=[render_mask, occlusion_mask, occlusion_mask_list]
576
- ).then(
577
- check_combined_mask,
578
- inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale],
579
- outputs=[vis_input, occluded_mask]
580
- )
581
-
582
- delete_button.click(
583
- delete_selected,
584
- inputs=[original_image, visible_points_state, occlusion_points_state, occlusion_mask_list, points_dropdown, point_type],
585
- outputs=[input_with_prompt, visible_points_state, occlusion_points_state, points_text, points_dropdown]
586
- )
587
-
588
- add_mask.click(
589
- add_current_mask,
590
- inputs=[visibility_mask, visibility_mask_list, point_type],
591
- outputs=[visibility_mask_list]
592
- )
593
-
594
- undo_mask.click(
595
- delete_mask,
596
- inputs=[visibility_mask_list, occlusion_mask_list, occlusion_points_state, point_type],
597
- outputs=[visibility_mask_list, occlusion_mask_list, occlusion_points_state]
598
- )
599
-
600
- check_visible_input.click(
601
- check_combined_mask,
602
- inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale],
603
- outputs=[vis_input, occluded_mask]
604
- )
605
-
606
-
607
- # 3D Amodal Reconstruction
608
- generate_btn.click(
609
- get_seed,
610
- inputs=[randomize_seed, seed],
611
- outputs=[seed],
612
- ).then(
613
- image_to_3d,
614
- inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, erode_kernel_size],
615
- outputs=[output_buf, video_output],
616
- )
617
-
618
- extract_glb_btn.click(
619
- extract_glb,
620
- inputs=[output_buf, mesh_simplify, texture_size],
621
- outputs=[model_output, download_glb],
622
- ).then(
623
- lambda: gr.Button(interactive=True),
624
- outputs=[download_glb],
625
- )
626
-
627
- model_output.clear(
628
- lambda: gr.Button(interactive=False),
629
- outputs=[download_glb],
630
- )
631
-
632
-
633
-
634
- if __name__ == "__main__":
635
- pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
636
- pipeline.cuda()
637
- try:
638
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
639
- except:
640
- pass
 
 
641
  demo.launch()
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+ import os
5
+
6
+ import shutil
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from easydict import EasyDict as edict
13
+ from PIL import Image
14
+ from Amodal3R.pipelines import Amodal3RImageTo3DPipeline
15
+ from Amodal3R.representations import Gaussian, MeshExtractResult
16
+ from Amodal3R.utils import render_utils, postprocessing_utils
17
+ from segment_anything import sam_model_registry, SamPredictor
18
+ from huggingface_hub import hf_hub_download
19
+ import cv2
20
+
21
+
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
24
+ os.makedirs(TMP_DIR, exist_ok=True)
25
+ os.environ['MASTER_ADDR'] = 'localhost'
26
+ os.environ['MASTER_PORT'] = '12355'
27
+
28
+ def start_session(req: gr.Request):
29
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
30
+ os.makedirs(user_dir, exist_ok=True)
31
+
32
+ def end_session(req: gr.Request):
33
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
+ shutil.rmtree(user_dir)
35
+
36
+ def change_message():
37
+ return "Please wait for a few seconds after uploading the image."
38
+
39
+ def reset_image(predictor, img):
40
+ img = np.array(img)
41
+ predictor.set_image(img)
42
+ original_img = img.copy()
43
+ return predictor, original_img, "The models are ready.", [], [], [], original_img
44
+
45
+ def button_clickable(selected_points):
46
+ if len(selected_points) > 0:
47
+ return gr.Button.update(interactive=True)
48
+ else:
49
+ return gr.Button.update(interactive=False)
50
+
51
+ def run_sam(img, predictor, selected_points):
52
+ if len(selected_points) == 0:
53
+ return np.zeros(img.shape[:2], dtype=np.uint8)
54
+ input_points = [p for p in selected_points]
55
+ input_labels = [1 for _ in range(len(selected_points))]
56
+ masks, _, _ = predictor.predict(
57
+ point_coords=np.array(input_points),
58
+ point_labels=np.array(input_labels),
59
+ multimask_output=False,
60
+ )
61
+ best_mask = masks[0].astype(np.uint8)
62
+ # dilate
63
+ if len(selected_points) > 1:
64
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
65
+ best_mask = cv2.dilate(best_mask, kernel, iterations=1)
66
+ best_mask = cv2.erode(best_mask, kernel, iterations=1)
67
+ return best_mask
68
+
69
+
70
+ @spaces.GPU
71
+ def image_to_3d(
72
+ image: np.ndarray,
73
+ mask: np.ndarray,
74
+ seed: int,
75
+ ss_guidance_strength: float,
76
+ ss_sampling_steps: int,
77
+ slat_guidance_strength: float,
78
+ slat_sampling_steps: int,
79
+ erode_kernel_size: int,
80
+ req: gr.Request,
81
+ ) -> Tuple[dict, str]:
82
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
83
+ outputs = pipeline.run_multi_image(
84
+ [image],
85
+ [mask],
86
+ seed=seed,
87
+ formats=["gaussian", "mesh"],
88
+ sparse_structure_sampler_params={
89
+ "steps": ss_sampling_steps,
90
+ "cfg_strength": ss_guidance_strength,
91
+ },
92
+ slat_sampler_params={
93
+ "steps": slat_sampling_steps,
94
+ "cfg_strength": slat_guidance_strength,
95
+ },
96
+ mode="stochastic",
97
+ erode_kernel_size=erode_kernel_size,
98
+ )
99
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120, bg_color=(1,1,1))['color']
100
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
101
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
102
+ video_path = os.path.join(user_dir, 'sample.mp4')
103
+ imageio.mimsave(video_path, video, fps=15)
104
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
105
+ torch.cuda.empty_cache()
106
+ return state, video_path
107
+
108
+
109
+ @spaces.GPU(duration=90)
110
+ def extract_glb(
111
+ state: dict,
112
+ mesh_simplify: float,
113
+ texture_size: int,
114
+ req: gr.Request,
115
+ ) -> tuple:
116
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
117
+ gs, mesh = unpack_state(state)
118
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
119
+ glb_path = os.path.join(user_dir, 'sample.glb')
120
+ glb.export(glb_path)
121
+ torch.cuda.empty_cache()
122
+ return glb_path, glb_path
123
+
124
+
125
+ @spaces.GPU
126
+ def extract_gaussian(state: dict, req: gr.Request) -> tuple:
127
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
128
+ gs, _ = unpack_state(state)
129
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
130
+ gs.save_ply(gaussian_path)
131
+ torch.cuda.empty_cache()
132
+ return gaussian_path, gaussian_path
133
+
134
+
135
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
136
+ return {
137
+ 'gaussian': {
138
+ **gs.init_params,
139
+ '_xyz': gs._xyz.cpu().numpy(),
140
+ '_features_dc': gs._features_dc.cpu().numpy(),
141
+ '_scaling': gs._scaling.cpu().numpy(),
142
+ '_rotation': gs._rotation.cpu().numpy(),
143
+ '_opacity': gs._opacity.cpu().numpy(),
144
+ },
145
+ 'mesh': {
146
+ 'vertices': mesh.vertices.cpu().numpy(),
147
+ 'faces': mesh.faces.cpu().numpy(),
148
+ },
149
+ }
150
+
151
+
152
+ def unpack_state(state: dict) -> tuple:
153
+ gs = Gaussian(
154
+ aabb=state['gaussian']['aabb'],
155
+ sh_degree=state['gaussian']['sh_degree'],
156
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
157
+ scaling_bias=state['gaussian']['scaling_bias'],
158
+ opacity_bias=state['gaussian']['opacity_bias'],
159
+ scaling_activation=state['gaussian']['scaling_activation'],
160
+ )
161
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
162
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
163
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
164
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
165
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
166
+
167
+ mesh = edict(
168
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
169
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
170
+ )
171
+
172
+ return gs, mesh
173
+
174
+ def get_sam_predictor():
175
+ sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
176
+ model_type = "vit_h"
177
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
178
+ sam_predictor = SamPredictor(sam)
179
+ return sam_predictor
180
+
181
+
182
+ def draw_points_on_image(image, point):
183
+ image_with_points = image.copy()
184
+ x, y = point
185
+ color = (255, 0, 0)
186
+ cv2.circle(image_with_points, (int(x), int(y)), radius=10, color=color, thickness=-1)
187
+ return image_with_points
188
+
189
+
190
+ def see_point(image, x, y):
191
+ updated_image = draw_points_on_image(image, [x,y])
192
+ return updated_image
193
+
194
+ def add_point(x, y, visible_points):
195
+ if [x, y] not in visible_points:
196
+ visible_points.append([x, y])
197
+ return visible_points
198
+
199
+ def delete_point(visible_points):
200
+ visible_points.pop()
201
+ return visible_points
202
+
203
+
204
+ def clear_all_points(image):
205
+ updated_image = image.copy()
206
+ return updated_image
207
+
208
+ def see_visible_points(image, visible_points):
209
+ updated_image = image.copy()
210
+ for p in visible_points:
211
+ cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
212
+ return updated_image
213
+
214
+ def see_occlusion_points(image, occlusion_points):
215
+ updated_image = image.copy()
216
+ for p in occlusion_points:
217
+ cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(0, 255, 0), thickness=-1)
218
+ return updated_image
219
+
220
+ def update_all_points(points):
221
+ text = f"Points: {points}"
222
+ dropdown_choices = [f"({p[0]}, {p[1]})" for p in points]
223
+ return text, gr.Dropdown(show_label=False, choices=dropdown_choices, value=None, interactive=True)
224
+
225
+ def delete_selected(image, visible_points, occlusion_points, occlusion_mask_list, selected_value, point_type):
226
+ if point_type == "visibility":
227
+ try:
228
+ selected_index = [f"({p[0]}, {p[1]})" for p in visible_points].index(selected_value)
229
+ except ValueError:
230
+ selected_index = None
231
+ if selected_index is not None and 0 <= selected_index < len(visible_points):
232
+ visible_points.pop(selected_index)
233
+ else:
234
+ try:
235
+ selected_index = [f"({p[0]}, {p[1]})" for p in occlusion_points].index(selected_value)
236
+ except ValueError:
237
+ selected_index = None
238
+ if selected_index is not None and 0 <= selected_index < len(occlusion_points):
239
+ occlusion_points.pop(selected_index)
240
+ occlusion_mask_list.pop(selected_index)
241
+ updated_image = image.copy()
242
+ updated_image = see_visible_points(updated_image, visible_points)
243
+ updated_image = see_occlusion_points(updated_image, occlusion_points)
244
+ if point_type == "visibility":
245
+ updated_text, dropdown = update_all_points(visible_points)
246
+ else:
247
+ updated_text, dropdown = update_all_points(occlusion_points)
248
+ return updated_image, visible_points, occlusion_points, updated_text, dropdown
249
+
250
+ def add_current_mask(visibility_mask, visibilty_mask_list, point_type):
251
+ if point_type == "visibility":
252
+ if len(visibilty_mask_list) > 0:
253
+ if np.array_equal(visibility_mask, visibilty_mask_list[-1]):
254
+ return visibilty_mask_list
255
+ visibilty_mask_list.append(visibility_mask)
256
+ return visibilty_mask_list
257
+ else: # the occlusion mask will be automatically added, so do nothing here
258
+ return visibilty_mask_list
259
+
260
+ def apply_mask_overlay(image, mask, color=(255, 0, 0)):
261
+ img_arr = image
262
+ overlay = img_arr.copy()
263
+ gray_color = np.array([200, 200, 200], dtype=np.uint8)
264
+ non_mask = mask == 0
265
+ overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
266
+ contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
267
+ cv2.drawContours(overlay, contours, -1, color, 2)
268
+ return overlay
269
+
270
+ def vis_mask(image, mask_list):
271
+ updated_image = image.copy()
272
+ combined_mask = np.zeros_like(updated_image[:, :, 0])
273
+ for mask in mask_list:
274
+ combined_mask = cv2.bitwise_or(combined_mask, mask)
275
+ updated_image = apply_mask_overlay(updated_image, combined_mask)
276
+ return updated_image
277
+
278
+ def segment_and_overlay(image, points, sam_predictor, mask_list, point_type):
279
+ if point_type == "visibility":
280
+ visible_mask = run_sam(image, sam_predictor, points)
281
+ for mask in mask_list:
282
+ visible_mask = cv2.bitwise_or(visible_mask, mask)
283
+ overlaid = apply_mask_overlay(image, visible_mask * 255)
284
+ return overlaid, visible_mask, mask_list
285
+ else:
286
+ combined_occlusion_mask = np.zeros_like(image[:, :, 0])
287
+ mask_list = []
288
+ if len(points) != 0:
289
+ for point in points:
290
+ mask = run_sam(image, sam_predictor, [point])
291
+ mask_list.append(mask)
292
+ combined_occlusion_mask = cv2.bitwise_or(combined_occlusion_mask, mask)
293
+ overlaid = apply_mask_overlay(image, combined_occlusion_mask * 255, color=(0, 255, 0))
294
+ return overlaid, combined_occlusion_mask, mask_list
295
+
296
+ def delete_mask(visibility_mask_list, occlusion_mask_list, occlusion_points_state, point_type):
297
+ if point_type == "visibility":
298
+ if len(visibility_mask_list) > 0:
299
+ visibility_mask_list.pop()
300
+ else:
301
+ if len(occlusion_mask_list) > 0:
302
+ occlusion_mask_list.pop()
303
+ occlusion_points_state.pop()
304
+ return visibility_mask_list, occlusion_mask_list, occlusion_points_state
305
+
306
+ def check_combined_mask(image, visibility_mask, visibility_mask_list, occlusion_mask_list, scale=0.68):
307
+ if visibility_mask.sum() == 0:
308
+ return np.zeros_like(image), np.zeros_like(image[:, :, 0])
309
+ updated_image = image.copy()
310
+ combined_mask = np.zeros_like(updated_image[:, :, 0])
311
+ occluded_mask = np.zeros_like(updated_image[:, :, 0])
312
+ binary_visibility_masks = [(m > 0).astype(np.uint8) for m in visibility_mask_list]
313
+ combined_mask = np.zeros_like(binary_visibility_masks[0]) if binary_visibility_masks else (visibility_mask > 0).astype(np.uint8)
314
+ for m in binary_visibility_masks:
315
+ combined_mask = cv2.bitwise_or(combined_mask, m)
316
+
317
+ if len(binary_visibility_masks) > 1:
318
+ kernel = np.ones((5, 5), np.uint8)
319
+ combined_mask = cv2.dilate(combined_mask, kernel, iterations=1)
320
+
321
+ binary_occlusion_masks = [(m > 0).astype(np.uint8) for m in occlusion_mask_list]
322
+ occluded_mask = np.zeros_like(binary_occlusion_masks[0]) if binary_occlusion_masks else np.zeros_like(combined_mask)
323
+ for m in binary_occlusion_masks:
324
+ occluded_mask = cv2.bitwise_or(occluded_mask, m)
325
+
326
+ kernel_small = np.ones((3, 3), np.uint8)
327
+ if len(binary_occlusion_masks) > 0:
328
+ dilated = cv2.dilate(combined_mask, kernel_small, iterations=1)
329
+ boundary_mask = dilated - combined_mask
330
+ occluded_mask = cv2.bitwise_or(occluded_mask, boundary_mask)
331
+ occluded_mask = (occluded_mask > 0).astype(np.uint8)
332
+ occluded_mask = cv2.dilate(occluded_mask, kernel_small, iterations=1)
333
+ occluded_mask = (occluded_mask > 0).astype(np.uint8)
334
+ else:
335
+ occluded_mask = 1 - combined_mask
336
+
337
+ combined_mask[occluded_mask == 1] = 0
338
+
339
+ occluded_mask = (1-occluded_mask) * 255
340
+
341
+ masked_img = updated_image * combined_mask[:, :, None]
342
+ occluded_mask[combined_mask == 1] = 127
343
+
344
+ x, y, w, h = cv2.boundingRect(combined_mask.astype(np.uint8))
345
+
346
+ ori_h, ori_w = masked_img.shape[:2]
347
+ target_size = 512
348
+ scale_factor = target_size / max(w, h)
349
+ final_scale = scale_factor * scale
350
+ new_w = int(round(ori_w * final_scale))
351
+ new_h = int(round(ori_h * final_scale))
352
+
353
+ resized_occluded_mask = cv2.resize(occluded_mask.astype(np.uint8), (new_w, new_h), interpolation=cv2.INTER_NEAREST)
354
+ resized_img = cv2.resize(masked_img, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
355
+
356
+ final_img = np.zeros((target_size, target_size, 3), dtype=updated_image.dtype)
357
+ final_occluded_mask = np.ones((target_size, target_size), dtype=np.uint8) * 255
358
+
359
+ new_x = int(round(x * final_scale))
360
+ new_y = int(round(y * final_scale))
361
+ new_w_box = int(round(w * final_scale))
362
+ new_h_box = int(round(h * final_scale))
363
+
364
+ new_cx = new_x + new_w_box // 2
365
+ new_cy = new_y + new_h_box // 2
366
+
367
+ final_cx, final_cy = target_size // 2, target_size // 2
368
+ x_offset = final_cx - new_cx
369
+ y_offset = final_cy - new_cy
370
+
371
+ final_x_start = max(0, x_offset)
372
+ final_y_start = max(0, y_offset)
373
+ final_x_end = min(target_size, x_offset + new_w)
374
+ final_y_end = min(target_size, y_offset + new_h)
375
+
376
+ img_x_start = max(0, -x_offset)
377
+ img_y_start = max(0, -y_offset)
378
+ img_x_end = min(new_w, target_size - x_offset)
379
+ img_y_end = min(new_h, target_size - y_offset)
380
+
381
+ final_img[final_y_start:final_y_end, final_x_start:final_x_end] = resized_img[img_y_start:img_y_end, img_x_start:img_x_end]
382
+ final_occluded_mask[final_y_start:final_y_end, final_x_start:final_x_end] = resized_occluded_mask[img_y_start:img_y_end, img_x_start:img_x_end]
383
+
384
+ return final_img, final_occluded_mask
385
+
386
+
387
+ def get_point(img, point_type, visible_points_state, occlusion_points_state, evt: gr.SelectData):
388
+ updated_img = np.array(img).copy()
389
+ if point_type == "visibility":
390
+ visible_points_state = add_point(evt.index[0], evt.index[1], visible_points_state)
391
+ else:
392
+ occlusion_points_state = add_point(evt.index[0], evt.index[1], occlusion_points_state)
393
+ updated_img = see_visible_points(updated_img, visible_points_state)
394
+ updated_img = see_occlusion_points(updated_img, occlusion_points_state)
395
+ return updated_img, visible_points_state, occlusion_points_state
396
+
397
+
398
+ def change_point_type(point_type, visible_points_state, occlusion_points_state):
399
+ if point_type == "visibility":
400
+ text = f"Points: {visible_points_state}"
401
+ dropdown_choices = [f"({p[0]}, {p[1]})" for p in visible_points_state]
402
+ else:
403
+ text = f"Points: {occlusion_points_state}"
404
+ dropdown_choices = [f"({p[0]}, {p[1]})" for p in occlusion_points_state]
405
+ return text, gr.Dropdown(show_label=False, choices=dropdown_choices, value=None, interactive=True)
406
+
407
+
408
+ def get_seed(randomize_seed: bool, seed: int) -> int:
409
+ """
410
+ Get the random seed.
411
+ """
412
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
413
+
414
+
415
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
416
+ gr.Markdown("""
417
+ ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
418
+ """)
419
+
420
+ predictor = gr.State(value=get_sam_predictor())
421
+ visible_points_state = gr.State(value=[])
422
+ occlusion_points_state = gr.State(value=[])
423
+ occlusion_mask = gr.State(value=None)
424
+ occlusion_mask_list = gr.State(value=[])
425
+ original_image = gr.State(value=None)
426
+ visibility_mask = gr.State(value=None)
427
+ visibility_mask_list = gr.State(value=[])
428
+
429
+ occluded_mask = gr.State(value=None)
430
+ output_buf = gr.State()
431
+
432
+
433
+ with gr.Row():
434
+ with gr.Column():
435
+ gr.Markdown("""
436
+ ### Step 1 - Generate Visibility and Occlusion Mask.
437
+ * Please click "Load Example Image" when using the provided example images (bottom).
438
+ * Please wait for a few seconds after uploading the image. Segment Anything is getting ready.
439
+ * **Click to add the point prompts** to indicate the target object (multiple points supported) and occluders (one point for an occluder for better usability).
440
+ * "Add mask", current mask will be saved if the input needs to be added sequentially.
441
+ * The scale of target object can be adjusted for better reconstruction, we suggest 0.4 to 0.7 for most cases.
442
+ """)
443
+ with gr.Row():
444
+ input_image = gr.Image(interactive=True, type='pil', label='Input Occlusion Image', show_label=True, sources="upload", height=300)
445
+ input_with_prompt = gr.Image(type="numpy", label='Input with Prompt', interactive=False, height=300)
446
+ with gr.Row():
447
+ apply_example_btn = gr.Button("Load Example Image")
448
+ message = gr.Markdown("Please wait a few seconds after uploading the image.", label="Message")
449
+ with gr.Row():
450
+ point_type = gr.Radio(["visibility", "occlusion"], label="Point Prompt Type", value="visibility")
451
+ with gr.Row():
452
+ with gr.Column():
453
+ points_text = gr.Textbox(show_label=False, interactive=False)
454
+ with gr.Column():
455
+ points_dropdown = gr.Dropdown(show_label=False, choices=[], value=None, interactive=True)
456
+ delete_button = gr.Button("Delete Selected Point")
457
+ with gr.Row():
458
+ with gr.Column():
459
+ render_mask = gr.Image(label='Render Mask', interactive=False, height=300)
460
+ with gr.Row():
461
+ add_mask = gr.Button("Add Mask")
462
+ undo_mask = gr.Button("Undo Last Mask")
463
+ with gr.Column():
464
+ vis_input = gr.Image(label='Visible Input', interactive=False, height=300)
465
+ with gr.Row():
466
+ zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.68, step=0.1)
467
+ with gr.Row():
468
+ check_visible_input = gr.Button("Generate Occluded Input")
469
+
470
+ with gr.Column():
471
+ gr.Markdown("""
472
+ ### Step 2 - 3D Amodal Reconstruction. (Thanks to [TRELLIS](https://huggingface.co/spaces/JeffreyXiang/TRELLIS) for the 3D rendering component!)
473
+ * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
474
+ * The boundary of the segmentation may not be accurate, so here we provide the option to erode the visible area (try 0, 3 or 5).
475
+ * If the reconstructed 3D asset is satisfactory, interactive GLB file can be extracted (may look dull due to the absence of light source) and downloaded.
476
+ """)
477
+ with gr.Row():
478
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
479
+ with gr.Row():
480
+ with gr.Accordion(label="Generation Settings", open=False):
481
+ with gr.Row():
482
+ with gr.Column():
483
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
484
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
485
+ with gr.Column():
486
+ erode_kernel_size = gr.Slider(0, 5, label="Erode Kernel Size", value=3, step=1)
487
+ gr.Markdown("Stage 1: Sparse Structure Generation")
488
+ with gr.Row():
489
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
490
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
491
+ gr.Markdown("Stage 2: Structured Latent Generation")
492
+ with gr.Row():
493
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
494
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
495
+ with gr.Row():
496
+ generate_btn = gr.Button("Amodal 3D Reconstruction")
497
+ with gr.Row():
498
+ model_output = gr.Model3D(label="Extracted GLB", pan_speed=0.5, height=300, clear_color=(0.9,0.9,0.9,1))
499
+ with gr.Row():
500
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
501
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
502
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
503
+ with gr.Row():
504
+ extract_glb_btn = gr.Button("Extract GLB")
505
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
506
+
507
+ with gr.Row():
508
+ examples = gr.Examples(
509
+ examples=[
510
+ f'assets/example_image/{image}'
511
+ for image in os.listdir("assets/example_image")
512
+ ],
513
+ inputs=[input_image],
514
+ fn=lambda x: x,
515
+ outputs=[input_image],
516
+ run_on_click=True,
517
+ examples_per_page=12,
518
+ )
519
+
520
+
521
+ # # Handlers
522
+ demo.load(start_session)
523
+ demo.unload(end_session)
524
+
525
+ input_image.upload(
526
+ change_message,
527
+ [],
528
+ [message]
529
+ ).then(
530
+ reset_image,
531
+ [predictor, input_image],
532
+ [predictor, original_image, message, visible_points_state, occlusion_points_state, occlusion_mask_list, input_with_prompt],
533
+ )
534
+
535
+ apply_example_btn.click(
536
+ change_message,
537
+ [],
538
+ [message]
539
+ ).then(
540
+ reset_image,
541
+ inputs=[predictor, input_image],
542
+ outputs=[predictor, original_image, message, visible_points_state, occlusion_points_state, occlusion_mask_list, input_with_prompt]
543
+ )
544
+ input_image.select(
545
+ get_point,
546
+ inputs=[input_image, point_type, visible_points_state, occlusion_points_state],
547
+ outputs=[input_with_prompt, visible_points_state, occlusion_points_state]
548
+ )
549
+
550
+ point_type.change(
551
+ change_point_type,
552
+ inputs=[point_type, visible_points_state, occlusion_points_state],
553
+ outputs=[points_text, points_dropdown]
554
+ )
555
+
556
+ visible_points_state.change(
557
+ update_all_points,
558
+ inputs=[visible_points_state],
559
+ outputs=[points_text, points_dropdown]
560
+ ).then(
561
+ segment_and_overlay,
562
+ inputs=[original_image, visible_points_state, predictor, visibility_mask_list, point_type],
563
+ outputs=[render_mask, visibility_mask, visibility_mask_list]
564
+ ).then(
565
+ check_combined_mask,
566
+ inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale],
567
+ outputs=[vis_input, occluded_mask]
568
+ )
569
+
570
+ occlusion_points_state.change(
571
+ update_all_points,
572
+ inputs=[occlusion_points_state],
573
+ outputs=[points_text, points_dropdown]
574
+ ).then(
575
+ segment_and_overlay,
576
+ inputs=[original_image, occlusion_points_state, predictor, occlusion_mask_list, point_type],
577
+ outputs=[render_mask, occlusion_mask, occlusion_mask_list]
578
+ ).then(
579
+ check_combined_mask,
580
+ inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale],
581
+ outputs=[vis_input, occluded_mask]
582
+ )
583
+
584
+ delete_button.click(
585
+ delete_selected,
586
+ inputs=[original_image, visible_points_state, occlusion_points_state, occlusion_mask_list, points_dropdown, point_type],
587
+ outputs=[input_with_prompt, visible_points_state, occlusion_points_state, points_text, points_dropdown]
588
+ )
589
+
590
+ add_mask.click(
591
+ add_current_mask,
592
+ inputs=[visibility_mask, visibility_mask_list, point_type],
593
+ outputs=[visibility_mask_list]
594
+ )
595
+
596
+ undo_mask.click(
597
+ delete_mask,
598
+ inputs=[visibility_mask_list, occlusion_mask_list, occlusion_points_state, point_type],
599
+ outputs=[visibility_mask_list, occlusion_mask_list, occlusion_points_state]
600
+ )
601
+
602
+ check_visible_input.click(
603
+ check_combined_mask,
604
+ inputs=[original_image, visibility_mask, visibility_mask_list, occlusion_mask_list, zoom_scale],
605
+ outputs=[vis_input, occluded_mask]
606
+ )
607
+
608
+
609
+ # 3D Amodal Reconstruction
610
+ generate_btn.click(
611
+ get_seed,
612
+ inputs=[randomize_seed, seed],
613
+ outputs=[seed],
614
+ ).then(
615
+ image_to_3d,
616
+ inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, erode_kernel_size],
617
+ outputs=[output_buf, video_output],
618
+ )
619
+
620
+ extract_glb_btn.click(
621
+ extract_glb,
622
+ inputs=[output_buf, mesh_simplify, texture_size],
623
+ outputs=[model_output, download_glb],
624
+ ).then(
625
+ lambda: gr.Button(interactive=True),
626
+ outputs=[download_glb],
627
+ )
628
+
629
+ model_output.clear(
630
+ lambda: gr.Button(interactive=False),
631
+ outputs=[download_glb],
632
+ )
633
+
634
+
635
+
636
+ if __name__ == "__main__":
637
+ pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
638
+ pipeline.cuda()
639
+ try:
640
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
641
+ except:
642
+ pass
643
  demo.launch()