Sm0kyWu commited on
Commit
05802f8
·
verified ·
1 Parent(s): 49bda9b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -101
app.py CHANGED
@@ -34,62 +34,110 @@ def end_session(req: gr.Request):
34
  shutil.rmtree(user_dir)
35
 
36
 
37
- def select_point_callback(image, points, evt):
 
 
 
 
 
38
  """
39
- 当用户点击图像时,记录点击点并在图像上绘制标记(十字)。
40
- 输入参数:
41
- - image:当前图像(numpy 数组)。
42
- - points:已记录的点列表。
43
- - evt:Gradio 的点击事件数据(包含 .index,即点击坐标)。
44
- 返回:
45
- - 更新后的图像(带标记)。
46
- - 更新后的点列表。
47
- - 以字符串形式展示的点列表(用于显示在文本框中)。
48
  """
49
- if points is None:
50
- points = []
51
- annotated_img = image.copy()
52
- # 如果没有点击事件,则直接返回原图和当前点列表
53
- if evt is None or evt.index is None:
54
- return image, points, str(points)
55
- coord = evt.index # 期望返回 (x, y)
56
- points.append((tuple(coord), 1)) # 记录为正样本 prompt
57
- # 绘制十字标记,颜色为红色
58
- cv2.drawMarker(annotated_img, tuple(coord), (255, 0, 0),
59
- markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2)
60
- return annotated_img, points
61
-
62
-
63
- def mark_point_on_image(image, points, evt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  """
65
- 当用户点击 image_prompt 时,在图像上直接绘制标记,并更新点击点状态。
66
- :param image: 当前图像(numpy 数组,RGB格式)。
67
- :param points: 已记录的点列表。
68
- :param evt: Gradio 的点击事件数据,包含 .index 属性(点击坐标)。
69
- :return: 更新后的图像、点列表以及显示的文本信息。
70
  """
71
- if image is None:
72
- return None, points, str(points)
73
-
74
- # 如果没有已有的点,则复制一份原图,保存原始版本(可以存到其他 State 供后续处理)
75
- annotated_image = image.copy()
76
- if points is None:
77
- points = []
78
-
79
- # 检查事件数据中是否有点击坐标
80
- if evt is None or evt.index is None:
81
- return annotated_image, points, str(points)
82
-
83
- # 获取点击坐标(格式:列表或元组)
84
- pt = tuple(evt.index)
85
- points.append((pt, 1)) # 1 表示正样本标记(可以根据需要调整)
86
-
87
- # 在图像上绘制所有点的标记
88
- for p, _ in points:
89
- cv2.drawMarker(annotated_image, p, (255, 0, 0),
90
- markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2)
91
-
92
- return annotated_image, points, str(points)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  @spaces.GPU
@@ -131,10 +179,12 @@ def apply_mask_overlay(image: Image.Image, mask: np.ndarray) -> Image.Image:
131
  return Image.fromarray(overlay)
132
 
133
 
134
- def segment_and_overlay(image: Image.Image, points):
135
  """
136
  调用 run_sam 获得 mask,然后叠加显示分割结果。
137
  """
 
 
138
  if image.mode != "RGB":
139
  image = image.convert("RGB")
140
  mask, _ = run_sam(sam_predictor, image, points)
@@ -310,6 +360,15 @@ def split_image(image: Image.Image) -> list:
310
  return [image for image in images]
311
 
312
 
 
 
 
 
 
 
 
 
 
313
  with gr.Blocks(delete_cache=(600, 600)) as demo:
314
  gr.Markdown("""
315
  ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
@@ -319,72 +378,92 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
319
  * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
320
  * If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
321
  """)
 
 
 
 
 
 
 
 
 
322
  with gr.Row():
323
  with gr.Column():
324
- # 上传的图像不经过预处理,直接展示原始图像
325
- image_prompt = gr.Image(type="numpy", label="Input Occlusion Image", interactive=True, height=512)
326
- # 用于交互标注的图像,点击时更新显示标记
327
- # image_annotation = gr.Image(type="numpy", label="Select Point Prompts for Target Object", interactive=True, height=512)
328
- # 存储点击点状态以及显示点击点坐标
329
- points_state = gr.State([])
330
- segment_button = gr.Button("Run Segmentation")
331
- # points_output = gr.Textbox(label="Target Object Prompts", interactive=False)
332
- # 展示 SAM 分割结果(只用于显示,不允许上传)
333
- segmented_output = gr.Image(label="Segmented Result", height=512, interactive=False)
334
 
335
- with gr.Accordion(label="Generation Settings", open=False):
336
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
337
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
338
- gr.Markdown("Stage 1: Sparse Structure Generation")
339
- with gr.Row():
340
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
341
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
342
- gr.Markdown("Stage 2: Structured Latent Generation")
343
- with gr.Row():
344
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
345
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
346
- # 其他组件(如生成按钮、视频展示、GLB 提取等)可根据需要添加
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  # 会话启动与结束
349
  demo.load(start_session)
350
  demo.unload(end_session)
351
 
352
- # 上传图像后直接显示,不做预处理
353
- image_prompt.upload(
354
- lambda x: x,
355
- inputs=[image_prompt],
356
- outputs=[image_prompt]
357
  )
358
-
359
- # 点击 image_annotation 时调用 select_point_callback,
360
- # 更新图像显示、点状态以及文本显示点击点信息
361
- image_prompt.select(
362
- select_point_callback,
363
- inputs=[image_prompt, points_state],
364
- outputs=[image_prompt, points_state]
365
  )
366
-
367
- # 添加一个按钮,用于运行 SAM 分割并展示叠加结果
368
- segment_button.click(
369
- segment_and_overlay,
370
- inputs=[image_prompt, points_state],
371
- outputs=[segmented_output]
 
 
 
 
 
372
  )
373
 
374
- # 后续可添加生成 3D 模型等其他流程...
375
 
376
  # 启动 Gradio App
377
  if __name__ == "__main__":
378
- sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
379
- model_type = "vit_h"
380
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
381
- sam.cuda()
382
- sam_predictor = SamPredictor(sam)
383
-
384
  pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
385
  pipeline.cuda()
386
  try:
387
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
388
  except:
389
  pass
390
- demo.launch()
 
34
  shutil.rmtree(user_dir)
35
 
36
 
37
+ def select_point(predictor: SamPredictor,
38
+ annotated_img: np.ndarray,
39
+ orig_img: np.ndarray,
40
+ sel_pix: list,
41
+ point_type: str,
42
+ evt: gr.SelectData):
43
  """
44
+ 当用户在标注图像上点击时:
45
+ - 将点击坐标添加到 sel_pix(正/负 prompt 根据 point_type),
46
+ - 根据 sel_pix 调用 SAM 得到 mask,
47
+ - 在 annotated_img 上绘制所有已选点的标记,
48
+ - 返回更新后的标注图像、SAM 输出(用于显示)及生成的 visible_mask(用于后续 pix2gestalt)。
 
 
 
 
49
  """
50
+ # 拷贝原图(用于标注)
51
+ img = annotated_img.copy()
52
+ h_original, w_original, _ = orig_img.shape
53
+ h_new, w_new = 256, 256
54
+ scale_x = w_new / w_original
55
+ scale_y = h_new / h_original
56
+
57
+ # 根据 prompt 类型添加点击点(evt.index 格式为 (x, y)
58
+ if point_type == 'positive_prompt':
59
+ sel_pix.append((evt.index, 1))
60
+ elif point_type == 'negative_prompt':
61
+ sel_pix.append((evt.index, 0))
62
+ else:
63
+ sel_pix.append((evt.index, 1))
64
+
65
+ # 将原始尺寸的点转换到 256x256 尺寸(SAM 输入要求)
66
+ processed_sel_pix = []
67
+ for point, label in sel_pix:
68
+ x, y = point
69
+ new_x = int(x * scale_x)
70
+ new_y = int(y * scale_y)
71
+ processed_sel_pix.append(([new_x, new_y], label))
72
+
73
+ visible_mask, overlay_mask = run_sam(predictor, processed_sel_pix)
74
+ # overlay_mask 是 SAM 输出的 mask(256x256),调整尺寸到原图尺寸以便显示
75
+ mask = np.squeeze(overlay_mask[0][0]) # (256, 256)
76
+ resized_mask = cv2.resize(mask.astype(np.uint8) * 255, (w_original, h_original), interpolation=cv2.INTER_AREA)
77
+ resized_mask = resized_mask > 127
78
+ # 制作 overlay 信息(供 output_mask 使用)
79
+ resized_overlay_mask = [(resized_mask, 'visible_mask')]
80
+
81
+ # 绘制所有点的标记
82
+ COLORS = [(255, 0, 0), (0, 255, 0)]
83
+ MARKERS = [1, 4]
84
+ scaling_factor = min(h_original / 256, w_original / 256)
85
+ marker_size = int(6 * scaling_factor)
86
+ marker_thickness = int(2 * scaling_factor)
87
+ for point, label in sel_pix:
88
+ cv2.drawMarker(img, tuple(point), COLORS[label], markerType=MARKERS[label],
89
+ markerSize=marker_size, thickness=marker_thickness)
90
+
91
+ return img, (orig_img, resized_overlay_mask), visible_mask
92
+
93
+ def undo_points(predictor, orig_img, sel_pix):
94
  """
95
+ 撤销最后一次点击:
96
+ - sel_pix 中 pop 出最后一个点,
97
+ - 根据剩余点重新调用 SAM 得到 mask,
98
+ - 返回更新后的图像和 mask。
 
99
  """
100
+ temp = orig_img.copy()
101
+ h_original, w_original, _ = orig_img.shape
102
+ COLORS = [(255, 0, 0), (0, 255, 0)]
103
+ MARKERS = [0, 5]
104
+ scaling_factor = min(h_original / 256, w_original / 256)
105
+ marker_size = int(6 * scaling_factor)
106
+ marker_thickness = int(2 * scaling_factor)
107
+ if len(sel_pix) > 0:
108
+ sel_pix.pop()
109
+ # 重新绘制剩余点
110
+ for point, label in sel_pix:
111
+ cv2.drawMarker(temp, tuple(point), COLORS[label],
112
+ markerType=MARKERS[label], markerSize=marker_size, thickness=marker_thickness)
113
+ else:
114
+ dummy_overlay_mask = [(np.zeros((h_original, w_original), dtype=np.uint8), 'visible_mask')]
115
+ return orig_img, (orig_img, dummy_overlay_mask), []
116
+
117
+ visible_mask, overlay_mask = run_sam(predictor, sel_pix)
118
+ mask = np.squeeze(overlay_mask[0][0])
119
+ resized_mask = cv2.resize(mask.astype(np.uint8) * 255, (w_original, h_original), interpolation=cv2.INTER_AREA)
120
+ resized_mask = resized_mask > 127
121
+ resized_overlay_mask = [(resized_mask, 'visible_mask')]
122
+ return temp, (orig_img, resized_overlay_mask), visible_mask
123
+
124
+ def reset_image(predictor, img):
125
+ """
126
+ 上传图像后调用:
127
+ - 重置 predictor,
128
+ - 设置 predictor 的输入图像,
129
+ - 返回原图、预处理图像、清空 sel_pix、以及初始输出(无 mask)。
130
+ """
131
+ preprocessed_image = img
132
+ predictor.set_image(preprocessed_image)
133
+ # 返回原始图像、预处理图像、清空点列表、初始输出(作为 SAM mask 显示,初始为原图复制)
134
+ return img, preprocessed_image, [], (img.copy(), [(np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8), 'visible_mask')])
135
+
136
+ def button_clickable(selected_points):
137
+ if len(selected_points) > 0:
138
+ return gr.Button.update(interactive=True)
139
+ else:
140
+ return gr.Button.update(interactive=False)
141
 
142
 
143
  @spaces.GPU
 
179
  return Image.fromarray(overlay)
180
 
181
 
182
+ def segment_and_overlay(image: np.ndarray, points):
183
  """
184
  调用 run_sam 获得 mask,然后叠加显示分割结果。
185
  """
186
+ if isinstance(image, np.ndarray):
187
+ image = Image.fromarray(image)
188
  if image.mode != "RGB":
189
  image = image.convert("RGB")
190
  mask, _ = run_sam(sam_predictor, image, points)
 
360
  return [image for image in images]
361
 
362
 
363
+ def get_sam_predictor():
364
+ sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
365
+ model_type = "vit_h"
366
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
367
+ sam.cuda()
368
+ sam_predictor = SamPredictor(sam)
369
+ return sam_predictor
370
+
371
+
372
  with gr.Blocks(delete_cache=(600, 600)) as demo:
373
  gr.Markdown("""
374
  ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
 
378
  * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
379
  * If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
380
  """)
381
+
382
+ # 定义各状态变量
383
+ predictor = gr.State(value=get_sam_predictor())
384
+ selected_points = gr.State(value=[])
385
+ original_image = gr.State(value=None)
386
+ preprocessed_image = gr.State(value=None)
387
+ visible_mask = gr.State(value=None)
388
+
389
+
390
  with gr.Row():
391
  with gr.Column():
392
+ # # 上传的图像不经过预处理,直接展示原始图像
393
+ # image_prompt = gr.Image(type="numpy", label="Input Occlusion Image", interactive=True, height=512)
394
+ # # 用于交互标注的图像,点击时更新显示标记
395
+ # # image_annotation = gr.Image(type="numpy", label="Select Point Prompts for Target Object", interactive=True, height=512)
396
+ # # 存储点击点状态以及显示点击点坐标
397
+ # points_state = gr.State([])
398
+ # segment_button = gr.Button("Run Segmentation")
399
+ # # points_output = gr.Textbox(label="Target Object Prompts", interactive=False)
400
+ # # 展示 SAM 分割结果(只用于显示,不允许上传)
401
+ # segmented_output = gr.Image(label="Segmented Result", height=512, interactive=False)
402
 
403
+ # with gr.Accordion(label="Generation Settings", open=False):
404
+ # seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
405
+ # randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
406
+ # gr.Markdown("Stage 1: Sparse Structure Generation")
407
+ # with gr.Row():
408
+ # ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
409
+ # ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
410
+ # gr.Markdown("Stage 2: Structured Latent Generation")
411
+ # with gr.Row():
412
+ # slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
413
+ # slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
414
+ # # 其他组件(如生成按钮、视频展示、GLB 提取等)可根据需要添加\
415
+ input_image = gr.Image(type="numpy", label='Input Occlusion Image', height=500)
416
+ annotation_image = gr.Image(type="numpy", label='Annotate Image', interactive=True, height=500)
417
+ undo_button = gr.Button('Undo Prompt')
418
+ fg_bg_radio = gr.Radio(['positive_prompt', 'negative_prompt'], label='Point Prompt Type')
419
+ gr.Markdown('''
420
+ ### Instructions:
421
+ - First, upload an image.
422
+ - Then, click on the "Annotate Image" to select visible regions.
423
+ - Use "Undo Prompt" to remove the last point.
424
+ - Once the SAM mask is satisfactory, click "Run pix2gestalt" to perform amodal completion.
425
+ ''')
426
+ with gr.Column():
427
+ # 显示 SAM 分割结果(带 overlay)—— 使用 AnnotatedImage 显示更直观
428
+ output_mask = gr.AnnotatedImage(label='SAM Generated Visible (Modal) Mask', height=500)
429
+
430
 
431
  # 会话启动与结束
432
  demo.load(start_session)
433
  demo.unload(end_session)
434
 
435
+ # 上传图像时:重置 predictor 并将原图赋值给 original_image、preprocessed_image、selected_points 以及 output_mask
436
+ input_image.upload(
437
+ reset_image,
438
+ [predictor, input_image],
439
+ [original_image, preprocessed_image, selected_points, output_mask]
440
  )
441
+ # 同时更新 annotation_image(使其与上传图像保持一致)
442
+ input_image.upload(
443
+ lambda x: x,
444
+ inputs=[input_image],
445
+ outputs=[annotation_image]
 
 
446
  )
447
+ # 撤销按钮:撤销最近一次点击
448
+ undo_button.click(
449
+ undo_points,
450
+ [predictor, original_image, selected_points],
451
+ [annotation_image, output_mask, visible_mask]
452
+ )
453
+ # 在 annotation_image 上点击:调用 select_point 更新标注图像和 SAM 分割结果
454
+ annotation_image.select(
455
+ select_point,
456
+ [predictor, annotation_image, original_image, selected_points, fg_bg_radio],
457
+ [annotation_image, output_mask, visible_mask]
458
  )
459
 
 
460
 
461
  # 启动 Gradio App
462
  if __name__ == "__main__":
 
 
 
 
 
 
463
  pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
464
  pipeline.cuda()
465
  try:
466
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
467
  except:
468
  pass
469
+ demo.launch()