hanszhu commited on
Commit
12b213c
·
1 Parent(s): eeb48d1

feat(medsam): prompt-only segmentation (bboxes/points JSON); skip if none; polygons by default; optional raw masks

Browse files
Files changed (1) hide show
  1. app.py +54 -8
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import numpy as np
7
  import cv2
8
  import time
 
9
  import traceback
10
 
11
  # Simple timestamped logger
@@ -643,7 +644,7 @@ def analyze(image):
643
  return {"error": "Internal error in analyze"}
644
 
645
 
646
- def analyze_with_medsam(base_result, image, include_raw_masks=False):
647
  try:
648
  log("analyze_with_medsam: start")
649
  if not isinstance(base_result, dict):
@@ -664,20 +665,36 @@ def analyze_with_medsam(base_result, image, include_raw_masks=False):
664
  img_path = tmp_path
665
  _medsam.load_image(img_path)
666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  segmentations = []
668
  masks_for_overlay = []
669
 
670
- # MedSAM over candidate boxes (original behavior)
671
- cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=20, min_area=100)
672
- log(f"analyze_with_medsam: candidate boxes={len(cand_bboxes)}")
673
- for bbox in cand_bboxes:
674
  m = _medsam.segment_with_box(bbox)
675
  if m is None or not isinstance(m.get('mask'), np.ndarray):
676
  continue
677
  mask_np = m['mask'].astype(np.uint8)
678
  seg_entry = {
679
  "confidence": float(m.get('confidence', 1.0)),
680
- "method": m.get("method", "medsam_box_auto"),
681
  "polygons": _mask_to_polygons(mask_np)
682
  }
683
  if include_raw_masks:
@@ -685,6 +702,33 @@ def analyze_with_medsam(base_result, image, include_raw_masks=False):
685
  segmentations.append(seg_entry)
686
  masks_for_overlay.append(m)
687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
  W, H = pil_img.size
689
  base_result["medsam"] = {
690
  "available": True,
@@ -745,6 +789,8 @@ with gr.Blocks(
745
  elem_id="image-input"
746
  )
747
  include_raw_masks_cb = gr.Checkbox(value=False, visible=False, elem_id="include-raw-masks")
 
 
748
 
749
  # Analyze button (single)
750
  analyze_btn = gr.Button(
@@ -776,10 +822,10 @@ with gr.Blocks(
776
  api_name="/predict" # ✅ Standard API name that gradio_client expects
777
  )
778
 
779
- # Automatic overlay generation step for medical images
780
  analyze_event.then(
781
  fn=analyze_with_medsam,
782
- inputs=[result_output, image_input, include_raw_masks_cb],
783
  outputs=[result_output, overlay_output],
784
  )
785
 
 
6
  import numpy as np
7
  import cv2
8
  import time
9
+ import json
10
  import traceback
11
 
12
  # Simple timestamped logger
 
644
  return {"error": "Internal error in analyze"}
645
 
646
 
647
+ def analyze_with_medsam(base_result, image, include_raw_masks=False, bboxes_json="", points_json=""):
648
  try:
649
  log("analyze_with_medsam: start")
650
  if not isinstance(base_result, dict):
 
665
  img_path = tmp_path
666
  _medsam.load_image(img_path)
667
 
668
+ # Parse prompts
669
+ parsed_bboxes = []
670
+ parsed_points = []
671
+ try:
672
+ if bboxes_json:
673
+ parsed_bboxes = json.loads(bboxes_json)
674
+ if points_json:
675
+ parsed_points = json.loads(points_json)
676
+ except Exception:
677
+ log("analyze_with_medsam: failed to parse prompts JSON")
678
+
679
+ # If no prompts provided, skip (follow original behavior)
680
+ if not parsed_bboxes and not parsed_points:
681
+ log("analyze_with_medsam: no prompts provided; skipping segmentation")
682
+ return base_result, None
683
+
684
  segmentations = []
685
  masks_for_overlay = []
686
 
687
+ # Run MedSAM for provided boxes
688
+ for bbox in parsed_bboxes:
689
+ if not isinstance(bbox, (list, tuple)) or len(bbox) != 4:
690
+ continue
691
  m = _medsam.segment_with_box(bbox)
692
  if m is None or not isinstance(m.get('mask'), np.ndarray):
693
  continue
694
  mask_np = m['mask'].astype(np.uint8)
695
  seg_entry = {
696
  "confidence": float(m.get('confidence', 1.0)),
697
+ "method": m.get("method", "medsam_box"),
698
  "polygons": _mask_to_polygons(mask_np)
699
  }
700
  if include_raw_masks:
 
702
  segmentations.append(seg_entry)
703
  masks_for_overlay.append(m)
704
 
705
+ # Run MedSAM for provided points by converting to bbox
706
+ for item in parsed_points:
707
+ try:
708
+ # Expect item like {"points": [[x,y],...]} or [ [x,y], ... ]
709
+ pts = item.get("points") if isinstance(item, dict) else item
710
+ pts_np = np.array(pts)
711
+ x_min, y_min = pts_np.min(axis=0)
712
+ x_max, y_max = pts_np.max(axis=0)
713
+ pad = 20
714
+ H, W = _medsam.current_image.shape[:2]
715
+ bbox = [max(0, x_min - pad), max(0, y_min - pad), min(W - 1, x_max + pad), min(H - 1, y_max + pad)]
716
+ m = _medsam.segment_with_box(bbox)
717
+ if m is None or not isinstance(m.get('mask'), np.ndarray):
718
+ continue
719
+ mask_np = m['mask'].astype(np.uint8)
720
+ seg_entry = {
721
+ "confidence": float(m.get('confidence', 1.0)),
722
+ "method": m.get("method", "medsam_points_box"),
723
+ "polygons": _mask_to_polygons(mask_np)
724
+ }
725
+ if include_raw_masks:
726
+ seg_entry["mask"] = mask_np.tolist()
727
+ segmentations.append(seg_entry)
728
+ masks_for_overlay.append(m)
729
+ except Exception:
730
+ continue
731
+
732
  W, H = pil_img.size
733
  base_result["medsam"] = {
734
  "available": True,
 
789
  elem_id="image-input"
790
  )
791
  include_raw_masks_cb = gr.Checkbox(value=False, visible=False, elem_id="include-raw-masks")
792
+ bboxes_tb = gr.Textbox(value="", visible=False, elem_id="bboxes-json")
793
+ points_tb = gr.Textbox(value="", visible=False, elem_id="points-json")
794
 
795
  # Analyze button (single)
796
  analyze_btn = gr.Button(
 
822
  api_name="/predict" # ✅ Standard API name that gradio_client expects
823
  )
824
 
825
+ # MedSAM step (prompt-only). If no prompts, it will skip
826
  analyze_event.then(
827
  fn=analyze_with_medsam,
828
+ inputs=[result_output, image_input, include_raw_masks_cb, bboxes_tb, points_tb],
829
  outputs=[result_output, overlay_output],
830
  )
831