hanszhu commited on
Commit
a39d1c3
·
1 Parent(s): eb4d305

deploy(space): push working Gradio app with API /predict, MedSAM auto-overlay, HF model downloads, cleaned requirements

Browse files
Files changed (3) hide show
  1. Dockerfile +0 -29
  2. README.md +1 -1
  3. app.py +77 -119
Dockerfile DELETED
@@ -1,29 +0,0 @@
1
- FROM python:3.10-slim
2
-
3
- ENV DEBIAN_FRONTEND=noninteractive \
4
- PIP_NO_CACHE_DIR=1 \
5
- MPLBACKEND=Agg \
6
- MIM_IGNORE_INSTALL_PYTORCH=1
7
-
8
- RUN apt-get update && apt-get install -y --no-install-recommends \
9
- libgl1 libglib2.0-0 git && \
10
- rm -rf /var/lib/apt/lists/*
11
-
12
- WORKDIR /app
13
-
14
- COPY requirements.txt /app/requirements.txt
15
-
16
- # Install pip deps and the mm stack with openmim
17
- RUN python -m pip install -U pip openmim && \
18
- pip install -r requirements.txt && \
19
- mim install "mmengine==0.10.4" && \
20
- mim install "mmcv==2.1.0" && \
21
- mim install "mmdet==3.3.0" && \
22
- pip install git+https://github.com/facebookresearch/segment-anything.git
23
-
24
- # Copy the rest of the application
25
- COPY . /app
26
-
27
- EXPOSE 7860
28
-
29
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐢
4
  colorFrom: purple
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: purple
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -47,8 +47,16 @@ class MedSAMIntegrator:
47
  import segment_anything # noqa: F401
48
  return True
49
  except Exception as e:
50
- print(f"⚠ segment_anything not available: {e}. It must be installed at build time (Dockerfile).")
51
- return False
 
 
 
 
 
 
 
 
52
 
53
  def _load_medsam_model(self):
54
  try:
@@ -199,48 +207,6 @@ class MedSAMIntegrator:
199
  # Single global instance
200
  _medsam = MedSAMIntegrator()
201
 
202
- # Cache for SAM automatic mask generator
203
- _sam_auto_generator = None
204
- _sam_auto_ckpt_path = None
205
-
206
-
207
- def _get_sam_generator():
208
- """Load and cache SAM ViT-H automatic mask generator with faster params if checkpoint exists."""
209
- global _sam_auto_generator, _sam_auto_ckpt_path
210
- if _sam_auto_generator is not None:
211
- return _sam_auto_generator
212
- try:
213
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
214
- ckpt = "models/sam_vit_h_4b8939.pth"
215
- if not os.path.exists(ckpt):
216
- try:
217
- from huggingface_hub import hf_hub_download
218
- ckpt = hf_hub_download(
219
- repo_id="Aniketg6/SAM",
220
- filename="sam_vit_h_4b8939.pth",
221
- cache_dir="./models"
222
- )
223
- print(f"✅ Downloaded SAM ViT-H checkpoint to: {ckpt}")
224
- except Exception as e:
225
- print(f"⚠ Failed to download SAM ViT-H checkpoint: {e}")
226
- return None
227
- _sam_auto_ckpt_path = ckpt
228
- sam = sam_model_registry["vit_h"](checkpoint=ckpt)
229
- # Speed-tuned generator params
230
- _sam_auto_generator = SamAutomaticMaskGenerator(
231
- sam,
232
- points_per_side=16,
233
- pred_iou_thresh=0.88,
234
- stability_score_thresh=0.9,
235
- crop_n_layers=0,
236
- box_nms_thresh=0.7,
237
- min_mask_region_area=512 # filter tiny masks
238
- )
239
- return _sam_auto_generator
240
- except Exception as e:
241
- print(f"_get_sam_generator failed: {e}")
242
- return None
243
-
244
 
245
  def _extract_bboxes_from_mmdet_result(det_result):
246
  """Extract Nx4 xyxy bboxes from various MMDet result formats."""
@@ -668,54 +634,46 @@ def analyze(image):
668
  # Chart Element Detection (Cascade R-CNN)
669
  if element_model is not None:
670
  try:
671
- # If medical image, skip heavy MMDet to speed up
672
- if isinstance(result.get("chart_type_label"), str) and result["chart_type_label"].lower() == "medical image":
673
- result["element_result"] = "skipped_for_medical"
 
 
 
 
 
 
 
 
 
674
  else:
675
- # Convert PIL image to numpy array for MMDetection
676
- np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL → BGR
677
-
678
- element_result = inference_detector(element_model, np_img)
679
-
680
- # Convert result to more API-friendly format
681
- if isinstance(element_result, tuple):
682
- bbox_result, segm_result = element_result
683
- element_data = {
684
- "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
685
- "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
686
- }
687
- else:
688
- element_data = str(element_result)
689
-
690
- result["element_result"] = element_data
691
- result["status"] = "Chart classification + element detection completed"
692
  except Exception as e:
693
  result["element_result"] = f"Error: {str(e)}"
694
 
695
  # Chart Data Point Segmentation (Mask R-CNN)
696
  if datapoint_model is not None:
697
  try:
698
- # If medical image, skip heavy MMDet to speed up
699
- if isinstance(result.get("chart_type_label"), str) and result["chart_type_label"].lower() == "medical image":
700
- result["datapoint_result"] = "skipped_for_medical"
 
 
 
 
 
 
 
 
 
701
  else:
702
- # Convert PIL image to numpy array for MMDetection
703
- np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL → BGR
704
-
705
- datapoint_result = inference_detector(datapoint_model, np_img)
706
-
707
- # Convert result to more API-friendly format
708
- if isinstance(datapoint_result, tuple):
709
- bbox_result, segm_result = datapoint_result
710
- datapoint_data = {
711
- "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
712
- "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
713
- }
714
- else:
715
- datapoint_data = str(datapoint_result)
716
-
717
- result["datapoint_result"] = datapoint_data
718
- result["status"] = "Full analysis completed"
719
  except Exception as e:
720
  result["datapoint_result"] = f"Error: {str(e)}"
721
 
@@ -744,35 +702,46 @@ def analyze_with_medsam(base_result, image):
744
  if not isinstance(base_result, dict):
745
  return base_result, None
746
  label = str(base_result.get("chart_type_label", "")).strip().lower()
747
- if label != "medical image":
748
  return base_result, None
749
 
750
  pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
751
  if pil_img is None:
752
  return base_result, None
753
 
 
 
 
 
 
 
 
 
754
  segmentations = []
755
  masks_for_overlay = []
756
 
757
- # Try fast SAM generator first; avoid MedSAM embedding when SAM is available
758
- gen = _get_sam_generator()
759
- if gen is not None and _sam_auto_ckpt_path is not None and os.path.exists(_sam_auto_ckpt_path):
760
- try:
761
- import cv2 as _cv2
762
- img_path = image if isinstance(image, str) else None
763
- if img_path is None:
764
- tmp_path = "./_tmp_input_image.png"
765
- pil_img.save(tmp_path)
766
- img_path = tmp_path
 
 
 
 
 
 
 
 
767
  img_bgr = _cv2.imread(img_path)
768
- masks = gen.generate(img_bgr)
769
- # Keep top-K by stability_score or area
770
- def _score(m):
771
- s = float(m.get('stability_score', 0.0))
772
- seg = m.get('segmentation', None)
773
- area = int(seg.sum()) if isinstance(seg, np.ndarray) else 0
774
- return (s, area)
775
- masks = sorted(masks, key=_score, reverse=True)[:8]
776
  for m in masks:
777
  seg = m.get('segmentation', None)
778
  if seg is None:
@@ -784,20 +753,9 @@ def analyze_with_medsam(base_result, image):
784
  "method": "sam_auto"
785
  })
786
  masks_for_overlay.append({"mask": seg_u8})
787
- except Exception as e:
788
- print(f"SAM generator segmentation failed: {e}")
789
-
790
- # Fallback to MedSAM boxes only if nothing produced
791
- if not segmentations and _medsam.is_available():
792
- try:
793
- # Prepare embedding once
794
- img_path = image if isinstance(image, str) else None
795
- if img_path is None:
796
- tmp_path = "./_tmp_input_image.png"
797
- pil_img.save(tmp_path)
798
- img_path = tmp_path
799
- _medsam.load_image(img_path)
800
- cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=5, min_area=400)
801
  for bbox in cand_bboxes:
802
  m = _medsam.segment_with_box(bbox)
803
  if m is None or not isinstance(m.get('mask'), np.ndarray):
@@ -808,8 +766,8 @@ def analyze_with_medsam(base_result, image):
808
  "method": m.get("method", "medsam_box_auto")
809
  })
810
  masks_for_overlay.append(m)
811
- except Exception as auto_e:
812
- print(f"MedSAM fallback segmentation failed: {auto_e}")
813
 
814
  W, H = pil_img.size
815
  base_result["medsam"] = {
 
47
  import segment_anything # noqa: F401
48
  return True
49
  except Exception as e:
50
+ print(f"⚠ segment_anything not available: {e}. Attempting install from Git...")
51
+ try:
52
+ import subprocess, sys
53
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
54
+ import segment_anything # noqa: F401
55
+ print("✓ segment_anything installed")
56
+ return True
57
+ except Exception as install_err:
58
+ print(f"❌ Failed to install segment_anything: {install_err}")
59
+ return False
60
 
61
  def _load_medsam_model(self):
62
  try:
 
207
  # Single global instance
208
  _medsam = MedSAMIntegrator()
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  def _extract_bboxes_from_mmdet_result(det_result):
212
  """Extract Nx4 xyxy bboxes from various MMDet result formats."""
 
634
  # Chart Element Detection (Cascade R-CNN)
635
  if element_model is not None:
636
  try:
637
+ # Convert PIL image to numpy array for MMDetection
638
+ np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL BGR
639
+
640
+ element_result = inference_detector(element_model, np_img)
641
+
642
+ # Convert result to more API-friendly format
643
+ if isinstance(element_result, tuple):
644
+ bbox_result, segm_result = element_result
645
+ element_data = {
646
+ "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
647
+ "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
648
+ }
649
  else:
650
+ element_data = str(element_result)
651
+
652
+ result["element_result"] = element_data
653
+ result["status"] = "Chart classification + element detection completed"
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  except Exception as e:
655
  result["element_result"] = f"Error: {str(e)}"
656
 
657
  # Chart Data Point Segmentation (Mask R-CNN)
658
  if datapoint_model is not None:
659
  try:
660
+ # Convert PIL image to numpy array for MMDetection
661
+ np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL BGR
662
+
663
+ datapoint_result = inference_detector(datapoint_model, np_img)
664
+
665
+ # Convert result to more API-friendly format
666
+ if isinstance(datapoint_result, tuple):
667
+ bbox_result, segm_result = datapoint_result
668
+ datapoint_data = {
669
+ "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
670
+ "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
671
+ }
672
  else:
673
+ datapoint_data = str(datapoint_result)
674
+
675
+ result["datapoint_result"] = datapoint_data
676
+ result["status"] = "Full analysis completed"
 
 
 
 
 
 
 
 
 
 
 
 
 
677
  except Exception as e:
678
  result["datapoint_result"] = f"Error: {str(e)}"
679
 
 
702
  if not isinstance(base_result, dict):
703
  return base_result, None
704
  label = str(base_result.get("chart_type_label", "")).strip().lower()
705
+ if label != "medical image" or not _medsam.is_available():
706
  return base_result, None
707
 
708
  pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
709
  if pil_img is None:
710
  return base_result, None
711
 
712
+ # Prepare embedding
713
+ img_path = image if isinstance(image, str) else None
714
+ if img_path is None:
715
+ tmp_path = "./_tmp_input_image.png"
716
+ pil_img.save(tmp_path)
717
+ img_path = tmp_path
718
+ _medsam.load_image(img_path)
719
+
720
  segmentations = []
721
  masks_for_overlay = []
722
 
723
+ # AUTO segmentation path
724
+ try:
725
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
726
+ import cv2 as _cv2
727
+ # If ViT-H checkpoint present, use SAM automatic mask generator (download if missing)
728
+ vit_h_ckpt = "models/sam_vit_h_4b8939.pth"
729
+ if not os.path.exists(vit_h_ckpt):
730
+ try:
731
+ from huggingface_hub import hf_hub_download
732
+ vit_h_ckpt = hf_hub_download(
733
+ repo_id="Aniketg6/SAM",
734
+ filename="sam_vit_h_4b8939.pth",
735
+ cache_dir="./models"
736
+ )
737
+ print(f"✅ Downloaded SAM ViT-H checkpoint to: {vit_h_ckpt}")
738
+ except Exception as dlh:
739
+ print(f"⚠ Failed to download SAM ViT-H checkpoint: {dlh}")
740
+ if os.path.exists(vit_h_ckpt):
741
  img_bgr = _cv2.imread(img_path)
742
+ sam = sam_model_registry["vit_h"](checkpoint=vit_h_ckpt)
743
+ mask_generator = SamAutomaticMaskGenerator(sam)
744
+ masks = mask_generator.generate(img_bgr)
 
 
 
 
 
745
  for m in masks:
746
  seg = m.get('segmentation', None)
747
  if seg is None:
 
753
  "method": "sam_auto"
754
  })
755
  masks_for_overlay.append({"mask": seg_u8})
756
+ else:
757
+ # Fallback: derive candidate boxes and run MedSAM per box
758
+ cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=20, min_area=200)
 
 
 
 
 
 
 
 
 
 
 
759
  for bbox in cand_bboxes:
760
  m = _medsam.segment_with_box(bbox)
761
  if m is None or not isinstance(m.get('mask'), np.ndarray):
 
766
  "method": m.get("method", "medsam_box_auto")
767
  })
768
  masks_for_overlay.append(m)
769
+ except Exception as auto_e:
770
+ print(f"Automatic MedSAM segmentation failed: {auto_e}")
771
 
772
  W, H = pil_img.size
773
  base_result["medsam"] = {