Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
@@ -126,12 +126,11 @@ def reset_image(predictor, img):
|
|
126 |
上传图像后调用:
|
127 |
- 重置 predictor,
|
128 |
- 设置 predictor 的输入图像,
|
129 |
-
-
|
130 |
"""
|
131 |
-
|
132 |
-
predictor
|
133 |
-
|
134 |
-
return img, preprocessed_image, [], (img.copy(), [(np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8), 'visible_mask')])
|
135 |
|
136 |
def button_clickable(selected_points):
|
137 |
if len(selected_points) > 0:
|
@@ -369,6 +368,47 @@ def get_sam_predictor():
|
|
369 |
return sam_predictor
|
370 |
|
371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
373 |
gr.Markdown("""
|
374 |
## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
|
@@ -381,51 +421,25 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
381 |
|
382 |
# 定义各状态变量
|
383 |
predictor = gr.State(value=get_sam_predictor())
|
384 |
-
|
385 |
-
|
386 |
-
preprocessed_image = gr.State(value=None)
|
387 |
-
visible_mask = gr.State(value=None)
|
388 |
|
389 |
|
390 |
with gr.Row():
|
391 |
with gr.Column():
|
392 |
-
|
393 |
-
# image_prompt = gr.Image(type="numpy", label="Input Occlusion Image", interactive=True, height=512)
|
394 |
-
# # 用于交互标注的图像,点击时更新显示标记
|
395 |
-
# # image_annotation = gr.Image(type="numpy", label="Select Point Prompts for Target Object", interactive=True, height=512)
|
396 |
-
# # 存储点击点状态以及显示点击点坐标
|
397 |
-
# points_state = gr.State([])
|
398 |
-
# segment_button = gr.Button("Run Segmentation")
|
399 |
-
# # points_output = gr.Textbox(label="Target Object Prompts", interactive=False)
|
400 |
-
# # 展示 SAM 分割结果(只用于显示,不允许上传)
|
401 |
-
# segmented_output = gr.Image(label="Segmented Result", height=512, interactive=False)
|
402 |
-
|
403 |
-
# with gr.Accordion(label="Generation Settings", open=False):
|
404 |
-
# seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
|
405 |
-
# randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
406 |
-
# gr.Markdown("Stage 1: Sparse Structure Generation")
|
407 |
-
# with gr.Row():
|
408 |
-
# ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
409 |
-
# ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
410 |
-
# gr.Markdown("Stage 2: Structured Latent Generation")
|
411 |
-
# with gr.Row():
|
412 |
-
# slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
413 |
-
# slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
414 |
-
# # 其他组件(如生成按钮、视频展示、GLB 提取等)可根据需要添加\
|
415 |
-
input_image = gr.Image(type="numpy", label='Input Occlusion Image', height=500)
|
416 |
-
annotation_image = gr.Image(type="numpy", label='Annotate Image', interactive=True, height=500)
|
417 |
-
undo_button = gr.Button('Undo Prompt')
|
418 |
fg_bg_radio = gr.Radio(['positive_prompt', 'negative_prompt'], label='Point Prompt Type')
|
419 |
-
gr.
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
|
|
426 |
with gr.Column():
|
427 |
# 显示 SAM 分割结果(带 overlay)—— 使用 AnnotatedImage 显示更直观
|
428 |
-
|
429 |
|
430 |
|
431 |
# 会话启动与结束
|
@@ -436,25 +450,24 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
436 |
input_image.upload(
|
437 |
reset_image,
|
438 |
[predictor, input_image],
|
439 |
-
[
|
440 |
)
|
441 |
-
#
|
442 |
-
|
443 |
-
|
444 |
-
inputs=[input_image],
|
445 |
-
outputs=[
|
446 |
)
|
447 |
-
#
|
448 |
-
|
449 |
-
|
450 |
-
[
|
451 |
-
[
|
452 |
)
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
[
|
457 |
-
[annotation_image, output_mask, visible_mask]
|
458 |
)
|
459 |
|
460 |
|
@@ -466,4 +479,4 @@ if __name__ == "__main__":
|
|
466 |
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
|
467 |
except:
|
468 |
pass
|
469 |
-
demo.launch()
|
|
|
126 |
上传图像后调用:
|
127 |
- 重置 predictor,
|
128 |
- 设置 predictor 的输入图像,
|
129 |
+
- 返回原图
|
130 |
"""
|
131 |
+
predictor.set_image(img)
|
132 |
+
# 返回predictor,原始图像
|
133 |
+
return predictor, img
|
|
|
134 |
|
135 |
def button_clickable(selected_points):
|
136 |
if len(selected_points) > 0:
|
|
|
368 |
return sam_predictor
|
369 |
|
370 |
|
371 |
+
def draw_points_on_image(image, point, point_type):
|
372 |
+
"""在图像上绘制所有点,points 为 [(x, y, point_type), ...]"""
|
373 |
+
image_with_points = image.copy()
|
374 |
+
x, y = point
|
375 |
+
color = (0, 0, 255) if point_type == "visible" else (0, 255, 0)
|
376 |
+
cv2.circle(image_with_points, (int(x), int(y)), radius=5, color=color, thickness=-1)
|
377 |
+
return image_with_points
|
378 |
+
|
379 |
+
|
380 |
+
def see_point(image, x, y, point_type):
|
381 |
+
"""
|
382 |
+
see操作:不修改 points 列表,仅在图像上临时显示这个点,
|
383 |
+
并返回更新后的图像和当前列表(不更新)。
|
384 |
+
"""
|
385 |
+
# 复制当前列表,并在副本中加上新点(仅用于显示)
|
386 |
+
updated_image = draw_points_on_image(image, [x,y], point_type)
|
387 |
+
return updated_image, points
|
388 |
+
|
389 |
+
def add_point(x, y, point_type, visible_points, occlusion_points):
|
390 |
+
"""
|
391 |
+
add操作:将新点添加到 points 列表中,
|
392 |
+
并返回更新后的图像和新的点列表。
|
393 |
+
"""
|
394 |
+
if point_type == "visible":
|
395 |
+
visible_points.append([x, y])
|
396 |
+
else:
|
397 |
+
occlusion_points.append([x, y])
|
398 |
+
return visible_points, occlusion_points
|
399 |
+
|
400 |
+
def delete_point(point_type, visible_points, occlusion_points):
|
401 |
+
"""
|
402 |
+
delete操作:删除 points 列表中的最后一个点,
|
403 |
+
并返回更新后的图像和新的点列表。
|
404 |
+
"""
|
405 |
+
if point_type == "visible":
|
406 |
+
visible_points.pop()
|
407 |
+
else:
|
408 |
+
occlusion_points.pop()
|
409 |
+
return visible_points, occlusion_points
|
410 |
+
|
411 |
+
|
412 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
413 |
gr.Markdown("""
|
414 |
## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
|
|
|
421 |
|
422 |
# 定义各状态变量
|
423 |
predictor = gr.State(value=get_sam_predictor())
|
424 |
+
visible_points_state = gr.State(value=[])
|
425 |
+
occlusion_points_state = gr.State(value=[])
|
|
|
|
|
426 |
|
427 |
|
428 |
with gr.Row():
|
429 |
with gr.Column():
|
430 |
+
input_image = gr.Image(type="numpy", label='Input Occlusion Image', height=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
fg_bg_radio = gr.Radio(['positive_prompt', 'negative_prompt'], label='Point Prompt Type')
|
432 |
+
with gr.Row():
|
433 |
+
x_input = gr.Number(label="X Coordinate", value=0)
|
434 |
+
y_input = gr.Number(label="Y Coordinate", value=0)
|
435 |
+
point_type = gr.Radio(choices=["visible", "occlusion"], label="Point Type", value="visible")
|
436 |
+
with gr.Row():
|
437 |
+
see_button = gr.Button("See")
|
438 |
+
add_button = gr.Button("Add")
|
439 |
+
delete_button = gr.Button("Delete")
|
440 |
with gr.Column():
|
441 |
# 显示 SAM 分割结果(带 overlay)—— 使用 AnnotatedImage 显示更直观
|
442 |
+
sam_image = gr.Image(label='SAM Generated Mask', interactive=False, height=300)
|
443 |
|
444 |
|
445 |
# 会话启动与结束
|
|
|
450 |
input_image.upload(
|
451 |
reset_image,
|
452 |
[predictor, input_image],
|
453 |
+
[predictor, sam_image]
|
454 |
)
|
455 |
+
# 如果点击see按钮,应该在input图片上生成对应的点,
|
456 |
+
see_button.click(
|
457 |
+
see_point,
|
458 |
+
inputs=[input_image, x_input, y_input, point_type],
|
459 |
+
outputs=[input_image]
|
460 |
)
|
461 |
+
# 如果点击add按钮,应该将对应的点添加到visible_points_state中
|
462 |
+
add_button.click(
|
463 |
+
add_point,
|
464 |
+
inputs=[x_input, y_input, point_type, visible_points_state, occlusion_points_state],
|
465 |
+
outputs=[visible_points_state, occlusion_points_state]
|
466 |
)
|
467 |
+
delete_button.click(
|
468 |
+
delete_point,
|
469 |
+
inputs=[point_type, visible_points_state, occlusion_points_state],
|
470 |
+
outputs=[visible_points_state, occlusion_points_state]
|
|
|
471 |
)
|
472 |
|
473 |
|
|
|
479 |
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
|
480 |
except:
|
481 |
pass
|
482 |
+
demo.launch()
|