hanszhu commited on
Commit
659aa40
·
1 Parent(s): a39d1c3

chore(space): switch to Docker SDK; add Dockerfile; minimal FastAPI app; trim requirements

Browse files
Files changed (4) hide show
  1. Dockerfile +26 -0
  2. README.md +1 -3
  3. app.py +5 -881
  4. requirements.txt +2 -12
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ libgl1 libglib2.0-0 git libsm6 libxext6 libxrender1 \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ ENV PIP_NO_CACHE_DIR=1 \
8
+ MPLBACKEND=Agg \
9
+ MIM_IGNORE_INSTALL_PYTORCH=1
10
+
11
+ WORKDIR /app
12
+
13
+ COPY requirements.txt /app/requirements.txt
14
+
15
+ RUN python -m pip install --upgrade pip wheel setuptools openmim \
16
+ && pip install --no-cache-dir -r requirements.txt \
17
+ && pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cpu torch==2.1.0 torchvision==0.16.0 \
18
+ && mim install "mmengine==0.10.4" \
19
+ && mim install "mmcv==2.1.0" \
20
+ && mim install "mmdet==3.3.0"
21
+
22
+ COPY . /app
23
+
24
+ EXPOSE 7860
25
+
26
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -3,9 +3,7 @@ title: Dense Captioning Platform
3
  emoji: 🐢
4
  colorFrom: purple
5
  colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.38.2
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
3
  emoji: 🐢
4
  colorFrom: purple
5
  colorTo: purple
6
+ sdk: docker
 
 
7
  pinned: false
8
  license: apache-2.0
9
  ---
app.py CHANGED
@@ -1,884 +1,8 @@
1
- import os
2
- import sys
3
- import gradio as gr
4
- from PIL import Image
5
- import torch
6
- import numpy as np
7
- import cv2
8
 
9
- # Add custom modules to path - try multiple possible locations
10
- possible_paths = [
11
- "./custom_models",
12
- "../custom_models",
13
- "./Dense-Captioning-Platform/custom_models"
14
- ]
15
 
16
- for path in possible_paths:
17
- if os.path.exists(path):
18
- sys.path.insert(0, os.path.abspath(path))
19
- break
20
-
21
- # Add mmcv to path if it exists
22
- if os.path.exists('./mmcv'):
23
- sys.path.insert(0, os.path.abspath('./mmcv'))
24
- print("✅ Added local mmcv to path")
25
-
26
- # Import and register custom modules
27
- try:
28
- from custom_models import register
29
- print("✅ Custom modules registered successfully")
30
- except Exception as e:
31
- print(f"⚠️ Warning: Could not register custom modules: {e}")
32
-
33
- # ----------------------
34
- # Optional MedSAM integration
35
- # ----------------------
36
- class MedSAMIntegrator:
37
- def __init__(self):
38
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
- self.medsam_model = None
40
- self.current_image = None
41
- self.current_image_path = None
42
- self.embedding = None
43
- self._load_medsam_model()
44
-
45
- def _ensure_segment_anything(self):
46
- try:
47
- import segment_anything # noqa: F401
48
- return True
49
- except Exception as e:
50
- print(f"⚠ segment_anything not available: {e}. Attempting install from Git...")
51
- try:
52
- import subprocess, sys
53
- subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
54
- import segment_anything # noqa: F401
55
- print("✓ segment_anything installed")
56
- return True
57
- except Exception as install_err:
58
- print(f"❌ Failed to install segment_anything: {install_err}")
59
- return False
60
-
61
- def _load_medsam_model(self):
62
- try:
63
- # Ensure library is present
64
- if not self._ensure_segment_anything():
65
- print("MedSAM features disabled (segment_anything not available)")
66
- return
67
-
68
- from segment_anything import sam_model_registry as _reg
69
- import torch as _torch
70
-
71
- # Preferred local path
72
- medsam_ckpt_path = "models/medsam_vit_b.pth"
73
-
74
- # If not present, fetch from HF Hub using provided repo or default
75
- if not os.path.exists(medsam_ckpt_path):
76
- try:
77
- from huggingface_hub import hf_hub_download, list_repo_files
78
- repo_id = os.environ.get("HF_MEDSAM_REPO", "Aniketg6/Fine-Tuned-MedSAM")
79
- # Try to find a .pth/.pt in the repo
80
- print(f"🔄 Trying to download MedSAM checkpoint from {repo_id} ...")
81
- files = list_repo_files(repo_id)
82
- candidate = None
83
- for f in files:
84
- lf = f.lower()
85
- if lf.endswith(".pth") or lf.endswith(".pt"):
86
- candidate = f
87
- break
88
- if candidate is None:
89
- # Fallback to a common name
90
- candidate = "medsam_vit_b.pth"
91
- ckpt_path = hf_hub_download(repo_id=repo_id, filename=candidate, cache_dir="./models")
92
- medsam_ckpt_path = ckpt_path
93
- print(f"✅ Downloaded MedSAM checkpoint: {medsam_ckpt_path}")
94
- except Exception as dl_err:
95
- print(f"⚠ Could not fetch MedSAM checkpoint from HF Hub: {dl_err}")
96
- print("MedSAM features disabled (no checkpoint)")
97
- return
98
-
99
- # Load checkpoint
100
- checkpoint = _torch.load(medsam_ckpt_path, map_location='cpu')
101
- self.medsam_model = _reg["vit_b"](checkpoint=None)
102
- self.medsam_model.load_state_dict(checkpoint)
103
- self.medsam_model.to(self.device)
104
- self.medsam_model.eval()
105
- print("✓ MedSAM model loaded successfully")
106
- except Exception as e:
107
- print(f"⚠ MedSAM model not available: {e}. MedSAM features disabled.")
108
-
109
- def is_available(self):
110
- return self.medsam_model is not None
111
-
112
- def load_image(self, image_path, precomputed_embedding=None):
113
- try:
114
- from skimage import transform, io # local import to avoid hard dep if unused
115
- img_np = io.imread(image_path)
116
- if len(img_np.shape) == 2:
117
- img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
118
- else:
119
- img_3c = img_np
120
- self.current_image = img_3c
121
- self.current_image_path = image_path
122
- if precomputed_embedding is not None:
123
- if not self.set_precomputed_embedding(precomputed_embedding):
124
- self.get_embeddings()
125
- else:
126
- self.get_embeddings()
127
- return True
128
- except Exception as e:
129
- print(f"Error loading image for MedSAM: {e}")
130
- return False
131
-
132
- @torch.no_grad()
133
- def get_embeddings(self):
134
- if self.current_image is None or self.medsam_model is None:
135
- return None
136
- from skimage import transform
137
- img_1024 = transform.resize(
138
- self.current_image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
139
- ).astype(np.uint8)
140
- img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None)
141
- img_1024_tensor = (
142
- torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(self.device)
143
- )
144
- self.embedding = self.medsam_model.image_encoder(img_1024_tensor)
145
- return self.embedding
146
-
147
- def set_precomputed_embedding(self, embedding_array):
148
- try:
149
- if isinstance(embedding_array, np.ndarray):
150
- embedding_tensor = torch.tensor(embedding_array).to(self.device)
151
- self.embedding = embedding_tensor
152
- return True
153
- return False
154
- except Exception as e:
155
- print(f"Error setting precomputed embedding: {e}")
156
- return False
157
-
158
- @torch.no_grad()
159
- def medsam_inference(self, box_1024, height, width):
160
- if self.embedding is None or self.medsam_model is None:
161
- return None
162
- box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=self.embedding.device)
163
- if len(box_torch.shape) == 2:
164
- box_torch = box_torch[:, None, :]
165
- sparse_embeddings, dense_embeddings = self.medsam_model.prompt_encoder(
166
- points=None, boxes=box_torch, masks=None,
167
- )
168
- low_res_logits, _ = self.medsam_model.mask_decoder(
169
- image_embeddings=self.embedding,
170
- image_pe=self.medsam_model.prompt_encoder.get_dense_pe(),
171
- sparse_prompt_embeddings=sparse_embeddings,
172
- dense_prompt_embeddings=dense_embeddings,
173
- multimask_output=False,
174
- )
175
- low_res_pred = torch.sigmoid(low_res_logits)
176
- low_res_pred = torch.nn.functional.interpolate(
177
- low_res_pred, size=(height, width), mode="bilinear", align_corners=False,
178
- )
179
- low_res_pred = low_res_pred.squeeze().cpu().numpy()
180
- medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
181
- return medsam_seg
182
-
183
- def segment_with_box(self, bbox):
184
- if self.embedding is None or self.current_image is None:
185
- return None
186
- try:
187
- H, W, _ = self.current_image.shape
188
- x1, y1, x2, y2 = bbox
189
- x1 = max(0, min(int(x1), W - 1))
190
- y1 = max(0, min(int(y1), H - 1))
191
- x2 = max(0, min(int(x2), W - 1))
192
- y2 = max(0, min(int(y2), H - 1))
193
- if x2 <= x1:
194
- x2 = min(x1 + 10, W - 1)
195
- if y2 <= y1:
196
- y2 = min(y1 + 10, H - 1)
197
- box_np = np.array([[x1, y1, x2, y2]], dtype=float)
198
- box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
199
- medsam_mask = self.medsam_inference(box_1024, H, W)
200
- if medsam_mask is not None:
201
- return {"mask": medsam_mask, "confidence": 1.0, "method": "medsam_box"}
202
- return None
203
- except Exception as e:
204
- print(f"Error in MedSAM box-based segmentation: {e}")
205
- return None
206
-
207
- # Single global instance
208
- _medsam = MedSAMIntegrator()
209
-
210
-
211
- def _extract_bboxes_from_mmdet_result(det_result):
212
- """Extract Nx4 xyxy bboxes from various MMDet result formats."""
213
- boxes = []
214
- try:
215
- # MMDet 3.x: list of DetDataSample
216
- if isinstance(det_result, list) and len(det_result) > 0:
217
- sample = det_result[0]
218
- if hasattr(sample, 'pred_instances'):
219
- inst = sample.pred_instances
220
- if hasattr(inst, 'bboxes'):
221
- b = inst.bboxes
222
- # mmengine structures may use .tensor for boxes
223
- if hasattr(b, 'tensor'):
224
- b = b.tensor
225
- boxes = b.detach().cpu().numpy().tolist()
226
- # Single DetDataSample
227
- elif hasattr(det_result, 'pred_instances'):
228
- inst = det_result.pred_instances
229
- if hasattr(inst, 'bboxes'):
230
- b = inst.bboxes
231
- if hasattr(b, 'tensor'):
232
- b = b.tensor
233
- boxes = b.detach().cpu().numpy().tolist()
234
- # MMDet 2.x: tuple of (bbox_result, segm_result)
235
- elif isinstance(det_result, tuple) and len(det_result) >= 1:
236
- bbox_result = det_result[0]
237
- # bbox_result is list per class, each Nx5 [x1,y1,x2,y2,score]
238
- if isinstance(bbox_result, (list, tuple)):
239
- for arr in bbox_result:
240
- try:
241
- arr_np = np.array(arr)
242
- if arr_np.ndim == 2 and arr_np.shape[1] >= 4:
243
- boxes.extend(arr_np[:, :4].tolist())
244
- except Exception:
245
- continue
246
- except Exception as e:
247
- print(f"Failed to parse MMDet result for boxes: {e}")
248
- return boxes
249
-
250
-
251
- def _overlay_masks_on_image(image_pil, mask_list, alpha=0.4):
252
- """Overlay binary masks on an image with random colors."""
253
- if image_pil is None or not mask_list:
254
- return image_pil
255
- img = np.array(image_pil.convert('RGB'))
256
- overlay = img.copy()
257
- for idx, m in enumerate(mask_list):
258
- if m is None or 'mask' not in m or m['mask'] is None:
259
- continue
260
- mask = m['mask'].astype(bool)
261
- color = np.random.RandomState(seed=idx + 1234).randint(0, 255, size=3)
262
- overlay[mask] = (0.5 * overlay[mask] + 0.5 * color).astype(np.uint8)
263
- blended = (alpha * overlay + (1 - alpha) * img).astype(np.uint8)
264
- return Image.fromarray(blended)
265
-
266
-
267
- def _mask_to_polygons(mask: np.ndarray):
268
- """Convert a binary mask (H,W) to a list of polygons ([[x,y], ...]) using OpenCV contours."""
269
- try:
270
- mask_u8 = (mask.astype(np.uint8) * 255)
271
- contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
272
- polygons = []
273
- for cnt in contours:
274
- if cnt is None or len(cnt) < 3:
275
- continue
276
- # Simplify contour slightly
277
- epsilon = 0.002 * cv2.arcLength(cnt, True)
278
- approx = cv2.approxPolyDP(cnt, epsilon, True)
279
- poly = approx.reshape(-1, 2).tolist()
280
- polygons.append(poly)
281
- return polygons
282
- except Exception as e:
283
- print(f"_mask_to_polygons failed: {e}")
284
- return []
285
-
286
-
287
- def _find_largest_foreground_bbox(pil_img: Image.Image):
288
- """Heuristic: find largest foreground region bbox via Otsu threshold on grayscale.
289
- Returns [x1, y1, x2, y2] or full-image bbox if none found."""
290
- try:
291
- img = np.array(pil_img.convert('RGB'))
292
- gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
293
- # Otsu threshold (invert if needed by checking mean)
294
- _, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
295
- # Assume foreground is darker; invert if threshold yields background as white majority
296
- if th.mean() > 127:
297
- th = 255 - th
298
- # Morph close to connect regions
299
- kernel = np.ones((5, 5), np.uint8)
300
- th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
301
- contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
302
- if not contours:
303
- W, H = pil_img.size
304
- return [0, 0, W - 1, H - 1]
305
- # Largest contour by area
306
- cnt = max(contours, key=cv2.contourArea)
307
- x, y, w, h = cv2.boundingRect(cnt)
308
- # Pad a little
309
- pad = int(0.02 * max(w, h))
310
- x1 = max(0, x - pad)
311
- y1 = max(0, y - pad)
312
- x2 = min(img.shape[1] - 1, x + w + pad)
313
- y2 = min(img.shape[0] - 1, y + h + pad)
314
- return [x1, y1, x2, y2]
315
- except Exception as e:
316
- print(f"_find_largest_foreground_bbox failed: {e}")
317
- W, H = pil_img.size
318
- return [0, 0, W - 1, H - 1]
319
-
320
-
321
- def _find_topk_foreground_bboxes(pil_img: Image.Image, max_regions: int = 20, min_area: int = 100):
322
- """Find top-K foreground bboxes via Otsu threshold + morphology. Returns list of [x1,y1,x2,y2]."""
323
- try:
324
- img = np.array(pil_img.convert('RGB'))
325
- gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
326
- _, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
327
- if th.mean() > 127:
328
- th = 255 - th
329
- kernel = np.ones((3, 3), np.uint8)
330
- th = cv2.morphologyEx(th, cv2.MORPH_OPEN, kernel, iterations=1)
331
- th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
332
- contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
333
- if not contours:
334
- return []
335
- contours = sorted(contours, key=cv2.contourArea, reverse=True)
336
- bboxes = []
337
- H, W = img.shape[:2]
338
- for cnt in contours:
339
- area = cv2.contourArea(cnt)
340
- if area < min_area:
341
- continue
342
- x, y, w, h = cv2.boundingRect(cnt)
343
- # Filter very thin shapes
344
- if w < 5 or h < 5:
345
- continue
346
- pad = int(0.01 * max(w, h))
347
- x1 = max(0, x - pad)
348
- y1 = max(0, y - pad)
349
- x2 = min(W - 1, x + w + pad)
350
- y2 = min(H - 1, y + h + pad)
351
- bboxes.append([x1, y1, x2, y2])
352
- if len(bboxes) >= max_regions:
353
- break
354
- return bboxes
355
- except Exception as e:
356
- print(f"_find_topk_foreground_bboxes failed: {e}")
357
- return []
358
-
359
- # Try to import mmdet for inference
360
- try:
361
- from mmdet.apis import init_detector, inference_detector
362
- MM_DET_AVAILABLE = True
363
- print("✅ MMDetection available for inference")
364
- except ImportError as e:
365
- print(f"⚠️ MMDetection import failed: {e}")
366
- print("🔄 Attempting to install MMDetection dependencies...")
367
- try:
368
- import subprocess
369
- import sys
370
-
371
- # Use the working solution with mim install
372
- print("🔄 Installing MMDetection dependencies with mim...")
373
-
374
- # Install openmim if not already installed
375
- subprocess.check_call([sys.executable, "-m", "pip", "install", "openmim"])
376
-
377
- # Install mmengine
378
- subprocess.check_call([sys.executable, "-m", "mim", "install", "mmengine"])
379
-
380
- # Install mmcv with mim (this handles compilation properly)
381
- subprocess.check_call([sys.executable, "-m", "mim", "install", "mmcv==2.1.0"])
382
-
383
- # Install mmdet
384
- subprocess.check_call([sys.executable, "-m", "mim", "install", "mmdet"])
385
-
386
- # Try importing again
387
- from mmdet.apis import init_detector, inference_detector
388
- MM_DET_AVAILABLE = True
389
- print("✅ MMDetection installed and available for inference")
390
- except Exception as install_error:
391
- print(f"❌ Failed to install MMDetection: {install_error}")
392
- MM_DET_AVAILABLE = False
393
-
394
- # === Chart Type Classification (DocFigure) ===
395
- print("🔄 Loading Chart Classification Model...")
396
-
397
- # Chart type labels from DocFigure dataset (28 classes)
398
- CHART_TYPE_LABELS = [
399
- 'Line graph', 'Natural image', 'Table', '3D object', 'Bar plot', 'Scatter plot',
400
- 'Medical image', 'Sketch', 'Geographic map', 'Flow chart', 'Heat map', 'Mask',
401
- 'Block diagram', 'Venn diagram', 'Confusion matrix', 'Histogram', 'Box plot',
402
- 'Vector plot', 'Pie chart', 'Surface plot', 'Algorithm', 'Contour plot',
403
- 'Tree diagram', 'Bubble chart', 'Polar plot', 'Area chart', 'Pareto chart', 'Radar chart'
404
- ]
405
-
406
- try:
407
- # Load the chart_type.pth model file from Hugging Face Hub
408
- from huggingface_hub import hf_hub_download
409
- import torch
410
- from torchvision import transforms
411
-
412
- print("🔄 Downloading chart_type.pth from Hugging Face Hub...")
413
- chart_type_path = hf_hub_download(
414
- repo_id="hanszhu/ChartTypeNet-DocFigure",
415
- filename="chart_type.pth",
416
- cache_dir="./models"
417
- )
418
- print(f"✅ Downloaded to: {chart_type_path}")
419
-
420
- # Load the PyTorch model
421
- loaded_data = torch.load(chart_type_path, map_location='cpu')
422
-
423
- # Check if it's a state dict or a complete model
424
- if isinstance(loaded_data, dict):
425
- # Check if it's a checkpoint with model_state_dict
426
- if "model_state_dict" in loaded_data:
427
- print("🔄 Loading checkpoint, extracting model_state_dict...")
428
- state_dict = loaded_data["model_state_dict"]
429
- else:
430
- # It's a direct state dict
431
- print("🔄 Loading state dict, creating model architecture...")
432
- state_dict = loaded_data
433
-
434
- # Strip "backbone." prefix from state dict keys if present
435
- cleaned_state_dict = {}
436
- for key, value in state_dict.items():
437
- if key.startswith("backbone."):
438
- # Remove "backbone." prefix
439
- new_key = key[9:]
440
- cleaned_state_dict[new_key] = value
441
- else:
442
- cleaned_state_dict[key] = value
443
-
444
- print(f"🔄 Cleaned state dict: {len(cleaned_state_dict)} keys")
445
-
446
- # Create the model architecture
447
- from torchvision.models import resnet50
448
- chart_type_model = resnet50(pretrained=False)
449
-
450
- # Create the correct classifier structure to match the state dict
451
- import torch.nn as nn
452
- in_features = chart_type_model.fc.in_features
453
- dropout = nn.Dropout(0.5)
454
-
455
- chart_type_model.fc = nn.Sequential(
456
- nn.Linear(in_features, 512),
457
- nn.ReLU(inplace=True),
458
- dropout,
459
- nn.Linear(512, 28)
460
- )
461
-
462
- # Load the cleaned state dict
463
- chart_type_model.load_state_dict(cleaned_state_dict)
464
- else:
465
- # It's a complete model
466
- chart_type_model = loaded_data
467
-
468
- chart_type_model.eval()
469
-
470
- # Create a simple processor for the model
471
- chart_type_processor = transforms.Compose([
472
- transforms.Resize((224, 224)),
473
- transforms.ToTensor(),
474
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
475
- ])
476
-
477
- CHART_TYPE_AVAILABLE = True
478
- print("✅ Chart classification model loaded")
479
- except Exception as e:
480
- print(f"⚠️ Failed to load chart classification model: {e}")
481
- import traceback
482
- print("🔍 Full traceback:")
483
- traceback.print_exc()
484
- CHART_TYPE_AVAILABLE = False
485
-
486
- # === Chart Element Detection (Cascade R-CNN) ===
487
- element_model = None
488
- datapoint_model = None
489
-
490
- print(f"🔍 MM_DET_AVAILABLE: {MM_DET_AVAILABLE}")
491
-
492
- if MM_DET_AVAILABLE:
493
- # Check if config files exist
494
- element_config = "models/chart_elementnet_swin.py"
495
- point_config = "models/chart_pointnet_swin.py"
496
-
497
- print(f"🔍 Checking config files...")
498
- print(f"🔍 Element config exists: {os.path.exists(element_config)}")
499
- print(f"🔍 Point config exists: {os.path.exists(point_config)}")
500
- print(f"🔍 Current working directory: {os.getcwd()}")
501
- print(f"🔍 Files in models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}")
502
-
503
- try:
504
- print("🔄 Loading ChartElementNet-MultiClass (Cascade R-CNN)...")
505
- print(f"🔍 Config path: {element_config}")
506
- print(f"🔍 Weights path: hanszhu/ChartElementNet-MultiClass")
507
- print(f"🔍 About to call init_detector...")
508
-
509
- # Download model from Hugging Face Hub
510
- from huggingface_hub import hf_hub_download
511
- print("🔄 Downloading ChartElementNet weights from Hugging Face Hub...")
512
- element_checkpoint = hf_hub_download(
513
- repo_id="hanszhu/ChartElementNet-MultiClass",
514
- filename="chart_label+.pth",
515
- cache_dir="./models"
516
- )
517
- print(f"✅ Downloaded to: {element_checkpoint}")
518
-
519
- # Use local config with downloaded weights
520
- element_model = init_detector(element_config, element_checkpoint, device="cpu")
521
- print("✅ ChartElementNet loaded successfully")
522
- except Exception as e:
523
- print(f"❌ Failed to load ChartElementNet: {e}")
524
- print(f"🔍 Error type: {type(e).__name__}")
525
- print(f"🔍 Error details: {str(e)}")
526
- import traceback
527
- print("🔍 Full traceback:")
528
- traceback.print_exc()
529
-
530
- try:
531
- print("🔄 Loading ChartPointNet-InstanceSeg (Mask R-CNN)...")
532
- print(f"🔍 Config path: {point_config}")
533
- print(f"🔍 Weights path: hanszhu/ChartPointNet-InstanceSeg")
534
- print(f"🔍 About to call init_detector...")
535
-
536
- # Download model from Hugging Face Hub
537
- print("🔄 Downloading ChartPointNet weights from Hugging Face Hub...")
538
- datapoint_checkpoint = hf_hub_download(
539
- repo_id="hanszhu/ChartPointNet-InstanceSeg",
540
- filename="chart_datapoint.pth",
541
- cache_dir="./models"
542
- )
543
- print(f"✅ Downloaded to: {datapoint_checkpoint}")
544
-
545
- # Use local config with downloaded weights
546
- datapoint_model = init_detector(point_config, datapoint_checkpoint, device="cpu")
547
- print("✅ ChartPointNet loaded successfully")
548
- except Exception as e:
549
- print(f"❌ Failed to load ChartPointNet: {e}")
550
- print(f"🔍 Error type: {type(e).__name__}")
551
- print(f"🔍 Error details: {str(e)}")
552
- import traceback
553
- print("🔍 Full traceback:")
554
- traceback.print_exc()
555
- else:
556
- print("❌ MMDetection not available - cannot load custom models")
557
- print(f"🔍 MM_DET_AVAILABLE was False")
558
-
559
- print(f"🔍 Final model status:")
560
- print(f"🔍 element_model: {element_model is not None}")
561
- print(f"🔍 datapoint_model: {datapoint_model is not None}")
562
-
563
- # === Main prediction function ===
564
- def analyze(image):
565
- """
566
- Analyze a chart image and return comprehensive results.
567
-
568
- Args:
569
- image: Input chart image (filepath string or PIL.Image)
570
-
571
- Returns:
572
- dict: Analysis results containing:
573
- - chart_type_id (int): Numeric chart type identifier (0-27)
574
- - chart_type_label (str): Human-readable chart type name
575
- - element_result (str): Detected chart elements (titles, axes, legends, etc.)
576
- - datapoint_result (str): Segmented data points and regions
577
- - status (str): Processing status message
578
- - processing_time (float): Time taken for analysis in seconds
579
- """
580
- import time
581
- from PIL import Image
582
-
583
- start_time = time.time()
584
-
585
- # Handle filepath input (convert to PIL Image)
586
- if isinstance(image, str):
587
- # It's a filepath, load the image
588
- image = Image.open(image).convert("RGB")
589
- elif image is None:
590
- return {"error": "No image provided"}
591
-
592
- # Ensure we have a PIL Image
593
- if not isinstance(image, Image.Image):
594
- return {"error": "Invalid image format"}
595
-
596
- result = {
597
- "chart_type_id": "Model not available",
598
- "chart_type_label": "Model not available",
599
- "element_result": "MMDetection models not available",
600
- "datapoint_result": "MMDetection models not available",
601
- "status": "Basic chart classification only",
602
- "processing_time": 0.0,
603
- "medsam": {"available": False}
604
- }
605
-
606
- # Chart Type Classification
607
- if CHART_TYPE_AVAILABLE:
608
- try:
609
- # Preprocess image for PyTorch model
610
- processed_image = chart_type_processor(image).unsqueeze(0) # Add batch dimension
611
-
612
- # Get prediction
613
- with torch.no_grad():
614
- outputs = chart_type_model(processed_image)
615
- # Handle different output formats
616
- if isinstance(outputs, torch.Tensor):
617
- logits = outputs
618
- elif hasattr(outputs, 'logits'):
619
- logits = outputs.logits
620
- else:
621
- logits = outputs
622
-
623
- predicted_class = logits.argmax(dim=-1).item()
624
-
625
- result["chart_type_id"] = predicted_class
626
- result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})"
627
- result["status"] = "Chart classification completed"
628
-
629
- except Exception as e:
630
- result["chart_type_id"] = f"Error: {str(e)}"
631
- result["chart_type_label"] = f"Error: {str(e)}"
632
- result["status"] = "Error in chart classification"
633
-
634
- # Chart Element Detection (Cascade R-CNN)
635
- if element_model is not None:
636
- try:
637
- # Convert PIL image to numpy array for MMDetection
638
- np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL → BGR
639
-
640
- element_result = inference_detector(element_model, np_img)
641
-
642
- # Convert result to more API-friendly format
643
- if isinstance(element_result, tuple):
644
- bbox_result, segm_result = element_result
645
- element_data = {
646
- "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
647
- "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
648
- }
649
- else:
650
- element_data = str(element_result)
651
-
652
- result["element_result"] = element_data
653
- result["status"] = "Chart classification + element detection completed"
654
- except Exception as e:
655
- result["element_result"] = f"Error: {str(e)}"
656
-
657
- # Chart Data Point Segmentation (Mask R-CNN)
658
- if datapoint_model is not None:
659
- try:
660
- # Convert PIL image to numpy array for MMDetection
661
- np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL → BGR
662
-
663
- datapoint_result = inference_detector(datapoint_model, np_img)
664
-
665
- # Convert result to more API-friendly format
666
- if isinstance(datapoint_result, tuple):
667
- bbox_result, segm_result = datapoint_result
668
- datapoint_data = {
669
- "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
670
- "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
671
- }
672
- else:
673
- datapoint_data = str(datapoint_result)
674
-
675
- result["datapoint_result"] = datapoint_data
676
- result["status"] = "Full analysis completed"
677
- except Exception as e:
678
- result["datapoint_result"] = f"Error: {str(e)}"
679
-
680
- # If predicted as medical image and MedSAM is available, include mask data (polygons)
681
- try:
682
- label_lower = str(result.get("chart_type_label", "")).strip().lower()
683
- if label_lower == "medical image":
684
- if _medsam.is_available():
685
- # Do not run heuristics here. Prompts are required and handled in the UI then-chain.
686
- # Indicate availability and that prompts are needed for segmentation.
687
- result["medsam"] = {"available": True, "reason": "provide bbox/points prompts to generate segmentations"}
688
- else:
689
- # Not available; include reason
690
- result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"}
691
- except Exception as e:
692
- print(f"MedSAM JSON augmentation failed: {e}")
693
-
694
- result["processing_time"] = round(time.time() - start_time, 3)
695
- return result
696
-
697
-
698
- def analyze_with_medsam(base_result, image):
699
- """Auto-generate segmentations for medical images using SAM ViT-H if available,
700
- otherwise fallback to MedSAM over top-K foreground boxes. Returns updated JSON and overlay image."""
701
- try:
702
- if not isinstance(base_result, dict):
703
- return base_result, None
704
- label = str(base_result.get("chart_type_label", "")).strip().lower()
705
- if label != "medical image" or not _medsam.is_available():
706
- return base_result, None
707
-
708
- pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
709
- if pil_img is None:
710
- return base_result, None
711
-
712
- # Prepare embedding
713
- img_path = image if isinstance(image, str) else None
714
- if img_path is None:
715
- tmp_path = "./_tmp_input_image.png"
716
- pil_img.save(tmp_path)
717
- img_path = tmp_path
718
- _medsam.load_image(img_path)
719
-
720
- segmentations = []
721
- masks_for_overlay = []
722
-
723
- # AUTO segmentation path
724
- try:
725
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
726
- import cv2 as _cv2
727
- # If ViT-H checkpoint present, use SAM automatic mask generator (download if missing)
728
- vit_h_ckpt = "models/sam_vit_h_4b8939.pth"
729
- if not os.path.exists(vit_h_ckpt):
730
- try:
731
- from huggingface_hub import hf_hub_download
732
- vit_h_ckpt = hf_hub_download(
733
- repo_id="Aniketg6/SAM",
734
- filename="sam_vit_h_4b8939.pth",
735
- cache_dir="./models"
736
- )
737
- print(f"✅ Downloaded SAM ViT-H checkpoint to: {vit_h_ckpt}")
738
- except Exception as dlh:
739
- print(f"⚠ Failed to download SAM ViT-H checkpoint: {dlh}")
740
- if os.path.exists(vit_h_ckpt):
741
- img_bgr = _cv2.imread(img_path)
742
- sam = sam_model_registry["vit_h"](checkpoint=vit_h_ckpt)
743
- mask_generator = SamAutomaticMaskGenerator(sam)
744
- masks = mask_generator.generate(img_bgr)
745
- for m in masks:
746
- seg = m.get('segmentation', None)
747
- if seg is None:
748
- continue
749
- seg_u8 = seg.astype(np.uint8)
750
- segmentations.append({
751
- "mask": seg_u8.tolist(),
752
- "confidence": float(m.get('stability_score', 1.0)),
753
- "method": "sam_auto"
754
- })
755
- masks_for_overlay.append({"mask": seg_u8})
756
- else:
757
- # Fallback: derive candidate boxes and run MedSAM per box
758
- cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=20, min_area=200)
759
- for bbox in cand_bboxes:
760
- m = _medsam.segment_with_box(bbox)
761
- if m is None or not isinstance(m.get('mask'), np.ndarray):
762
- continue
763
- segmentations.append({
764
- "mask": m['mask'].astype(np.uint8).tolist(),
765
- "confidence": float(m.get('confidence', 1.0)),
766
- "method": m.get("method", "medsam_box_auto")
767
- })
768
- masks_for_overlay.append(m)
769
- except Exception as auto_e:
770
- print(f"Automatic MedSAM segmentation failed: {auto_e}")
771
-
772
- W, H = pil_img.size
773
- base_result["medsam"] = {
774
- "available": True,
775
- "height": H,
776
- "width": W,
777
- "segmentations": segmentations,
778
- "num_segments": len(segmentations)
779
- }
780
-
781
- overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None
782
- return base_result, overlay_img
783
- except Exception as e:
784
- print(f"analyze_with_medsam failed: {e}")
785
- return base_result, None
786
-
787
- # === Gradio UI with API enhancements ===
788
- # Create Blocks interface with explicit API name for stable API surface
789
- with gr.Blocks(
790
- title="📊 Dense Captioning Platform"
791
- ) as demo:
792
-
793
- gr.Markdown("# 📊 Dense Captioning Platform")
794
- gr.Markdown("""
795
- **Comprehensive Chart Analysis API**
796
-
797
- Upload a chart image to get:
798
- - **Chart Type Classification**: Identifies the type of chart (line, bar, scatter, etc.)
799
- - **Element Detection**: Detects chart elements like titles, axes, legends, data points
800
- - **Data Point Segmentation**: Segments individual data points and regions
801
-
802
- Masks will be automatically generated for medical images when supported.
803
-
804
- **API Usage:**
805
- ```python
806
- from gradio_client import Client, handle_file
807
-
808
- client = Client("hanszhu/Dense-Captioning-Platform")
809
- result = client.predict(
810
- image=handle_file('path/to/your/chart.png'),
811
- api_name="/predict"
812
- )
813
- print(result)
814
- ```
815
-
816
- **Supported Chart Types:** Line graphs, Bar plots, Scatter plots, Pie charts, Heat maps, and 23+ more
817
- """)
818
-
819
- with gr.Row():
820
- with gr.Column():
821
- # Input
822
- image_input = gr.Image(
823
- type="filepath", # ✅ REQUIRED for gradio_client
824
- label="Upload Chart Image",
825
- height=400
826
- )
827
-
828
- # Analyze button (single)
829
- analyze_btn = gr.Button(
830
- "🔍 Analyze",
831
- variant="primary",
832
- size="lg"
833
- )
834
-
835
- with gr.Column():
836
- # Output JSON
837
- result_output = gr.JSON(
838
- label="Analysis Results",
839
- height=400
840
- )
841
- # Overlay image output (populated only for medical images)
842
- overlay_output = gr.Image(
843
- label="MedSAM Overlay (Medical images)",
844
- height=400
845
- )
846
-
847
- # Single API endpoint for JSON
848
- analyze_event = analyze_btn.click(
849
- fn=analyze,
850
- inputs=image_input,
851
- outputs=result_output,
852
- api_name="/predict" # ✅ Standard API name that gradio_client expects
853
- )
854
-
855
- # Automatic overlay generation step for medical images
856
- analyze_event.then(
857
- fn=analyze_with_medsam,
858
- inputs=[result_output, image_input],
859
- outputs=[result_output, overlay_output],
860
- )
861
-
862
- # Add some examples
863
- gr.Examples(
864
- examples=[
865
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"]
866
- ],
867
- inputs=image_input,
868
- label="Try with this example"
869
- )
870
-
871
- # Launch with API-friendly settings
872
- if __name__ == "__main__":
873
- launch_kwargs = {
874
- "server_name": "0.0.0.0", # Allow external connections
875
- "server_port": 7860,
876
- "share": False, # Set to True if you want a public link
877
- "show_error": True, # Show detailed errors for debugging
878
- "quiet": False, # Show startup messages
879
- "show_api": True # Enable API documentation
880
- }
881
-
882
- # Enable queue for gradio_client compatibility
883
- demo.queue().launch(**launch_kwargs) # ✅ required for gradio_client to work
884
 
 
1
+ from fastapi import FastAPI
 
 
 
 
 
 
2
 
3
+ app = FastAPI()
 
 
 
 
 
4
 
5
+ @app.get("/")
6
+ def greet_json():
7
+ return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
requirements.txt CHANGED
@@ -1,12 +1,2 @@
1
- gradio==5.39.0
2
- torch>=2.0.0
3
- torchvision>=0.15.0
4
- transformers>=4.30.0
5
- Pillow>=9.0.0
6
- numpy>=1.21.0
7
- opencv-python>=4.8.0
8
- huggingface-hub>=0.16.0
9
- openmim
10
- mmdet
11
- mmengine
12
- scikit-image>=0.21.0
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.6