Spaces:
Sleeping
Sleeping
chore(logging): add per-request timestamped logs and exception tracebacks for analyze paths
Browse files
app.py
CHANGED
@@ -5,6 +5,12 @@ from PIL import Image
|
|
5 |
import torch
|
6 |
import numpy as np
|
7 |
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# Writable cache directory for HF downloads
|
10 |
HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/data/hf-cache")
|
@@ -533,153 +539,118 @@ print(f"🔍 datapoint_model: {datapoint_model is not None}")
|
|
533 |
|
534 |
# === Main prediction function ===
|
535 |
def analyze(image):
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
# Handle filepath input (convert to PIL Image)
|
557 |
-
if isinstance(image, str):
|
558 |
-
# It's a filepath, load the image
|
559 |
-
image = Image.open(image).convert("RGB")
|
560 |
-
elif image is None:
|
561 |
-
return {"error": "No image provided"}
|
562 |
-
|
563 |
-
# Ensure we have a PIL Image
|
564 |
-
if not isinstance(image, Image.Image):
|
565 |
-
return {"error": "Invalid image format"}
|
566 |
-
|
567 |
-
result = {
|
568 |
-
"chart_type_id": "Model not available",
|
569 |
-
"chart_type_label": "Model not available",
|
570 |
-
"element_result": "MMDetection models not available",
|
571 |
-
"datapoint_result": "MMDetection models not available",
|
572 |
-
"status": "Basic chart classification only",
|
573 |
-
"processing_time": 0.0,
|
574 |
-
"medsam": {"available": False}
|
575 |
-
}
|
576 |
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
else:
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
element_data = str(element_result)
|
622 |
-
|
623 |
-
result["element_result"] = element_data
|
624 |
-
result["status"] = "Chart classification + element detection completed"
|
625 |
-
except Exception as e:
|
626 |
-
result["element_result"] = f"Error: {str(e)}"
|
627 |
-
|
628 |
-
# Chart Data Point Segmentation (Mask R-CNN)
|
629 |
-
if datapoint_model is not None:
|
630 |
try:
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
except Exception as e:
|
649 |
-
result["datapoint_result"] = f"Error: {str(e)}"
|
650 |
-
|
651 |
-
# If predicted as medical image and MedSAM is available, include mask data (polygons)
|
652 |
-
try:
|
653 |
-
label_lower = str(result.get("chart_type_label", "")).strip().lower()
|
654 |
-
if label_lower == "medical image":
|
655 |
-
if _medsam.is_available():
|
656 |
-
# Indicate availability; masks are generated in then-chain
|
657 |
-
result["medsam"] = {"available": True}
|
658 |
-
else:
|
659 |
-
# Not available; include reason
|
660 |
-
result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"}
|
661 |
-
except Exception as e:
|
662 |
-
print(f"MedSAM JSON augmentation failed: {e}")
|
663 |
-
|
664 |
-
result["processing_time"] = round(time.time() - start_time, 3)
|
665 |
-
return result
|
666 |
|
667 |
|
668 |
def analyze_with_medsam(base_result, image):
|
669 |
-
"""Auto-generate segmentations for medical images using SAM ViT-H if available,
|
670 |
-
otherwise fallback to MedSAM over top-K foreground boxes. Returns updated JSON and overlay image."""
|
671 |
try:
|
|
|
672 |
if not isinstance(base_result, dict):
|
673 |
return base_result, None
|
674 |
label = str(base_result.get("chart_type_label", "")).strip().lower()
|
675 |
if label != "medical image" or not _medsam.is_available():
|
|
|
676 |
return base_result, None
|
677 |
|
678 |
pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
|
679 |
if pil_img is None:
|
680 |
return base_result, None
|
681 |
|
682 |
-
# Prepare embedding
|
683 |
img_path = image if isinstance(image, str) else None
|
684 |
if img_path is None:
|
685 |
tmp_path = "./_tmp_input_image.png"
|
@@ -690,22 +661,19 @@ def analyze_with_medsam(base_result, image):
|
|
690 |
segmentations = []
|
691 |
masks_for_overlay = []
|
692 |
|
693 |
-
#
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
masks_for_overlay.append(m)
|
707 |
-
except Exception as auto_e:
|
708 |
-
print(f"Automatic MedSAM segmentation failed: {auto_e}")
|
709 |
|
710 |
W, H = pil_img.size
|
711 |
base_result["medsam"] = {
|
@@ -715,11 +683,14 @@ def analyze_with_medsam(base_result, image):
|
|
715 |
"segmentations": segmentations,
|
716 |
"num_segments": len(segmentations)
|
717 |
}
|
|
|
718 |
|
719 |
overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None
|
|
|
720 |
return base_result, overlay_img
|
721 |
-
except Exception
|
722 |
-
|
|
|
723 |
return base_result, None
|
724 |
|
725 |
# === Gradio UI with API enhancements ===
|
|
|
5 |
import torch
|
6 |
import numpy as np
|
7 |
import cv2
|
8 |
+
import time
|
9 |
+
import traceback
|
10 |
+
|
11 |
+
# Simple timestamped logger
|
12 |
+
def log(msg: str) -> None:
|
13 |
+
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
|
14 |
|
15 |
# Writable cache directory for HF downloads
|
16 |
HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/data/hf-cache")
|
|
|
539 |
|
540 |
# === Main prediction function ===
|
541 |
def analyze(image):
|
542 |
+
try:
|
543 |
+
log("analyze: start")
|
544 |
+
start_time = time.time()
|
545 |
+
# Handle filepath input
|
546 |
+
if isinstance(image, str):
|
547 |
+
image = Image.open(image).convert("RGB")
|
548 |
+
elif image is None:
|
549 |
+
return {"error": "No image provided"}
|
550 |
+
if not isinstance(image, Image.Image):
|
551 |
+
return {"error": "Invalid image format"}
|
552 |
+
|
553 |
+
result = {
|
554 |
+
"chart_type_id": "Model not available",
|
555 |
+
"chart_type_label": "Model not available",
|
556 |
+
"element_result": "MMDetection models not available",
|
557 |
+
"datapoint_result": "MMDetection models not available",
|
558 |
+
"status": "Basic chart classification only",
|
559 |
+
"processing_time": 0.0,
|
560 |
+
"medsam": {"available": False}
|
561 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
+
# Chart Type Classification
|
564 |
+
if CHART_TYPE_AVAILABLE:
|
565 |
+
try:
|
566 |
+
processed_image = chart_type_processor(image).unsqueeze(0)
|
567 |
+
with torch.no_grad():
|
568 |
+
outputs = chart_type_model(processed_image)
|
569 |
+
logits = outputs if isinstance(outputs, torch.Tensor) else getattr(outputs, 'logits', outputs)
|
570 |
+
predicted_class = logits.argmax(dim=-1).item()
|
571 |
+
result["chart_type_id"] = predicted_class
|
572 |
+
result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})"
|
573 |
+
result["status"] = "Chart classification completed"
|
574 |
+
log(f"analyze: chart_type={result['chart_type_label']} ({result['chart_type_id']})")
|
575 |
+
except Exception:
|
576 |
+
log("analyze: chart classification error")
|
577 |
+
traceback.print_exc()
|
578 |
+
|
579 |
+
# Element Detection
|
580 |
+
if element_model is not None:
|
581 |
+
try:
|
582 |
+
np_img = np.array(image.convert("RGB"))[:, :, ::-1]
|
583 |
+
element_result = inference_detector(element_model, np_img)
|
584 |
+
if isinstance(element_result, tuple):
|
585 |
+
bbox_result, segm_result = element_result
|
586 |
+
element_data = {
|
587 |
+
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
|
588 |
+
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
|
589 |
+
}
|
590 |
else:
|
591 |
+
element_data = str(element_result)
|
592 |
+
result["element_result"] = element_data
|
593 |
+
result["status"] = "Chart classification + element detection completed"
|
594 |
+
log("analyze: element detection done")
|
595 |
+
except Exception:
|
596 |
+
log("analyze: element detection error")
|
597 |
+
traceback.print_exc()
|
598 |
+
|
599 |
+
# Datapoint Segmentation
|
600 |
+
if datapoint_model is not None:
|
601 |
+
try:
|
602 |
+
np_img = np.array(image.convert("RGB"))[:, :, ::-1]
|
603 |
+
datapoint_result = inference_detector(datapoint_model, np_img)
|
604 |
+
if isinstance(datapoint_result, tuple):
|
605 |
+
bbox_result, segm_result = datapoint_result
|
606 |
+
datapoint_data = {
|
607 |
+
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
|
608 |
+
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
|
609 |
+
}
|
610 |
+
else:
|
611 |
+
datapoint_data = str(datapoint_result)
|
612 |
+
result["datapoint_result"] = datapoint_data
|
613 |
+
result["status"] = "Full analysis completed"
|
614 |
+
log("analyze: datapoint segmentation done")
|
615 |
+
except Exception:
|
616 |
+
log("analyze: datapoint segmentation error")
|
617 |
+
traceback.print_exc()
|
618 |
+
|
619 |
+
# MedSAM availability info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
try:
|
621 |
+
label_lower = str(result.get("chart_type_label", "")).strip().lower()
|
622 |
+
if label_lower == "medical image":
|
623 |
+
if _medsam.is_available():
|
624 |
+
result["medsam"] = {"available": True}
|
625 |
+
else:
|
626 |
+
result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"}
|
627 |
+
except Exception:
|
628 |
+
log("analyze: medsam availability annotation error")
|
629 |
+
traceback.print_exc()
|
630 |
+
|
631 |
+
result["processing_time"] = round(time.time() - start_time, 3)
|
632 |
+
log(f"analyze: end in {result['processing_time']}s")
|
633 |
+
return result
|
634 |
+
except Exception:
|
635 |
+
log("analyze: fatal error")
|
636 |
+
traceback.print_exc()
|
637 |
+
return {"error": "Internal error in analyze"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
638 |
|
639 |
|
640 |
def analyze_with_medsam(base_result, image):
|
|
|
|
|
641 |
try:
|
642 |
+
log("analyze_with_medsam: start")
|
643 |
if not isinstance(base_result, dict):
|
644 |
return base_result, None
|
645 |
label = str(base_result.get("chart_type_label", "")).strip().lower()
|
646 |
if label != "medical image" or not _medsam.is_available():
|
647 |
+
log("analyze_with_medsam: skip (non-medical or MedSAM unavailable)")
|
648 |
return base_result, None
|
649 |
|
650 |
pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
|
651 |
if pil_img is None:
|
652 |
return base_result, None
|
653 |
|
|
|
654 |
img_path = image if isinstance(image, str) else None
|
655 |
if img_path is None:
|
656 |
tmp_path = "./_tmp_input_image.png"
|
|
|
661 |
segmentations = []
|
662 |
masks_for_overlay = []
|
663 |
|
664 |
+
# MedSAM over candidate boxes (original behavior)
|
665 |
+
cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=8, min_area=200)
|
666 |
+
log(f"analyze_with_medsam: candidate boxes={len(cand_bboxes)}")
|
667 |
+
for bbox in cand_bboxes:
|
668 |
+
m = _medsam.segment_with_box(bbox)
|
669 |
+
if m is None or not isinstance(m.get('mask'), np.ndarray):
|
670 |
+
continue
|
671 |
+
segmentations.append({
|
672 |
+
"mask": m['mask'].astype(np.uint8).tolist(),
|
673 |
+
"confidence": float(m.get('confidence', 1.0)),
|
674 |
+
"method": m.get("method", "medsam_box_auto")
|
675 |
+
})
|
676 |
+
masks_for_overlay.append(m)
|
|
|
|
|
|
|
677 |
|
678 |
W, H = pil_img.size
|
679 |
base_result["medsam"] = {
|
|
|
683 |
"segmentations": segmentations,
|
684 |
"num_segments": len(segmentations)
|
685 |
}
|
686 |
+
log(f"analyze_with_medsam: segments={len(segmentations)}")
|
687 |
|
688 |
overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None
|
689 |
+
log("analyze_with_medsam: end")
|
690 |
return base_result, overlay_img
|
691 |
+
except Exception:
|
692 |
+
log("analyze_with_medsam: fatal error")
|
693 |
+
traceback.print_exc()
|
694 |
return base_result, None
|
695 |
|
696 |
# === Gradio UI with API enhancements ===
|