Sm0kyWu commited on
Commit
e19c4e5
·
verified ·
1 Parent(s): 7394bda

Upload app_old.py

Browse files
Files changed (1) hide show
  1. app_old.py +659 -0
app_old.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
+
5
+ import os
6
+ import shutil
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from easydict import EasyDict as edict
13
+ from PIL import Image
14
+ from Amodal3R.pipelines import Amodal3RImageTo3DPipeline
15
+ from Amodal3R.representations import Gaussian, MeshExtractResult
16
+ from Amodal3R.utils import render_utils, postprocessing_utils
17
+ from segment_anything import sam_model_registry, SamPredictor
18
+ from huggingface_hub import hf_hub_download
19
+ import cv2
20
+
21
+
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
24
+ os.makedirs(TMP_DIR, exist_ok=True)
25
+
26
+
27
+ def start_session(req: gr.Request):
28
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
29
+ os.makedirs(user_dir, exist_ok=True)
30
+
31
+
32
+ def end_session(req: gr.Request):
33
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
+ shutil.rmtree(user_dir)
35
+
36
+ def reset_image(predictor, img):
37
+ """
38
+ 上传图像后调用:
39
+ - 重置 predictor,
40
+ - 设置 predictor 的输入图像,
41
+ - 返回原图
42
+ """
43
+ predictor.set_image(img)
44
+ original_img = img.copy()
45
+ # 返回predictor,visible occlusion mask初始化, 原始图像
46
+ return predictor, original_img, "The models are ready."
47
+
48
+ def button_clickable(selected_points):
49
+ if len(selected_points) > 0:
50
+ return gr.Button.update(interactive=True)
51
+ else:
52
+ return gr.Button.update(interactive=False)
53
+
54
+ def run_sam(predictor, selected_points):
55
+ """
56
+ 调用 SAM 模型进行分割。
57
+ """
58
+ # predictor.set_image(image)
59
+ if len(selected_points) == 0:
60
+ return [], None
61
+ input_points = [p for p in selected_points]
62
+ input_labels = [1 for _ in range(len(selected_points))]
63
+ # input_points = np.array([[210, 300]])
64
+ # input_labels = np.array([1])
65
+ masks, _, _ = predictor.predict(
66
+ point_coords=np.array(input_points),
67
+ point_labels=np.array(input_labels),
68
+ multimask_output=False, # 单对象输出
69
+ )
70
+ best_mask = masks[0].astype(np.uint8)
71
+ # dilate
72
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
73
+ best_mask = cv2.dilate(best_mask, kernel, iterations=1)
74
+ best_mask = cv2.erode(best_mask, kernel, iterations=1)
75
+ return best_mask
76
+
77
+ def apply_mask_overlay(image, mask, color=(255, 0, 0)):
78
+ """
79
+ 在原图上叠加 mask:使用红色绘制 mask 的轮廓,非 mask 区域叠加浅灰色半透明遮罩。
80
+ """
81
+ img_arr = image
82
+ overlay = img_arr.copy()
83
+ gray_color = np.array([200, 200, 200], dtype=np.uint8)
84
+ non_mask = mask == 0
85
+ overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
86
+ contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
87
+ cv2.drawContours(overlay, contours, -1, color, 2)
88
+ return overlay
89
+
90
+ def segment_and_overlay(image, points, sam_predictor):
91
+ """
92
+ 调用 run_sam 获得 mask,然后叠加显示分割结果。
93
+ """
94
+ visible_mask = run_sam(sam_predictor, points)
95
+ overlaid = apply_mask_overlay(image, visible_mask * 255)
96
+ return overlaid, visible_mask
97
+
98
+
99
+ def reset_points():
100
+ """
101
+ 清空点击点提示。
102
+ """
103
+ return [], ""
104
+
105
+
106
+ @spaces.GPU
107
+ def image_to_3d(
108
+ image: List[tuple],
109
+ masks: List[np.ndarray],
110
+ seed: int,
111
+ ss_guidance_strength: float,
112
+ ss_sampling_steps: int,
113
+ slat_guidance_strength: float,
114
+ slat_sampling_steps: int,
115
+ multiimage_algo: str,
116
+ req: gr.Request,
117
+ ) -> tuple:
118
+ """
119
+ 将图像转换为 3D 模型。
120
+ """
121
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
122
+ outputs = pipeline.run_multi_image(
123
+ [img[0] for img in image],
124
+ [mask[0] for mask in masks],
125
+ seed=seed,
126
+ formats=["gaussian", "mesh"],
127
+ preprocess_image=False,
128
+ sparse_structure_sampler_params={
129
+ "steps": ss_sampling_steps,
130
+ "cfg_strength": ss_guidance_strength,
131
+ },
132
+ slat_sampler_params={
133
+ "steps": slat_sampling_steps,
134
+ "cfg_strength": slat_guidance_strength,
135
+ },
136
+ mode=multiimage_algo,
137
+ )
138
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
139
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
140
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
141
+ video_path = os.path.join(user_dir, 'sample.mp4')
142
+ imageio.mimsave(video_path, video, fps=15)
143
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
144
+ torch.cuda.empty_cache()
145
+ return state, video_path
146
+
147
+
148
+ @spaces.GPU(duration=90)
149
+ def extract_glb(
150
+ state: dict,
151
+ mesh_simplify: float,
152
+ texture_size: int,
153
+ req: gr.Request,
154
+ ) -> tuple:
155
+ """
156
+ 从生成的 3D 模型中提取 GLB 文件。
157
+ """
158
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
159
+ gs, mesh = unpack_state(state)
160
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
161
+ glb_path = os.path.join(user_dir, 'sample.glb')
162
+ glb.export(glb_path)
163
+ torch.cuda.empty_cache()
164
+ return glb_path, glb_path
165
+
166
+
167
+ @spaces.GPU
168
+ def extract_gaussian(state: dict, req: gr.Request) -> tuple:
169
+ """
170
+ 从生成的 3D 模型中提取 Gaussian 文件。
171
+ """
172
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
173
+ gs, _ = unpack_state(state)
174
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
175
+ gs.save_ply(gaussian_path)
176
+ torch.cuda.empty_cache()
177
+ return gaussian_path, gaussian_path
178
+
179
+
180
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
181
+ return {
182
+ 'gaussian': {
183
+ **gs.init_params,
184
+ '_xyz': gs._xyz.cpu().numpy(),
185
+ '_features_dc': gs._features_dc.cpu().numpy(),
186
+ '_scaling': gs._scaling.cpu().numpy(),
187
+ '_rotation': gs._rotation.cpu().numpy(),
188
+ '_opacity': gs._opacity.cpu().numpy(),
189
+ },
190
+ 'mesh': {
191
+ 'vertices': mesh.vertices.cpu().numpy(),
192
+ 'faces': mesh.faces.cpu().numpy(),
193
+ },
194
+ }
195
+
196
+
197
+ def unpack_state(state: dict) -> tuple:
198
+ gs = Gaussian(
199
+ aabb=state['gaussian']['aabb'],
200
+ sh_degree=state['gaussian']['sh_degree'],
201
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
202
+ scaling_bias=state['gaussian']['scaling_bias'],
203
+ opacity_bias=state['gaussian']['opacity_bias'],
204
+ scaling_activation=state['gaussian']['scaling_activation'],
205
+ )
206
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
207
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
208
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
209
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
210
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
211
+
212
+ mesh = edict(
213
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
214
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
215
+ )
216
+
217
+ return gs, mesh
218
+
219
+
220
+ def prepare_multi_example() -> list:
221
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
222
+ images = []
223
+ for case in multi_case:
224
+ _images = []
225
+ for i in range(1, 4):
226
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
227
+ W, H = img.size
228
+ img = img.resize((int(W / H * 512), 512))
229
+ _images.append(np.array(img))
230
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
231
+ return images
232
+
233
+
234
+ def split_image(image: Image.Image) -> list:
235
+ """
236
+ 将图像拆分为多个视图(不进行预处理)。
237
+ """
238
+ image = np.array(image)
239
+ alpha = image[..., 3]
240
+ alpha = np.any(alpha > 0, axis=0)
241
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
242
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
243
+ images = []
244
+ for s, e in zip(start_pos, end_pos):
245
+ images.append(Image.fromarray(image[:, s:e+1]))
246
+ return [image for image in images]
247
+
248
+ def get_sam_predictor():
249
+ sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
250
+ model_type = "vit_h"
251
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
252
+ # sam.cuda()
253
+ sam_predictor = SamPredictor(sam)
254
+ return sam_predictor
255
+
256
+
257
+ def draw_points_on_image(image, point, point_type):
258
+ """在图像上绘制所有点,points 为 [(x, y, point_type), ...]"""
259
+ image_with_points = image.copy()
260
+ x, y = point
261
+ color = (255, 0, 0) if point_type == "vis" else (0, 255, 0)
262
+ cv2.circle(image_with_points, (int(x), int(y)), radius=10, color=color, thickness=-1)
263
+ return image_with_points
264
+
265
+
266
+ def see_point(image, x, y, point_type):
267
+ """
268
+ see操作:不修改 points 列表,仅在图像上临时显示这个点,
269
+ 并返回更新后的图像和当前列表(不更新)。
270
+ """
271
+ # 复制当前列表,并在副本中加上新点(仅用于显示)
272
+ updated_image = draw_points_on_image(image, [x,y], point_type)
273
+ return updated_image
274
+
275
+ def add_point(x, y, point_type, visible_points, occlusion_points):
276
+ """
277
+ add操作:将新点添加到 points 列表中,
278
+ 并返回更新后的图像和新的点列表。
279
+ """
280
+ if point_type == "vis":
281
+ # check duplicate
282
+ if [x, y] not in visible_points:
283
+ visible_points.append([x, y])
284
+ else:
285
+ if [x, y] not in occlusion_points:
286
+ occlusion_points.append([x, y])
287
+ return visible_points, occlusion_points
288
+
289
+ def delete_point(point_type, visible_points, occlusion_points):
290
+ """
291
+ delete操作:删除 points 列表中的最后一个点,
292
+ 并返回更新后的图像和新的点列表。
293
+ """
294
+ if point_type == "vis":
295
+ visible_points.pop()
296
+ else:
297
+ occlusion_points.pop()
298
+ return visible_points, occlusion_points
299
+
300
+
301
+ def clear_all_points(image):
302
+ """
303
+ 清除所有点:返回原图、空的 visible 和 occlusion 列表,
304
+ 以及更新后的点文本信息和空下拉菜单列表。
305
+ """
306
+ updated_image = image.copy()
307
+ return updated_image
308
+
309
+ def see_visible_points(image, visible_points):
310
+ """
311
+ 在图像上绘制所有 visible 点(红色)。
312
+ """
313
+ updated_image = image.copy()
314
+ for p in visible_points:
315
+ cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
316
+ return updated_image
317
+
318
+ def see_occlusion_points(image, occlusion_points):
319
+ """
320
+ 在图像上绘制所有 occlusion 点(绿色)。
321
+ """
322
+ updated_image = image.copy()
323
+ for p in occlusion_points:
324
+ cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(0, 255, 0), thickness=-1)
325
+ return updated_image
326
+
327
+ def update_all_points(visible_points, occlusion_points):
328
+ text = f"Visible Points: {visible_points}\nOcclusion Points: {occlusion_points}"
329
+ visible_dropdown_choices = [f"({p[0]}, {p[1]})" for p in visible_points]
330
+ occlusion_dropdown_choices = [f"({p[0]}, {p[1]})" for p in occlusion_points]
331
+ # 返回更新字典来明确设置 choices 和 value
332
+ return text, gr.Dropdown(label="Select Visible Point to Delete", choices=visible_dropdown_choices, value=None, interactive=True), gr.Dropdown(label="Select Occlusion Point to Delete", choices=occlusion_dropdown_choices, value=None, interactive=True)
333
+
334
+ def delete_selected_visible(image, visible_points, occlusion_points, selected_value):
335
+ # selected_value 是类似 "(x, y)" 的字符串
336
+ try:
337
+ selected_index = [f"({p[0]}, {p[1]})" for p in visible_points].index(selected_value)
338
+ except ValueError:
339
+ selected_index = None
340
+ if selected_index is not None and 0 <= selected_index < len(visible_points):
341
+ visible_points.pop(selected_index)
342
+ updated_image = image.copy()
343
+ # 重新绘制所有 visible 点(红色)
344
+ for p in visible_points:
345
+ cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
346
+ updated_text, vis_dropdown, occ_dropdown = update_all_points(visible_points, occlusion_points)
347
+ return updated_image, visible_points, occlusion_points, updated_text, vis_dropdown, occ_dropdown
348
+
349
+ def delete_selected_occlusion(image, visible_points, occlusion_points, selected_value):
350
+ try:
351
+ selected_index = [f"({p[0]}, {p[1]})" for p in occlusion_points].index(selected_value)
352
+ except ValueError:
353
+ selected_index = None
354
+ if selected_index is not None and 0 <= selected_index < len(occlusion_points):
355
+ occlusion_points.pop(selected_index)
356
+ updated_image = image.copy()
357
+ # 重新绘制所有 occlusion 点(绿色)
358
+ for p in occlusion_points:
359
+ cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(0, 255, 0), thickness=-1)
360
+ updated_text, vis_dropdown, occ_dropdown = update_all_points(visible_points, occlusion_points)
361
+ return updated_image, visible_points, occlusion_points, updated_text, vis_dropdown, occ_dropdown
362
+
363
+ def add_mask(mask, mask_list):
364
+ # check if the mask if same as the last mask in the list
365
+ if len(mask_list) > 0:
366
+ if np.array_equal(mask, mask_list[-1]):
367
+ return mask_list
368
+ mask_list.append(mask)
369
+ return mask_list
370
+
371
+ def vis_mask(image, mask_list):
372
+ updated_image = image.copy()
373
+ # combine all the mask:
374
+ combined_mask = np.zeros_like(updated_image[:, :, 0])
375
+ for mask in mask_list:
376
+ combined_mask = cv2.bitwise_or(combined_mask, mask)
377
+ # overlay the mask on the image
378
+ updated_image = apply_mask_overlay(updated_image, combined_mask)
379
+ return updated_image
380
+
381
+ def delete_mask(mask_list):
382
+ if len(mask_list) > 0:
383
+ mask_list.pop()
384
+ return mask_list
385
+
386
+
387
+ def apply_combined_mask_overlay(image, vis_mask, occ_mask):
388
+ """
389
+ 在原图上叠加 mask:使用红色绘制 mask 的轮廓,非 mask 区域叠加浅灰色半透明遮罩。
390
+ """
391
+ img_arr = image
392
+ overlay = img_arr.copy()
393
+ gray_color = np.array([200, 200, 200], dtype=np.uint8)
394
+ non_mask = (vis_mask == 0) & (occ_mask == 0)
395
+ overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
396
+ contours_occ, _ = cv2.findContours(occ_mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
397
+ cv2.drawContours(overlay, contours_occ, -1, (0,0,255), 2)
398
+ contours_vis, _ = cv2.findContours(vis_mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
399
+ cv2.drawContours(overlay, contours_vis, -1, (255,0,0), 2)
400
+ return overlay
401
+
402
+
403
+ def combine_mask(image, visible_mask_list, occlusion_mask_list):
404
+ combined_vis_mask = np.zeros_like(image[:, :, 0])
405
+ combined_occ_mask = np.zeros_like(image[:, :, 0])
406
+ combined_mask = np.zeros_like(image[:, :, 0])
407
+ for mask in visible_mask_list:
408
+ combined_vis_mask = cv2.bitwise_or(combined_mask, mask)
409
+ for mask in occlusion_mask_list:
410
+ combined_occ_mask = cv2.bitwise_or(combined_mask, mask)
411
+ # 添加 visible mask 边缘作为 occlusion mask 的一部分
412
+
413
+ overlay = apply_combined_mask_overlay(image, combined_vis_mask, combined_occ_mask)
414
+ # 5*5 kernel dilate for occlusion mask
415
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
416
+ combined_occ_mask = cv2.dilate(combined_occ_mask, kernel, iterations=1)
417
+ combined_mask[combined_occ_mask > 0] = 128
418
+ combined_mask[combined_vis_mask > 0] = 255
419
+ # concat the mask and overlay to be a single image
420
+ print(overlay.shape, combined_mask.shape)
421
+ result = cv2.hconcat([overlay, combined_mask[..., None].repeat(3, axis=-1)])
422
+ return result, combined_mask, occluded_image
423
+
424
+
425
+ def get_seed(randomize_seed: bool, seed: int) -> int:
426
+ """
427
+ Get the random seed.
428
+ """
429
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
430
+
431
+
432
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
433
+ gr.Markdown("""
434
+ ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
435
+ """)
436
+
437
+ # 定义各状态变量
438
+ predictor = gr.State(value=get_sam_predictor())
439
+ visible_points_state = gr.State(value=[])
440
+ occlusion_points_state = gr.State(value=[])
441
+ original_image = gr.State(value=None)
442
+ visibility_mask = gr.State(value=None)
443
+ occlusion_mask = gr.State(value=None)
444
+ visibility_mask_list = gr.State(value=[])
445
+ occlusion_mask_list = gr.State(value=[])
446
+
447
+ combined_mask = gr.State(value=None)
448
+ occluded_image = gr.State(value=None)
449
+
450
+
451
+ with gr.Row():
452
+ gr.Markdown("""* Step 1 - Generate Visibility Mask and Occlusion Mask.
453
+ * Please wait for a few seconds after uploading the image. The 2D segmenter is getting ready.
454
+ * Add the point prompts to indicate the target object and occluders separately.
455
+ * "Render Point", see the position of the point to be added.
456
+ * "Add Point", the point will be added to the list.
457
+ * "Generate mask", see the segmented area corresponding to current point list.
458
+ * "Add mask", current mask will be added for 3D amodal completion.
459
+ """)
460
+ with gr.Row():
461
+ with gr.Column():
462
+ input_image = gr.Image(type="numpy", label='Input Occlusion Image', sources="upload", height=300)
463
+ with gr.Row():
464
+ message = gr.Markdown("Please wait a few seconds after uploading the image.", label="Message") # 用于显示提示信息
465
+ with gr.Row():
466
+ x_input = gr.Number(label="X Coordinate", value=0)
467
+ y_input = gr.Number(label="Y Coordinate", value=0)
468
+ point_type = gr.Radio(["vis", "occ"], label="Point Prompt Type", value="vis")
469
+ with gr.Row():
470
+ see_button = gr.Button("Render Point")
471
+ add_button = gr.Button("Add Point")
472
+ with gr.Row():
473
+ # 新增按钮:Clear、分别查看 visible/occlusion
474
+ clear_button = gr.Button("Clear Points")
475
+ see_visible_button = gr.Button("Visible Points")
476
+ see_occlusion_button = gr.Button("Occluded Points")
477
+ with gr.Row():
478
+ # 新增文本框实时显示点列表
479
+ points_text = gr.Textbox(label="Points List", interactive=False)
480
+ with gr.Row():
481
+ # 新增下拉菜单,用户可选择需要删除的点
482
+ visible_points_dropdown = gr.Dropdown(label="Select Visible Point to Delete", choices=[], value=None, interactive=True)
483
+ occlusion_points_dropdown = gr.Dropdown(label="Select Occlusion Point to Delete", choices=[], value=None, interactive=True)
484
+ with gr.Row():
485
+ delete_visible_button = gr.Button("Delete Selected Visible")
486
+ delete_occlusion_button = gr.Button("Delete Selected Occlusion")
487
+ with gr.Column():
488
+ # 用于显示 SAM 分割结果
489
+ visible_mask = gr.Image(label='Visible Mask', interactive=False, height=300)
490
+ with gr.Row():
491
+ gen_vis_mask = gr.Button("Generate Mask")
492
+ add_vis_mask = gr.Button("Add Mask")
493
+ with gr.Row():
494
+ render_vis_mask = gr.Button("Render Mask")
495
+ undo_vis_mask = gr.Button("Undo Last Mask")
496
+ occluded_mask = gr.Image(label='Occlusion Mask', interactive=False, height=300)
497
+ with gr.Row():
498
+ gen_occ_mask = gr.Button("Generate Mask")
499
+ add_occ_mask = gr.Button("Add Mask")
500
+ with gr.Row():
501
+ render_occ_mask = gr.Button("Render Mask")
502
+ undo_occ_mask = gr.Button("Undo Last Mask")
503
+ with gr.Row():
504
+ with gr.Column():
505
+ mask_check = gr.Image(label='Combined Mask', interactive=False, height=300)
506
+ with gr.Row():
507
+ check_combine_button = gr.Button("Check Combined Mask, make sure there is no GAP between the visible area (white) and occluded area (gray)")
508
+ with gr.Row():
509
+ gr.Markdown("""* Step 2 - 3D Amodal Completion.
510
+ * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
511
+ * If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
512
+ """)
513
+ with gr.Row():
514
+ with gr.Column():
515
+ with gr.Accordion(label="Generation Settings", open=True):
516
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
517
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
518
+ gr.Markdown("Stage 1: Sparse Structure Generation")
519
+ with gr.Row():
520
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
521
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
522
+ gr.Markdown("Stage 2: Structured Latent Generation")
523
+ with gr.Row():
524
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
525
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
526
+ generate_btn = gr.Button("Generate")
527
+ with gr.Column():
528
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
529
+
530
+
531
+ # ---------------------------
532
+ # 原有交互逻辑(略)
533
+ # ---------------------------
534
+ input_image.upload(
535
+ reset_image,
536
+ [predictor, input_image],
537
+ [predictor, original_image, message],
538
+ )
539
+ see_button.click(
540
+ see_point,
541
+ inputs=[original_image, x_input, y_input, point_type],
542
+ outputs=[input_image]
543
+ )
544
+ add_button.click(
545
+ add_point,
546
+ inputs=[x_input, y_input, point_type, visible_points_state, occlusion_points_state],
547
+ outputs=[visible_points_state, occlusion_points_state]
548
+ )
549
+
550
+ # ---------------------------
551
+ # 新增的交互逻辑
552
+ # ---------------------------
553
+ clear_button.click(
554
+ clear_all_points,
555
+ inputs=[original_image],
556
+ outputs=[input_image]
557
+ )
558
+ see_visible_button.click(
559
+ see_visible_points,
560
+ inputs=[input_image, visible_points_state],
561
+ outputs=input_image
562
+ )
563
+ see_occlusion_button.click(
564
+ see_occlusion_points,
565
+ inputs=[input_image, occlusion_points_state],
566
+ outputs=input_image
567
+ )
568
+ # 当 visible_points_state 或 occlusion_points_state 变化时,更新文本框和下拉菜单
569
+ visible_points_state.change(
570
+ update_all_points,
571
+ inputs=[visible_points_state, occlusion_points_state],
572
+ outputs=[points_text, visible_points_dropdown, occlusion_points_dropdown]
573
+ )
574
+ occlusion_points_state.change(
575
+ update_all_points,
576
+ inputs=[visible_points_state, occlusion_points_state],
577
+ outputs=[points_text, visible_points_dropdown, occlusion_points_dropdown]
578
+ )
579
+ delete_visible_button.click(
580
+ delete_selected_visible,
581
+ inputs=[input_image, visible_points_state, occlusion_points_state, visible_points_dropdown],
582
+ outputs=[input_image, visible_points_state, occlusion_points_state, points_text, visible_points_dropdown, occlusion_points_dropdown]
583
+ )
584
+ delete_occlusion_button.click(
585
+ delete_selected_occlusion,
586
+ inputs=[input_image, visible_points_state, occlusion_points_state, occlusion_points_dropdown],
587
+ outputs=[input_image, visible_points_state, occlusion_points_state, points_text, visible_points_dropdown, occlusion_points_dropdown]
588
+ )
589
+
590
+ # 生成mask的逻辑
591
+ gen_vis_mask.click(
592
+ segment_and_overlay,
593
+ inputs=[original_image, visible_points_state, predictor],
594
+ outputs=[visible_mask, visibility_mask]
595
+ )
596
+ add_vis_mask.click(
597
+ add_mask,
598
+ inputs=[visibility_mask, visibility_mask_list],
599
+ outputs=[visibility_mask_list]
600
+ )
601
+ render_vis_mask.click(
602
+ vis_mask,
603
+ inputs=[original_image, visibility_mask_list],
604
+ outputs=[visible_mask]
605
+ )
606
+ undo_vis_mask.click(
607
+ delete_mask,
608
+ inputs=[visibility_mask_list],
609
+ outputs=[visibility_mask_list]
610
+ )
611
+ gen_occ_mask.click(
612
+ segment_and_overlay,
613
+ inputs=[original_image, occlusion_points_state, predictor],
614
+ outputs=[occluded_mask, occlusion_mask]
615
+ )
616
+ add_occ_mask.click(
617
+ add_mask,
618
+ inputs=[occlusion_mask, occlusion_mask_list],
619
+ outputs=[occlusion_mask_list]
620
+ )
621
+ render_occ_mask.click(
622
+ vis_mask,
623
+ inputs=[original_image, occlusion_mask_list],
624
+ outputs=[occluded_mask]
625
+ )
626
+ undo_occ_mask.click(
627
+ delete_mask,
628
+ inputs=[occlusion_mask_list],
629
+ outputs=[occlusion_mask_list]
630
+ )
631
+
632
+ # check combined mask
633
+ check_combine_button.click(
634
+ combine_mask,
635
+ inputs=[original_image, visibility_mask_list, occlusion_mask_list],
636
+ outputs=[mask_check, combined_mask]
637
+ )
638
+
639
+ # 3D Amodal Reconstruction
640
+ generate_btn.click(
641
+ get_seed,
642
+ inputs=[randomize_seed, seed],
643
+ outputs=[seed],
644
+ ).then(
645
+ image_to_3d,
646
+ inputs=[original_image, [combined_mask], seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, "multiimage"],
647
+ outputs=[visibility_mask]
648
+ )
649
+
650
+
651
+ # 启动 Gradio App
652
+ if __name__ == "__main__":
653
+ pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
654
+ pipeline.cuda()
655
+ try:
656
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
657
+ except:
658
+ pass
659
+ demo.launch()