Sm0kyWu commited on
Commit
540680a
·
verified ·
1 Parent(s): 05802f8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -60
app.py CHANGED
@@ -126,12 +126,11 @@ def reset_image(predictor, img):
126
  上传图像后调用:
127
  - 重置 predictor,
128
  - 设置 predictor 的输入图像,
129
- - 返回原图、预处理图像、清空 sel_pix、以及初始输出(无 mask)。
130
  """
131
- preprocessed_image = img
132
- predictor.set_image(preprocessed_image)
133
- # 返回原始图像、预处理图像、清空点列表、初始输出(作为 SAM mask 显示,初始为原图复制)
134
- return img, preprocessed_image, [], (img.copy(), [(np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8), 'visible_mask')])
135
 
136
  def button_clickable(selected_points):
137
  if len(selected_points) > 0:
@@ -369,6 +368,47 @@ def get_sam_predictor():
369
  return sam_predictor
370
 
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  with gr.Blocks(delete_cache=(600, 600)) as demo:
373
  gr.Markdown("""
374
  ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
@@ -381,51 +421,25 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
381
 
382
  # 定义各状态变量
383
  predictor = gr.State(value=get_sam_predictor())
384
- selected_points = gr.State(value=[])
385
- original_image = gr.State(value=None)
386
- preprocessed_image = gr.State(value=None)
387
- visible_mask = gr.State(value=None)
388
 
389
 
390
  with gr.Row():
391
  with gr.Column():
392
- # # 上传的图像不经过预处理,直接展示原始图像
393
- # image_prompt = gr.Image(type="numpy", label="Input Occlusion Image", interactive=True, height=512)
394
- # # 用于交互标注的图像,点击时更新显示标记
395
- # # image_annotation = gr.Image(type="numpy", label="Select Point Prompts for Target Object", interactive=True, height=512)
396
- # # 存储点击点状态以及显示点击点坐标
397
- # points_state = gr.State([])
398
- # segment_button = gr.Button("Run Segmentation")
399
- # # points_output = gr.Textbox(label="Target Object Prompts", interactive=False)
400
- # # 展示 SAM 分割结果(只用于显示,不允许上传)
401
- # segmented_output = gr.Image(label="Segmented Result", height=512, interactive=False)
402
-
403
- # with gr.Accordion(label="Generation Settings", open=False):
404
- # seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
405
- # randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
406
- # gr.Markdown("Stage 1: Sparse Structure Generation")
407
- # with gr.Row():
408
- # ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
409
- # ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
410
- # gr.Markdown("Stage 2: Structured Latent Generation")
411
- # with gr.Row():
412
- # slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
413
- # slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
414
- # # 其他组件(如生成按钮、视频展示、GLB 提取等)可根据需要添加\
415
- input_image = gr.Image(type="numpy", label='Input Occlusion Image', height=500)
416
- annotation_image = gr.Image(type="numpy", label='Annotate Image', interactive=True, height=500)
417
- undo_button = gr.Button('Undo Prompt')
418
  fg_bg_radio = gr.Radio(['positive_prompt', 'negative_prompt'], label='Point Prompt Type')
419
- gr.Markdown('''
420
- ### Instructions:
421
- - First, upload an image.
422
- - Then, click on the "Annotate Image" to select visible regions.
423
- - Use "Undo Prompt" to remove the last point.
424
- - Once the SAM mask is satisfactory, click "Run pix2gestalt" to perform amodal completion.
425
- ''')
 
426
  with gr.Column():
427
  # 显示 SAM 分割结果(带 overlay)—— 使用 AnnotatedImage 显示更直观
428
- output_mask = gr.AnnotatedImage(label='SAM Generated Visible (Modal) Mask', height=500)
429
 
430
 
431
  # 会话启动与结束
@@ -436,25 +450,24 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
436
  input_image.upload(
437
  reset_image,
438
  [predictor, input_image],
439
- [original_image, preprocessed_image, selected_points, output_mask]
440
  )
441
- # 同时更新 annotation_image(使其与上传图像保持一致)
442
- input_image.upload(
443
- lambda x: x,
444
- inputs=[input_image],
445
- outputs=[annotation_image]
446
  )
447
- # 撤销按钮:撤销最近一次点击
448
- undo_button.click(
449
- undo_points,
450
- [predictor, original_image, selected_points],
451
- [annotation_image, output_mask, visible_mask]
452
  )
453
- # 在 annotation_image 上点击:调用 select_point 更新标注图像和 SAM 分割结果
454
- annotation_image.select(
455
- select_point,
456
- [predictor, annotation_image, original_image, selected_points, fg_bg_radio],
457
- [annotation_image, output_mask, visible_mask]
458
  )
459
 
460
 
@@ -466,4 +479,4 @@ if __name__ == "__main__":
466
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
467
  except:
468
  pass
469
- demo.launch()
 
126
  上传图像后调用:
127
  - 重置 predictor,
128
  - 设置 predictor 的输入图像,
129
+ - 返回原图
130
  """
131
+ predictor.set_image(img)
132
+ # 返回predictor,原始图像
133
+ return predictor, img
 
134
 
135
  def button_clickable(selected_points):
136
  if len(selected_points) > 0:
 
368
  return sam_predictor
369
 
370
 
371
+ def draw_points_on_image(image, point, point_type):
372
+ """在图像上绘制所有点,points 为 [(x, y, point_type), ...]"""
373
+ image_with_points = image.copy()
374
+ x, y = point
375
+ color = (0, 0, 255) if point_type == "visible" else (0, 255, 0)
376
+ cv2.circle(image_with_points, (int(x), int(y)), radius=5, color=color, thickness=-1)
377
+ return image_with_points
378
+
379
+
380
+ def see_point(image, x, y, point_type):
381
+ """
382
+ see操作:不修改 points 列表,仅在图像上临时显示这个点,
383
+ 并返回更新后的图像和当前列表(不更新)。
384
+ """
385
+ # 复制当前列表,并在副本中加上新点(仅用于显示)
386
+ updated_image = draw_points_on_image(image, [x,y], point_type)
387
+ return updated_image, points
388
+
389
+ def add_point(x, y, point_type, visible_points, occlusion_points):
390
+ """
391
+ add操作:将新点添加到 points 列表中,
392
+ 并返回更新后的图像和新的点列表。
393
+ """
394
+ if point_type == "visible":
395
+ visible_points.append([x, y])
396
+ else:
397
+ occlusion_points.append([x, y])
398
+ return visible_points, occlusion_points
399
+
400
+ def delete_point(point_type, visible_points, occlusion_points):
401
+ """
402
+ delete操作:删除 points 列表中的最后一个点,
403
+ 并返回更新后的图像和新的点列表。
404
+ """
405
+ if point_type == "visible":
406
+ visible_points.pop()
407
+ else:
408
+ occlusion_points.pop()
409
+ return visible_points, occlusion_points
410
+
411
+
412
  with gr.Blocks(delete_cache=(600, 600)) as demo:
413
  gr.Markdown("""
414
  ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
 
421
 
422
  # 定义各状态变量
423
  predictor = gr.State(value=get_sam_predictor())
424
+ visible_points_state = gr.State(value=[])
425
+ occlusion_points_state = gr.State(value=[])
 
 
426
 
427
 
428
  with gr.Row():
429
  with gr.Column():
430
+ input_image = gr.Image(type="numpy", label='Input Occlusion Image', height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  fg_bg_radio = gr.Radio(['positive_prompt', 'negative_prompt'], label='Point Prompt Type')
432
+ with gr.Row():
433
+ x_input = gr.Number(label="X Coordinate", value=0)
434
+ y_input = gr.Number(label="Y Coordinate", value=0)
435
+ point_type = gr.Radio(choices=["visible", "occlusion"], label="Point Type", value="visible")
436
+ with gr.Row():
437
+ see_button = gr.Button("See")
438
+ add_button = gr.Button("Add")
439
+ delete_button = gr.Button("Delete")
440
  with gr.Column():
441
  # 显示 SAM 分割结果(带 overlay)—— 使用 AnnotatedImage 显示更直观
442
+ sam_image = gr.Image(label='SAM Generated Mask', interactive=False, height=300)
443
 
444
 
445
  # 会话启动与结束
 
450
  input_image.upload(
451
  reset_image,
452
  [predictor, input_image],
453
+ [predictor, sam_image]
454
  )
455
+ # 如果点击see按钮,应该在input图片上生成对应的点,
456
+ see_button.click(
457
+ see_point,
458
+ inputs=[input_image, x_input, y_input, point_type],
459
+ outputs=[input_image]
460
  )
461
+ # 如果点击add按钮,应该将对应的点添加到visible_points_state中
462
+ add_button.click(
463
+ add_point,
464
+ inputs=[x_input, y_input, point_type, visible_points_state, occlusion_points_state],
465
+ outputs=[visible_points_state, occlusion_points_state]
466
  )
467
+ delete_button.click(
468
+ delete_point,
469
+ inputs=[point_type, visible_points_state, occlusion_points_state],
470
+ outputs=[visible_points_state, occlusion_points_state]
 
471
  )
472
 
473
 
 
479
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
480
  except:
481
  pass
482
+ demo.launch()