hanszhu commited on
Commit
c4937ea
·
1 Parent(s): e25aace

chore(logging): add per-request timestamped logs and exception tracebacks for analyze paths

Browse files
Files changed (1) hide show
  1. app.py +119 -148
app.py CHANGED
@@ -5,6 +5,12 @@ from PIL import Image
5
  import torch
6
  import numpy as np
7
  import cv2
 
 
 
 
 
 
8
 
9
  # Writable cache directory for HF downloads
10
  HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/data/hf-cache")
@@ -533,153 +539,118 @@ print(f"🔍 datapoint_model: {datapoint_model is not None}")
533
 
534
  # === Main prediction function ===
535
  def analyze(image):
536
- """
537
- Analyze a chart image and return comprehensive results.
538
-
539
- Args:
540
- image: Input chart image (filepath string or PIL.Image)
541
-
542
- Returns:
543
- dict: Analysis results containing:
544
- - chart_type_id (int): Numeric chart type identifier (0-27)
545
- - chart_type_label (str): Human-readable chart type name
546
- - element_result (str): Detected chart elements (titles, axes, legends, etc.)
547
- - datapoint_result (str): Segmented data points and regions
548
- - status (str): Processing status message
549
- - processing_time (float): Time taken for analysis in seconds
550
- """
551
- import time
552
- from PIL import Image
553
-
554
- start_time = time.time()
555
-
556
- # Handle filepath input (convert to PIL Image)
557
- if isinstance(image, str):
558
- # It's a filepath, load the image
559
- image = Image.open(image).convert("RGB")
560
- elif image is None:
561
- return {"error": "No image provided"}
562
-
563
- # Ensure we have a PIL Image
564
- if not isinstance(image, Image.Image):
565
- return {"error": "Invalid image format"}
566
-
567
- result = {
568
- "chart_type_id": "Model not available",
569
- "chart_type_label": "Model not available",
570
- "element_result": "MMDetection models not available",
571
- "datapoint_result": "MMDetection models not available",
572
- "status": "Basic chart classification only",
573
- "processing_time": 0.0,
574
- "medsam": {"available": False}
575
- }
576
 
577
- # Chart Type Classification
578
- if CHART_TYPE_AVAILABLE:
579
- try:
580
- # Preprocess image for PyTorch model
581
- processed_image = chart_type_processor(image).unsqueeze(0) # Add batch dimension
582
-
583
- # Get prediction
584
- with torch.no_grad():
585
- outputs = chart_type_model(processed_image)
586
- # Handle different output formats
587
- if isinstance(outputs, torch.Tensor):
588
- logits = outputs
589
- elif hasattr(outputs, 'logits'):
590
- logits = outputs.logits
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  else:
592
- logits = outputs
593
-
594
- predicted_class = logits.argmax(dim=-1).item()
595
-
596
- result["chart_type_id"] = predicted_class
597
- result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})"
598
- result["status"] = "Chart classification completed"
599
-
600
- except Exception as e:
601
- result["chart_type_id"] = f"Error: {str(e)}"
602
- result["chart_type_label"] = f"Error: {str(e)}"
603
- result["status"] = "Error in chart classification"
604
-
605
- # Chart Element Detection (Cascade R-CNN)
606
- if element_model is not None:
607
- try:
608
- # Convert PIL image to numpy array for MMDetection
609
- np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL BGR
610
-
611
- element_result = inference_detector(element_model, np_img)
612
-
613
- # Convert result to more API-friendly format
614
- if isinstance(element_result, tuple):
615
- bbox_result, segm_result = element_result
616
- element_data = {
617
- "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
618
- "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
619
- }
620
- else:
621
- element_data = str(element_result)
622
-
623
- result["element_result"] = element_data
624
- result["status"] = "Chart classification + element detection completed"
625
- except Exception as e:
626
- result["element_result"] = f"Error: {str(e)}"
627
-
628
- # Chart Data Point Segmentation (Mask R-CNN)
629
- if datapoint_model is not None:
630
  try:
631
- # Convert PIL image to numpy array for MMDetection
632
- np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL → BGR
633
-
634
- datapoint_result = inference_detector(datapoint_model, np_img)
635
-
636
- # Convert result to more API-friendly format
637
- if isinstance(datapoint_result, tuple):
638
- bbox_result, segm_result = datapoint_result
639
- datapoint_data = {
640
- "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
641
- "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
642
- }
643
- else:
644
- datapoint_data = str(datapoint_result)
645
-
646
- result["datapoint_result"] = datapoint_data
647
- result["status"] = "Full analysis completed"
648
- except Exception as e:
649
- result["datapoint_result"] = f"Error: {str(e)}"
650
-
651
- # If predicted as medical image and MedSAM is available, include mask data (polygons)
652
- try:
653
- label_lower = str(result.get("chart_type_label", "")).strip().lower()
654
- if label_lower == "medical image":
655
- if _medsam.is_available():
656
- # Indicate availability; masks are generated in then-chain
657
- result["medsam"] = {"available": True}
658
- else:
659
- # Not available; include reason
660
- result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"}
661
- except Exception as e:
662
- print(f"MedSAM JSON augmentation failed: {e}")
663
-
664
- result["processing_time"] = round(time.time() - start_time, 3)
665
- return result
666
 
667
 
668
  def analyze_with_medsam(base_result, image):
669
- """Auto-generate segmentations for medical images using SAM ViT-H if available,
670
- otherwise fallback to MedSAM over top-K foreground boxes. Returns updated JSON and overlay image."""
671
  try:
 
672
  if not isinstance(base_result, dict):
673
  return base_result, None
674
  label = str(base_result.get("chart_type_label", "")).strip().lower()
675
  if label != "medical image" or not _medsam.is_available():
 
676
  return base_result, None
677
 
678
  pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
679
  if pil_img is None:
680
  return base_result, None
681
 
682
- # Prepare embedding
683
  img_path = image if isinstance(image, str) else None
684
  if img_path is None:
685
  tmp_path = "./_tmp_input_image.png"
@@ -690,22 +661,19 @@ def analyze_with_medsam(base_result, image):
690
  segmentations = []
691
  masks_for_overlay = []
692
 
693
- # AUTO segmentation path
694
- try:
695
- # Follow original behavior: use MedSAM with box prompts; no SAM auto in main path
696
- cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=8, min_area=200)
697
- for bbox in cand_bboxes:
698
- m = _medsam.segment_with_box(bbox)
699
- if m is None or not isinstance(m.get('mask'), np.ndarray):
700
- continue
701
- segmentations.append({
702
- "mask": m['mask'].astype(np.uint8).tolist(),
703
- "confidence": float(m.get('confidence', 1.0)),
704
- "method": m.get("method", "medsam_box_auto")
705
- })
706
- masks_for_overlay.append(m)
707
- except Exception as auto_e:
708
- print(f"Automatic MedSAM segmentation failed: {auto_e}")
709
 
710
  W, H = pil_img.size
711
  base_result["medsam"] = {
@@ -715,11 +683,14 @@ def analyze_with_medsam(base_result, image):
715
  "segmentations": segmentations,
716
  "num_segments": len(segmentations)
717
  }
 
718
 
719
  overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None
 
720
  return base_result, overlay_img
721
- except Exception as e:
722
- print(f"analyze_with_medsam failed: {e}")
 
723
  return base_result, None
724
 
725
  # === Gradio UI with API enhancements ===
 
5
  import torch
6
  import numpy as np
7
  import cv2
8
+ import time
9
+ import traceback
10
+
11
+ # Simple timestamped logger
12
+ def log(msg: str) -> None:
13
+ print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
14
 
15
  # Writable cache directory for HF downloads
16
  HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/data/hf-cache")
 
539
 
540
  # === Main prediction function ===
541
  def analyze(image):
542
+ try:
543
+ log("analyze: start")
544
+ start_time = time.time()
545
+ # Handle filepath input
546
+ if isinstance(image, str):
547
+ image = Image.open(image).convert("RGB")
548
+ elif image is None:
549
+ return {"error": "No image provided"}
550
+ if not isinstance(image, Image.Image):
551
+ return {"error": "Invalid image format"}
552
+
553
+ result = {
554
+ "chart_type_id": "Model not available",
555
+ "chart_type_label": "Model not available",
556
+ "element_result": "MMDetection models not available",
557
+ "datapoint_result": "MMDetection models not available",
558
+ "status": "Basic chart classification only",
559
+ "processing_time": 0.0,
560
+ "medsam": {"available": False}
561
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
+ # Chart Type Classification
564
+ if CHART_TYPE_AVAILABLE:
565
+ try:
566
+ processed_image = chart_type_processor(image).unsqueeze(0)
567
+ with torch.no_grad():
568
+ outputs = chart_type_model(processed_image)
569
+ logits = outputs if isinstance(outputs, torch.Tensor) else getattr(outputs, 'logits', outputs)
570
+ predicted_class = logits.argmax(dim=-1).item()
571
+ result["chart_type_id"] = predicted_class
572
+ result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})"
573
+ result["status"] = "Chart classification completed"
574
+ log(f"analyze: chart_type={result['chart_type_label']} ({result['chart_type_id']})")
575
+ except Exception:
576
+ log("analyze: chart classification error")
577
+ traceback.print_exc()
578
+
579
+ # Element Detection
580
+ if element_model is not None:
581
+ try:
582
+ np_img = np.array(image.convert("RGB"))[:, :, ::-1]
583
+ element_result = inference_detector(element_model, np_img)
584
+ if isinstance(element_result, tuple):
585
+ bbox_result, segm_result = element_result
586
+ element_data = {
587
+ "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
588
+ "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
589
+ }
590
  else:
591
+ element_data = str(element_result)
592
+ result["element_result"] = element_data
593
+ result["status"] = "Chart classification + element detection completed"
594
+ log("analyze: element detection done")
595
+ except Exception:
596
+ log("analyze: element detection error")
597
+ traceback.print_exc()
598
+
599
+ # Datapoint Segmentation
600
+ if datapoint_model is not None:
601
+ try:
602
+ np_img = np.array(image.convert("RGB"))[:, :, ::-1]
603
+ datapoint_result = inference_detector(datapoint_model, np_img)
604
+ if isinstance(datapoint_result, tuple):
605
+ bbox_result, segm_result = datapoint_result
606
+ datapoint_data = {
607
+ "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
608
+ "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
609
+ }
610
+ else:
611
+ datapoint_data = str(datapoint_result)
612
+ result["datapoint_result"] = datapoint_data
613
+ result["status"] = "Full analysis completed"
614
+ log("analyze: datapoint segmentation done")
615
+ except Exception:
616
+ log("analyze: datapoint segmentation error")
617
+ traceback.print_exc()
618
+
619
+ # MedSAM availability info
 
 
 
 
 
 
 
 
 
620
  try:
621
+ label_lower = str(result.get("chart_type_label", "")).strip().lower()
622
+ if label_lower == "medical image":
623
+ if _medsam.is_available():
624
+ result["medsam"] = {"available": True}
625
+ else:
626
+ result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"}
627
+ except Exception:
628
+ log("analyze: medsam availability annotation error")
629
+ traceback.print_exc()
630
+
631
+ result["processing_time"] = round(time.time() - start_time, 3)
632
+ log(f"analyze: end in {result['processing_time']}s")
633
+ return result
634
+ except Exception:
635
+ log("analyze: fatal error")
636
+ traceback.print_exc()
637
+ return {"error": "Internal error in analyze"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
 
639
 
640
  def analyze_with_medsam(base_result, image):
 
 
641
  try:
642
+ log("analyze_with_medsam: start")
643
  if not isinstance(base_result, dict):
644
  return base_result, None
645
  label = str(base_result.get("chart_type_label", "")).strip().lower()
646
  if label != "medical image" or not _medsam.is_available():
647
+ log("analyze_with_medsam: skip (non-medical or MedSAM unavailable)")
648
  return base_result, None
649
 
650
  pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
651
  if pil_img is None:
652
  return base_result, None
653
 
 
654
  img_path = image if isinstance(image, str) else None
655
  if img_path is None:
656
  tmp_path = "./_tmp_input_image.png"
 
661
  segmentations = []
662
  masks_for_overlay = []
663
 
664
+ # MedSAM over candidate boxes (original behavior)
665
+ cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=8, min_area=200)
666
+ log(f"analyze_with_medsam: candidate boxes={len(cand_bboxes)}")
667
+ for bbox in cand_bboxes:
668
+ m = _medsam.segment_with_box(bbox)
669
+ if m is None or not isinstance(m.get('mask'), np.ndarray):
670
+ continue
671
+ segmentations.append({
672
+ "mask": m['mask'].astype(np.uint8).tolist(),
673
+ "confidence": float(m.get('confidence', 1.0)),
674
+ "method": m.get("method", "medsam_box_auto")
675
+ })
676
+ masks_for_overlay.append(m)
 
 
 
677
 
678
  W, H = pil_img.size
679
  base_result["medsam"] = {
 
683
  "segmentations": segmentations,
684
  "num_segments": len(segmentations)
685
  }
686
+ log(f"analyze_with_medsam: segments={len(segmentations)}")
687
 
688
  overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None
689
+ log("analyze_with_medsam: end")
690
  return base_result, overlay_img
691
+ except Exception:
692
+ log("analyze_with_medsam: fatal error")
693
+ traceback.print_exc()
694
  return base_result, None
695
 
696
  # === Gradio UI with API enhancements ===