Spaces:
Sleeping
Sleeping
deploy(space): push working Gradio app with API /predict, MedSAM auto-overlay, HF model downloads, cleaned requirements
Browse files- Dockerfile +0 -29
- README.md +1 -1
- app.py +77 -119
Dockerfile
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
FROM python:3.10-slim
|
2 |
-
|
3 |
-
ENV DEBIAN_FRONTEND=noninteractive \
|
4 |
-
PIP_NO_CACHE_DIR=1 \
|
5 |
-
MPLBACKEND=Agg \
|
6 |
-
MIM_IGNORE_INSTALL_PYTORCH=1
|
7 |
-
|
8 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \
|
9 |
-
libgl1 libglib2.0-0 git && \
|
10 |
-
rm -rf /var/lib/apt/lists/*
|
11 |
-
|
12 |
-
WORKDIR /app
|
13 |
-
|
14 |
-
COPY requirements.txt /app/requirements.txt
|
15 |
-
|
16 |
-
# Install pip deps and the mm stack with openmim
|
17 |
-
RUN python -m pip install -U pip openmim && \
|
18 |
-
pip install -r requirements.txt && \
|
19 |
-
mim install "mmengine==0.10.4" && \
|
20 |
-
mim install "mmcv==2.1.0" && \
|
21 |
-
mim install "mmdet==3.3.0" && \
|
22 |
-
pip install git+https://github.com/facebookresearch/segment-anything.git
|
23 |
-
|
24 |
-
# Copy the rest of the application
|
25 |
-
COPY . /app
|
26 |
-
|
27 |
-
EXPOSE 7860
|
28 |
-
|
29 |
-
CMD ["python", "app.py"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🐢
|
|
4 |
colorFrom: purple
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
4 |
colorFrom: purple
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.38.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
app.py
CHANGED
@@ -47,8 +47,16 @@ class MedSAMIntegrator:
|
|
47 |
import segment_anything # noqa: F401
|
48 |
return True
|
49 |
except Exception as e:
|
50 |
-
print(f"⚠ segment_anything not available: {e}.
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
def _load_medsam_model(self):
|
54 |
try:
|
@@ -199,48 +207,6 @@ class MedSAMIntegrator:
|
|
199 |
# Single global instance
|
200 |
_medsam = MedSAMIntegrator()
|
201 |
|
202 |
-
# Cache for SAM automatic mask generator
|
203 |
-
_sam_auto_generator = None
|
204 |
-
_sam_auto_ckpt_path = None
|
205 |
-
|
206 |
-
|
207 |
-
def _get_sam_generator():
|
208 |
-
"""Load and cache SAM ViT-H automatic mask generator with faster params if checkpoint exists."""
|
209 |
-
global _sam_auto_generator, _sam_auto_ckpt_path
|
210 |
-
if _sam_auto_generator is not None:
|
211 |
-
return _sam_auto_generator
|
212 |
-
try:
|
213 |
-
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
214 |
-
ckpt = "models/sam_vit_h_4b8939.pth"
|
215 |
-
if not os.path.exists(ckpt):
|
216 |
-
try:
|
217 |
-
from huggingface_hub import hf_hub_download
|
218 |
-
ckpt = hf_hub_download(
|
219 |
-
repo_id="Aniketg6/SAM",
|
220 |
-
filename="sam_vit_h_4b8939.pth",
|
221 |
-
cache_dir="./models"
|
222 |
-
)
|
223 |
-
print(f"✅ Downloaded SAM ViT-H checkpoint to: {ckpt}")
|
224 |
-
except Exception as e:
|
225 |
-
print(f"⚠ Failed to download SAM ViT-H checkpoint: {e}")
|
226 |
-
return None
|
227 |
-
_sam_auto_ckpt_path = ckpt
|
228 |
-
sam = sam_model_registry["vit_h"](checkpoint=ckpt)
|
229 |
-
# Speed-tuned generator params
|
230 |
-
_sam_auto_generator = SamAutomaticMaskGenerator(
|
231 |
-
sam,
|
232 |
-
points_per_side=16,
|
233 |
-
pred_iou_thresh=0.88,
|
234 |
-
stability_score_thresh=0.9,
|
235 |
-
crop_n_layers=0,
|
236 |
-
box_nms_thresh=0.7,
|
237 |
-
min_mask_region_area=512 # filter tiny masks
|
238 |
-
)
|
239 |
-
return _sam_auto_generator
|
240 |
-
except Exception as e:
|
241 |
-
print(f"_get_sam_generator failed: {e}")
|
242 |
-
return None
|
243 |
-
|
244 |
|
245 |
def _extract_bboxes_from_mmdet_result(det_result):
|
246 |
"""Extract Nx4 xyxy bboxes from various MMDet result formats."""
|
@@ -668,54 +634,46 @@ def analyze(image):
|
|
668 |
# Chart Element Detection (Cascade R-CNN)
|
669 |
if element_model is not None:
|
670 |
try:
|
671 |
-
#
|
672 |
-
|
673 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
else:
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
# Convert result to more API-friendly format
|
681 |
-
if isinstance(element_result, tuple):
|
682 |
-
bbox_result, segm_result = element_result
|
683 |
-
element_data = {
|
684 |
-
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
|
685 |
-
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
|
686 |
-
}
|
687 |
-
else:
|
688 |
-
element_data = str(element_result)
|
689 |
-
|
690 |
-
result["element_result"] = element_data
|
691 |
-
result["status"] = "Chart classification + element detection completed"
|
692 |
except Exception as e:
|
693 |
result["element_result"] = f"Error: {str(e)}"
|
694 |
|
695 |
# Chart Data Point Segmentation (Mask R-CNN)
|
696 |
if datapoint_model is not None:
|
697 |
try:
|
698 |
-
#
|
699 |
-
|
700 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
else:
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
# Convert result to more API-friendly format
|
708 |
-
if isinstance(datapoint_result, tuple):
|
709 |
-
bbox_result, segm_result = datapoint_result
|
710 |
-
datapoint_data = {
|
711 |
-
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
|
712 |
-
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
|
713 |
-
}
|
714 |
-
else:
|
715 |
-
datapoint_data = str(datapoint_result)
|
716 |
-
|
717 |
-
result["datapoint_result"] = datapoint_data
|
718 |
-
result["status"] = "Full analysis completed"
|
719 |
except Exception as e:
|
720 |
result["datapoint_result"] = f"Error: {str(e)}"
|
721 |
|
@@ -744,35 +702,46 @@ def analyze_with_medsam(base_result, image):
|
|
744 |
if not isinstance(base_result, dict):
|
745 |
return base_result, None
|
746 |
label = str(base_result.get("chart_type_label", "")).strip().lower()
|
747 |
-
if label != "medical image":
|
748 |
return base_result, None
|
749 |
|
750 |
pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
|
751 |
if pil_img is None:
|
752 |
return base_result, None
|
753 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
754 |
segmentations = []
|
755 |
masks_for_overlay = []
|
756 |
|
757 |
-
#
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
767 |
img_bgr = _cv2.imread(img_path)
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
s = float(m.get('stability_score', 0.0))
|
772 |
-
seg = m.get('segmentation', None)
|
773 |
-
area = int(seg.sum()) if isinstance(seg, np.ndarray) else 0
|
774 |
-
return (s, area)
|
775 |
-
masks = sorted(masks, key=_score, reverse=True)[:8]
|
776 |
for m in masks:
|
777 |
seg = m.get('segmentation', None)
|
778 |
if seg is None:
|
@@ -784,20 +753,9 @@ def analyze_with_medsam(base_result, image):
|
|
784 |
"method": "sam_auto"
|
785 |
})
|
786 |
masks_for_overlay.append({"mask": seg_u8})
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
# Fallback to MedSAM boxes only if nothing produced
|
791 |
-
if not segmentations and _medsam.is_available():
|
792 |
-
try:
|
793 |
-
# Prepare embedding once
|
794 |
-
img_path = image if isinstance(image, str) else None
|
795 |
-
if img_path is None:
|
796 |
-
tmp_path = "./_tmp_input_image.png"
|
797 |
-
pil_img.save(tmp_path)
|
798 |
-
img_path = tmp_path
|
799 |
-
_medsam.load_image(img_path)
|
800 |
-
cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=5, min_area=400)
|
801 |
for bbox in cand_bboxes:
|
802 |
m = _medsam.segment_with_box(bbox)
|
803 |
if m is None or not isinstance(m.get('mask'), np.ndarray):
|
@@ -808,8 +766,8 @@ def analyze_with_medsam(base_result, image):
|
|
808 |
"method": m.get("method", "medsam_box_auto")
|
809 |
})
|
810 |
masks_for_overlay.append(m)
|
811 |
-
|
812 |
-
|
813 |
|
814 |
W, H = pil_img.size
|
815 |
base_result["medsam"] = {
|
|
|
47 |
import segment_anything # noqa: F401
|
48 |
return True
|
49 |
except Exception as e:
|
50 |
+
print(f"⚠ segment_anything not available: {e}. Attempting install from Git...")
|
51 |
+
try:
|
52 |
+
import subprocess, sys
|
53 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
|
54 |
+
import segment_anything # noqa: F401
|
55 |
+
print("✓ segment_anything installed")
|
56 |
+
return True
|
57 |
+
except Exception as install_err:
|
58 |
+
print(f"❌ Failed to install segment_anything: {install_err}")
|
59 |
+
return False
|
60 |
|
61 |
def _load_medsam_model(self):
|
62 |
try:
|
|
|
207 |
# Single global instance
|
208 |
_medsam = MedSAMIntegrator()
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
def _extract_bboxes_from_mmdet_result(det_result):
|
212 |
"""Extract Nx4 xyxy bboxes from various MMDet result formats."""
|
|
|
634 |
# Chart Element Detection (Cascade R-CNN)
|
635 |
if element_model is not None:
|
636 |
try:
|
637 |
+
# Convert PIL image to numpy array for MMDetection
|
638 |
+
np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL → BGR
|
639 |
+
|
640 |
+
element_result = inference_detector(element_model, np_img)
|
641 |
+
|
642 |
+
# Convert result to more API-friendly format
|
643 |
+
if isinstance(element_result, tuple):
|
644 |
+
bbox_result, segm_result = element_result
|
645 |
+
element_data = {
|
646 |
+
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
|
647 |
+
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
|
648 |
+
}
|
649 |
else:
|
650 |
+
element_data = str(element_result)
|
651 |
+
|
652 |
+
result["element_result"] = element_data
|
653 |
+
result["status"] = "Chart classification + element detection completed"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
except Exception as e:
|
655 |
result["element_result"] = f"Error: {str(e)}"
|
656 |
|
657 |
# Chart Data Point Segmentation (Mask R-CNN)
|
658 |
if datapoint_model is not None:
|
659 |
try:
|
660 |
+
# Convert PIL image to numpy array for MMDetection
|
661 |
+
np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL → BGR
|
662 |
+
|
663 |
+
datapoint_result = inference_detector(datapoint_model, np_img)
|
664 |
+
|
665 |
+
# Convert result to more API-friendly format
|
666 |
+
if isinstance(datapoint_result, tuple):
|
667 |
+
bbox_result, segm_result = datapoint_result
|
668 |
+
datapoint_data = {
|
669 |
+
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
|
670 |
+
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
|
671 |
+
}
|
672 |
else:
|
673 |
+
datapoint_data = str(datapoint_result)
|
674 |
+
|
675 |
+
result["datapoint_result"] = datapoint_data
|
676 |
+
result["status"] = "Full analysis completed"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
except Exception as e:
|
678 |
result["datapoint_result"] = f"Error: {str(e)}"
|
679 |
|
|
|
702 |
if not isinstance(base_result, dict):
|
703 |
return base_result, None
|
704 |
label = str(base_result.get("chart_type_label", "")).strip().lower()
|
705 |
+
if label != "medical image" or not _medsam.is_available():
|
706 |
return base_result, None
|
707 |
|
708 |
pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
|
709 |
if pil_img is None:
|
710 |
return base_result, None
|
711 |
|
712 |
+
# Prepare embedding
|
713 |
+
img_path = image if isinstance(image, str) else None
|
714 |
+
if img_path is None:
|
715 |
+
tmp_path = "./_tmp_input_image.png"
|
716 |
+
pil_img.save(tmp_path)
|
717 |
+
img_path = tmp_path
|
718 |
+
_medsam.load_image(img_path)
|
719 |
+
|
720 |
segmentations = []
|
721 |
masks_for_overlay = []
|
722 |
|
723 |
+
# AUTO segmentation path
|
724 |
+
try:
|
725 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
726 |
+
import cv2 as _cv2
|
727 |
+
# If ViT-H checkpoint present, use SAM automatic mask generator (download if missing)
|
728 |
+
vit_h_ckpt = "models/sam_vit_h_4b8939.pth"
|
729 |
+
if not os.path.exists(vit_h_ckpt):
|
730 |
+
try:
|
731 |
+
from huggingface_hub import hf_hub_download
|
732 |
+
vit_h_ckpt = hf_hub_download(
|
733 |
+
repo_id="Aniketg6/SAM",
|
734 |
+
filename="sam_vit_h_4b8939.pth",
|
735 |
+
cache_dir="./models"
|
736 |
+
)
|
737 |
+
print(f"✅ Downloaded SAM ViT-H checkpoint to: {vit_h_ckpt}")
|
738 |
+
except Exception as dlh:
|
739 |
+
print(f"⚠ Failed to download SAM ViT-H checkpoint: {dlh}")
|
740 |
+
if os.path.exists(vit_h_ckpt):
|
741 |
img_bgr = _cv2.imread(img_path)
|
742 |
+
sam = sam_model_registry["vit_h"](checkpoint=vit_h_ckpt)
|
743 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
744 |
+
masks = mask_generator.generate(img_bgr)
|
|
|
|
|
|
|
|
|
|
|
745 |
for m in masks:
|
746 |
seg = m.get('segmentation', None)
|
747 |
if seg is None:
|
|
|
753 |
"method": "sam_auto"
|
754 |
})
|
755 |
masks_for_overlay.append({"mask": seg_u8})
|
756 |
+
else:
|
757 |
+
# Fallback: derive candidate boxes and run MedSAM per box
|
758 |
+
cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=20, min_area=200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
759 |
for bbox in cand_bboxes:
|
760 |
m = _medsam.segment_with_box(bbox)
|
761 |
if m is None or not isinstance(m.get('mask'), np.ndarray):
|
|
|
766 |
"method": m.get("method", "medsam_box_auto")
|
767 |
})
|
768 |
masks_for_overlay.append(m)
|
769 |
+
except Exception as auto_e:
|
770 |
+
print(f"Automatic MedSAM segmentation failed: {auto_e}")
|
771 |
|
772 |
W, H = pil_img.size
|
773 |
base_result["medsam"] = {
|