Spaces:
Sleeping
Sleeping
feat(medsam): prompt-only segmentation (bboxes/points JSON); skip if none; polygons by default; optional raw masks
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
import numpy as np
|
7 |
import cv2
|
8 |
import time
|
|
|
9 |
import traceback
|
10 |
|
11 |
# Simple timestamped logger
|
@@ -643,7 +644,7 @@ def analyze(image):
|
|
643 |
return {"error": "Internal error in analyze"}
|
644 |
|
645 |
|
646 |
-
def analyze_with_medsam(base_result, image, include_raw_masks=False):
|
647 |
try:
|
648 |
log("analyze_with_medsam: start")
|
649 |
if not isinstance(base_result, dict):
|
@@ -664,20 +665,36 @@ def analyze_with_medsam(base_result, image, include_raw_masks=False):
|
|
664 |
img_path = tmp_path
|
665 |
_medsam.load_image(img_path)
|
666 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
segmentations = []
|
668 |
masks_for_overlay = []
|
669 |
|
670 |
-
# MedSAM
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
m = _medsam.segment_with_box(bbox)
|
675 |
if m is None or not isinstance(m.get('mask'), np.ndarray):
|
676 |
continue
|
677 |
mask_np = m['mask'].astype(np.uint8)
|
678 |
seg_entry = {
|
679 |
"confidence": float(m.get('confidence', 1.0)),
|
680 |
-
"method": m.get("method", "
|
681 |
"polygons": _mask_to_polygons(mask_np)
|
682 |
}
|
683 |
if include_raw_masks:
|
@@ -685,6 +702,33 @@ def analyze_with_medsam(base_result, image, include_raw_masks=False):
|
|
685 |
segmentations.append(seg_entry)
|
686 |
masks_for_overlay.append(m)
|
687 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
W, H = pil_img.size
|
689 |
base_result["medsam"] = {
|
690 |
"available": True,
|
@@ -745,6 +789,8 @@ with gr.Blocks(
|
|
745 |
elem_id="image-input"
|
746 |
)
|
747 |
include_raw_masks_cb = gr.Checkbox(value=False, visible=False, elem_id="include-raw-masks")
|
|
|
|
|
748 |
|
749 |
# Analyze button (single)
|
750 |
analyze_btn = gr.Button(
|
@@ -776,10 +822,10 @@ with gr.Blocks(
|
|
776 |
api_name="/predict" # ✅ Standard API name that gradio_client expects
|
777 |
)
|
778 |
|
779 |
-
#
|
780 |
analyze_event.then(
|
781 |
fn=analyze_with_medsam,
|
782 |
-
inputs=[result_output, image_input, include_raw_masks_cb],
|
783 |
outputs=[result_output, overlay_output],
|
784 |
)
|
785 |
|
|
|
6 |
import numpy as np
|
7 |
import cv2
|
8 |
import time
|
9 |
+
import json
|
10 |
import traceback
|
11 |
|
12 |
# Simple timestamped logger
|
|
|
644 |
return {"error": "Internal error in analyze"}
|
645 |
|
646 |
|
647 |
+
def analyze_with_medsam(base_result, image, include_raw_masks=False, bboxes_json="", points_json=""):
|
648 |
try:
|
649 |
log("analyze_with_medsam: start")
|
650 |
if not isinstance(base_result, dict):
|
|
|
665 |
img_path = tmp_path
|
666 |
_medsam.load_image(img_path)
|
667 |
|
668 |
+
# Parse prompts
|
669 |
+
parsed_bboxes = []
|
670 |
+
parsed_points = []
|
671 |
+
try:
|
672 |
+
if bboxes_json:
|
673 |
+
parsed_bboxes = json.loads(bboxes_json)
|
674 |
+
if points_json:
|
675 |
+
parsed_points = json.loads(points_json)
|
676 |
+
except Exception:
|
677 |
+
log("analyze_with_medsam: failed to parse prompts JSON")
|
678 |
+
|
679 |
+
# If no prompts provided, skip (follow original behavior)
|
680 |
+
if not parsed_bboxes and not parsed_points:
|
681 |
+
log("analyze_with_medsam: no prompts provided; skipping segmentation")
|
682 |
+
return base_result, None
|
683 |
+
|
684 |
segmentations = []
|
685 |
masks_for_overlay = []
|
686 |
|
687 |
+
# Run MedSAM for provided boxes
|
688 |
+
for bbox in parsed_bboxes:
|
689 |
+
if not isinstance(bbox, (list, tuple)) or len(bbox) != 4:
|
690 |
+
continue
|
691 |
m = _medsam.segment_with_box(bbox)
|
692 |
if m is None or not isinstance(m.get('mask'), np.ndarray):
|
693 |
continue
|
694 |
mask_np = m['mask'].astype(np.uint8)
|
695 |
seg_entry = {
|
696 |
"confidence": float(m.get('confidence', 1.0)),
|
697 |
+
"method": m.get("method", "medsam_box"),
|
698 |
"polygons": _mask_to_polygons(mask_np)
|
699 |
}
|
700 |
if include_raw_masks:
|
|
|
702 |
segmentations.append(seg_entry)
|
703 |
masks_for_overlay.append(m)
|
704 |
|
705 |
+
# Run MedSAM for provided points by converting to bbox
|
706 |
+
for item in parsed_points:
|
707 |
+
try:
|
708 |
+
# Expect item like {"points": [[x,y],...]} or [ [x,y], ... ]
|
709 |
+
pts = item.get("points") if isinstance(item, dict) else item
|
710 |
+
pts_np = np.array(pts)
|
711 |
+
x_min, y_min = pts_np.min(axis=0)
|
712 |
+
x_max, y_max = pts_np.max(axis=0)
|
713 |
+
pad = 20
|
714 |
+
H, W = _medsam.current_image.shape[:2]
|
715 |
+
bbox = [max(0, x_min - pad), max(0, y_min - pad), min(W - 1, x_max + pad), min(H - 1, y_max + pad)]
|
716 |
+
m = _medsam.segment_with_box(bbox)
|
717 |
+
if m is None or not isinstance(m.get('mask'), np.ndarray):
|
718 |
+
continue
|
719 |
+
mask_np = m['mask'].astype(np.uint8)
|
720 |
+
seg_entry = {
|
721 |
+
"confidence": float(m.get('confidence', 1.0)),
|
722 |
+
"method": m.get("method", "medsam_points_box"),
|
723 |
+
"polygons": _mask_to_polygons(mask_np)
|
724 |
+
}
|
725 |
+
if include_raw_masks:
|
726 |
+
seg_entry["mask"] = mask_np.tolist()
|
727 |
+
segmentations.append(seg_entry)
|
728 |
+
masks_for_overlay.append(m)
|
729 |
+
except Exception:
|
730 |
+
continue
|
731 |
+
|
732 |
W, H = pil_img.size
|
733 |
base_result["medsam"] = {
|
734 |
"available": True,
|
|
|
789 |
elem_id="image-input"
|
790 |
)
|
791 |
include_raw_masks_cb = gr.Checkbox(value=False, visible=False, elem_id="include-raw-masks")
|
792 |
+
bboxes_tb = gr.Textbox(value="", visible=False, elem_id="bboxes-json")
|
793 |
+
points_tb = gr.Textbox(value="", visible=False, elem_id="points-json")
|
794 |
|
795 |
# Analyze button (single)
|
796 |
analyze_btn = gr.Button(
|
|
|
822 |
api_name="/predict" # ✅ Standard API name that gradio_client expects
|
823 |
)
|
824 |
|
825 |
+
# MedSAM step (prompt-only). If no prompts, it will skip
|
826 |
analyze_event.then(
|
827 |
fn=analyze_with_medsam,
|
828 |
+
inputs=[result_output, image_input, include_raw_masks_cb, bboxes_tb, points_tb],
|
829 |
outputs=[result_output, overlay_output],
|
830 |
)
|
831 |
|