Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import gradio as gr | |
from PIL import Image | |
import torch | |
import numpy as np | |
import cv2 | |
# Writable cache directory for HF downloads | |
HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/data/hf-cache") | |
try: | |
os.makedirs(HF_CACHE_DIR, exist_ok=True) | |
except Exception: | |
pass | |
# Add custom modules to path - try multiple possible locations | |
possible_paths = [ | |
"./custom_models", | |
"../custom_models", | |
"./Dense-Captioning-Platform/custom_models" | |
] | |
for path in possible_paths: | |
if os.path.exists(path): | |
sys.path.insert(0, os.path.abspath(path)) | |
break | |
# Add mmcv to path if it exists | |
if os.path.exists('./mmcv'): | |
sys.path.insert(0, os.path.abspath('./mmcv')) | |
print("β Added local mmcv to path") | |
# Import and register custom modules | |
try: | |
from custom_models import register | |
print("β Custom modules registered successfully") | |
except Exception as e: | |
print(f"β οΈ Warning: Could not register custom modules: {e}") | |
# ---------------------- | |
# Optional MedSAM integration | |
# ---------------------- | |
class MedSAMIntegrator: | |
def __init__(self): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.medsam_model = None | |
self.current_image = None | |
self.current_image_path = None | |
self.embedding = None | |
self._load_medsam_model() | |
def _ensure_segment_anything(self): | |
try: | |
import segment_anything # noqa: F401 | |
return True | |
except Exception as e: | |
print(f"β segment_anything not available: {e}. Install it in Dockerfile to enable MedSAM.") | |
return False | |
def _load_medsam_model(self): | |
try: | |
# Ensure library is present | |
if not self._ensure_segment_anything(): | |
print("MedSAM features disabled (segment_anything not available)") | |
return | |
from segment_anything import sam_model_registry as _reg | |
import torch as _torch | |
# Preferred local path in HF cache | |
medsam_ckpt_path = os.path.join(HF_CACHE_DIR, "medsam_vit_b.pth") | |
# If not present, fetch from HF Hub using provided repo or default | |
if not os.path.exists(medsam_ckpt_path): | |
try: | |
from huggingface_hub import hf_hub_download, list_repo_files | |
repo_id = os.environ.get("HF_MEDSAM_REPO", "Aniketg6/Fine-Tuned-MedSAM") | |
print(f"π Trying to download MedSAM checkpoint from {repo_id} ...") | |
files = list_repo_files(repo_id) | |
candidate = None | |
for f in files: | |
lf = f.lower() | |
if lf.endswith(".pth") or lf.endswith(".pt"): | |
candidate = f | |
break | |
if candidate is None: | |
candidate = "medsam_vit_b.pth" | |
ckpt_path = hf_hub_download(repo_id=repo_id, filename=candidate, cache_dir=HF_CACHE_DIR) | |
medsam_ckpt_path = ckpt_path | |
print(f"β Downloaded MedSAM checkpoint: {medsam_ckpt_path}") | |
except Exception as dl_err: | |
print(f"β Could not fetch MedSAM checkpoint from HF Hub: {dl_err}") | |
print("MedSAM features disabled (no checkpoint)") | |
return | |
# Load checkpoint | |
checkpoint = _torch.load(medsam_ckpt_path, map_location='cpu') | |
self.medsam_model = _reg["vit_b"](checkpoint=None) | |
self.medsam_model.load_state_dict(checkpoint) | |
self.medsam_model.to(self.device) | |
self.medsam_model.eval() | |
print("β MedSAM model loaded successfully") | |
except Exception as e: | |
print(f"β MedSAM model not available: {e}. MedSAM features disabled.") | |
def is_available(self): | |
return self.medsam_model is not None | |
def load_image(self, image_path, precomputed_embedding=None): | |
try: | |
from skimage import transform, io # local import to avoid hard dep if unused | |
img_np = io.imread(image_path) | |
if len(img_np.shape) == 2: | |
img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) | |
else: | |
img_3c = img_np | |
self.current_image = img_3c | |
self.current_image_path = image_path | |
if precomputed_embedding is not None: | |
if not self.set_precomputed_embedding(precomputed_embedding): | |
self.get_embeddings() | |
else: | |
self.get_embeddings() | |
return True | |
except Exception as e: | |
print(f"Error loading image for MedSAM: {e}") | |
return False | |
def get_embeddings(self): | |
if self.current_image is None or self.medsam_model is None: | |
return None | |
from skimage import transform | |
img_1024 = transform.resize( | |
self.current_image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True | |
).astype(np.uint8) | |
img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None) | |
img_1024_tensor = ( | |
torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(self.device) | |
) | |
self.embedding = self.medsam_model.image_encoder(img_1024_tensor) | |
return self.embedding | |
def set_precomputed_embedding(self, embedding_array): | |
try: | |
if isinstance(embedding_array, np.ndarray): | |
embedding_tensor = torch.tensor(embedding_array).to(self.device) | |
self.embedding = embedding_tensor | |
return True | |
return False | |
except Exception as e: | |
print(f"Error setting precomputed embedding: {e}") | |
return False | |
def medsam_inference(self, box_1024, height, width): | |
if self.embedding is None or self.medsam_model is None: | |
return None | |
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=self.embedding.device) | |
if len(box_torch.shape) == 2: | |
box_torch = box_torch[:, None, :] | |
sparse_embeddings, dense_embeddings = self.medsam_model.prompt_encoder( | |
points=None, boxes=box_torch, masks=None, | |
) | |
low_res_logits, _ = self.medsam_model.mask_decoder( | |
image_embeddings=self.embedding, | |
image_pe=self.medsam_model.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=False, | |
) | |
low_res_pred = torch.sigmoid(low_res_logits) | |
low_res_pred = torch.nn.functional.interpolate( | |
low_res_pred, size=(height, width), mode="bilinear", align_corners=False, | |
) | |
low_res_pred = low_res_pred.squeeze().cpu().numpy() | |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8) | |
return medsam_seg | |
def segment_with_box(self, bbox): | |
if self.embedding is None or self.current_image is None: | |
return None | |
try: | |
H, W, _ = self.current_image.shape | |
x1, y1, x2, y2 = bbox | |
x1 = max(0, min(int(x1), W - 1)) | |
y1 = max(0, min(int(y1), H - 1)) | |
x2 = max(0, min(int(x2), W - 1)) | |
y2 = max(0, min(int(y2), H - 1)) | |
if x2 <= x1: | |
x2 = min(x1 + 10, W - 1) | |
if y2 <= y1: | |
y2 = min(y1 + 10, H - 1) | |
box_np = np.array([[x1, y1, x2, y2]], dtype=float) | |
box_1024 = box_np / np.array([W, H, W, H]) * 1024.0 | |
medsam_mask = self.medsam_inference(box_1024, H, W) | |
if medsam_mask is not None: | |
return {"mask": medsam_mask, "confidence": 1.0, "method": "medsam_box"} | |
return None | |
except Exception as e: | |
print(f"Error in MedSAM box-based segmentation: {e}") | |
return None | |
# Single global instance | |
_medsam = MedSAMIntegrator() | |
def _extract_bboxes_from_mmdet_result(det_result): | |
"""Extract Nx4 xyxy bboxes from various MMDet result formats.""" | |
boxes = [] | |
try: | |
# MMDet 3.x: list of DetDataSample | |
if isinstance(det_result, list) and len(det_result) > 0: | |
sample = det_result[0] | |
if hasattr(sample, 'pred_instances'): | |
inst = sample.pred_instances | |
if hasattr(inst, 'bboxes'): | |
b = inst.bboxes | |
# mmengine structures may use .tensor for boxes | |
if hasattr(b, 'tensor'): | |
b = b.tensor | |
boxes = b.detach().cpu().numpy().tolist() | |
# Single DetDataSample | |
elif hasattr(det_result, 'pred_instances'): | |
inst = det_result.pred_instances | |
if hasattr(inst, 'bboxes'): | |
b = inst.bboxes | |
if hasattr(b, 'tensor'): | |
b = b.tensor | |
boxes = b.detach().cpu().numpy().tolist() | |
# MMDet 2.x: tuple of (bbox_result, segm_result) | |
elif isinstance(det_result, tuple) and len(det_result) >= 1: | |
bbox_result = det_result[0] | |
# bbox_result is list per class, each Nx5 [x1,y1,x2,y2,score] | |
if isinstance(bbox_result, (list, tuple)): | |
for arr in bbox_result: | |
try: | |
arr_np = np.array(arr) | |
if arr_np.ndim == 2 and arr_np.shape[1] >= 4: | |
boxes.extend(arr_np[:, :4].tolist()) | |
except Exception: | |
continue | |
except Exception as e: | |
print(f"Failed to parse MMDet result for boxes: {e}") | |
return boxes | |
def _overlay_masks_on_image(image_pil, mask_list, alpha=0.4): | |
"""Overlay binary masks on an image with random colors.""" | |
if image_pil is None or not mask_list: | |
return image_pil | |
img = np.array(image_pil.convert('RGB')) | |
overlay = img.copy() | |
for idx, m in enumerate(mask_list): | |
if m is None or 'mask' not in m or m['mask'] is None: | |
continue | |
mask = m['mask'].astype(bool) | |
color = np.random.RandomState(seed=idx + 1234).randint(0, 255, size=3) | |
overlay[mask] = (0.5 * overlay[mask] + 0.5 * color).astype(np.uint8) | |
blended = (alpha * overlay + (1 - alpha) * img).astype(np.uint8) | |
return Image.fromarray(blended) | |
def _mask_to_polygons(mask: np.ndarray): | |
"""Convert a binary mask (H,W) to a list of polygons ([[x,y], ...]) using OpenCV contours.""" | |
try: | |
mask_u8 = (mask.astype(np.uint8) * 255) | |
contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
polygons = [] | |
for cnt in contours: | |
if cnt is None or len(cnt) < 3: | |
continue | |
# Simplify contour slightly | |
epsilon = 0.002 * cv2.arcLength(cnt, True) | |
approx = cv2.approxPolyDP(cnt, epsilon, True) | |
poly = approx.reshape(-1, 2).tolist() | |
polygons.append(poly) | |
return polygons | |
except Exception as e: | |
print(f"_mask_to_polygons failed: {e}") | |
return [] | |
def _find_largest_foreground_bbox(pil_img: Image.Image): | |
"""Heuristic: find largest foreground region bbox via Otsu threshold on grayscale. | |
Returns [x1, y1, x2, y2] or full-image bbox if none found.""" | |
try: | |
img = np.array(pil_img.convert('RGB')) | |
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
# Otsu threshold (invert if needed by checking mean) | |
_, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
# Assume foreground is darker; invert if threshold yields background as white majority | |
if th.mean() > 127: | |
th = 255 - th | |
# Morph close to connect regions | |
kernel = np.ones((5, 5), np.uint8) | |
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2) | |
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if not contours: | |
W, H = pil_img.size | |
return [0, 0, W - 1, H - 1] | |
# Largest contour by area | |
cnt = max(contours, key=cv2.contourArea) | |
x, y, w, h = cv2.boundingRect(cnt) | |
# Pad a little | |
pad = int(0.02 * max(w, h)) | |
x1 = max(0, x - pad) | |
y1 = max(0, y - pad) | |
x2 = min(img.shape[1] - 1, x + w + pad) | |
y2 = min(img.shape[0] - 1, y + h + pad) | |
return [x1, y1, x2, y2] | |
except Exception as e: | |
print(f"_find_largest_foreground_bbox failed: {e}") | |
W, H = pil_img.size | |
return [0, 0, W - 1, H - 1] | |
def _find_topk_foreground_bboxes(pil_img: Image.Image, max_regions: int = 20, min_area: int = 100): | |
"""Find top-K foreground bboxes via Otsu threshold + morphology. Returns list of [x1,y1,x2,y2].""" | |
try: | |
img = np.array(pil_img.convert('RGB')) | |
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
_, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
if th.mean() > 127: | |
th = 255 - th | |
kernel = np.ones((3, 3), np.uint8) | |
th = cv2.morphologyEx(th, cv2.MORPH_OPEN, kernel, iterations=1) | |
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2) | |
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if not contours: | |
return [] | |
contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
bboxes = [] | |
H, W = img.shape[:2] | |
for cnt in contours: | |
area = cv2.contourArea(cnt) | |
if area < min_area: | |
continue | |
x, y, w, h = cv2.boundingRect(cnt) | |
# Filter very thin shapes | |
if w < 5 or h < 5: | |
continue | |
pad = int(0.01 * max(w, h)) | |
x1 = max(0, x - pad) | |
y1 = max(0, y - pad) | |
x2 = min(W - 1, x + w + pad) | |
y2 = min(H - 1, y + h + pad) | |
bboxes.append([x1, y1, x2, y2]) | |
if len(bboxes) >= max_regions: | |
break | |
return bboxes | |
except Exception as e: | |
print(f"_find_topk_foreground_bboxes failed: {e}") | |
return [] | |
# Try to import mmdet for inference | |
try: | |
from mmdet.apis import init_detector, inference_detector | |
MM_DET_AVAILABLE = True | |
print("β MMDetection available for inference") | |
except ImportError as e: | |
print(f"β οΈ MMDetection import failed: {e}") | |
print("β MMDetection not available - install in Dockerfile") | |
MM_DET_AVAILABLE = False | |
# === Chart Type Classification (DocFigure) === | |
print("π Loading Chart Classification Model...") | |
# Chart type labels from DocFigure dataset (28 classes) | |
CHART_TYPE_LABELS = [ | |
'Line graph', 'Natural image', 'Table', '3D object', 'Bar plot', 'Scatter plot', | |
'Medical image', 'Sketch', 'Geographic map', 'Flow chart', 'Heat map', 'Mask', | |
'Block diagram', 'Venn diagram', 'Confusion matrix', 'Histogram', 'Box plot', | |
'Vector plot', 'Pie chart', 'Surface plot', 'Algorithm', 'Contour plot', | |
'Tree diagram', 'Bubble chart', 'Polar plot', 'Area chart', 'Pareto chart', 'Radar chart' | |
] | |
try: | |
# Load the chart_type.pth model file from Hugging Face Hub | |
from huggingface_hub import hf_hub_download | |
from torchvision import transforms | |
print("π Downloading chart_type.pth from Hugging Face Hub...") | |
chart_type_path = hf_hub_download( | |
repo_id="hanszhu/ChartTypeNet-DocFigure", | |
filename="chart_type.pth", | |
cache_dir=HF_CACHE_DIR | |
) | |
print(f"β Downloaded to: {chart_type_path}") | |
# Load the PyTorch model | |
loaded_data = torch.load(chart_type_path, map_location='cpu') | |
# Check if it's a state dict or a complete model | |
if isinstance(loaded_data, dict): | |
# Check if it's a checkpoint with model_state_dict | |
if "model_state_dict" in loaded_data: | |
print("π Loading checkpoint, extracting model_state_dict...") | |
state_dict = loaded_data["model_state_dict"] | |
else: | |
# It's a direct state dict | |
print("π Loading state dict, creating model architecture...") | |
state_dict = loaded_data | |
# Strip "backbone." prefix from state dict keys if present | |
cleaned_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.startswith("backbone."): | |
# Remove "backbone." prefix | |
new_key = key[9:] | |
cleaned_state_dict[new_key] = value | |
else: | |
cleaned_state_dict[key] = value | |
print(f"π Cleaned state dict: {len(cleaned_state_dict)} keys") | |
# Create the model architecture | |
from torchvision.models import resnet50 | |
chart_type_model = resnet50(pretrained=False) | |
# Create the correct classifier structure to match the state dict | |
import torch.nn as nn | |
in_features = chart_type_model.fc.in_features | |
dropout = nn.Dropout(0.5) | |
chart_type_model.fc = nn.Sequential( | |
nn.Linear(in_features, 512), | |
nn.ReLU(inplace=True), | |
dropout, | |
nn.Linear(512, 28) | |
) | |
# Load the cleaned state dict | |
chart_type_model.load_state_dict(cleaned_state_dict) | |
else: | |
# It's a complete model | |
chart_type_model = loaded_data | |
chart_type_model.eval() | |
# Create a simple processor for the model | |
chart_type_processor = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
CHART_TYPE_AVAILABLE = True | |
print("β Chart classification model loaded") | |
except Exception as e: | |
print(f"β οΈ Failed to load chart classification model: {e}") | |
import traceback | |
print("π Full traceback:") | |
traceback.print_exc() | |
CHART_TYPE_AVAILABLE = False | |
# === Chart Element Detection (Cascade R-CNN) === | |
element_model = None | |
datapoint_model = None | |
print(f"π MM_DET_AVAILABLE: {MM_DET_AVAILABLE}") | |
if MM_DET_AVAILABLE: | |
# Check if config files exist | |
element_config = "models/chart_elementnet_swin.py" | |
point_config = "models/chart_pointnet_swin.py" | |
print(f"π Checking config files...") | |
print(f"π Element config exists: {os.path.exists(element_config)}") | |
print(f"π Point config exists: {os.path.exists(point_config)}") | |
print(f"π Current working directory: {os.getcwd()}") | |
print(f"π Files in models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}") | |
try: | |
print("π Loading ChartElementNet-MultiClass (Cascade R-CNN)...") | |
print(f"π Config path: {element_config}") | |
print(f"π Weights path: hanszhu/ChartElementNet-MultiClass") | |
print(f"π About to call init_detector...") | |
# Download model from Hugging Face Hub | |
from huggingface_hub import hf_hub_download | |
print("π Downloading ChartElementNet weights from Hugging Face Hub...") | |
element_checkpoint = hf_hub_download( | |
repo_id="hanszhu/ChartElementNet-MultiClass", | |
filename="chart_label+.pth", | |
cache_dir=HF_CACHE_DIR | |
) | |
print(f"β Downloaded to: {element_checkpoint}") | |
# Use local config with downloaded weights | |
element_model = init_detector(element_config, element_checkpoint, device="cpu") | |
print("β ChartElementNet loaded successfully") | |
except Exception as e: | |
print(f"β Failed to load ChartElementNet: {e}") | |
print(f"π Error type: {type(e).__name__}") | |
print(f"π Error details: {str(e)}") | |
import traceback | |
print("π Full traceback:") | |
traceback.print_exc() | |
try: | |
print("π Loading ChartPointNet-InstanceSeg (Mask R-CNN)...") | |
print(f"π Config path: {point_config}") | |
print(f"π Weights path: hanszhu/ChartPointNet-InstanceSeg") | |
print(f"π About to call init_detector...") | |
# Download model from Hugging Face Hub | |
print("π Downloading ChartPointNet weights from Hugging Face Hub...") | |
datapoint_checkpoint = hf_hub_download( | |
repo_id="hanszhu/ChartPointNet-InstanceSeg", | |
filename="chart_datapoint.pth", | |
cache_dir=HF_CACHE_DIR | |
) | |
print(f"β Downloaded to: {datapoint_checkpoint}") | |
# Use local config with downloaded weights | |
datapoint_model = init_detector(point_config, datapoint_checkpoint, device="cpu") | |
print("β ChartPointNet loaded successfully") | |
except Exception as e: | |
print(f"β Failed to load ChartPointNet: {e}") | |
print(f"π Error type: {type(e).__name__}") | |
print(f"π Error details: {str(e)}") | |
import traceback | |
print("π Full traceback:") | |
traceback.print_exc() | |
else: | |
print("β MMDetection not available - cannot load custom models") | |
print(f"π MM_DET_AVAILABLE was False") | |
print(f"π Final model status:") | |
print(f"π element_model: {element_model is not None}") | |
print(f"π datapoint_model: {datapoint_model is not None}") | |
# === Main prediction function === | |
def analyze(image): | |
""" | |
Analyze a chart image and return comprehensive results. | |
Args: | |
image: Input chart image (filepath string or PIL.Image) | |
Returns: | |
dict: Analysis results containing: | |
- chart_type_id (int): Numeric chart type identifier (0-27) | |
- chart_type_label (str): Human-readable chart type name | |
- element_result (str): Detected chart elements (titles, axes, legends, etc.) | |
- datapoint_result (str): Segmented data points and regions | |
- status (str): Processing status message | |
- processing_time (float): Time taken for analysis in seconds | |
""" | |
import time | |
from PIL import Image | |
start_time = time.time() | |
# Handle filepath input (convert to PIL Image) | |
if isinstance(image, str): | |
# It's a filepath, load the image | |
image = Image.open(image).convert("RGB") | |
elif image is None: | |
return {"error": "No image provided"} | |
# Ensure we have a PIL Image | |
if not isinstance(image, Image.Image): | |
return {"error": "Invalid image format"} | |
result = { | |
"chart_type_id": "Model not available", | |
"chart_type_label": "Model not available", | |
"element_result": "MMDetection models not available", | |
"datapoint_result": "MMDetection models not available", | |
"status": "Basic chart classification only", | |
"processing_time": 0.0, | |
"medsam": {"available": False} | |
} | |
# Chart Type Classification | |
if CHART_TYPE_AVAILABLE: | |
try: | |
# Preprocess image for PyTorch model | |
processed_image = chart_type_processor(image).unsqueeze(0) # Add batch dimension | |
# Get prediction | |
with torch.no_grad(): | |
outputs = chart_type_model(processed_image) | |
# Handle different output formats | |
if isinstance(outputs, torch.Tensor): | |
logits = outputs | |
elif hasattr(outputs, 'logits'): | |
logits = outputs.logits | |
else: | |
logits = outputs | |
predicted_class = logits.argmax(dim=-1).item() | |
result["chart_type_id"] = predicted_class | |
result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})" | |
result["status"] = "Chart classification completed" | |
except Exception as e: | |
result["chart_type_id"] = f"Error: {str(e)}" | |
result["chart_type_label"] = f"Error: {str(e)}" | |
result["status"] = "Error in chart classification" | |
# Chart Element Detection (Cascade R-CNN) | |
if element_model is not None: | |
try: | |
# Convert PIL image to numpy array for MMDetection | |
np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL β BGR | |
element_result = inference_detector(element_model, np_img) | |
# Convert result to more API-friendly format | |
if isinstance(element_result, tuple): | |
bbox_result, segm_result = element_result | |
element_data = { | |
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result), | |
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result) | |
} | |
else: | |
element_data = str(element_result) | |
result["element_result"] = element_data | |
result["status"] = "Chart classification + element detection completed" | |
except Exception as e: | |
result["element_result"] = f"Error: {str(e)}" | |
# Chart Data Point Segmentation (Mask R-CNN) | |
if datapoint_model is not None: | |
try: | |
# Convert PIL image to numpy array for MMDetection | |
np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL β BGR | |
datapoint_result = inference_detector(datapoint_model, np_img) | |
# Convert result to more API-friendly format | |
if isinstance(datapoint_result, tuple): | |
bbox_result, segm_result = datapoint_result | |
datapoint_data = { | |
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result), | |
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result) | |
} | |
else: | |
datapoint_data = str(datapoint_result) | |
result["datapoint_result"] = datapoint_data | |
result["status"] = "Full analysis completed" | |
except Exception as e: | |
result["datapoint_result"] = f"Error: {str(e)}" | |
# If predicted as medical image and MedSAM is available, include mask data (polygons) | |
try: | |
label_lower = str(result.get("chart_type_label", "")).strip().lower() | |
if label_lower == "medical image": | |
if _medsam.is_available(): | |
# Indicate availability; masks are generated in then-chain | |
result["medsam"] = {"available": True} | |
else: | |
# Not available; include reason | |
result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"} | |
except Exception as e: | |
print(f"MedSAM JSON augmentation failed: {e}") | |
result["processing_time"] = round(time.time() - start_time, 3) | |
return result | |
def analyze_with_medsam(base_result, image): | |
"""Auto-generate segmentations for medical images using SAM ViT-H if available, | |
otherwise fallback to MedSAM over top-K foreground boxes. Returns updated JSON and overlay image.""" | |
try: | |
if not isinstance(base_result, dict): | |
return base_result, None | |
label = str(base_result.get("chart_type_label", "")).strip().lower() | |
if label != "medical image" or not _medsam.is_available(): | |
return base_result, None | |
pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image | |
if pil_img is None: | |
return base_result, None | |
# Prepare embedding | |
img_path = image if isinstance(image, str) else None | |
if img_path is None: | |
tmp_path = "./_tmp_input_image.png" | |
pil_img.save(tmp_path) | |
img_path = tmp_path | |
_medsam.load_image(img_path) | |
segmentations = [] | |
masks_for_overlay = [] | |
# AUTO segmentation path | |
try: | |
# Follow original behavior: use MedSAM with box prompts; no SAM auto in main path | |
cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=8, min_area=200) | |
for bbox in cand_bboxes: | |
m = _medsam.segment_with_box(bbox) | |
if m is None or not isinstance(m.get('mask'), np.ndarray): | |
continue | |
segmentations.append({ | |
"mask": m['mask'].astype(np.uint8).tolist(), | |
"confidence": float(m.get('confidence', 1.0)), | |
"method": m.get("method", "medsam_box_auto") | |
}) | |
masks_for_overlay.append(m) | |
except Exception as auto_e: | |
print(f"Automatic MedSAM segmentation failed: {auto_e}") | |
W, H = pil_img.size | |
base_result["medsam"] = { | |
"available": True, | |
"height": H, | |
"width": W, | |
"segmentations": segmentations, | |
"num_segments": len(segmentations) | |
} | |
overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None | |
return base_result, overlay_img | |
except Exception as e: | |
print(f"analyze_with_medsam failed: {e}") | |
return base_result, None | |
# === Gradio UI with API enhancements === | |
# Create Blocks interface with explicit API name for stable API surface | |
with gr.Blocks( | |
title="π Dense Captioning Platform" | |
) as demo: | |
gr.Markdown("# π Dense Captioning Platform") | |
gr.Markdown(""" | |
**Comprehensive Chart Analysis API** | |
Upload a chart image to get: | |
- **Chart Type Classification**: Identifies the type of chart (line, bar, scatter, etc.) | |
- **Element Detection**: Detects chart elements like titles, axes, legends, data points | |
- **Data Point Segmentation**: Segments individual data points and regions | |
Masks will be automatically generated for medical images when supported. | |
**API Usage:** | |
```python | |
from gradio_client import Client, handle_file | |
client = Client("hanszhu/Dense-Captioning-Platform") | |
result = client.predict( | |
image=handle_file('path/to/your/chart.png'), | |
api_name="/predict" | |
) | |
print(result) | |
``` | |
**Supported Chart Types:** Line graphs, Bar plots, Scatter plots, Pie charts, Heat maps, and 23+ more | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Input | |
image_input = gr.Image( | |
type="filepath", # β REQUIRED for gradio_client | |
label="Upload Chart Image", | |
height=400, | |
elem_id="image-input" | |
) | |
# Analyze button (single) | |
analyze_btn = gr.Button( | |
"π Analyze", | |
variant="primary", | |
size="lg", | |
elem_id="analyze-btn" | |
) | |
with gr.Column(): | |
# Output JSON | |
result_output = gr.JSON( | |
label="Analysis Results", | |
height=400, | |
elem_id="result-output" | |
) | |
# Overlay image output (populated only for medical images) | |
overlay_output = gr.Image( | |
label="MedSAM Overlay (Medical images)", | |
height=400, | |
elem_id="overlay-output" | |
) | |
# Single API endpoint for JSON | |
analyze_event = analyze_btn.click( | |
fn=analyze, | |
inputs=image_input, | |
outputs=result_output, | |
api_name="/predict" # β Standard API name that gradio_client expects | |
) | |
# Automatic overlay generation step for medical images | |
analyze_event.then( | |
fn=analyze_with_medsam, | |
inputs=[result_output, image_input], | |
outputs=[result_output, overlay_output], | |
) | |
# Add some examples | |
gr.Examples( | |
examples=[ | |
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"] | |
], | |
inputs=image_input, | |
label="Try with this example" | |
) | |
# Launch with API-friendly settings | |
if __name__ == "__main__": | |
launch_kwargs = { | |
"server_name": "0.0.0.0", # Allow external connections | |
"server_port": 7860, | |
"share": False, # Set to True if you want a public link | |
"show_error": True, # Show detailed errors for debugging | |
"quiet": False, # Show startup messages | |
"show_api": True, # Enable API documentation | |
"ssr_mode": False # Disable experimental SSR in Docker env | |
} | |
# Lightweight keepalive (self-ping) to avoid idle shutdowns | |
try: | |
import threading, time, requests | |
def _keepalive(): | |
url = "http://127.0.0.1:7860/" | |
while True: | |
try: | |
requests.get(url, timeout=3) | |
except Exception: | |
pass | |
time.sleep(60) | |
threading.Thread(target=_keepalive, daemon=True).start() | |
except Exception: | |
pass | |
# Enable queue for gradio_client compatibility | |
demo.queue().launch(**launch_kwargs) # β required for gradio_client to work |