hanszhu commited on
Commit
3b0d9c5
·
1 Parent(s): d1c88af

feat(medsam): follow original behavior; remove SAM auto from inference path; MedSAM boxes only (max 8)

Browse files
Files changed (1) hide show
  1. app.py +12 -46
app.py CHANGED
@@ -692,52 +692,18 @@ def analyze_with_medsam(base_result, image):
692
 
693
  # AUTO segmentation path
694
  try:
695
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
696
- import cv2 as _cv2
697
- # If ViT-H checkpoint present, use SAM automatic mask generator (download if missing)
698
- vit_h_ckpt = os.path.join(HF_CACHE_DIR, "sam_vit_h_4b8939.pth")
699
- if not os.path.exists(vit_h_ckpt):
700
- try:
701
- from huggingface_hub import hf_hub_download
702
- vit_h_ckpt = hf_hub_download(
703
- repo_id="Aniketg6/SAM",
704
- filename="sam_vit_h_4b8939.pth",
705
- cache_dir=HF_CACHE_DIR
706
- )
707
- print(f"✅ Downloaded SAM ViT-H checkpoint to: {vit_h_ckpt}")
708
- except Exception as dlh:
709
- print(f"⚠ Failed to download SAM ViT-H checkpoint: {dlh}")
710
- if os.path.exists(vit_h_ckpt):
711
- img_bgr = _cv2.imread(img_path)
712
- sam = sam_model_registry["vit_h"](checkpoint=vit_h_ckpt)
713
- mask_generator = SamAutomaticMaskGenerator(sam)
714
- masks = mask_generator.generate(img_bgr)
715
- # Keep top-12 masks by stability_score
716
- masks = sorted(masks, key=lambda m: m.get('stability_score', 0), reverse=True)[:12]
717
- for m in masks:
718
- seg = m.get('segmentation', None)
719
- if seg is None:
720
- continue
721
- seg_u8 = seg.astype(np.uint8)
722
- segmentations.append({
723
- "mask": seg_u8.tolist(),
724
- "confidence": float(m.get('stability_score', 1.0)),
725
- "method": "sam_auto"
726
- })
727
- masks_for_overlay.append({"mask": seg_u8})
728
- else:
729
- # Fallback: derive candidate boxes and run MedSAM per box
730
- cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=8, min_area=200)
731
- for bbox in cand_bboxes:
732
- m = _medsam.segment_with_box(bbox)
733
- if m is None or not isinstance(m.get('mask'), np.ndarray):
734
- continue
735
- segmentations.append({
736
- "mask": m['mask'].astype(np.uint8).tolist(),
737
- "confidence": float(m.get('confidence', 1.0)),
738
- "method": m.get("method", "medsam_box_auto")
739
- })
740
- masks_for_overlay.append(m)
741
  except Exception as auto_e:
742
  print(f"Automatic MedSAM segmentation failed: {auto_e}")
743
 
 
692
 
693
  # AUTO segmentation path
694
  try:
695
+ # Follow original behavior: use MedSAM with box prompts; no SAM auto in main path
696
+ cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=8, min_area=200)
697
+ for bbox in cand_bboxes:
698
+ m = _medsam.segment_with_box(bbox)
699
+ if m is None or not isinstance(m.get('mask'), np.ndarray):
700
+ continue
701
+ segmentations.append({
702
+ "mask": m['mask'].astype(np.uint8).tolist(),
703
+ "confidence": float(m.get('confidence', 1.0)),
704
+ "method": m.get("method", "medsam_box_auto")
705
+ })
706
+ masks_for_overlay.append(m)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  except Exception as auto_e:
708
  print(f"Automatic MedSAM segmentation failed: {auto_e}")
709