Spaces:
Sleeping
Sleeping
Commit
Β·
eb4d305
0
Parent(s):
build(space): initial Docker Space with Gradio app, MMDet, SAM integration
Browse files- .gitattributes +35 -0
- .gitignore +46 -0
- Dockerfile +29 -0
- README.md +13 -0
- README_API.md +309 -0
- app.py +926 -0
- custom_models/custom_cascade_with_meta.py +152 -0
- custom_models/custom_dataset.py +537 -0
- custom_models/custom_faster_rcnn_with_meta.py +166 -0
- custom_models/custom_heads.py +267 -0
- custom_models/flexible_load_annotations.py +191 -0
- custom_models/mask_filter.py +48 -0
- custom_models/nan_recovery_hook.py +181 -0
- custom_models/progressive_loss_hook.py +286 -0
- custom_models/register.py +29 -0
- custom_models/square_fcn_mask_head.py +170 -0
- custom_models/square_mask_target.py +67 -0
- debug_api.py +123 -0
- find_api_endpoint.py +167 -0
- models/chart_elementnet_swin.py +394 -0
- models/chart_pointnet_swin.py +374 -0
- requirements.txt +12 -0
- simple_test.py +60 -0
- test_api.py +95 -0
- test_api_endpoints.py +152 -0
- web_test.py +115 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python cache files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
|
7 |
+
# Model files (downloaded automatically)
|
8 |
+
models/models--*/
|
9 |
+
models/.locks/
|
10 |
+
*.pth
|
11 |
+
*.pkl
|
12 |
+
*.h5
|
13 |
+
*.onnx
|
14 |
+
|
15 |
+
# Environment files
|
16 |
+
.env
|
17 |
+
.venv
|
18 |
+
env/
|
19 |
+
venv/
|
20 |
+
ENV/
|
21 |
+
env.bak/
|
22 |
+
venv.bak/
|
23 |
+
|
24 |
+
# IDE files
|
25 |
+
.vscode/
|
26 |
+
.idea/
|
27 |
+
*.swp
|
28 |
+
*.swo
|
29 |
+
*~
|
30 |
+
|
31 |
+
# OS files
|
32 |
+
.DS_Store
|
33 |
+
.DS_Store?
|
34 |
+
._*
|
35 |
+
.Spotlight-V100
|
36 |
+
.Trashes
|
37 |
+
ehthumbs.db
|
38 |
+
Thumbs.db
|
39 |
+
|
40 |
+
# Logs
|
41 |
+
*.log
|
42 |
+
logs/
|
43 |
+
|
44 |
+
# Temporary files
|
45 |
+
*.tmp
|
46 |
+
*.temp
|
Dockerfile
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
4 |
+
PIP_NO_CACHE_DIR=1 \
|
5 |
+
MPLBACKEND=Agg \
|
6 |
+
MIM_IGNORE_INSTALL_PYTORCH=1
|
7 |
+
|
8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
9 |
+
libgl1 libglib2.0-0 git && \
|
10 |
+
rm -rf /var/lib/apt/lists/*
|
11 |
+
|
12 |
+
WORKDIR /app
|
13 |
+
|
14 |
+
COPY requirements.txt /app/requirements.txt
|
15 |
+
|
16 |
+
# Install pip deps and the mm stack with openmim
|
17 |
+
RUN python -m pip install -U pip openmim && \
|
18 |
+
pip install -r requirements.txt && \
|
19 |
+
mim install "mmengine==0.10.4" && \
|
20 |
+
mim install "mmcv==2.1.0" && \
|
21 |
+
mim install "mmdet==3.3.0" && \
|
22 |
+
pip install git+https://github.com/facebookresearch/segment-anything.git
|
23 |
+
|
24 |
+
# Copy the rest of the application
|
25 |
+
COPY . /app
|
26 |
+
|
27 |
+
EXPOSE 7860
|
28 |
+
|
29 |
+
CMD ["python", "app.py"]
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Dense Captioning Platform
|
3 |
+
emoji: π’
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.39.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
README_API.md
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# π Dense Captioning Platform API Documentation
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
The Dense Captioning Platform provides comprehensive chart analysis through a Gradio-based API. It can classify chart types, detect chart elements, and segment data points from uploaded images.
|
6 |
+
|
7 |
+
## API Access
|
8 |
+
|
9 |
+
**Base URL:** `https://hanszhu-dense-captioning-platform.hf.space`
|
10 |
+
|
11 |
+
**API Type:** Gradio Client API (not RESTful)
|
12 |
+
|
13 |
+
## Installation
|
14 |
+
|
15 |
+
### Prerequisites
|
16 |
+
|
17 |
+
```bash
|
18 |
+
pip install gradio-client
|
19 |
+
```
|
20 |
+
|
21 |
+
### Quick Start
|
22 |
+
|
23 |
+
### Python Client (Recommended)
|
24 |
+
|
25 |
+
```python
|
26 |
+
from gradio_client import Client, handle_file
|
27 |
+
|
28 |
+
# Initialize client with direct URL
|
29 |
+
client = Client("https://hanszhu-dense-captioning-platform.hf.space")
|
30 |
+
|
31 |
+
# Analyze a chart image using file path
|
32 |
+
result = client.predict(
|
33 |
+
image=handle_file('path/to/your/chart.png'),
|
34 |
+
fn_index=0
|
35 |
+
)
|
36 |
+
|
37 |
+
print(result)
|
38 |
+
```
|
39 |
+
|
40 |
+
### Using a URL
|
41 |
+
|
42 |
+
```python
|
43 |
+
from gradio_client import Client, handle_file
|
44 |
+
|
45 |
+
client = Client("https://hanszhu-dense-captioning-platform.hf.space")
|
46 |
+
|
47 |
+
# Use a publicly accessible image URL
|
48 |
+
result = client.predict(
|
49 |
+
image=handle_file("https://example.com/chart.png"),
|
50 |
+
fn_index=0
|
51 |
+
)
|
52 |
+
|
53 |
+
print(result)
|
54 |
+
```
|
55 |
+
|
56 |
+
## Input Parameters
|
57 |
+
|
58 |
+
| Parameter | Type | Required | Description |
|
59 |
+
|-----------|------|----------|-------------|
|
60 |
+
| `image` | File/URL | Yes | Chart image to analyze (PNG, JPG, JPEG supported) |
|
61 |
+
|
62 |
+
## Important Notes
|
63 |
+
|
64 |
+
### β
Working Approach
|
65 |
+
- **Use `fn_index=0`** instead of `api_name="/predict"`
|
66 |
+
- **Use direct URL** `"https://hanszhu-dense-captioning-platform.hf.space"`
|
67 |
+
- **Always use `handle_file()`** for both local files and URLs
|
68 |
+
- **This is a Gradio Client API**, not a RESTful API
|
69 |
+
|
70 |
+
### β What Doesn't Work
|
71 |
+
- Direct HTTP POST requests to `/predict`
|
72 |
+
- Using `api_name="/predict"` with this setup
|
73 |
+
- Using `Client("hanszhu/Dense-Captioning-Platform")` (use direct URL instead)
|
74 |
+
|
75 |
+
## Output Format
|
76 |
+
|
77 |
+
The API returns a JSON object with the following structure:
|
78 |
+
|
79 |
+
```json
|
80 |
+
{
|
81 |
+
"chart_type_id": 4,
|
82 |
+
"chart_type_label": "Bar plot",
|
83 |
+
"element_result": {
|
84 |
+
"bboxes": [...],
|
85 |
+
"segments": [...]
|
86 |
+
},
|
87 |
+
"datapoint_result": {
|
88 |
+
"bboxes": [...],
|
89 |
+
"segments": [...]
|
90 |
+
},
|
91 |
+
"status": "Full analysis completed",
|
92 |
+
"processing_time": 2.345
|
93 |
+
}
|
94 |
+
```
|
95 |
+
|
96 |
+
### Output Fields
|
97 |
+
|
98 |
+
| Field | Type | Description |
|
99 |
+
|-------|------|-------------|
|
100 |
+
| `chart_type_id` | int | Numeric identifier for chart type (0-27) |
|
101 |
+
| `chart_type_label` | string | Human-readable chart type name |
|
102 |
+
| `element_result` | object/string | Detected chart elements (titles, axes, legends, etc.) |
|
103 |
+
| `datapoint_result` | object/string | Segmented data points and regions |
|
104 |
+
| `status` | string | Processing status message |
|
105 |
+
| `processing_time` | float | Time taken for analysis in seconds |
|
106 |
+
|
107 |
+
## Supported Chart Types
|
108 |
+
|
109 |
+
The platform can classify 28 different chart types:
|
110 |
+
|
111 |
+
| ID | Chart Type | ID | Chart Type |
|
112 |
+
|----|------------|----|------------|
|
113 |
+
| 0 | Line graph | 14 | Histogram |
|
114 |
+
| 1 | Natural image | 15 | Box plot |
|
115 |
+
| 2 | Table | 16 | Vector plot |
|
116 |
+
| 3 | 3D object | 17 | Pie chart |
|
117 |
+
| 4 | Bar plot | 18 | Surface plot |
|
118 |
+
| 5 | Scatter plot | 19 | Algorithm |
|
119 |
+
| 6 | Medical image | 20 | Contour plot |
|
120 |
+
| 7 | Sketch | 21 | Tree diagram |
|
121 |
+
| 8 | Geographic map | 22 | Bubble chart |
|
122 |
+
| 9 | Flow chart | 23 | Polar plot |
|
123 |
+
| 10 | Heat map | 24 | Area chart |
|
124 |
+
| 11 | Mask | 25 | Pareto chart |
|
125 |
+
| 12 | Block diagram | 26 | Radar chart |
|
126 |
+
| 13 | Venn diagram | 27 | Confusion matrix |
|
127 |
+
|
128 |
+
## Chart Elements Detected
|
129 |
+
|
130 |
+
The element detection model identifies:
|
131 |
+
|
132 |
+
- **Titles & Labels**: Chart title, subtitle, axis labels
|
133 |
+
- **Axes**: X-axis, Y-axis, tick labels
|
134 |
+
- **Legend**: Legend title, legend items, legend text
|
135 |
+
- **Data Elements**: Data points, data lines, data bars, data areas
|
136 |
+
- **Structural Elements**: Grid lines, plot areas
|
137 |
+
|
138 |
+
## Error Handling
|
139 |
+
|
140 |
+
The API returns error messages in the response fields when issues occur:
|
141 |
+
|
142 |
+
```json
|
143 |
+
{
|
144 |
+
"chart_type_id": "Error: Model not available",
|
145 |
+
"chart_type_label": "Error: Model not available",
|
146 |
+
"element_result": "Error: Invalid image format",
|
147 |
+
"datapoint_result": "Error: Processing failed",
|
148 |
+
"status": "Error in chart classification",
|
149 |
+
"processing_time": 0.0
|
150 |
+
}
|
151 |
+
```
|
152 |
+
|
153 |
+
## Rate Limits
|
154 |
+
|
155 |
+
- **Free Tier**: Limited requests per hour
|
156 |
+
- **Processing Time**: Typically 2-5 seconds per image
|
157 |
+
- **Image Size**: Recommended max 10MB
|
158 |
+
|
159 |
+
## Complete Working Example
|
160 |
+
|
161 |
+
Here's a complete example that demonstrates all the working patterns:
|
162 |
+
|
163 |
+
```python
|
164 |
+
from gradio_client import Client, handle_file
|
165 |
+
import json
|
166 |
+
|
167 |
+
def analyze_chart(image_path_or_url):
|
168 |
+
"""
|
169 |
+
Analyze a chart image using the Dense Captioning Platform API
|
170 |
+
|
171 |
+
Args:
|
172 |
+
image_path_or_url (str): Path to local image file or URL to image
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
dict: Analysis results with chart type, elements, and data points
|
176 |
+
"""
|
177 |
+
try:
|
178 |
+
# Initialize client with direct URL
|
179 |
+
client = Client("https://hanszhu-dense-captioning-platform.hf.space")
|
180 |
+
|
181 |
+
# Make prediction using the working approach
|
182 |
+
result = client.predict(
|
183 |
+
image=handle_file(image_path_or_url),
|
184 |
+
fn_index=0
|
185 |
+
)
|
186 |
+
|
187 |
+
return result
|
188 |
+
|
189 |
+
except Exception as e:
|
190 |
+
return {
|
191 |
+
"error": f"API call failed: {str(e)}",
|
192 |
+
"status": "Error",
|
193 |
+
"processing_time": 0.0
|
194 |
+
}
|
195 |
+
|
196 |
+
# Example usage
|
197 |
+
if __name__ == "__main__":
|
198 |
+
# Test with a local file
|
199 |
+
local_result = analyze_chart("path/to/your/chart.png")
|
200 |
+
print("Local file result:", json.dumps(local_result, indent=2))
|
201 |
+
|
202 |
+
# Test with a URL
|
203 |
+
url_result = analyze_chart("https://example.com/chart.png")
|
204 |
+
print("URL result:", json.dumps(url_result, indent=2))
|
205 |
+
```
|
206 |
+
|
207 |
+
## Examples
|
208 |
+
|
209 |
+
### Example 1: Bar Chart Analysis
|
210 |
+
|
211 |
+
```python
|
212 |
+
from gradio_client import Client, handle_file
|
213 |
+
|
214 |
+
client = Client("https://hanszhu-dense-captioning-platform.hf.space")
|
215 |
+
|
216 |
+
# Analyze a bar chart
|
217 |
+
result = client.predict(
|
218 |
+
image=handle_file('bar_chart.png'),
|
219 |
+
fn_index=0
|
220 |
+
)
|
221 |
+
|
222 |
+
print(f"Chart Type: {result['chart_type_label']}")
|
223 |
+
print(f"Processing Time: {result['processing_time']}s")
|
224 |
+
```
|
225 |
+
|
226 |
+
### Example 2: Batch Processing
|
227 |
+
|
228 |
+
```python
|
229 |
+
from gradio_client import Client, handle_file
|
230 |
+
import os
|
231 |
+
|
232 |
+
client = Client("https://hanszhu-dense-captioning-platform.hf.space")
|
233 |
+
|
234 |
+
# Process multiple charts
|
235 |
+
chart_files = ['chart1.png', 'chart2.png', 'chart3.png']
|
236 |
+
results = []
|
237 |
+
|
238 |
+
for chart_file in chart_files:
|
239 |
+
if os.path.exists(chart_file):
|
240 |
+
result = client.predict(
|
241 |
+
image=handle_file(chart_file),
|
242 |
+
fn_index=0
|
243 |
+
)
|
244 |
+
results.append(result)
|
245 |
+
print(f"Processed {chart_file}: {result['chart_type_label']}")
|
246 |
+
```
|
247 |
+
|
248 |
+
### Example 3: Test with Public Image
|
249 |
+
|
250 |
+
```python
|
251 |
+
from gradio_client import Client, handle_file
|
252 |
+
|
253 |
+
client = Client("https://hanszhu-dense-captioning-platform.hf.space")
|
254 |
+
|
255 |
+
# Test with a public image URL
|
256 |
+
result = client.predict(
|
257 |
+
image=handle_file("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"),
|
258 |
+
fn_index=0
|
259 |
+
)
|
260 |
+
|
261 |
+
print("β
API Test Successful!")
|
262 |
+
print(f"Chart Type: {result['chart_type_label']}")
|
263 |
+
print(f"Status: {result['status']}")
|
264 |
+
```
|
265 |
+
|
266 |
+
## Troubleshooting
|
267 |
+
|
268 |
+
### Common Issues
|
269 |
+
|
270 |
+
1. **"Model not available"**: The models are still loading, wait a moment and retry
|
271 |
+
2. **"Invalid image format"**: Ensure the image is in PNG, JPG, or JPEG format
|
272 |
+
3. **"Processing failed"**: The image might be corrupted or too large
|
273 |
+
4. **"Expecting value: line 1 column 1"**: Use `fn_index=0` instead of `api_name="/predict"`
|
274 |
+
5. **"Cannot find a function with api_name"**: Use direct URL and `fn_index=0`
|
275 |
+
|
276 |
+
### Best Practices
|
277 |
+
|
278 |
+
1. **Image Quality**: Use clear, high-resolution images for best results
|
279 |
+
2. **Format**: PNG or JPG formats work best
|
280 |
+
3. **Size**: Keep images under 10MB for faster processing
|
281 |
+
4. **Client Setup**: Always use direct URL and `fn_index=0`
|
282 |
+
5. **File Handling**: Always use `handle_file()` for both local files and URLs
|
283 |
+
6. **Retry Logic**: Implement retry logic for failed requests
|
284 |
+
|
285 |
+
### Quick Test
|
286 |
+
|
287 |
+
To verify the API is working, run this test:
|
288 |
+
|
289 |
+
```python
|
290 |
+
from gradio_client import Client, handle_file
|
291 |
+
|
292 |
+
try:
|
293 |
+
client = Client("https://hanszhu-dense-captioning-platform.hf.space")
|
294 |
+
result = client.predict(
|
295 |
+
image=handle_file("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"),
|
296 |
+
fn_index=0
|
297 |
+
)
|
298 |
+
print("β
API is working!")
|
299 |
+
print(f"Chart Type: {result['chart_type_label']}")
|
300 |
+
except Exception as e:
|
301 |
+
print(f"β API test failed: {e}")
|
302 |
+
```
|
303 |
+
|
304 |
+
## Support
|
305 |
+
|
306 |
+
For issues or questions:
|
307 |
+
- Check the [Hugging Face Space](https://huggingface.co/spaces/hanszhu/Dense-Captioning-Platform)
|
308 |
+
- Review the error messages in the API response
|
309 |
+
- Ensure your image format and size are within limits
|
app.py
ADDED
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
# Add custom modules to path - try multiple possible locations
|
10 |
+
possible_paths = [
|
11 |
+
"./custom_models",
|
12 |
+
"../custom_models",
|
13 |
+
"./Dense-Captioning-Platform/custom_models"
|
14 |
+
]
|
15 |
+
|
16 |
+
for path in possible_paths:
|
17 |
+
if os.path.exists(path):
|
18 |
+
sys.path.insert(0, os.path.abspath(path))
|
19 |
+
break
|
20 |
+
|
21 |
+
# Add mmcv to path if it exists
|
22 |
+
if os.path.exists('./mmcv'):
|
23 |
+
sys.path.insert(0, os.path.abspath('./mmcv'))
|
24 |
+
print("β
Added local mmcv to path")
|
25 |
+
|
26 |
+
# Import and register custom modules
|
27 |
+
try:
|
28 |
+
from custom_models import register
|
29 |
+
print("β
Custom modules registered successfully")
|
30 |
+
except Exception as e:
|
31 |
+
print(f"β οΈ Warning: Could not register custom modules: {e}")
|
32 |
+
|
33 |
+
# ----------------------
|
34 |
+
# Optional MedSAM integration
|
35 |
+
# ----------------------
|
36 |
+
class MedSAMIntegrator:
|
37 |
+
def __init__(self):
|
38 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
39 |
+
self.medsam_model = None
|
40 |
+
self.current_image = None
|
41 |
+
self.current_image_path = None
|
42 |
+
self.embedding = None
|
43 |
+
self._load_medsam_model()
|
44 |
+
|
45 |
+
def _ensure_segment_anything(self):
|
46 |
+
try:
|
47 |
+
import segment_anything # noqa: F401
|
48 |
+
return True
|
49 |
+
except Exception as e:
|
50 |
+
print(f"β segment_anything not available: {e}. It must be installed at build time (Dockerfile).")
|
51 |
+
return False
|
52 |
+
|
53 |
+
def _load_medsam_model(self):
|
54 |
+
try:
|
55 |
+
# Ensure library is present
|
56 |
+
if not self._ensure_segment_anything():
|
57 |
+
print("MedSAM features disabled (segment_anything not available)")
|
58 |
+
return
|
59 |
+
|
60 |
+
from segment_anything import sam_model_registry as _reg
|
61 |
+
import torch as _torch
|
62 |
+
|
63 |
+
# Preferred local path
|
64 |
+
medsam_ckpt_path = "models/medsam_vit_b.pth"
|
65 |
+
|
66 |
+
# If not present, fetch from HF Hub using provided repo or default
|
67 |
+
if not os.path.exists(medsam_ckpt_path):
|
68 |
+
try:
|
69 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
70 |
+
repo_id = os.environ.get("HF_MEDSAM_REPO", "Aniketg6/Fine-Tuned-MedSAM")
|
71 |
+
# Try to find a .pth/.pt in the repo
|
72 |
+
print(f"π Trying to download MedSAM checkpoint from {repo_id} ...")
|
73 |
+
files = list_repo_files(repo_id)
|
74 |
+
candidate = None
|
75 |
+
for f in files:
|
76 |
+
lf = f.lower()
|
77 |
+
if lf.endswith(".pth") or lf.endswith(".pt"):
|
78 |
+
candidate = f
|
79 |
+
break
|
80 |
+
if candidate is None:
|
81 |
+
# Fallback to a common name
|
82 |
+
candidate = "medsam_vit_b.pth"
|
83 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=candidate, cache_dir="./models")
|
84 |
+
medsam_ckpt_path = ckpt_path
|
85 |
+
print(f"β
Downloaded MedSAM checkpoint: {medsam_ckpt_path}")
|
86 |
+
except Exception as dl_err:
|
87 |
+
print(f"β Could not fetch MedSAM checkpoint from HF Hub: {dl_err}")
|
88 |
+
print("MedSAM features disabled (no checkpoint)")
|
89 |
+
return
|
90 |
+
|
91 |
+
# Load checkpoint
|
92 |
+
checkpoint = _torch.load(medsam_ckpt_path, map_location='cpu')
|
93 |
+
self.medsam_model = _reg["vit_b"](checkpoint=None)
|
94 |
+
self.medsam_model.load_state_dict(checkpoint)
|
95 |
+
self.medsam_model.to(self.device)
|
96 |
+
self.medsam_model.eval()
|
97 |
+
print("β MedSAM model loaded successfully")
|
98 |
+
except Exception as e:
|
99 |
+
print(f"β MedSAM model not available: {e}. MedSAM features disabled.")
|
100 |
+
|
101 |
+
def is_available(self):
|
102 |
+
return self.medsam_model is not None
|
103 |
+
|
104 |
+
def load_image(self, image_path, precomputed_embedding=None):
|
105 |
+
try:
|
106 |
+
from skimage import transform, io # local import to avoid hard dep if unused
|
107 |
+
img_np = io.imread(image_path)
|
108 |
+
if len(img_np.shape) == 2:
|
109 |
+
img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
|
110 |
+
else:
|
111 |
+
img_3c = img_np
|
112 |
+
self.current_image = img_3c
|
113 |
+
self.current_image_path = image_path
|
114 |
+
if precomputed_embedding is not None:
|
115 |
+
if not self.set_precomputed_embedding(precomputed_embedding):
|
116 |
+
self.get_embeddings()
|
117 |
+
else:
|
118 |
+
self.get_embeddings()
|
119 |
+
return True
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error loading image for MedSAM: {e}")
|
122 |
+
return False
|
123 |
+
|
124 |
+
@torch.no_grad()
|
125 |
+
def get_embeddings(self):
|
126 |
+
if self.current_image is None or self.medsam_model is None:
|
127 |
+
return None
|
128 |
+
from skimage import transform
|
129 |
+
img_1024 = transform.resize(
|
130 |
+
self.current_image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
|
131 |
+
).astype(np.uint8)
|
132 |
+
img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None)
|
133 |
+
img_1024_tensor = (
|
134 |
+
torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(self.device)
|
135 |
+
)
|
136 |
+
self.embedding = self.medsam_model.image_encoder(img_1024_tensor)
|
137 |
+
return self.embedding
|
138 |
+
|
139 |
+
def set_precomputed_embedding(self, embedding_array):
|
140 |
+
try:
|
141 |
+
if isinstance(embedding_array, np.ndarray):
|
142 |
+
embedding_tensor = torch.tensor(embedding_array).to(self.device)
|
143 |
+
self.embedding = embedding_tensor
|
144 |
+
return True
|
145 |
+
return False
|
146 |
+
except Exception as e:
|
147 |
+
print(f"Error setting precomputed embedding: {e}")
|
148 |
+
return False
|
149 |
+
|
150 |
+
@torch.no_grad()
|
151 |
+
def medsam_inference(self, box_1024, height, width):
|
152 |
+
if self.embedding is None or self.medsam_model is None:
|
153 |
+
return None
|
154 |
+
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=self.embedding.device)
|
155 |
+
if len(box_torch.shape) == 2:
|
156 |
+
box_torch = box_torch[:, None, :]
|
157 |
+
sparse_embeddings, dense_embeddings = self.medsam_model.prompt_encoder(
|
158 |
+
points=None, boxes=box_torch, masks=None,
|
159 |
+
)
|
160 |
+
low_res_logits, _ = self.medsam_model.mask_decoder(
|
161 |
+
image_embeddings=self.embedding,
|
162 |
+
image_pe=self.medsam_model.prompt_encoder.get_dense_pe(),
|
163 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
164 |
+
dense_prompt_embeddings=dense_embeddings,
|
165 |
+
multimask_output=False,
|
166 |
+
)
|
167 |
+
low_res_pred = torch.sigmoid(low_res_logits)
|
168 |
+
low_res_pred = torch.nn.functional.interpolate(
|
169 |
+
low_res_pred, size=(height, width), mode="bilinear", align_corners=False,
|
170 |
+
)
|
171 |
+
low_res_pred = low_res_pred.squeeze().cpu().numpy()
|
172 |
+
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
173 |
+
return medsam_seg
|
174 |
+
|
175 |
+
def segment_with_box(self, bbox):
|
176 |
+
if self.embedding is None or self.current_image is None:
|
177 |
+
return None
|
178 |
+
try:
|
179 |
+
H, W, _ = self.current_image.shape
|
180 |
+
x1, y1, x2, y2 = bbox
|
181 |
+
x1 = max(0, min(int(x1), W - 1))
|
182 |
+
y1 = max(0, min(int(y1), H - 1))
|
183 |
+
x2 = max(0, min(int(x2), W - 1))
|
184 |
+
y2 = max(0, min(int(y2), H - 1))
|
185 |
+
if x2 <= x1:
|
186 |
+
x2 = min(x1 + 10, W - 1)
|
187 |
+
if y2 <= y1:
|
188 |
+
y2 = min(y1 + 10, H - 1)
|
189 |
+
box_np = np.array([[x1, y1, x2, y2]], dtype=float)
|
190 |
+
box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
|
191 |
+
medsam_mask = self.medsam_inference(box_1024, H, W)
|
192 |
+
if medsam_mask is not None:
|
193 |
+
return {"mask": medsam_mask, "confidence": 1.0, "method": "medsam_box"}
|
194 |
+
return None
|
195 |
+
except Exception as e:
|
196 |
+
print(f"Error in MedSAM box-based segmentation: {e}")
|
197 |
+
return None
|
198 |
+
|
199 |
+
# Single global instance
|
200 |
+
_medsam = MedSAMIntegrator()
|
201 |
+
|
202 |
+
# Cache for SAM automatic mask generator
|
203 |
+
_sam_auto_generator = None
|
204 |
+
_sam_auto_ckpt_path = None
|
205 |
+
|
206 |
+
|
207 |
+
def _get_sam_generator():
|
208 |
+
"""Load and cache SAM ViT-H automatic mask generator with faster params if checkpoint exists."""
|
209 |
+
global _sam_auto_generator, _sam_auto_ckpt_path
|
210 |
+
if _sam_auto_generator is not None:
|
211 |
+
return _sam_auto_generator
|
212 |
+
try:
|
213 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
214 |
+
ckpt = "models/sam_vit_h_4b8939.pth"
|
215 |
+
if not os.path.exists(ckpt):
|
216 |
+
try:
|
217 |
+
from huggingface_hub import hf_hub_download
|
218 |
+
ckpt = hf_hub_download(
|
219 |
+
repo_id="Aniketg6/SAM",
|
220 |
+
filename="sam_vit_h_4b8939.pth",
|
221 |
+
cache_dir="./models"
|
222 |
+
)
|
223 |
+
print(f"β
Downloaded SAM ViT-H checkpoint to: {ckpt}")
|
224 |
+
except Exception as e:
|
225 |
+
print(f"β Failed to download SAM ViT-H checkpoint: {e}")
|
226 |
+
return None
|
227 |
+
_sam_auto_ckpt_path = ckpt
|
228 |
+
sam = sam_model_registry["vit_h"](checkpoint=ckpt)
|
229 |
+
# Speed-tuned generator params
|
230 |
+
_sam_auto_generator = SamAutomaticMaskGenerator(
|
231 |
+
sam,
|
232 |
+
points_per_side=16,
|
233 |
+
pred_iou_thresh=0.88,
|
234 |
+
stability_score_thresh=0.9,
|
235 |
+
crop_n_layers=0,
|
236 |
+
box_nms_thresh=0.7,
|
237 |
+
min_mask_region_area=512 # filter tiny masks
|
238 |
+
)
|
239 |
+
return _sam_auto_generator
|
240 |
+
except Exception as e:
|
241 |
+
print(f"_get_sam_generator failed: {e}")
|
242 |
+
return None
|
243 |
+
|
244 |
+
|
245 |
+
def _extract_bboxes_from_mmdet_result(det_result):
|
246 |
+
"""Extract Nx4 xyxy bboxes from various MMDet result formats."""
|
247 |
+
boxes = []
|
248 |
+
try:
|
249 |
+
# MMDet 3.x: list of DetDataSample
|
250 |
+
if isinstance(det_result, list) and len(det_result) > 0:
|
251 |
+
sample = det_result[0]
|
252 |
+
if hasattr(sample, 'pred_instances'):
|
253 |
+
inst = sample.pred_instances
|
254 |
+
if hasattr(inst, 'bboxes'):
|
255 |
+
b = inst.bboxes
|
256 |
+
# mmengine structures may use .tensor for boxes
|
257 |
+
if hasattr(b, 'tensor'):
|
258 |
+
b = b.tensor
|
259 |
+
boxes = b.detach().cpu().numpy().tolist()
|
260 |
+
# Single DetDataSample
|
261 |
+
elif hasattr(det_result, 'pred_instances'):
|
262 |
+
inst = det_result.pred_instances
|
263 |
+
if hasattr(inst, 'bboxes'):
|
264 |
+
b = inst.bboxes
|
265 |
+
if hasattr(b, 'tensor'):
|
266 |
+
b = b.tensor
|
267 |
+
boxes = b.detach().cpu().numpy().tolist()
|
268 |
+
# MMDet 2.x: tuple of (bbox_result, segm_result)
|
269 |
+
elif isinstance(det_result, tuple) and len(det_result) >= 1:
|
270 |
+
bbox_result = det_result[0]
|
271 |
+
# bbox_result is list per class, each Nx5 [x1,y1,x2,y2,score]
|
272 |
+
if isinstance(bbox_result, (list, tuple)):
|
273 |
+
for arr in bbox_result:
|
274 |
+
try:
|
275 |
+
arr_np = np.array(arr)
|
276 |
+
if arr_np.ndim == 2 and arr_np.shape[1] >= 4:
|
277 |
+
boxes.extend(arr_np[:, :4].tolist())
|
278 |
+
except Exception:
|
279 |
+
continue
|
280 |
+
except Exception as e:
|
281 |
+
print(f"Failed to parse MMDet result for boxes: {e}")
|
282 |
+
return boxes
|
283 |
+
|
284 |
+
|
285 |
+
def _overlay_masks_on_image(image_pil, mask_list, alpha=0.4):
|
286 |
+
"""Overlay binary masks on an image with random colors."""
|
287 |
+
if image_pil is None or not mask_list:
|
288 |
+
return image_pil
|
289 |
+
img = np.array(image_pil.convert('RGB'))
|
290 |
+
overlay = img.copy()
|
291 |
+
for idx, m in enumerate(mask_list):
|
292 |
+
if m is None or 'mask' not in m or m['mask'] is None:
|
293 |
+
continue
|
294 |
+
mask = m['mask'].astype(bool)
|
295 |
+
color = np.random.RandomState(seed=idx + 1234).randint(0, 255, size=3)
|
296 |
+
overlay[mask] = (0.5 * overlay[mask] + 0.5 * color).astype(np.uint8)
|
297 |
+
blended = (alpha * overlay + (1 - alpha) * img).astype(np.uint8)
|
298 |
+
return Image.fromarray(blended)
|
299 |
+
|
300 |
+
|
301 |
+
def _mask_to_polygons(mask: np.ndarray):
|
302 |
+
"""Convert a binary mask (H,W) to a list of polygons ([[x,y], ...]) using OpenCV contours."""
|
303 |
+
try:
|
304 |
+
mask_u8 = (mask.astype(np.uint8) * 255)
|
305 |
+
contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
306 |
+
polygons = []
|
307 |
+
for cnt in contours:
|
308 |
+
if cnt is None or len(cnt) < 3:
|
309 |
+
continue
|
310 |
+
# Simplify contour slightly
|
311 |
+
epsilon = 0.002 * cv2.arcLength(cnt, True)
|
312 |
+
approx = cv2.approxPolyDP(cnt, epsilon, True)
|
313 |
+
poly = approx.reshape(-1, 2).tolist()
|
314 |
+
polygons.append(poly)
|
315 |
+
return polygons
|
316 |
+
except Exception as e:
|
317 |
+
print(f"_mask_to_polygons failed: {e}")
|
318 |
+
return []
|
319 |
+
|
320 |
+
|
321 |
+
def _find_largest_foreground_bbox(pil_img: Image.Image):
|
322 |
+
"""Heuristic: find largest foreground region bbox via Otsu threshold on grayscale.
|
323 |
+
Returns [x1, y1, x2, y2] or full-image bbox if none found."""
|
324 |
+
try:
|
325 |
+
img = np.array(pil_img.convert('RGB'))
|
326 |
+
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
327 |
+
# Otsu threshold (invert if needed by checking mean)
|
328 |
+
_, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
329 |
+
# Assume foreground is darker; invert if threshold yields background as white majority
|
330 |
+
if th.mean() > 127:
|
331 |
+
th = 255 - th
|
332 |
+
# Morph close to connect regions
|
333 |
+
kernel = np.ones((5, 5), np.uint8)
|
334 |
+
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
|
335 |
+
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
336 |
+
if not contours:
|
337 |
+
W, H = pil_img.size
|
338 |
+
return [0, 0, W - 1, H - 1]
|
339 |
+
# Largest contour by area
|
340 |
+
cnt = max(contours, key=cv2.contourArea)
|
341 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
342 |
+
# Pad a little
|
343 |
+
pad = int(0.02 * max(w, h))
|
344 |
+
x1 = max(0, x - pad)
|
345 |
+
y1 = max(0, y - pad)
|
346 |
+
x2 = min(img.shape[1] - 1, x + w + pad)
|
347 |
+
y2 = min(img.shape[0] - 1, y + h + pad)
|
348 |
+
return [x1, y1, x2, y2]
|
349 |
+
except Exception as e:
|
350 |
+
print(f"_find_largest_foreground_bbox failed: {e}")
|
351 |
+
W, H = pil_img.size
|
352 |
+
return [0, 0, W - 1, H - 1]
|
353 |
+
|
354 |
+
|
355 |
+
def _find_topk_foreground_bboxes(pil_img: Image.Image, max_regions: int = 20, min_area: int = 100):
|
356 |
+
"""Find top-K foreground bboxes via Otsu threshold + morphology. Returns list of [x1,y1,x2,y2]."""
|
357 |
+
try:
|
358 |
+
img = np.array(pil_img.convert('RGB'))
|
359 |
+
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
360 |
+
_, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
361 |
+
if th.mean() > 127:
|
362 |
+
th = 255 - th
|
363 |
+
kernel = np.ones((3, 3), np.uint8)
|
364 |
+
th = cv2.morphologyEx(th, cv2.MORPH_OPEN, kernel, iterations=1)
|
365 |
+
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
|
366 |
+
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
367 |
+
if not contours:
|
368 |
+
return []
|
369 |
+
contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
370 |
+
bboxes = []
|
371 |
+
H, W = img.shape[:2]
|
372 |
+
for cnt in contours:
|
373 |
+
area = cv2.contourArea(cnt)
|
374 |
+
if area < min_area:
|
375 |
+
continue
|
376 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
377 |
+
# Filter very thin shapes
|
378 |
+
if w < 5 or h < 5:
|
379 |
+
continue
|
380 |
+
pad = int(0.01 * max(w, h))
|
381 |
+
x1 = max(0, x - pad)
|
382 |
+
y1 = max(0, y - pad)
|
383 |
+
x2 = min(W - 1, x + w + pad)
|
384 |
+
y2 = min(H - 1, y + h + pad)
|
385 |
+
bboxes.append([x1, y1, x2, y2])
|
386 |
+
if len(bboxes) >= max_regions:
|
387 |
+
break
|
388 |
+
return bboxes
|
389 |
+
except Exception as e:
|
390 |
+
print(f"_find_topk_foreground_bboxes failed: {e}")
|
391 |
+
return []
|
392 |
+
|
393 |
+
# Try to import mmdet for inference
|
394 |
+
try:
|
395 |
+
from mmdet.apis import init_detector, inference_detector
|
396 |
+
MM_DET_AVAILABLE = True
|
397 |
+
print("β
MMDetection available for inference")
|
398 |
+
except ImportError as e:
|
399 |
+
print(f"β οΈ MMDetection import failed: {e}")
|
400 |
+
print("π Attempting to install MMDetection dependencies...")
|
401 |
+
try:
|
402 |
+
import subprocess
|
403 |
+
import sys
|
404 |
+
|
405 |
+
# Use the working solution with mim install
|
406 |
+
print("π Installing MMDetection dependencies with mim...")
|
407 |
+
|
408 |
+
# Install openmim if not already installed
|
409 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "openmim"])
|
410 |
+
|
411 |
+
# Install mmengine
|
412 |
+
subprocess.check_call([sys.executable, "-m", "mim", "install", "mmengine"])
|
413 |
+
|
414 |
+
# Install mmcv with mim (this handles compilation properly)
|
415 |
+
subprocess.check_call([sys.executable, "-m", "mim", "install", "mmcv==2.1.0"])
|
416 |
+
|
417 |
+
# Install mmdet
|
418 |
+
subprocess.check_call([sys.executable, "-m", "mim", "install", "mmdet"])
|
419 |
+
|
420 |
+
# Try importing again
|
421 |
+
from mmdet.apis import init_detector, inference_detector
|
422 |
+
MM_DET_AVAILABLE = True
|
423 |
+
print("β
MMDetection installed and available for inference")
|
424 |
+
except Exception as install_error:
|
425 |
+
print(f"β Failed to install MMDetection: {install_error}")
|
426 |
+
MM_DET_AVAILABLE = False
|
427 |
+
|
428 |
+
# === Chart Type Classification (DocFigure) ===
|
429 |
+
print("π Loading Chart Classification Model...")
|
430 |
+
|
431 |
+
# Chart type labels from DocFigure dataset (28 classes)
|
432 |
+
CHART_TYPE_LABELS = [
|
433 |
+
'Line graph', 'Natural image', 'Table', '3D object', 'Bar plot', 'Scatter plot',
|
434 |
+
'Medical image', 'Sketch', 'Geographic map', 'Flow chart', 'Heat map', 'Mask',
|
435 |
+
'Block diagram', 'Venn diagram', 'Confusion matrix', 'Histogram', 'Box plot',
|
436 |
+
'Vector plot', 'Pie chart', 'Surface plot', 'Algorithm', 'Contour plot',
|
437 |
+
'Tree diagram', 'Bubble chart', 'Polar plot', 'Area chart', 'Pareto chart', 'Radar chart'
|
438 |
+
]
|
439 |
+
|
440 |
+
try:
|
441 |
+
# Load the chart_type.pth model file from Hugging Face Hub
|
442 |
+
from huggingface_hub import hf_hub_download
|
443 |
+
import torch
|
444 |
+
from torchvision import transforms
|
445 |
+
|
446 |
+
print("π Downloading chart_type.pth from Hugging Face Hub...")
|
447 |
+
chart_type_path = hf_hub_download(
|
448 |
+
repo_id="hanszhu/ChartTypeNet-DocFigure",
|
449 |
+
filename="chart_type.pth",
|
450 |
+
cache_dir="./models"
|
451 |
+
)
|
452 |
+
print(f"β
Downloaded to: {chart_type_path}")
|
453 |
+
|
454 |
+
# Load the PyTorch model
|
455 |
+
loaded_data = torch.load(chart_type_path, map_location='cpu')
|
456 |
+
|
457 |
+
# Check if it's a state dict or a complete model
|
458 |
+
if isinstance(loaded_data, dict):
|
459 |
+
# Check if it's a checkpoint with model_state_dict
|
460 |
+
if "model_state_dict" in loaded_data:
|
461 |
+
print("π Loading checkpoint, extracting model_state_dict...")
|
462 |
+
state_dict = loaded_data["model_state_dict"]
|
463 |
+
else:
|
464 |
+
# It's a direct state dict
|
465 |
+
print("π Loading state dict, creating model architecture...")
|
466 |
+
state_dict = loaded_data
|
467 |
+
|
468 |
+
# Strip "backbone." prefix from state dict keys if present
|
469 |
+
cleaned_state_dict = {}
|
470 |
+
for key, value in state_dict.items():
|
471 |
+
if key.startswith("backbone."):
|
472 |
+
# Remove "backbone." prefix
|
473 |
+
new_key = key[9:]
|
474 |
+
cleaned_state_dict[new_key] = value
|
475 |
+
else:
|
476 |
+
cleaned_state_dict[key] = value
|
477 |
+
|
478 |
+
print(f"π Cleaned state dict: {len(cleaned_state_dict)} keys")
|
479 |
+
|
480 |
+
# Create the model architecture
|
481 |
+
from torchvision.models import resnet50
|
482 |
+
chart_type_model = resnet50(pretrained=False)
|
483 |
+
|
484 |
+
# Create the correct classifier structure to match the state dict
|
485 |
+
import torch.nn as nn
|
486 |
+
in_features = chart_type_model.fc.in_features
|
487 |
+
dropout = nn.Dropout(0.5)
|
488 |
+
|
489 |
+
chart_type_model.fc = nn.Sequential(
|
490 |
+
nn.Linear(in_features, 512),
|
491 |
+
nn.ReLU(inplace=True),
|
492 |
+
dropout,
|
493 |
+
nn.Linear(512, 28)
|
494 |
+
)
|
495 |
+
|
496 |
+
# Load the cleaned state dict
|
497 |
+
chart_type_model.load_state_dict(cleaned_state_dict)
|
498 |
+
else:
|
499 |
+
# It's a complete model
|
500 |
+
chart_type_model = loaded_data
|
501 |
+
|
502 |
+
chart_type_model.eval()
|
503 |
+
|
504 |
+
# Create a simple processor for the model
|
505 |
+
chart_type_processor = transforms.Compose([
|
506 |
+
transforms.Resize((224, 224)),
|
507 |
+
transforms.ToTensor(),
|
508 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
509 |
+
])
|
510 |
+
|
511 |
+
CHART_TYPE_AVAILABLE = True
|
512 |
+
print("β
Chart classification model loaded")
|
513 |
+
except Exception as e:
|
514 |
+
print(f"β οΈ Failed to load chart classification model: {e}")
|
515 |
+
import traceback
|
516 |
+
print("π Full traceback:")
|
517 |
+
traceback.print_exc()
|
518 |
+
CHART_TYPE_AVAILABLE = False
|
519 |
+
|
520 |
+
# === Chart Element Detection (Cascade R-CNN) ===
|
521 |
+
element_model = None
|
522 |
+
datapoint_model = None
|
523 |
+
|
524 |
+
print(f"π MM_DET_AVAILABLE: {MM_DET_AVAILABLE}")
|
525 |
+
|
526 |
+
if MM_DET_AVAILABLE:
|
527 |
+
# Check if config files exist
|
528 |
+
element_config = "models/chart_elementnet_swin.py"
|
529 |
+
point_config = "models/chart_pointnet_swin.py"
|
530 |
+
|
531 |
+
print(f"π Checking config files...")
|
532 |
+
print(f"π Element config exists: {os.path.exists(element_config)}")
|
533 |
+
print(f"π Point config exists: {os.path.exists(point_config)}")
|
534 |
+
print(f"π Current working directory: {os.getcwd()}")
|
535 |
+
print(f"π Files in models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}")
|
536 |
+
|
537 |
+
try:
|
538 |
+
print("π Loading ChartElementNet-MultiClass (Cascade R-CNN)...")
|
539 |
+
print(f"π Config path: {element_config}")
|
540 |
+
print(f"π Weights path: hanszhu/ChartElementNet-MultiClass")
|
541 |
+
print(f"π About to call init_detector...")
|
542 |
+
|
543 |
+
# Download model from Hugging Face Hub
|
544 |
+
from huggingface_hub import hf_hub_download
|
545 |
+
print("π Downloading ChartElementNet weights from Hugging Face Hub...")
|
546 |
+
element_checkpoint = hf_hub_download(
|
547 |
+
repo_id="hanszhu/ChartElementNet-MultiClass",
|
548 |
+
filename="chart_label+.pth",
|
549 |
+
cache_dir="./models"
|
550 |
+
)
|
551 |
+
print(f"β
Downloaded to: {element_checkpoint}")
|
552 |
+
|
553 |
+
# Use local config with downloaded weights
|
554 |
+
element_model = init_detector(element_config, element_checkpoint, device="cpu")
|
555 |
+
print("β
ChartElementNet loaded successfully")
|
556 |
+
except Exception as e:
|
557 |
+
print(f"β Failed to load ChartElementNet: {e}")
|
558 |
+
print(f"π Error type: {type(e).__name__}")
|
559 |
+
print(f"π Error details: {str(e)}")
|
560 |
+
import traceback
|
561 |
+
print("π Full traceback:")
|
562 |
+
traceback.print_exc()
|
563 |
+
|
564 |
+
try:
|
565 |
+
print("π Loading ChartPointNet-InstanceSeg (Mask R-CNN)...")
|
566 |
+
print(f"π Config path: {point_config}")
|
567 |
+
print(f"π Weights path: hanszhu/ChartPointNet-InstanceSeg")
|
568 |
+
print(f"π About to call init_detector...")
|
569 |
+
|
570 |
+
# Download model from Hugging Face Hub
|
571 |
+
print("π Downloading ChartPointNet weights from Hugging Face Hub...")
|
572 |
+
datapoint_checkpoint = hf_hub_download(
|
573 |
+
repo_id="hanszhu/ChartPointNet-InstanceSeg",
|
574 |
+
filename="chart_datapoint.pth",
|
575 |
+
cache_dir="./models"
|
576 |
+
)
|
577 |
+
print(f"β
Downloaded to: {datapoint_checkpoint}")
|
578 |
+
|
579 |
+
# Use local config with downloaded weights
|
580 |
+
datapoint_model = init_detector(point_config, datapoint_checkpoint, device="cpu")
|
581 |
+
print("β
ChartPointNet loaded successfully")
|
582 |
+
except Exception as e:
|
583 |
+
print(f"β Failed to load ChartPointNet: {e}")
|
584 |
+
print(f"π Error type: {type(e).__name__}")
|
585 |
+
print(f"π Error details: {str(e)}")
|
586 |
+
import traceback
|
587 |
+
print("π Full traceback:")
|
588 |
+
traceback.print_exc()
|
589 |
+
else:
|
590 |
+
print("β MMDetection not available - cannot load custom models")
|
591 |
+
print(f"π MM_DET_AVAILABLE was False")
|
592 |
+
|
593 |
+
print(f"π Final model status:")
|
594 |
+
print(f"π element_model: {element_model is not None}")
|
595 |
+
print(f"π datapoint_model: {datapoint_model is not None}")
|
596 |
+
|
597 |
+
# === Main prediction function ===
|
598 |
+
def analyze(image):
|
599 |
+
"""
|
600 |
+
Analyze a chart image and return comprehensive results.
|
601 |
+
|
602 |
+
Args:
|
603 |
+
image: Input chart image (filepath string or PIL.Image)
|
604 |
+
|
605 |
+
Returns:
|
606 |
+
dict: Analysis results containing:
|
607 |
+
- chart_type_id (int): Numeric chart type identifier (0-27)
|
608 |
+
- chart_type_label (str): Human-readable chart type name
|
609 |
+
- element_result (str): Detected chart elements (titles, axes, legends, etc.)
|
610 |
+
- datapoint_result (str): Segmented data points and regions
|
611 |
+
- status (str): Processing status message
|
612 |
+
- processing_time (float): Time taken for analysis in seconds
|
613 |
+
"""
|
614 |
+
import time
|
615 |
+
from PIL import Image
|
616 |
+
|
617 |
+
start_time = time.time()
|
618 |
+
|
619 |
+
# Handle filepath input (convert to PIL Image)
|
620 |
+
if isinstance(image, str):
|
621 |
+
# It's a filepath, load the image
|
622 |
+
image = Image.open(image).convert("RGB")
|
623 |
+
elif image is None:
|
624 |
+
return {"error": "No image provided"}
|
625 |
+
|
626 |
+
# Ensure we have a PIL Image
|
627 |
+
if not isinstance(image, Image.Image):
|
628 |
+
return {"error": "Invalid image format"}
|
629 |
+
|
630 |
+
result = {
|
631 |
+
"chart_type_id": "Model not available",
|
632 |
+
"chart_type_label": "Model not available",
|
633 |
+
"element_result": "MMDetection models not available",
|
634 |
+
"datapoint_result": "MMDetection models not available",
|
635 |
+
"status": "Basic chart classification only",
|
636 |
+
"processing_time": 0.0,
|
637 |
+
"medsam": {"available": False}
|
638 |
+
}
|
639 |
+
|
640 |
+
# Chart Type Classification
|
641 |
+
if CHART_TYPE_AVAILABLE:
|
642 |
+
try:
|
643 |
+
# Preprocess image for PyTorch model
|
644 |
+
processed_image = chart_type_processor(image).unsqueeze(0) # Add batch dimension
|
645 |
+
|
646 |
+
# Get prediction
|
647 |
+
with torch.no_grad():
|
648 |
+
outputs = chart_type_model(processed_image)
|
649 |
+
# Handle different output formats
|
650 |
+
if isinstance(outputs, torch.Tensor):
|
651 |
+
logits = outputs
|
652 |
+
elif hasattr(outputs, 'logits'):
|
653 |
+
logits = outputs.logits
|
654 |
+
else:
|
655 |
+
logits = outputs
|
656 |
+
|
657 |
+
predicted_class = logits.argmax(dim=-1).item()
|
658 |
+
|
659 |
+
result["chart_type_id"] = predicted_class
|
660 |
+
result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})"
|
661 |
+
result["status"] = "Chart classification completed"
|
662 |
+
|
663 |
+
except Exception as e:
|
664 |
+
result["chart_type_id"] = f"Error: {str(e)}"
|
665 |
+
result["chart_type_label"] = f"Error: {str(e)}"
|
666 |
+
result["status"] = "Error in chart classification"
|
667 |
+
|
668 |
+
# Chart Element Detection (Cascade R-CNN)
|
669 |
+
if element_model is not None:
|
670 |
+
try:
|
671 |
+
# If medical image, skip heavy MMDet to speed up
|
672 |
+
if isinstance(result.get("chart_type_label"), str) and result["chart_type_label"].lower() == "medical image":
|
673 |
+
result["element_result"] = "skipped_for_medical"
|
674 |
+
else:
|
675 |
+
# Convert PIL image to numpy array for MMDetection
|
676 |
+
np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL β BGR
|
677 |
+
|
678 |
+
element_result = inference_detector(element_model, np_img)
|
679 |
+
|
680 |
+
# Convert result to more API-friendly format
|
681 |
+
if isinstance(element_result, tuple):
|
682 |
+
bbox_result, segm_result = element_result
|
683 |
+
element_data = {
|
684 |
+
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
|
685 |
+
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
|
686 |
+
}
|
687 |
+
else:
|
688 |
+
element_data = str(element_result)
|
689 |
+
|
690 |
+
result["element_result"] = element_data
|
691 |
+
result["status"] = "Chart classification + element detection completed"
|
692 |
+
except Exception as e:
|
693 |
+
result["element_result"] = f"Error: {str(e)}"
|
694 |
+
|
695 |
+
# Chart Data Point Segmentation (Mask R-CNN)
|
696 |
+
if datapoint_model is not None:
|
697 |
+
try:
|
698 |
+
# If medical image, skip heavy MMDet to speed up
|
699 |
+
if isinstance(result.get("chart_type_label"), str) and result["chart_type_label"].lower() == "medical image":
|
700 |
+
result["datapoint_result"] = "skipped_for_medical"
|
701 |
+
else:
|
702 |
+
# Convert PIL image to numpy array for MMDetection
|
703 |
+
np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL β BGR
|
704 |
+
|
705 |
+
datapoint_result = inference_detector(datapoint_model, np_img)
|
706 |
+
|
707 |
+
# Convert result to more API-friendly format
|
708 |
+
if isinstance(datapoint_result, tuple):
|
709 |
+
bbox_result, segm_result = datapoint_result
|
710 |
+
datapoint_data = {
|
711 |
+
"bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
|
712 |
+
"segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
|
713 |
+
}
|
714 |
+
else:
|
715 |
+
datapoint_data = str(datapoint_result)
|
716 |
+
|
717 |
+
result["datapoint_result"] = datapoint_data
|
718 |
+
result["status"] = "Full analysis completed"
|
719 |
+
except Exception as e:
|
720 |
+
result["datapoint_result"] = f"Error: {str(e)}"
|
721 |
+
|
722 |
+
# If predicted as medical image and MedSAM is available, include mask data (polygons)
|
723 |
+
try:
|
724 |
+
label_lower = str(result.get("chart_type_label", "")).strip().lower()
|
725 |
+
if label_lower == "medical image":
|
726 |
+
if _medsam.is_available():
|
727 |
+
# Do not run heuristics here. Prompts are required and handled in the UI then-chain.
|
728 |
+
# Indicate availability and that prompts are needed for segmentation.
|
729 |
+
result["medsam"] = {"available": True, "reason": "provide bbox/points prompts to generate segmentations"}
|
730 |
+
else:
|
731 |
+
# Not available; include reason
|
732 |
+
result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"}
|
733 |
+
except Exception as e:
|
734 |
+
print(f"MedSAM JSON augmentation failed: {e}")
|
735 |
+
|
736 |
+
result["processing_time"] = round(time.time() - start_time, 3)
|
737 |
+
return result
|
738 |
+
|
739 |
+
|
740 |
+
def analyze_with_medsam(base_result, image):
|
741 |
+
"""Auto-generate segmentations for medical images using SAM ViT-H if available,
|
742 |
+
otherwise fallback to MedSAM over top-K foreground boxes. Returns updated JSON and overlay image."""
|
743 |
+
try:
|
744 |
+
if not isinstance(base_result, dict):
|
745 |
+
return base_result, None
|
746 |
+
label = str(base_result.get("chart_type_label", "")).strip().lower()
|
747 |
+
if label != "medical image":
|
748 |
+
return base_result, None
|
749 |
+
|
750 |
+
pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
|
751 |
+
if pil_img is None:
|
752 |
+
return base_result, None
|
753 |
+
|
754 |
+
segmentations = []
|
755 |
+
masks_for_overlay = []
|
756 |
+
|
757 |
+
# Try fast SAM generator first; avoid MedSAM embedding when SAM is available
|
758 |
+
gen = _get_sam_generator()
|
759 |
+
if gen is not None and _sam_auto_ckpt_path is not None and os.path.exists(_sam_auto_ckpt_path):
|
760 |
+
try:
|
761 |
+
import cv2 as _cv2
|
762 |
+
img_path = image if isinstance(image, str) else None
|
763 |
+
if img_path is None:
|
764 |
+
tmp_path = "./_tmp_input_image.png"
|
765 |
+
pil_img.save(tmp_path)
|
766 |
+
img_path = tmp_path
|
767 |
+
img_bgr = _cv2.imread(img_path)
|
768 |
+
masks = gen.generate(img_bgr)
|
769 |
+
# Keep top-K by stability_score or area
|
770 |
+
def _score(m):
|
771 |
+
s = float(m.get('stability_score', 0.0))
|
772 |
+
seg = m.get('segmentation', None)
|
773 |
+
area = int(seg.sum()) if isinstance(seg, np.ndarray) else 0
|
774 |
+
return (s, area)
|
775 |
+
masks = sorted(masks, key=_score, reverse=True)[:8]
|
776 |
+
for m in masks:
|
777 |
+
seg = m.get('segmentation', None)
|
778 |
+
if seg is None:
|
779 |
+
continue
|
780 |
+
seg_u8 = seg.astype(np.uint8)
|
781 |
+
segmentations.append({
|
782 |
+
"mask": seg_u8.tolist(),
|
783 |
+
"confidence": float(m.get('stability_score', 1.0)),
|
784 |
+
"method": "sam_auto"
|
785 |
+
})
|
786 |
+
masks_for_overlay.append({"mask": seg_u8})
|
787 |
+
except Exception as e:
|
788 |
+
print(f"SAM generator segmentation failed: {e}")
|
789 |
+
|
790 |
+
# Fallback to MedSAM boxes only if nothing produced
|
791 |
+
if not segmentations and _medsam.is_available():
|
792 |
+
try:
|
793 |
+
# Prepare embedding once
|
794 |
+
img_path = image if isinstance(image, str) else None
|
795 |
+
if img_path is None:
|
796 |
+
tmp_path = "./_tmp_input_image.png"
|
797 |
+
pil_img.save(tmp_path)
|
798 |
+
img_path = tmp_path
|
799 |
+
_medsam.load_image(img_path)
|
800 |
+
cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=5, min_area=400)
|
801 |
+
for bbox in cand_bboxes:
|
802 |
+
m = _medsam.segment_with_box(bbox)
|
803 |
+
if m is None or not isinstance(m.get('mask'), np.ndarray):
|
804 |
+
continue
|
805 |
+
segmentations.append({
|
806 |
+
"mask": m['mask'].astype(np.uint8).tolist(),
|
807 |
+
"confidence": float(m.get('confidence', 1.0)),
|
808 |
+
"method": m.get("method", "medsam_box_auto")
|
809 |
+
})
|
810 |
+
masks_for_overlay.append(m)
|
811 |
+
except Exception as auto_e:
|
812 |
+
print(f"MedSAM fallback segmentation failed: {auto_e}")
|
813 |
+
|
814 |
+
W, H = pil_img.size
|
815 |
+
base_result["medsam"] = {
|
816 |
+
"available": True,
|
817 |
+
"height": H,
|
818 |
+
"width": W,
|
819 |
+
"segmentations": segmentations,
|
820 |
+
"num_segments": len(segmentations)
|
821 |
+
}
|
822 |
+
|
823 |
+
overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None
|
824 |
+
return base_result, overlay_img
|
825 |
+
except Exception as e:
|
826 |
+
print(f"analyze_with_medsam failed: {e}")
|
827 |
+
return base_result, None
|
828 |
+
|
829 |
+
# === Gradio UI with API enhancements ===
|
830 |
+
# Create Blocks interface with explicit API name for stable API surface
|
831 |
+
with gr.Blocks(
|
832 |
+
title="π Dense Captioning Platform"
|
833 |
+
) as demo:
|
834 |
+
|
835 |
+
gr.Markdown("# π Dense Captioning Platform")
|
836 |
+
gr.Markdown("""
|
837 |
+
**Comprehensive Chart Analysis API**
|
838 |
+
|
839 |
+
Upload a chart image to get:
|
840 |
+
- **Chart Type Classification**: Identifies the type of chart (line, bar, scatter, etc.)
|
841 |
+
- **Element Detection**: Detects chart elements like titles, axes, legends, data points
|
842 |
+
- **Data Point Segmentation**: Segments individual data points and regions
|
843 |
+
|
844 |
+
Masks will be automatically generated for medical images when supported.
|
845 |
+
|
846 |
+
**API Usage:**
|
847 |
+
```python
|
848 |
+
from gradio_client import Client, handle_file
|
849 |
+
|
850 |
+
client = Client("hanszhu/Dense-Captioning-Platform")
|
851 |
+
result = client.predict(
|
852 |
+
image=handle_file('path/to/your/chart.png'),
|
853 |
+
api_name="/predict"
|
854 |
+
)
|
855 |
+
print(result)
|
856 |
+
```
|
857 |
+
|
858 |
+
**Supported Chart Types:** Line graphs, Bar plots, Scatter plots, Pie charts, Heat maps, and 23+ more
|
859 |
+
""")
|
860 |
+
|
861 |
+
with gr.Row():
|
862 |
+
with gr.Column():
|
863 |
+
# Input
|
864 |
+
image_input = gr.Image(
|
865 |
+
type="filepath", # β
REQUIRED for gradio_client
|
866 |
+
label="Upload Chart Image",
|
867 |
+
height=400
|
868 |
+
)
|
869 |
+
|
870 |
+
# Analyze button (single)
|
871 |
+
analyze_btn = gr.Button(
|
872 |
+
"π Analyze",
|
873 |
+
variant="primary",
|
874 |
+
size="lg"
|
875 |
+
)
|
876 |
+
|
877 |
+
with gr.Column():
|
878 |
+
# Output JSON
|
879 |
+
result_output = gr.JSON(
|
880 |
+
label="Analysis Results",
|
881 |
+
height=400
|
882 |
+
)
|
883 |
+
# Overlay image output (populated only for medical images)
|
884 |
+
overlay_output = gr.Image(
|
885 |
+
label="MedSAM Overlay (Medical images)",
|
886 |
+
height=400
|
887 |
+
)
|
888 |
+
|
889 |
+
# Single API endpoint for JSON
|
890 |
+
analyze_event = analyze_btn.click(
|
891 |
+
fn=analyze,
|
892 |
+
inputs=image_input,
|
893 |
+
outputs=result_output,
|
894 |
+
api_name="/predict" # β
Standard API name that gradio_client expects
|
895 |
+
)
|
896 |
+
|
897 |
+
# Automatic overlay generation step for medical images
|
898 |
+
analyze_event.then(
|
899 |
+
fn=analyze_with_medsam,
|
900 |
+
inputs=[result_output, image_input],
|
901 |
+
outputs=[result_output, overlay_output],
|
902 |
+
)
|
903 |
+
|
904 |
+
# Add some examples
|
905 |
+
gr.Examples(
|
906 |
+
examples=[
|
907 |
+
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"]
|
908 |
+
],
|
909 |
+
inputs=image_input,
|
910 |
+
label="Try with this example"
|
911 |
+
)
|
912 |
+
|
913 |
+
# Launch with API-friendly settings
|
914 |
+
if __name__ == "__main__":
|
915 |
+
launch_kwargs = {
|
916 |
+
"server_name": "0.0.0.0", # Allow external connections
|
917 |
+
"server_port": 7860,
|
918 |
+
"share": False, # Set to True if you want a public link
|
919 |
+
"show_error": True, # Show detailed errors for debugging
|
920 |
+
"quiet": False, # Show startup messages
|
921 |
+
"show_api": True # Enable API documentation
|
922 |
+
}
|
923 |
+
|
924 |
+
# Enable queue for gradio_client compatibility
|
925 |
+
demo.queue().launch(**launch_kwargs) # β
required for gradio_client to work
|
926 |
+
|
custom_models/custom_cascade_with_meta.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmdet.models.detectors import CascadeRCNN
|
2 |
+
from mmdet.registry import MODELS
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
@MODELS.register_module()
|
7 |
+
class CustomCascadeWithMeta(CascadeRCNN):
|
8 |
+
"""Custom Cascade R-CNN with metadata prediction heads."""
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
*args,
|
12 |
+
chart_cls_head=None,
|
13 |
+
plot_reg_head=None,
|
14 |
+
axes_info_head=None,
|
15 |
+
data_series_head=None,
|
16 |
+
data_points_count_head=None,
|
17 |
+
coordinate_standardization=None,
|
18 |
+
data_series_config=None,
|
19 |
+
axis_aware_feature=None,
|
20 |
+
**kwargs):
|
21 |
+
super().__init__(*args, **kwargs)
|
22 |
+
|
23 |
+
# Initialize metadata prediction heads
|
24 |
+
if chart_cls_head is not None:
|
25 |
+
self.chart_cls_head = MODELS.build(chart_cls_head)
|
26 |
+
if plot_reg_head is not None:
|
27 |
+
self.plot_reg_head = MODELS.build(plot_reg_head)
|
28 |
+
if axes_info_head is not None:
|
29 |
+
self.axes_info_head = MODELS.build(axes_info_head)
|
30 |
+
if data_series_head is not None:
|
31 |
+
self.data_series_head = MODELS.build(data_series_head)
|
32 |
+
if data_points_count_head is not None:
|
33 |
+
self.data_points_count_head = MODELS.build(data_points_count_head)
|
34 |
+
else:
|
35 |
+
# Default simple regression head for data point count
|
36 |
+
self.data_points_count_head = nn.Sequential(
|
37 |
+
nn.Linear(2048, 512), # Assuming ResNet-50 backbone features
|
38 |
+
nn.ReLU(),
|
39 |
+
nn.Dropout(0.1),
|
40 |
+
nn.Linear(512, 1) # Single output for count
|
41 |
+
)
|
42 |
+
|
43 |
+
# Store configurations
|
44 |
+
self.coordinate_standardization = coordinate_standardization
|
45 |
+
self.data_series_config = data_series_config
|
46 |
+
self.axis_aware_feature = axis_aware_feature
|
47 |
+
|
48 |
+
def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
|
49 |
+
"""Forward function during training."""
|
50 |
+
# Get base detector predictions
|
51 |
+
x = self.extract_feat(img)
|
52 |
+
losses = dict()
|
53 |
+
|
54 |
+
# RPN forward and loss
|
55 |
+
if self.with_rpn:
|
56 |
+
proposal_cfg = self.train_cfg.get('rpn_proposal',
|
57 |
+
self.test_cfg.rpn)
|
58 |
+
rpn_losses, proposal_list = self.rpn_head.forward_train(
|
59 |
+
x,
|
60 |
+
img_metas,
|
61 |
+
gt_bboxes,
|
62 |
+
gt_labels=None,
|
63 |
+
ann_weight=None,
|
64 |
+
proposal_cfg=proposal_cfg)
|
65 |
+
losses.update(rpn_losses)
|
66 |
+
else:
|
67 |
+
proposal_list = kwargs.get('proposals', None)
|
68 |
+
|
69 |
+
# ROI forward and loss
|
70 |
+
roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
|
71 |
+
gt_bboxes, gt_labels, **kwargs)
|
72 |
+
losses.update(roi_losses)
|
73 |
+
|
74 |
+
# Get global features for metadata prediction
|
75 |
+
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
|
76 |
+
|
77 |
+
# Extract ground truth data point counts from img_metas
|
78 |
+
gt_data_point_counts = []
|
79 |
+
for img_meta in img_metas:
|
80 |
+
count = img_meta.get('img_info', {}).get('num_data_points', 0)
|
81 |
+
gt_data_point_counts.append(count)
|
82 |
+
gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
|
83 |
+
|
84 |
+
# Predict data point counts and compute loss
|
85 |
+
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
|
86 |
+
data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
|
87 |
+
losses['data_points_count_loss'] = data_points_count_loss
|
88 |
+
|
89 |
+
# Use predicted data point count as additional feature for ROI head
|
90 |
+
# Expand the global feature with data point count information
|
91 |
+
normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range
|
92 |
+
enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
|
93 |
+
|
94 |
+
# Metadata prediction losses
|
95 |
+
if hasattr(self, 'chart_cls_head'):
|
96 |
+
chart_cls_loss = self.chart_cls_head(enhanced_global_feat)
|
97 |
+
losses['chart_cls_loss'] = chart_cls_loss
|
98 |
+
|
99 |
+
if hasattr(self, 'plot_reg_head'):
|
100 |
+
plot_reg_loss = self.plot_reg_head(enhanced_global_feat)
|
101 |
+
losses['plot_reg_loss'] = plot_reg_loss
|
102 |
+
|
103 |
+
if hasattr(self, 'axes_info_head'):
|
104 |
+
axes_info_loss = self.axes_info_head(enhanced_global_feat)
|
105 |
+
losses['axes_info_loss'] = axes_info_loss
|
106 |
+
|
107 |
+
if hasattr(self, 'data_series_head'):
|
108 |
+
data_series_loss = self.data_series_head(enhanced_global_feat)
|
109 |
+
losses['data_series_loss'] = data_series_loss
|
110 |
+
|
111 |
+
return losses
|
112 |
+
|
113 |
+
def simple_test(self, img, img_metas, **kwargs):
|
114 |
+
"""Test without augmentation."""
|
115 |
+
x = self.extract_feat(img)
|
116 |
+
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
|
117 |
+
det_bboxes, det_labels = self.roi_head.simple_test_bboxes(
|
118 |
+
x, img_metas, proposal_list, self.test_cfg.rcnn, **kwargs)
|
119 |
+
|
120 |
+
# Get global features for metadata prediction
|
121 |
+
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
|
122 |
+
|
123 |
+
# Predict data point counts
|
124 |
+
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
|
125 |
+
|
126 |
+
# Use predicted data point count as additional feature
|
127 |
+
normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range
|
128 |
+
enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
|
129 |
+
|
130 |
+
# Get metadata predictions
|
131 |
+
results = []
|
132 |
+
for i, (bboxes, labels) in enumerate(zip(det_bboxes, det_labels)):
|
133 |
+
result = DetDataSample()
|
134 |
+
result.bboxes = bboxes
|
135 |
+
result.labels = labels
|
136 |
+
|
137 |
+
# Add data point count prediction
|
138 |
+
result.predicted_data_points = pred_data_point_counts[i].item()
|
139 |
+
|
140 |
+
# Add metadata predictions using enhanced features
|
141 |
+
if hasattr(self, 'chart_cls_head'):
|
142 |
+
result.chart_type = self.chart_cls_head(enhanced_global_feat[i:i+1])
|
143 |
+
if hasattr(self, 'plot_reg_head'):
|
144 |
+
result.plot_bb = self.plot_reg_head(enhanced_global_feat[i:i+1])
|
145 |
+
if hasattr(self, 'axes_info_head'):
|
146 |
+
result.axes_info = self.axes_info_head(enhanced_global_feat[i:i+1])
|
147 |
+
if hasattr(self, 'data_series_head'):
|
148 |
+
result.data_series = self.data_series_head(enhanced_global_feat[i:i+1])
|
149 |
+
|
150 |
+
results.append(result)
|
151 |
+
|
152 |
+
return results
|
custom_models/custom_dataset.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os.path as osp
|
3 |
+
import numpy as np
|
4 |
+
from mmcv.transforms import BaseTransform
|
5 |
+
from mmcv.transforms import LoadImageFromFile
|
6 |
+
from mmdet.registry import DATASETS, TRANSFORMS
|
7 |
+
from mmdet.datasets.transforms import PackDetInputs
|
8 |
+
from mmdet.datasets.base_det_dataset import BaseDetDataset
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
# βββ Enhanced robust image loader for real images βββ
|
12 |
+
@TRANSFORMS.register_module()
|
13 |
+
class RobustLoadImageFromFile(LoadImageFromFile):
|
14 |
+
"""Enhanced image loader: tries real images first, falls back to dummy if needed."""
|
15 |
+
|
16 |
+
# Class variable to track missing images
|
17 |
+
missing_count = 0
|
18 |
+
|
19 |
+
def __init__(self, try_real_images=True, fallback_to_dummy=True, **kwargs):
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
self.try_real_images = try_real_images
|
22 |
+
self.fallback_to_dummy = fallback_to_dummy
|
23 |
+
|
24 |
+
def transform(self, results):
|
25 |
+
"""Try to load real image first, fall back to dummy if not found."""
|
26 |
+
if self.try_real_images:
|
27 |
+
try:
|
28 |
+
# Try standard MMDet image loading first
|
29 |
+
results = super().transform(results)
|
30 |
+
return results
|
31 |
+
|
32 |
+
except (FileNotFoundError, OSError, Exception) as e:
|
33 |
+
# Count missing image
|
34 |
+
RobustLoadImageFromFile.missing_count += 1
|
35 |
+
|
36 |
+
# Log warning every 10 missing images to avoid spam
|
37 |
+
if RobustLoadImageFromFile.missing_count % 10 == 1:
|
38 |
+
warnings.warn(f"Missing image #{RobustLoadImageFromFile.missing_count}: {results.get('img_path', 'unknown')}. "
|
39 |
+
f"Total missing so far: {RobustLoadImageFromFile.missing_count}",
|
40 |
+
UserWarning)
|
41 |
+
|
42 |
+
if not self.fallback_to_dummy:
|
43 |
+
raise e
|
44 |
+
# Fall through to create dummy image
|
45 |
+
|
46 |
+
# Create dummy image (either by choice or because real image loading failed)
|
47 |
+
if 'img_shape' in results:
|
48 |
+
h, w = results['img_shape'][:2]
|
49 |
+
else:
|
50 |
+
h = results.get('height', 800)
|
51 |
+
w = results.get('width', 600)
|
52 |
+
|
53 |
+
results['img'] = np.zeros((h, w, 3), dtype=np.uint8)
|
54 |
+
results['img_shape'] = (h, w, 3)
|
55 |
+
results['ori_shape'] = (h, w, 3)
|
56 |
+
return results
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def get_missing_count(cls):
|
60 |
+
"""Get the total count of missing images."""
|
61 |
+
return cls.missing_count
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def reset_missing_count(cls):
|
65 |
+
"""Reset the missing image counter."""
|
66 |
+
cls.missing_count = 0
|
67 |
+
|
68 |
+
# βββ Legacy support for old transform name βββ
|
69 |
+
@TRANSFORMS.register_module()
|
70 |
+
class CreateDummyImg(RobustLoadImageFromFile):
|
71 |
+
"""Legacy alias for RobustLoadImageFromFile."""
|
72 |
+
pass
|
73 |
+
|
74 |
+
@TRANSFORMS.register_module()
|
75 |
+
class ClampBBoxes(BaseTransform):
|
76 |
+
"""Simple bbox clamping transform - only clamps coordinates, doesn't filter."""
|
77 |
+
def __init__(self, min_size=1):
|
78 |
+
self.min_size = min_size
|
79 |
+
|
80 |
+
def transform(self, results):
|
81 |
+
"""Clamp bboxes to image bounds without removing any boxes."""
|
82 |
+
if 'gt_bboxes' not in results:
|
83 |
+
return results
|
84 |
+
|
85 |
+
h, w = results['img_shape'][:2]
|
86 |
+
|
87 |
+
# Handle both numpy arrays and MMDet's HorizontalBoxes objects
|
88 |
+
gt_bboxes = results['gt_bboxes']
|
89 |
+
if hasattr(gt_bboxes, 'tensor'):
|
90 |
+
# MMDet HorizontalBoxes object - clamp in place
|
91 |
+
gt_bboxes.tensor[:, 0].clamp_(0, w) # x1
|
92 |
+
gt_bboxes.tensor[:, 1].clamp_(0, h) # y1
|
93 |
+
gt_bboxes.tensor[:, 2].clamp_(0, w) # x2
|
94 |
+
gt_bboxes.tensor[:, 3].clamp_(0, h) # y2
|
95 |
+
else:
|
96 |
+
# Regular numpy array - clamp in place
|
97 |
+
if len(gt_bboxes) > 0:
|
98 |
+
gt_bboxes[:, 0] = np.clip(gt_bboxes[:, 0], 0, w) # x1
|
99 |
+
gt_bboxes[:, 1] = np.clip(gt_bboxes[:, 1], 0, h) # y1
|
100 |
+
gt_bboxes[:, 2] = np.clip(gt_bboxes[:, 2], 0, w) # x2
|
101 |
+
gt_bboxes[:, 3] = np.clip(gt_bboxes[:, 3], 0, h) # y2
|
102 |
+
|
103 |
+
# Don't drop anything here - let filter_cfg handle empty GT filtering
|
104 |
+
results['gt_bboxes'] = gt_bboxes
|
105 |
+
return results
|
106 |
+
|
107 |
+
@TRANSFORMS.register_module()
|
108 |
+
class SetScaleFactor(BaseTransform):
|
109 |
+
"""Compute scale_factor from data_series & plot_bb before any Resize."""
|
110 |
+
def __init__(self, default_scale=(1.0, 1.0)):
|
111 |
+
self.default_scale = default_scale
|
112 |
+
|
113 |
+
def calculate_scale_factor(self, results):
|
114 |
+
bb = results.get('plot_bb', {})
|
115 |
+
w, h = bb.get('width', 0), bb.get('height', 0)
|
116 |
+
xs, ys = [], []
|
117 |
+
for series in results.get('data_series', []):
|
118 |
+
for pt in series.get('data', []):
|
119 |
+
x, y = pt.get('x'), pt.get('y')
|
120 |
+
if isinstance(x, (int, float)): xs.append(x)
|
121 |
+
if isinstance(y, (int, float)): ys.append(y)
|
122 |
+
if xs and max(xs) != min(xs):
|
123 |
+
x_scale = w / (max(xs) - min(xs))
|
124 |
+
else:
|
125 |
+
x_scale = self.default_scale[0]
|
126 |
+
if ys and max(ys) != min(ys):
|
127 |
+
y_scale = -h / (max(ys) - min(ys))
|
128 |
+
else:
|
129 |
+
y_scale = self.default_scale[1]
|
130 |
+
return (x_scale, y_scale)
|
131 |
+
|
132 |
+
def transform(self, results):
|
133 |
+
try:
|
134 |
+
sf = self.calculate_scale_factor(results)
|
135 |
+
results['scale_factor'] = np.array(sf, dtype=np.float32)
|
136 |
+
except Exception:
|
137 |
+
results['scale_factor'] = np.array(self.default_scale, dtype=np.float32)
|
138 |
+
H, W = results.get('height', 0), results.get('width', 0)
|
139 |
+
results['img_shape'] = (H, W, 3)
|
140 |
+
return results
|
141 |
+
|
142 |
+
@TRANSFORMS.register_module()
|
143 |
+
class EnsureScaleFactor(BaseTransform):
|
144 |
+
"""Fallback if no scale_factor set yet."""
|
145 |
+
def transform(self, results):
|
146 |
+
results['scale_factor'] = np.array([1.0, 1.0], dtype=np.float32)
|
147 |
+
return results
|
148 |
+
|
149 |
+
@TRANSFORMS.register_module()
|
150 |
+
class SetInputs(BaseTransform):
|
151 |
+
"""Copy dummy img into inputs for DetDataPreprocessor."""
|
152 |
+
def transform(self, results):
|
153 |
+
if 'img' in results:
|
154 |
+
results['inputs'] = results['img'].copy()
|
155 |
+
return results
|
156 |
+
|
157 |
+
@TRANSFORMS.register_module()
|
158 |
+
class CustomPackDetInputs(PackDetInputs):
|
159 |
+
"""Final packing into DetDataSample, ensure inputs present."""
|
160 |
+
def transform(self, results):
|
161 |
+
if 'img' in results:
|
162 |
+
results['inputs'] = results['img'].copy()
|
163 |
+
return super().transform(results)
|
164 |
+
|
165 |
+
@DATASETS.register_module()
|
166 |
+
class ChartDataset(BaseDetDataset):
|
167 |
+
"""Enhanced dataset for comprehensive chart element detection and analysis."""
|
168 |
+
|
169 |
+
# Updated METAINFO with 21 enhanced categories
|
170 |
+
METAINFO = {
|
171 |
+
'classes': [
|
172 |
+
'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label',
|
173 |
+
'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item',
|
174 |
+
'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line',
|
175 |
+
'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area'
|
176 |
+
]
|
177 |
+
}
|
178 |
+
|
179 |
+
# Chart-type specific element filtering based on actual dataset distribution
|
180 |
+
# Data from analyze_chart_types.py:
|
181 |
+
# β’ line (41.9%): 1710 images β data-line only
|
182 |
+
# β’ scatter (18.2%): 742 images β data-point only
|
183 |
+
# β’ vertical_bar (30.5%): 1246 images β data-bar only
|
184 |
+
# β’ dot (9.2%): 374 images β data-point only
|
185 |
+
# β’ horizontal_bar (0.2%): 9 images β data-bar only
|
186 |
+
CHART_TYPE_ELEMENT_MAPPING = {
|
187 |
+
# Line charts (41.9% - 1710 images): ONLY data-line
|
188 |
+
'line': {
|
189 |
+
'allowed_data_elements': {'data-line'},
|
190 |
+
'forbidden_data_elements': {'data-point', 'data-bar', 'data-area'}
|
191 |
+
},
|
192 |
+
# Scatter charts (18.2% - 742 images): ONLY data-point
|
193 |
+
'scatter': {
|
194 |
+
'allowed_data_elements': {'data-point'},
|
195 |
+
'forbidden_data_elements': {'data-line', 'data-bar', 'data-area'}
|
196 |
+
},
|
197 |
+
# Vertical bar charts (30.5% - 1246 images): ONLY data-bar
|
198 |
+
'vertical_bar': {
|
199 |
+
'allowed_data_elements': {'data-bar'},
|
200 |
+
'forbidden_data_elements': {'data-point', 'data-line', 'data-area'}
|
201 |
+
},
|
202 |
+
# Dot charts (9.2% - 374 images): ONLY data-point
|
203 |
+
'dot': {
|
204 |
+
'allowed_data_elements': {'data-point'},
|
205 |
+
'forbidden_data_elements': {'data-line', 'data-bar', 'data-area'}
|
206 |
+
},
|
207 |
+
# Horizontal bar charts (0.2% - 9 images): ONLY data-bar
|
208 |
+
'horizontal_bar': {
|
209 |
+
'allowed_data_elements': {'data-bar'},
|
210 |
+
'forbidden_data_elements': {'data-point', 'data-line', 'data-area'}
|
211 |
+
}
|
212 |
+
}
|
213 |
+
|
214 |
+
def __init__(self, *args, **kwargs):
|
215 |
+
super().__init__(*args, **kwargs)
|
216 |
+
self.metainfo.update(self.METAINFO)
|
217 |
+
|
218 |
+
# Print configuration info
|
219 |
+
print(f"π ChartDataset initialized with {len(self.METAINFO['classes'])} categories:")
|
220 |
+
for i, cls_name in enumerate(self.METAINFO['classes']):
|
221 |
+
print(f" {i}: {cls_name}")
|
222 |
+
|
223 |
+
# Print chart-type filtering info
|
224 |
+
print(f"π― Chart-type specific filtering enabled:")
|
225 |
+
for chart_type, mapping in self.CHART_TYPE_ELEMENT_MAPPING.items():
|
226 |
+
allowed = mapping.get('allowed_data_elements', set())
|
227 |
+
forbidden = mapping.get('forbidden_data_elements', set())
|
228 |
+
print(f" β’ {chart_type}: β
{allowed} | π« {forbidden}")
|
229 |
+
|
230 |
+
# Debug print the data configuration
|
231 |
+
print(f"π Dataset configuration:")
|
232 |
+
print(f" β’ data_root: {getattr(self, 'data_root', 'None')}")
|
233 |
+
print(f" β’ data_prefix: {getattr(self, 'data_prefix', 'None')}")
|
234 |
+
print(f" β’ ann_file: {getattr(self, 'ann_file', 'None')}")
|
235 |
+
|
236 |
+
def load_data_list(self):
|
237 |
+
"""Load enhanced annotation files with priority order."""
|
238 |
+
|
239 |
+
# Auto-detect best annotation file (same logic as config)
|
240 |
+
def get_best_ann_file(split):
|
241 |
+
ann_dir = osp.join(self.data_root, 'annotations_JSON')
|
242 |
+
|
243 |
+
# Priority order with flexible naming
|
244 |
+
candidates = [
|
245 |
+
f'{split}_enriched_with_info.json',
|
246 |
+
f'{split}_enriched.json',
|
247 |
+
f'{split}_with_info.json', # Added: Handles val_with_info.json
|
248 |
+
f'{split}.json',
|
249 |
+
f'{split}_cleaned.json'
|
250 |
+
]
|
251 |
+
|
252 |
+
for candidate in candidates:
|
253 |
+
full_path = osp.join(ann_dir, candidate)
|
254 |
+
if osp.exists(full_path):
|
255 |
+
print(f"π ChartDataset using {candidate}")
|
256 |
+
return full_path
|
257 |
+
|
258 |
+
# Fallback to ann_file if specified
|
259 |
+
if hasattr(self, 'ann_file') and self.ann_file:
|
260 |
+
fallback_path = osp.join(self.data_root, self.ann_file)
|
261 |
+
if osp.exists(fallback_path):
|
262 |
+
print(f"π Using fallback annotation file: {self.ann_file}")
|
263 |
+
return fallback_path
|
264 |
+
|
265 |
+
raise FileNotFoundError(f"No annotation files found in {ann_dir}")
|
266 |
+
|
267 |
+
# Determine file path
|
268 |
+
if hasattr(self, 'ann_file') and self.ann_file:
|
269 |
+
ann_file_path = osp.join(self.data_root, self.ann_file)
|
270 |
+
else:
|
271 |
+
# Try to auto-detect based on common patterns
|
272 |
+
for split in ['train', 'val']:
|
273 |
+
try:
|
274 |
+
ann_file_path = get_best_ann_file(split)
|
275 |
+
break
|
276 |
+
except FileNotFoundError:
|
277 |
+
continue
|
278 |
+
else:
|
279 |
+
raise FileNotFoundError("Could not find any annotation files")
|
280 |
+
|
281 |
+
# Load annotation file
|
282 |
+
with open(ann_file_path, 'r') as f:
|
283 |
+
ann = json.load(f)
|
284 |
+
|
285 |
+
print(f"π Loading from {ann_file_path}")
|
286 |
+
print(f" β’ Images: {len(ann.get('images', []))}")
|
287 |
+
print(f" β’ Annotations: {len(ann.get('annotations', []))}")
|
288 |
+
|
289 |
+
# Build image lookup
|
290 |
+
img_id_to_info = {img['id']: img for img in ann['images']}
|
291 |
+
|
292 |
+
# Group annotations by image
|
293 |
+
img_id_to_anns = {}
|
294 |
+
for ann_data in ann.get('annotations', []):
|
295 |
+
img_id = ann_data['image_id']
|
296 |
+
if img_id not in img_id_to_anns:
|
297 |
+
img_id_to_anns[img_id] = []
|
298 |
+
img_id_to_anns[img_id].append(ann_data)
|
299 |
+
|
300 |
+
# Create data list with enhanced metadata
|
301 |
+
data_list = []
|
302 |
+
for img_id, img_info in img_id_to_info.items():
|
303 |
+
annotations = img_id_to_anns.get(img_id, [])
|
304 |
+
|
305 |
+
# Skip images without annotations if filter_empty_gt is enabled
|
306 |
+
if not annotations and self.filter_cfg.get('filter_empty_gt', False):
|
307 |
+
continue
|
308 |
+
|
309 |
+
# Convert annotations to instances format
|
310 |
+
instances = []
|
311 |
+
for ann in annotations:
|
312 |
+
bbox = ann['bbox'] # [x, y, width, height]
|
313 |
+
# Convert to [x1, y1, x2, y2] format for MMDet
|
314 |
+
bbox_xyxy = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
|
315 |
+
|
316 |
+
instance = {
|
317 |
+
'bbox': bbox_xyxy,
|
318 |
+
'bbox_label': ann['category_id'],
|
319 |
+
'ignore_flag': 0,
|
320 |
+
'annotation_id': ann.get('id', -1),
|
321 |
+
'area': ann.get('area', bbox[2] * bbox[3]),
|
322 |
+
'element_type': ann.get('element_type', 'unknown')
|
323 |
+
}
|
324 |
+
|
325 |
+
# Add additional annotation metadata if available
|
326 |
+
for key in ['text', 'role', 'data_point', 'chart_type', 'total_data_points']:
|
327 |
+
if key in ann:
|
328 |
+
instance[key] = ann[key]
|
329 |
+
|
330 |
+
instances.append(instance)
|
331 |
+
|
332 |
+
# Create data info with enhanced metadata
|
333 |
+
# Fix: Construct full image path using data_prefix (like standard MMDet datasets)
|
334 |
+
filename = img_info['file_name']
|
335 |
+
if self.data_prefix.get('img'):
|
336 |
+
img_path = osp.join(self.data_prefix['img'], filename)
|
337 |
+
else:
|
338 |
+
img_path = filename # Fallback to original filename
|
339 |
+
|
340 |
+
data_info = {
|
341 |
+
'img_id': img_info['id'],
|
342 |
+
'img_path': img_path, # Use constructed path
|
343 |
+
'height': img_info['height'],
|
344 |
+
'width': img_info['width'],
|
345 |
+
'instances': instances,
|
346 |
+
# Enhanced metadata from enriched annotations
|
347 |
+
'chart_type': img_info.get('chart_type', ''),
|
348 |
+
'plot_bb': img_info.get('plot_bb', {}),
|
349 |
+
'data_series': img_info.get('data_series', []),
|
350 |
+
'data_series_stats': img_info.get('data_series_stats', {}),
|
351 |
+
'axes_info': img_info.get('axes_info', {}),
|
352 |
+
'element_counts': img_info.get('element_counts', {}),
|
353 |
+
'source': img_info.get('source', 'unknown')
|
354 |
+
}
|
355 |
+
|
356 |
+
data_list.append(data_info)
|
357 |
+
|
358 |
+
print(f"β
Loaded {len(data_list)} images with enhanced metadata")
|
359 |
+
return data_list
|
360 |
+
|
361 |
+
def parse_data_info(self, raw_data_info):
|
362 |
+
"""Parse data info with enhanced metadata support."""
|
363 |
+
d = raw_data_info.copy()
|
364 |
+
|
365 |
+
# Debug logging for first few images to verify path construction
|
366 |
+
if hasattr(self, '_debug_count'):
|
367 |
+
self._debug_count += 1
|
368 |
+
else:
|
369 |
+
self._debug_count = 1
|
370 |
+
|
371 |
+
if self._debug_count <= 3:
|
372 |
+
print(f"π Path verification debug #{self._debug_count}:")
|
373 |
+
print(f" β’ img_path from load_data_list: {d['img_path']}")
|
374 |
+
print(f" β’ data_root: {getattr(self, 'data_root', 'None')}")
|
375 |
+
full_path = osp.join(self.data_root, d['img_path']) if hasattr(self, 'data_root') else d['img_path']
|
376 |
+
print(f" β’ Full absolute path: {full_path}")
|
377 |
+
print(f" β’ Path exists: {osp.exists(full_path)}")
|
378 |
+
|
379 |
+
# Create or get image information
|
380 |
+
img_h, img_w = d['height'], d['width']
|
381 |
+
|
382 |
+
# Get class names for class-specific filtering
|
383 |
+
class_names = self.METAINFO['classes']
|
384 |
+
|
385 |
+
# Get filter configuration
|
386 |
+
min_size = self.filter_cfg.get('min_size', 1)
|
387 |
+
class_specific_min_sizes = self.filter_cfg.get('class_specific_min_sizes', {})
|
388 |
+
|
389 |
+
# Handle bboxes and labels from instances with enhanced filtering
|
390 |
+
bboxes, labels = [], []
|
391 |
+
filtered_count = 0
|
392 |
+
enlarged_count = 0
|
393 |
+
chart_type_filtered_count = 0
|
394 |
+
|
395 |
+
# Get chart type for filtering
|
396 |
+
chart_type = d.get('chart_type', '').lower()
|
397 |
+
chart_mapping = self.CHART_TYPE_ELEMENT_MAPPING.get(chart_type, {})
|
398 |
+
allowed_data_elements = chart_mapping.get('allowed_data_elements', set())
|
399 |
+
forbidden_data_elements = chart_mapping.get('forbidden_data_elements', set())
|
400 |
+
|
401 |
+
for inst in d.get('instances', []):
|
402 |
+
bbox = inst['bbox']
|
403 |
+
label_id = inst['bbox_label']
|
404 |
+
|
405 |
+
# Get class name for this label
|
406 |
+
class_name = class_names[label_id] if 0 <= label_id < len(class_names) else 'unknown'
|
407 |
+
|
408 |
+
# Chart-type specific filtering: Skip forbidden data elements
|
409 |
+
if chart_type and class_name in forbidden_data_elements:
|
410 |
+
chart_type_filtered_count += 1
|
411 |
+
if self._debug_count <= 3 and chart_type_filtered_count <= 3:
|
412 |
+
print(f" π« Filtered {class_name} from {chart_type} chart (inappropriate data element)")
|
413 |
+
continue
|
414 |
+
|
415 |
+
# Chart-type specific validation: Log allowed data elements
|
416 |
+
if chart_type and class_name in allowed_data_elements:
|
417 |
+
if self._debug_count <= 3:
|
418 |
+
print(f" β
Keeping {class_name} for {chart_type} chart (appropriate data element)")
|
419 |
+
|
420 |
+
# Validate and clamp bbox
|
421 |
+
x1, y1, x2, y2 = bbox
|
422 |
+
x1 = max(0, min(x1, img_w))
|
423 |
+
y1 = max(0, min(y1, img_h))
|
424 |
+
x2 = max(x1, min(x2, img_w))
|
425 |
+
y2 = max(y1, min(y2, img_h))
|
426 |
+
|
427 |
+
# Skip invalid bboxes
|
428 |
+
if x2 <= x1 or y2 <= y1:
|
429 |
+
filtered_count += 1
|
430 |
+
continue
|
431 |
+
|
432 |
+
# Calculate current bbox dimensions
|
433 |
+
bbox_w = x2 - x1
|
434 |
+
bbox_h = y2 - y1
|
435 |
+
bbox_min_dim = min(bbox_w, bbox_h)
|
436 |
+
|
437 |
+
# Check class-specific minimum size
|
438 |
+
required_min_size = class_specific_min_sizes.get(class_name, min_size)
|
439 |
+
|
440 |
+
# If bbox is smaller than required, enlarge it to meet the minimum size
|
441 |
+
if bbox_min_dim < required_min_size:
|
442 |
+
# Calculate expansion needed
|
443 |
+
expand_w = max(0, required_min_size - bbox_w) / 2
|
444 |
+
expand_h = max(0, required_min_size - bbox_h) / 2
|
445 |
+
|
446 |
+
# Expand bbox while keeping it within image bounds
|
447 |
+
new_x1 = max(0, x1 - expand_w)
|
448 |
+
new_y1 = max(0, y1 - expand_h)
|
449 |
+
new_x2 = min(img_w, x2 + expand_w)
|
450 |
+
new_y2 = min(img_h, y2 + expand_h)
|
451 |
+
|
452 |
+
# Update bbox coordinates
|
453 |
+
x1, y1, x2, y2 = new_x1, new_y1, new_x2, new_y2
|
454 |
+
enlarged_count += 1
|
455 |
+
|
456 |
+
if self._debug_count <= 3 and enlarged_count <= 3:
|
457 |
+
print(f" π Enlarged {class_name} bbox: {bbox_w:.1f}x{bbox_h:.1f} β {(x2-x1):.1f}x{(y2-y1):.1f}")
|
458 |
+
|
459 |
+
bboxes.append([x1, y1, x2, y2])
|
460 |
+
labels.append(label_id)
|
461 |
+
|
462 |
+
# Log filtering and enlargement statistics for first few images
|
463 |
+
if self._debug_count <= 3:
|
464 |
+
print(f" π Bbox processing: {len(bboxes)} kept, {filtered_count} filtered (invalid), {chart_type_filtered_count} filtered (chart-type), {enlarged_count} enlarged")
|
465 |
+
if chart_type:
|
466 |
+
print(f" π Chart type: {chart_type} | Allowed data elements: {allowed_data_elements}")
|
467 |
+
if forbidden_data_elements:
|
468 |
+
print(f" π« Forbidden data elements for {chart_type}: {forbidden_data_elements}")
|
469 |
+
|
470 |
+
# Convert to arrays
|
471 |
+
d['gt_bboxes'] = np.array(bboxes, dtype=np.float32) if bboxes else np.zeros((0, 4), dtype=np.float32)
|
472 |
+
d['gt_bboxes_labels'] = np.array(labels, dtype=np.int64) if labels else np.zeros((0,), dtype=np.int64)
|
473 |
+
|
474 |
+
# Enhanced scale factor calculation using data_series_stats
|
475 |
+
d['scale_factor'] = np.array([1.0, 1.0], dtype=np.float32)
|
476 |
+
|
477 |
+
# Use enhanced metadata for better scale factor calculation
|
478 |
+
data_series_stats = d.get('data_series_stats', {})
|
479 |
+
plot_bb = d.get('plot_bb', {})
|
480 |
+
|
481 |
+
if data_series_stats and plot_bb and all(k in plot_bb for k in ['width', 'height']):
|
482 |
+
x_range = data_series_stats.get('x_range')
|
483 |
+
y_range = data_series_stats.get('y_range')
|
484 |
+
|
485 |
+
if x_range and len(x_range) == 2 and x_range[1] != x_range[0]:
|
486 |
+
d['scale_factor'][0] = plot_bb['width'] / (x_range[1] - x_range[0])
|
487 |
+
if y_range and len(y_range) == 2 and y_range[1] != y_range[0]:
|
488 |
+
d['scale_factor'][1] = -plot_bb['height'] / (y_range[1] - y_range[0])
|
489 |
+
|
490 |
+
# Required MMDet fields
|
491 |
+
d.update({
|
492 |
+
'img_shape': (img_h, img_w, 3),
|
493 |
+
'ori_shape': (img_h, img_w, 3),
|
494 |
+
'pad_shape': (img_h, img_w, 3),
|
495 |
+
'flip': False,
|
496 |
+
'flip_direction': None,
|
497 |
+
'img_fields': ['img'],
|
498 |
+
'bbox_fields': ['bbox'],
|
499 |
+
})
|
500 |
+
|
501 |
+
# Additional metadata for training
|
502 |
+
d['img_info'] = {
|
503 |
+
'height': img_h,
|
504 |
+
'width': img_w,
|
505 |
+
'img_shape': d['img_shape'],
|
506 |
+
'ori_shape': d['ori_shape'],
|
507 |
+
'pad_shape': d['pad_shape'],
|
508 |
+
'scale_factor': d['scale_factor'].copy(),
|
509 |
+
'flip': d['flip'],
|
510 |
+
'flip_direction': d['flip_direction'],
|
511 |
+
# Enhanced metadata
|
512 |
+
'chart_type': d.get('chart_type', ''),
|
513 |
+
'num_data_points': data_series_stats.get('num_data_points', 0),
|
514 |
+
'element_counts': d.get('element_counts', {})
|
515 |
+
}
|
516 |
+
|
517 |
+
return d
|
518 |
+
|
519 |
+
def print_missing_image_summary():
|
520 |
+
"""Print summary of missing images."""
|
521 |
+
count = RobustLoadImageFromFile.get_missing_count()
|
522 |
+
if count > 0:
|
523 |
+
print(f"π MISSING IMAGES SUMMARY: {count} images were not found and replaced with dummy images")
|
524 |
+
else:
|
525 |
+
print("β
All images loaded successfully!")
|
526 |
+
|
527 |
+
def print_dataset_summary():
|
528 |
+
"""Print summary of dataset configuration."""
|
529 |
+
print("π ENHANCED CHART DATASET SUMMARY:")
|
530 |
+
print(f" β’ 21 categories supported for comprehensive chart element detection")
|
531 |
+
print(f" β’ Auto-detects best annotation files (enriched_with_info > enriched > regular)")
|
532 |
+
print(f" β’ Enhanced metadata: chart_type, data_series_stats, element_counts, axes_info")
|
533 |
+
print(f" β’ Robust image loading with fallback to dummy images")
|
534 |
+
print(f" β’ Multiple annotations per image (not just plot areas)")
|
535 |
+
|
536 |
+
print("β
[PLUGIN] Enhanced ChartDataset + transforms registered!")
|
537 |
+
print_dataset_summary()
|
custom_models/custom_faster_rcnn_with_meta.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# custom_faster_rcnn_with_meta.py - Faster R-CNN with coordinate handling for chart data
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from mmdet.models.detectors.faster_rcnn import FasterRCNN
|
5 |
+
from mmdet.registry import MODELS
|
6 |
+
|
7 |
+
|
8 |
+
@MODELS.register_module()
|
9 |
+
class CustomFasterRCNNWithMeta(FasterRCNN):
|
10 |
+
"""Faster R-CNN with coordinate standardization for chart detection."""
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
*args,
|
14 |
+
coordinate_standardization=None,
|
15 |
+
data_points_count_head=None,
|
16 |
+
**kwargs):
|
17 |
+
super().__init__(*args, **kwargs)
|
18 |
+
|
19 |
+
# Store coordinate standardization settings
|
20 |
+
self.coord_std = coordinate_standardization or {}
|
21 |
+
|
22 |
+
# Initialize data points count head
|
23 |
+
if data_points_count_head is not None:
|
24 |
+
self.data_points_count_head = MODELS.build(data_points_count_head)
|
25 |
+
else:
|
26 |
+
# Default simple regression head for data point count
|
27 |
+
self.data_points_count_head = nn.Sequential(
|
28 |
+
nn.Linear(2048, 512), # Assuming ResNet-50 backbone features
|
29 |
+
nn.ReLU(),
|
30 |
+
nn.Dropout(0.1),
|
31 |
+
nn.Linear(512, 1) # Single output for count
|
32 |
+
)
|
33 |
+
|
34 |
+
print(f"π― CustomFasterRCNNWithMeta initialized with coordinate handling:")
|
35 |
+
print(f" β’ Enabled: {self.coord_std.get('enabled', False)}")
|
36 |
+
print(f" β’ Origin: {self.coord_std.get('origin', 'top_left')}")
|
37 |
+
print(f" β’ Normalize: {self.coord_std.get('normalize', False)}")
|
38 |
+
print(f" β’ Data points count prediction: Enabled")
|
39 |
+
|
40 |
+
def transform_coordinates(self, coords, img_shape, plot_bb=None, axes_info=None):
|
41 |
+
"""Transform coordinates based on standardization settings."""
|
42 |
+
if not self.coord_std.get('enabled', False):
|
43 |
+
return coords
|
44 |
+
|
45 |
+
# Get image dimensions
|
46 |
+
img_height, img_width = img_shape[-2:]
|
47 |
+
|
48 |
+
# Convert to tensor if not already
|
49 |
+
if not isinstance(coords, torch.Tensor):
|
50 |
+
coords = torch.tensor(coords, device=img_shape.device if hasattr(img_shape, 'device') else 'cpu')
|
51 |
+
|
52 |
+
# Ensure coords is 2D
|
53 |
+
if coords.dim() == 1:
|
54 |
+
coords = coords.view(-1, 2)
|
55 |
+
|
56 |
+
# Normalize coordinates if needed
|
57 |
+
if self.coord_std.get('normalize', True):
|
58 |
+
coords = coords / torch.tensor([img_width, img_height], device=coords.device)
|
59 |
+
|
60 |
+
# Handle bottom-left to top-left origin conversion
|
61 |
+
if self.coord_std.get('origin', 'bottom_left') == 'bottom_left':
|
62 |
+
# Flip y-coordinates to convert from bottom-left to top-left origin
|
63 |
+
coords[:, 1] = 1.0 - coords[:, 1]
|
64 |
+
|
65 |
+
# Convert back to pixel coordinates
|
66 |
+
if self.coord_std.get('normalize', True):
|
67 |
+
coords = coords * torch.tensor([img_width, img_height], device=coords.device)
|
68 |
+
|
69 |
+
return coords
|
70 |
+
|
71 |
+
def forward_train(self,
|
72 |
+
img,
|
73 |
+
img_metas,
|
74 |
+
gt_bboxes,
|
75 |
+
gt_labels,
|
76 |
+
gt_bboxes_ignore=None,
|
77 |
+
**kwargs):
|
78 |
+
"""Forward function during training with coordinate transformation."""
|
79 |
+
|
80 |
+
# Transform ground truth bboxes if coordinate standardization is enabled
|
81 |
+
if self.coord_std.get('enabled', False) and gt_bboxes is not None:
|
82 |
+
transformed_gt_bboxes = []
|
83 |
+
for i, bboxes in enumerate(gt_bboxes):
|
84 |
+
if len(bboxes) > 0:
|
85 |
+
# Convert bbox format for transformation
|
86 |
+
# MMDet uses [x1, y1, x2, y2] format
|
87 |
+
centers = torch.stack([
|
88 |
+
(bboxes[:, 0] + bboxes[:, 2]) / 2, # center_x
|
89 |
+
(bboxes[:, 1] + bboxes[:, 3]) / 2 # center_y
|
90 |
+
], dim=1)
|
91 |
+
|
92 |
+
# Transform centers
|
93 |
+
img_shape = img.shape if hasattr(img, 'shape') else (img_metas[i]['img_shape'][0], img_metas[i]['img_shape'][1])
|
94 |
+
transformed_centers = self.transform_coordinates(
|
95 |
+
centers, img_shape,
|
96 |
+
plot_bb=img_metas[i].get('plot_bb'),
|
97 |
+
axes_info=img_metas[i].get('axes_info')
|
98 |
+
)
|
99 |
+
|
100 |
+
# Reconstruct bboxes with transformed centers
|
101 |
+
widths = bboxes[:, 2] - bboxes[:, 0]
|
102 |
+
heights = bboxes[:, 3] - bboxes[:, 1]
|
103 |
+
|
104 |
+
transformed_bboxes = torch.stack([
|
105 |
+
transformed_centers[:, 0] - widths / 2, # x1
|
106 |
+
transformed_centers[:, 1] - heights / 2, # y1
|
107 |
+
transformed_centers[:, 0] + widths / 2, # x2
|
108 |
+
transformed_centers[:, 1] + heights / 2 # y2
|
109 |
+
], dim=1)
|
110 |
+
|
111 |
+
transformed_gt_bboxes.append(transformed_bboxes)
|
112 |
+
else:
|
113 |
+
transformed_gt_bboxes.append(bboxes)
|
114 |
+
|
115 |
+
gt_bboxes = transformed_gt_bboxes
|
116 |
+
|
117 |
+
# Call parent forward_train with transformed coordinates to get losses
|
118 |
+
losses = super().forward_train(
|
119 |
+
img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore, **kwargs)
|
120 |
+
|
121 |
+
# Extract features for data point count prediction
|
122 |
+
x = self.extract_feat(img)
|
123 |
+
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
|
124 |
+
|
125 |
+
# Extract ground truth data point counts from img_metas
|
126 |
+
gt_data_point_counts = []
|
127 |
+
for img_meta in img_metas:
|
128 |
+
count = img_meta.get('img_info', {}).get('num_data_points', 0)
|
129 |
+
gt_data_point_counts.append(count)
|
130 |
+
gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
|
131 |
+
|
132 |
+
# Predict data point counts and compute loss
|
133 |
+
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
|
134 |
+
data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
|
135 |
+
losses['data_points_count_loss'] = data_points_count_loss
|
136 |
+
|
137 |
+
return losses
|
138 |
+
|
139 |
+
def simple_test(self, img, img_metas, proposals=None, rescale=False):
|
140 |
+
"""Simple test function with coordinate inverse transformation."""
|
141 |
+
# Get predictions from parent
|
142 |
+
results = super().simple_test(img, img_metas, proposals, rescale)
|
143 |
+
|
144 |
+
# Extract features for data point count prediction
|
145 |
+
x = self.extract_feat(img)
|
146 |
+
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
|
147 |
+
|
148 |
+
# Predict data point counts
|
149 |
+
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
|
150 |
+
|
151 |
+
# Add data point count predictions to results
|
152 |
+
if results is not None:
|
153 |
+
for i, result in enumerate(results):
|
154 |
+
if hasattr(result, 'pred_instances'):
|
155 |
+
result.pred_instances.predicted_data_points = pred_data_point_counts[i].item()
|
156 |
+
elif hasattr(result, 'bboxes'):
|
157 |
+
# For older MMDet versions, add as additional attribute
|
158 |
+
result.predicted_data_points = pred_data_point_counts[i].item()
|
159 |
+
|
160 |
+
# Inverse transform predictions if coordinate standardization is enabled
|
161 |
+
if self.coord_std.get('enabled', False) and results is not None:
|
162 |
+
# Note: For simplicity, we're not doing inverse transform in test
|
163 |
+
# The coordinate system should be consistent during training
|
164 |
+
pass
|
165 |
+
|
166 |
+
return results
|
custom_models/custom_heads.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from mmdet.registry import MODELS
|
5 |
+
|
6 |
+
@MODELS.register_module()
|
7 |
+
class FCHead(nn.Module):
|
8 |
+
"""Enhanced fully connected head for classification tasks with attention."""
|
9 |
+
|
10 |
+
def __init__(self, in_channels, num_classes, loss=None):
|
11 |
+
super().__init__()
|
12 |
+
self.attention = nn.MultiheadAttention(in_channels, num_heads=8)
|
13 |
+
self.fc1 = nn.Linear(in_channels, in_channels // 2)
|
14 |
+
self.fc2 = nn.Linear(in_channels // 2, num_classes)
|
15 |
+
self.loss = loss
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
# Apply self-attention
|
19 |
+
x = self.attention(x, x, x)[0]
|
20 |
+
# Apply MLP
|
21 |
+
x = F.relu(self.fc1(x))
|
22 |
+
return self.fc2(x)
|
23 |
+
|
24 |
+
@MODELS.register_module()
|
25 |
+
class RegHead(nn.Module):
|
26 |
+
"""Enhanced regression head for coordinate prediction with distance-based loss."""
|
27 |
+
|
28 |
+
def __init__(self, in_channels, out_dims, max_points=None, loss=None, attention=False, use_axis_info=False):
|
29 |
+
super().__init__()
|
30 |
+
self.fc = nn.Linear(in_channels, out_dims)
|
31 |
+
self.max_points = max_points
|
32 |
+
self.loss = loss
|
33 |
+
self.attention = attention
|
34 |
+
self.use_axis_info = use_axis_info
|
35 |
+
|
36 |
+
if attention:
|
37 |
+
self.attention_layer = nn.MultiheadAttention(in_channels, num_heads=8)
|
38 |
+
|
39 |
+
# Add axis orientation detection
|
40 |
+
if use_axis_info:
|
41 |
+
self.axis_orientation = nn.Linear(in_channels, 2) # 2 for x/y axis orientation
|
42 |
+
|
43 |
+
def compute_distance_loss(self, pred_points, gt_points):
|
44 |
+
"""Compute distance-based loss between predicted and ground truth points."""
|
45 |
+
# Ensure points are in the same format
|
46 |
+
if pred_points.dim() == 2:
|
47 |
+
pred_points = pred_points.unsqueeze(0)
|
48 |
+
if gt_points.dim() == 2:
|
49 |
+
gt_points = gt_points.unsqueeze(0)
|
50 |
+
|
51 |
+
# Compute pairwise distances
|
52 |
+
dist = torch.cdist(pred_points, gt_points)
|
53 |
+
|
54 |
+
# Get minimum distance for each point
|
55 |
+
min_dist, _ = torch.min(dist, dim=2)
|
56 |
+
|
57 |
+
# Compute loss (using smooth L1 loss for robustness)
|
58 |
+
return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist))
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
if self.attention:
|
62 |
+
x = self.attention_layer(x, x, x)[0]
|
63 |
+
|
64 |
+
# Get base predictions
|
65 |
+
pred = self.fc(x)
|
66 |
+
|
67 |
+
# If using axis info, also predict axis orientation
|
68 |
+
if self.use_axis_info:
|
69 |
+
axis_orientation = self.axis_orientation(x)
|
70 |
+
return pred, axis_orientation
|
71 |
+
|
72 |
+
return pred
|
73 |
+
|
74 |
+
class CoordinateTransformer:
|
75 |
+
"""Helper class to transform coordinates between different spaces."""
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def to_axis_relative(points, axis_info):
|
79 |
+
"""Transform points to be relative to axis coordinates.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
points (torch.Tensor): Points in image coordinates (N, 2)
|
83 |
+
axis_info (torch.Tensor): Axis information [x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale]
|
84 |
+
"""
|
85 |
+
# Extract axis information
|
86 |
+
x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1)
|
87 |
+
|
88 |
+
# Normalize to [0, 1] range
|
89 |
+
x_norm = (points[..., 0] - x_min) / (x_max - x_min)
|
90 |
+
y_norm = (points[..., 1] - y_min) / (y_max - y_min)
|
91 |
+
|
92 |
+
# Scale to axis units
|
93 |
+
x_axis = x_norm * x_scale + x_origin
|
94 |
+
y_axis = y_norm * y_scale + y_origin
|
95 |
+
|
96 |
+
return torch.stack([x_axis, y_axis], dim=-1)
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def to_image_coordinates(points, axis_info):
|
100 |
+
"""Transform points from axis coordinates to image coordinates."""
|
101 |
+
# Extract axis information
|
102 |
+
x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1)
|
103 |
+
|
104 |
+
# Convert from axis units to normalized coordinates
|
105 |
+
x_norm = (points[..., 0] - x_origin) / x_scale
|
106 |
+
y_norm = (points[..., 1] - y_origin) / y_scale
|
107 |
+
|
108 |
+
# Convert to image coordinates
|
109 |
+
x_img = x_norm * (x_max - x_min) + x_min
|
110 |
+
y_img = y_norm * (y_max - y_min) + y_min
|
111 |
+
|
112 |
+
return torch.stack([x_img, y_img], dim=-1)
|
113 |
+
|
114 |
+
@MODELS.register_module()
|
115 |
+
class DataSeriesHead(nn.Module):
|
116 |
+
"""Specialized head for data series prediction with dual attention to coordinates and axis-relative positions."""
|
117 |
+
|
118 |
+
def __init__(self, in_channels, max_points=50, loss=None):
|
119 |
+
super().__init__()
|
120 |
+
self.max_points = max_points
|
121 |
+
self.loss = loss
|
122 |
+
|
123 |
+
# Feature extraction
|
124 |
+
self.fc1 = nn.Linear(in_channels, in_channels // 2)
|
125 |
+
|
126 |
+
# Separate branches for absolute and relative coordinates
|
127 |
+
self.absolute_branch = nn.Sequential(
|
128 |
+
nn.Linear(in_channels // 2, in_channels // 4),
|
129 |
+
nn.ReLU(),
|
130 |
+
nn.Linear(in_channels // 4, max_points * 2) # 2 coordinates per point
|
131 |
+
)
|
132 |
+
|
133 |
+
self.relative_branch = nn.Sequential(
|
134 |
+
nn.Linear(in_channels // 2, in_channels // 4),
|
135 |
+
nn.ReLU(),
|
136 |
+
nn.Linear(in_channels // 4, max_points * 2) # 2 coordinates per point
|
137 |
+
)
|
138 |
+
|
139 |
+
# Attention mechanisms
|
140 |
+
self.coord_attention = nn.MultiheadAttention(in_channels, num_heads=8)
|
141 |
+
self.axis_attention = nn.MultiheadAttention(in_channels, num_heads=8)
|
142 |
+
self.sequence_attention = nn.MultiheadAttention(in_channels, num_heads=8)
|
143 |
+
|
144 |
+
# Sequence-aware processing
|
145 |
+
self.sequence_encoder = nn.TransformerEncoder(
|
146 |
+
nn.TransformerEncoderLayer(
|
147 |
+
d_model=in_channels,
|
148 |
+
nhead=8,
|
149 |
+
dim_feedforward=in_channels * 4,
|
150 |
+
dropout=0.1
|
151 |
+
),
|
152 |
+
num_layers=2
|
153 |
+
)
|
154 |
+
|
155 |
+
# Pattern recognition
|
156 |
+
self.pattern_recognizer = nn.Sequential(
|
157 |
+
nn.Linear(in_channels, in_channels // 2),
|
158 |
+
nn.ReLU(),
|
159 |
+
nn.Linear(in_channels // 2, 5) # 5 for different chart patterns
|
160 |
+
)
|
161 |
+
|
162 |
+
# Coordinate transformer
|
163 |
+
self.coord_transformer = CoordinateTransformer()
|
164 |
+
|
165 |
+
def check_monotonicity(self, points, chart_type):
|
166 |
+
"""Check if points follow expected monotonicity based on chart type."""
|
167 |
+
if chart_type in ['line', 'scatter']:
|
168 |
+
# For line/scatter, check if points are generally increasing or decreasing
|
169 |
+
diffs = points[..., 1].diff()
|
170 |
+
return torch.all(diffs >= 0) or torch.all(diffs <= 0)
|
171 |
+
return True
|
172 |
+
|
173 |
+
def forward(self, x, axis_info=None, chart_type=None):
|
174 |
+
# Apply coordinate attention
|
175 |
+
coord_feat = self.coord_attention(x, x, x)[0]
|
176 |
+
|
177 |
+
# Apply axis attention if axis info is available
|
178 |
+
if axis_info is not None:
|
179 |
+
axis_feat = self.axis_attention(x, x, x)[0]
|
180 |
+
# Combine features
|
181 |
+
x = coord_feat + axis_feat
|
182 |
+
else:
|
183 |
+
x = coord_feat
|
184 |
+
|
185 |
+
# Apply sequence attention
|
186 |
+
seq_feat = self.sequence_attention(x, x, x)[0]
|
187 |
+
x = x + seq_feat
|
188 |
+
|
189 |
+
# Process through sequence encoder
|
190 |
+
x = self.sequence_encoder(x.unsqueeze(0)).squeeze(0)
|
191 |
+
|
192 |
+
# Extract base features
|
193 |
+
x = F.relu(self.fc1(x))
|
194 |
+
|
195 |
+
# Get predictions from both branches
|
196 |
+
absolute_points = self.absolute_branch(x)
|
197 |
+
relative_points = self.relative_branch(x)
|
198 |
+
|
199 |
+
# Reshape to (batch_size, max_points, 2)
|
200 |
+
absolute_points = absolute_points.view(-1, self.max_points, 2)
|
201 |
+
relative_points = relative_points.view(-1, self.max_points, 2)
|
202 |
+
|
203 |
+
# If axis information is provided, transform relative points
|
204 |
+
if axis_info is not None:
|
205 |
+
relative_points = self.coord_transformer.to_axis_relative(relative_points, axis_info)
|
206 |
+
|
207 |
+
# Get pattern prediction
|
208 |
+
pattern_logits = self.pattern_recognizer(x)
|
209 |
+
|
210 |
+
# Check monotonicity if chart type is provided
|
211 |
+
if chart_type is not None:
|
212 |
+
monotonicity = self.check_monotonicity(absolute_points, chart_type)
|
213 |
+
else:
|
214 |
+
monotonicity = None
|
215 |
+
|
216 |
+
return absolute_points, relative_points, pattern_logits, monotonicity
|
217 |
+
|
218 |
+
def compute_loss(self, pred_absolute, pred_relative, gt_absolute, gt_relative,
|
219 |
+
pattern_logits, gt_pattern, axis_info=None, chart_type=None):
|
220 |
+
"""Compute combined loss for both absolute and relative coordinates."""
|
221 |
+
# Ensure points are in the same format
|
222 |
+
if pred_absolute.dim() == 2:
|
223 |
+
pred_absolute = pred_absolute.unsqueeze(0)
|
224 |
+
if pred_relative.dim() == 2:
|
225 |
+
pred_relative = pred_relative.unsqueeze(0)
|
226 |
+
if gt_absolute.dim() == 2:
|
227 |
+
gt_absolute = gt_absolute.unsqueeze(0)
|
228 |
+
if gt_relative.dim() == 2:
|
229 |
+
gt_relative = gt_relative.unsqueeze(0)
|
230 |
+
|
231 |
+
# Compute absolute coordinate loss
|
232 |
+
absolute_loss = self.compute_distance_loss(pred_absolute, gt_absolute)
|
233 |
+
|
234 |
+
# Compute relative coordinate loss
|
235 |
+
if axis_info is not None:
|
236 |
+
# Transform predicted absolute points to relative coordinates
|
237 |
+
pred_absolute_relative = self.coord_transformer.to_axis_relative(pred_absolute, axis_info)
|
238 |
+
relative_loss = self.compute_distance_loss(pred_absolute_relative, gt_relative)
|
239 |
+
else:
|
240 |
+
relative_loss = torch.tensor(0.0, device=pred_absolute.device)
|
241 |
+
|
242 |
+
# Compute pattern recognition loss
|
243 |
+
pattern_loss = F.cross_entropy(pattern_logits, gt_pattern)
|
244 |
+
|
245 |
+
# Add monotonicity penalty if applicable
|
246 |
+
if chart_type is not None:
|
247 |
+
monotonicity = self.check_monotonicity(pred_absolute, chart_type)
|
248 |
+
monotonicity_loss = F.binary_cross_entropy(monotonicity.float(), torch.ones_like(monotonicity.float()))
|
249 |
+
else:
|
250 |
+
monotonicity_loss = torch.tensor(0.0, device=pred_absolute.device)
|
251 |
+
|
252 |
+
# Combine losses with weights
|
253 |
+
total_loss = (absolute_loss + relative_loss +
|
254 |
+
0.5 * pattern_loss + 0.3 * monotonicity_loss)
|
255 |
+
|
256 |
+
return total_loss
|
257 |
+
|
258 |
+
def compute_distance_loss(self, pred_points, gt_points):
|
259 |
+
"""Compute distance-based loss between predicted and ground truth points."""
|
260 |
+
# Compute pairwise distances
|
261 |
+
dist = torch.cdist(pred_points, gt_points)
|
262 |
+
|
263 |
+
# Get minimum distance for each point
|
264 |
+
min_dist, _ = torch.min(dist, dim=2)
|
265 |
+
|
266 |
+
# Compute loss (using smooth L1 loss for robustness)
|
267 |
+
return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist))
|
custom_models/flexible_load_annotations.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Dict, Optional
|
3 |
+
from mmcv.transforms.base import BaseTransform
|
4 |
+
from mmdet.registry import TRANSFORMS
|
5 |
+
from mmdet.datasets.transforms.loading import LoadAnnotations
|
6 |
+
import logging
|
7 |
+
from mmdet.structures.mask import BitmapMasks
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
@TRANSFORMS.register_module()
|
12 |
+
class FlexibleLoadAnnotations(LoadAnnotations):
|
13 |
+
"""
|
14 |
+
Flexible annotation loader that handles mixed mask/bbox datasets.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
with_bbox: bool = True,
|
19 |
+
with_mask: bool = True,
|
20 |
+
with_seg: bool = False,
|
21 |
+
poly2mask: bool = True,
|
22 |
+
**kwargs):
|
23 |
+
super().__init__(
|
24 |
+
with_bbox=with_bbox,
|
25 |
+
with_mask=with_mask,
|
26 |
+
with_seg=with_seg,
|
27 |
+
poly2mask=poly2mask,
|
28 |
+
**kwargs
|
29 |
+
)
|
30 |
+
self.mask_stats = {'total': 0, 'with_masks': 0, 'without_masks': 0}
|
31 |
+
|
32 |
+
def _load_masks(self, results: dict) -> dict:
|
33 |
+
"""Load mask annotations from COCO format instances."""
|
34 |
+
if not self.with_mask or not isinstance(results, dict):
|
35 |
+
return results
|
36 |
+
|
37 |
+
# Check for ann_info format (what COCO dataset actually provides)
|
38 |
+
ann_info = results.get('ann_info')
|
39 |
+
if isinstance(ann_info, dict):
|
40 |
+
# Check if segmentation is in ann_info
|
41 |
+
if 'segmentation' in ann_info:
|
42 |
+
segmentation = ann_info['segmentation']
|
43 |
+
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
|
44 |
+
# Convert to mask format
|
45 |
+
ann_info['masks'] = segmentation
|
46 |
+
return super()._load_masks(results)
|
47 |
+
|
48 |
+
# Check for polygon data in ann_info
|
49 |
+
if 'polygon' in ann_info:
|
50 |
+
polygon = ann_info['polygon']
|
51 |
+
if polygon and isinstance(polygon, dict):
|
52 |
+
try:
|
53 |
+
# Convert polygon to COCO segmentation format
|
54 |
+
coords = []
|
55 |
+
for j in range(4): # Assuming 4-point polygons
|
56 |
+
x_key = f'x{j}'
|
57 |
+
y_key = f'y{j}'
|
58 |
+
if x_key in polygon and y_key in polygon:
|
59 |
+
coords.extend([polygon[x_key], polygon[y_key]])
|
60 |
+
|
61 |
+
if len(coords) >= 6: # Need at least 3 points (6 coordinates)
|
62 |
+
# Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
|
63 |
+
segmentation = [coords]
|
64 |
+
ann_info['segmentation'] = segmentation
|
65 |
+
ann_info['masks'] = segmentation
|
66 |
+
return super()._load_masks(results)
|
67 |
+
except Exception as e:
|
68 |
+
logger.debug(f"Polygon conversion failed: {e}")
|
69 |
+
|
70 |
+
# Handle COCO format: instances with segmentation
|
71 |
+
instances = results.get('instances')
|
72 |
+
if isinstance(instances, list):
|
73 |
+
# Process ALL instances - keep both with and without masks
|
74 |
+
valid_instances = []
|
75 |
+
|
76 |
+
for i, instance in enumerate(instances):
|
77 |
+
self.mask_stats['total'] += 1
|
78 |
+
|
79 |
+
# Check for segmentation in COCO format (COCO dataset stores it in 'mask' field)
|
80 |
+
segmentation = instance.get('mask') or instance.get('segmentation')
|
81 |
+
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
|
82 |
+
# Handle nested list format: [[x1, y1, x2, y2, ...]]
|
83 |
+
if isinstance(segmentation[0], list):
|
84 |
+
# Nested format - check if inner list has enough coordinates
|
85 |
+
inner_seg = segmentation[0]
|
86 |
+
if len(inner_seg) >= 6: # Need at least 3 points (6 coordinates)
|
87 |
+
instance['mask'] = segmentation # Keep original nested format for parent
|
88 |
+
valid_instances.append(instance)
|
89 |
+
self.mask_stats['with_masks'] += 1
|
90 |
+
else:
|
91 |
+
# Keep instance for bbox training even without valid mask
|
92 |
+
instance['mask'] = []
|
93 |
+
valid_instances.append(instance)
|
94 |
+
self.mask_stats['without_masks'] += 1
|
95 |
+
else:
|
96 |
+
# Flat format - already correct
|
97 |
+
instance['mask'] = segmentation
|
98 |
+
valid_instances.append(instance)
|
99 |
+
self.mask_stats['with_masks'] += 1
|
100 |
+
else:
|
101 |
+
# Check for polygon data and convert to segmentation
|
102 |
+
polygon = instance.get('polygon')
|
103 |
+
if polygon and isinstance(polygon, dict):
|
104 |
+
# Convert polygon to COCO segmentation format
|
105 |
+
try:
|
106 |
+
# Extract polygon coordinates
|
107 |
+
coords = []
|
108 |
+
for j in range(4): # Assuming 4-point polygons
|
109 |
+
x_key = f'x{j}'
|
110 |
+
y_key = f'y{j}'
|
111 |
+
if x_key in polygon and y_key in polygon:
|
112 |
+
coords.extend([polygon[x_key], polygon[y_key]])
|
113 |
+
|
114 |
+
if len(coords) >= 6: # Need at least 3 points (6 coordinates)
|
115 |
+
# Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
|
116 |
+
segmentation = [coords]
|
117 |
+
instance['segmentation'] = segmentation
|
118 |
+
instance['mask'] = segmentation
|
119 |
+
valid_instances.append(instance)
|
120 |
+
self.mask_stats['with_masks'] += 1
|
121 |
+
else:
|
122 |
+
# Keep instance for bbox training even without mask
|
123 |
+
# Add empty mask field to prevent KeyError in parent class
|
124 |
+
instance['mask'] = []
|
125 |
+
valid_instances.append(instance)
|
126 |
+
self.mask_stats['without_masks'] += 1
|
127 |
+
except Exception as e:
|
128 |
+
# Keep instance for bbox training even if polygon conversion fails
|
129 |
+
# Add empty mask field to prevent KeyError in parent class
|
130 |
+
instance['mask'] = []
|
131 |
+
valid_instances.append(instance)
|
132 |
+
self.mask_stats['without_masks'] += 1
|
133 |
+
else:
|
134 |
+
# Keep instance for bbox training even without segmentation
|
135 |
+
# Add empty mask field to prevent KeyError in parent class
|
136 |
+
instance['mask'] = []
|
137 |
+
valid_instances.append(instance)
|
138 |
+
self.mask_stats['without_masks'] += 1
|
139 |
+
|
140 |
+
# Update results with valid instances only
|
141 |
+
results['instances'] = valid_instances
|
142 |
+
|
143 |
+
# Call parent method to process the filtered instances
|
144 |
+
if valid_instances:
|
145 |
+
super()._load_masks(results) # Parent modifies results in place
|
146 |
+
return results
|
147 |
+
else:
|
148 |
+
# No valid masks, create empty mask structure
|
149 |
+
h, w = results.get('img_shape', (0, 0))
|
150 |
+
results['gt_masks'] = BitmapMasks([], h, w)
|
151 |
+
results['gt_ignore_flags'] = np.array([], dtype=bool)
|
152 |
+
return results
|
153 |
+
|
154 |
+
# Check for direct segmentation in results
|
155 |
+
if 'segmentation' in results:
|
156 |
+
segmentation = results['segmentation']
|
157 |
+
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
|
158 |
+
results['masks'] = segmentation
|
159 |
+
return super()._load_masks(results)
|
160 |
+
|
161 |
+
return results
|
162 |
+
|
163 |
+
def transform(self, results: dict) -> dict:
|
164 |
+
"""Transform function to load annotations."""
|
165 |
+
# ensure we always return a dict
|
166 |
+
if not isinstance(results, dict):
|
167 |
+
logger.error(f"Expected dict, got {type(results)}")
|
168 |
+
return {}
|
169 |
+
|
170 |
+
# Call parent transform to handle bbox loading
|
171 |
+
results = super().transform(results)
|
172 |
+
|
173 |
+
# Handle mask loading with our custom logic
|
174 |
+
results = self._load_masks(results)
|
175 |
+
|
176 |
+
# periodic logging
|
177 |
+
if self.mask_stats['total'] % 1000 == 0:
|
178 |
+
t = self.mask_stats['total']
|
179 |
+
w = self.mask_stats['with_masks']
|
180 |
+
wo = self.mask_stats['without_masks']
|
181 |
+
logger.info(f"Mask stats - total: {t}, with_masks: {w}, without_masks: {wo}")
|
182 |
+
|
183 |
+
return results
|
184 |
+
|
185 |
+
def __repr__(self) -> str:
|
186 |
+
"""String representation."""
|
187 |
+
return (f'{self.__class__.__name__}('
|
188 |
+
f'with_bbox={self.with_bbox}, '
|
189 |
+
f'with_mask={self.with_mask}, '
|
190 |
+
f'with_seg={self.with_seg}, '
|
191 |
+
f'poly2mask={self.poly2mask})')
|
custom_models/mask_filter.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from mmcv.transforms.base import BaseTransform
|
3 |
+
from mmdet.registry import TRANSFORMS
|
4 |
+
import logging
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
@TRANSFORMS.register_module()
|
9 |
+
class MaskFilter(BaseTransform):
|
10 |
+
"""Filter out images with no valid masks during training.
|
11 |
+
|
12 |
+
This transform checks if there are any valid masks in the image and
|
13 |
+
returns None if no masks are found, which will cause the image to be skipped.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, min_masks=1):
|
17 |
+
self.min_masks = min_masks
|
18 |
+
|
19 |
+
def transform(self, results):
|
20 |
+
"""Filter results based on mask availability.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
results (dict): Result dict from dataset.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
dict or None: Returns results if valid masks found, None otherwise.
|
27 |
+
"""
|
28 |
+
# Check if we have valid masks
|
29 |
+
gt_masks = results.get('gt_masks')
|
30 |
+
|
31 |
+
if gt_masks is None:
|
32 |
+
logger.warning("MaskFilter: No gt_masks found, skipping image")
|
33 |
+
return None
|
34 |
+
|
35 |
+
# Count valid masks
|
36 |
+
if hasattr(gt_masks, 'masks'):
|
37 |
+
num_masks = len(gt_masks.masks)
|
38 |
+
elif hasattr(gt_masks, 'polygons'):
|
39 |
+
num_masks = len(gt_masks.polygons)
|
40 |
+
else:
|
41 |
+
num_masks = 0
|
42 |
+
|
43 |
+
if num_masks < self.min_masks:
|
44 |
+
logger.info(f"MaskFilter: Only {num_masks} masks found (min: {self.min_masks}), skipping image")
|
45 |
+
return None
|
46 |
+
|
47 |
+
logger.info(f"MaskFilter: {num_masks} masks found, keeping image for training")
|
48 |
+
return results
|
custom_models/nan_recovery_hook.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# nan_recovery_hook.py - Graceful NaN loss recovery for Cascade R-CNN
|
2 |
+
import torch
|
3 |
+
from mmengine.hooks import Hook
|
4 |
+
from mmengine.runner import Runner
|
5 |
+
from mmdet.registry import HOOKS
|
6 |
+
from typing import Optional, Dict, Any
|
7 |
+
|
8 |
+
|
9 |
+
@HOOKS.register_module()
|
10 |
+
class NanRecoveryHook(Hook):
|
11 |
+
"""Hook to handle NaN losses gracefully without crashing training.
|
12 |
+
|
13 |
+
This hook detects NaN losses and handles them by:
|
14 |
+
1. Replacing NaN losses with the last valid loss value
|
15 |
+
2. Skipping gradient updates for that iteration
|
16 |
+
3. Logging the recovery for monitoring
|
17 |
+
4. Allowing training to continue normally
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
fallback_loss: float = 0.5,
|
22 |
+
max_consecutive_nans: int = 100, # Increased from 10
|
23 |
+
log_interval: int = 50): # Log less frequently
|
24 |
+
self.fallback_loss = fallback_loss
|
25 |
+
self.max_consecutive_nans = max_consecutive_nans
|
26 |
+
self.log_interval = log_interval
|
27 |
+
|
28 |
+
# State tracking
|
29 |
+
self.last_valid_loss = fallback_loss
|
30 |
+
self.consecutive_nans = 0
|
31 |
+
self.total_nans = 0
|
32 |
+
self.nan_iterations = []
|
33 |
+
|
34 |
+
def before_train_iter(self,
|
35 |
+
runner: Runner,
|
36 |
+
batch_idx: int,
|
37 |
+
data_batch: Optional[dict] = None) -> None:
|
38 |
+
"""Reset any state before training iteration."""
|
39 |
+
pass
|
40 |
+
|
41 |
+
def after_train_iter(self,
|
42 |
+
runner: Runner,
|
43 |
+
batch_idx: int,
|
44 |
+
data_batch: Optional[dict] = None,
|
45 |
+
outputs: Optional[Dict[str, Any]] = None) -> None:
|
46 |
+
"""Handle NaN losses after training iteration."""
|
47 |
+
if outputs is None:
|
48 |
+
return
|
49 |
+
|
50 |
+
# Check ALL loss components for NaN, not just the main loss
|
51 |
+
has_nan = False
|
52 |
+
|
53 |
+
# Check main loss
|
54 |
+
total_loss = outputs.get('loss')
|
55 |
+
if total_loss is not None and (torch.isnan(total_loss) or torch.isinf(total_loss)):
|
56 |
+
has_nan = True
|
57 |
+
|
58 |
+
# Check all individual loss components
|
59 |
+
for key, value in outputs.items():
|
60 |
+
if isinstance(value, torch.Tensor) and 'loss' in key.lower():
|
61 |
+
if torch.isnan(value) or torch.isinf(value):
|
62 |
+
has_nan = True
|
63 |
+
break
|
64 |
+
|
65 |
+
if has_nan:
|
66 |
+
self._handle_nan_loss(runner, batch_idx, outputs)
|
67 |
+
else:
|
68 |
+
# Valid loss - update tracking
|
69 |
+
if total_loss is not None:
|
70 |
+
self.last_valid_loss = float(total_loss.item())
|
71 |
+
if self.consecutive_nans > 0:
|
72 |
+
runner.logger.info(f"π Loss recovered after {self.consecutive_nans} NaN iterations")
|
73 |
+
self.consecutive_nans = 0
|
74 |
+
|
75 |
+
def _handle_nan_loss(self, runner: Runner, batch_idx: int, outputs: Dict[str, Any]) -> None:
|
76 |
+
"""Handle NaN loss by replacing with detached fallback and managing state."""
|
77 |
+
self.consecutive_nans += 1
|
78 |
+
self.total_nans += 1
|
79 |
+
self.nan_iterations.append(batch_idx)
|
80 |
+
|
81 |
+
# Try to get last good state from SkipBadSamplesHook if available
|
82 |
+
last_good_iteration = batch_idx
|
83 |
+
last_good_loss = self.last_valid_loss
|
84 |
+
|
85 |
+
for hook in runner.hooks:
|
86 |
+
if hasattr(hook, 'last_good_iteration') and hasattr(hook, 'last_good_loss'):
|
87 |
+
if hook.last_good_loss is not None:
|
88 |
+
last_good_iteration = hook.last_good_iteration
|
89 |
+
last_good_loss = hook.last_good_loss
|
90 |
+
break
|
91 |
+
|
92 |
+
# Replace NaN loss with detached fallback (no gradients = true no-op)
|
93 |
+
if 'loss' in outputs and outputs['loss'] is not None:
|
94 |
+
fallback_tensor = torch.tensor(
|
95 |
+
last_good_loss,
|
96 |
+
device=outputs['loss'].device,
|
97 |
+
dtype=outputs['loss'].dtype
|
98 |
+
# NOTE: No requires_grad=True - this makes it detached
|
99 |
+
)
|
100 |
+
outputs['loss'] = fallback_tensor
|
101 |
+
|
102 |
+
# Also fix individual loss components with detached tensors
|
103 |
+
self._fix_loss_components(outputs, last_good_loss)
|
104 |
+
|
105 |
+
# Log recovery with state info
|
106 |
+
if self.consecutive_nans <= 5 or self.consecutive_nans % self.log_interval == 0:
|
107 |
+
runner.logger.warning(
|
108 |
+
f"π NaN Recovery at iteration {batch_idx}: "
|
109 |
+
f"Using last good loss {last_good_loss:.4f} from iteration {last_good_iteration}. "
|
110 |
+
f"Consecutive NaNs: {self.consecutive_nans}, Total: {self.total_nans}"
|
111 |
+
)
|
112 |
+
|
113 |
+
# Reset training state if too many consecutive NaNs
|
114 |
+
if self.consecutive_nans >= self.max_consecutive_nans:
|
115 |
+
self._reset_nan_state(runner, last_good_iteration)
|
116 |
+
|
117 |
+
def _reset_nan_state(self, runner: Runner, last_good_iteration: int) -> None:
|
118 |
+
"""Reset training state when too many consecutive NaNs."""
|
119 |
+
runner.logger.error(
|
120 |
+
f"π Too many consecutive NaN losses ({self.consecutive_nans}). "
|
121 |
+
f"Resetting to last good state from iteration {last_good_iteration}"
|
122 |
+
)
|
123 |
+
|
124 |
+
try:
|
125 |
+
# Clear model gradients
|
126 |
+
if hasattr(runner.model, 'zero_grad'):
|
127 |
+
runner.model.zero_grad()
|
128 |
+
|
129 |
+
# Clear CUDA cache
|
130 |
+
if torch.cuda.is_available():
|
131 |
+
torch.cuda.empty_cache()
|
132 |
+
|
133 |
+
# Reset consecutive counter
|
134 |
+
self.consecutive_nans = 0
|
135 |
+
|
136 |
+
runner.logger.info(f"β
NaN state reset. Resuming training...")
|
137 |
+
|
138 |
+
except Exception as e:
|
139 |
+
runner.logger.error(f"β Failed to reset NaN state: {e}")
|
140 |
+
|
141 |
+
def _fix_loss_components(self, outputs: Dict[str, Any], fallback_loss: float = None) -> None:
|
142 |
+
"""Fix ALL loss components with detached tensors (no gradients)."""
|
143 |
+
if fallback_loss is None:
|
144 |
+
fallback_loss = self.last_valid_loss
|
145 |
+
|
146 |
+
fallback_small = max(0.01, fallback_loss * 0.1) # Ensure non-zero minimum
|
147 |
+
|
148 |
+
# Fix ALL tensors with 'loss' in the key name using detached tensors
|
149 |
+
for key, value in outputs.items():
|
150 |
+
if isinstance(value, torch.Tensor) and 'loss' in key.lower():
|
151 |
+
if torch.isnan(value) or torch.isinf(value):
|
152 |
+
# Create detached replacement tensor (no gradients)
|
153 |
+
replacement = torch.tensor(
|
154 |
+
fallback_small,
|
155 |
+
device=value.device,
|
156 |
+
dtype=value.dtype
|
157 |
+
# NOTE: No requires_grad=True - detached for true no-op
|
158 |
+
)
|
159 |
+
outputs[key] = replacement
|
160 |
+
print(f" π§ Fixed {key}: {value.item():.4f} -> detached {fallback_small:.4f}")
|
161 |
+
|
162 |
+
# Also fix any scalar values that might be NaN
|
163 |
+
for key, value in list(outputs.items()):
|
164 |
+
if isinstance(value, (int, float)) and 'loss' in key.lower():
|
165 |
+
if not torch.isfinite(torch.tensor(value)):
|
166 |
+
outputs[key] = fallback_small
|
167 |
+
print(f" π§ Fixed scalar {key}: {value} -> {fallback_small:.4f}")
|
168 |
+
|
169 |
+
def after_train_epoch(self, runner: Runner) -> None:
|
170 |
+
"""Summary statistics after each epoch."""
|
171 |
+
if self.total_nans > 0:
|
172 |
+
runner.logger.info(
|
173 |
+
f"π NaN Recovery Summary for Epoch: "
|
174 |
+
f"{self.total_nans} NaN losses recovered. "
|
175 |
+
f"Training continued successfully."
|
176 |
+
)
|
177 |
+
|
178 |
+
# Reset for next epoch
|
179 |
+
self.consecutive_nans = 0
|
180 |
+
self.total_nans = 0
|
181 |
+
self.nan_iterations.clear()
|
custom_models/progressive_loss_hook.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# progressive_loss_hook.py - Progressive Loss Switching Hook for Cascade R-CNN
|
2 |
+
import torch
|
3 |
+
from mmengine.hooks import Hook
|
4 |
+
from mmdet.registry import HOOKS
|
5 |
+
from mmdet.models.losses import SmoothL1Loss, GIoULoss, CIoULoss, DIoULoss
|
6 |
+
|
7 |
+
@HOOKS.register_module()
|
8 |
+
class ProgressiveLossHook(Hook):
|
9 |
+
"""
|
10 |
+
Progressive Loss Switching Hook for Cascade R-CNN.
|
11 |
+
|
12 |
+
Starts with SmoothL1Loss for all stages, then progressively switches
|
13 |
+
stage 3 (final stage) to GIoU/CIoU/DIoU after the model stabilizes.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
switch_epoch (int): Epoch to switch stage 3 from SmoothL1 to target loss
|
17 |
+
target_loss_type (str): Target loss type for stage 3 ('GIoULoss', 'CIoULoss', or 'DIoULoss')
|
18 |
+
loss_weight (float): Loss weight for the new loss function
|
19 |
+
warmup_epochs (int): Number of epochs to gradually blend the losses
|
20 |
+
monitor_stage_weights (bool): Whether to log stage loss weights
|
21 |
+
nan_detection (bool): Whether to enable NaN detection and rollback
|
22 |
+
max_nan_tolerance (int): Maximum consecutive NaN losses before rollback
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
switch_epoch=5,
|
27 |
+
target_loss_type='GIoULoss',
|
28 |
+
loss_weight=1.0,
|
29 |
+
warmup_epochs=2,
|
30 |
+
monitor_stage_weights=True,
|
31 |
+
nan_detection=False,
|
32 |
+
max_nan_tolerance=5):
|
33 |
+
super().__init__()
|
34 |
+
self.switch_epoch = switch_epoch
|
35 |
+
self.target_loss_type = target_loss_type
|
36 |
+
self.loss_weight = loss_weight
|
37 |
+
self.warmup_epochs = warmup_epochs
|
38 |
+
self.monitor_stage_weights = monitor_stage_weights
|
39 |
+
self.nan_detection = nan_detection
|
40 |
+
self.max_nan_tolerance = max_nan_tolerance
|
41 |
+
self.switched = False
|
42 |
+
self.original_loss = None
|
43 |
+
self.consecutive_nans = 0
|
44 |
+
self.rollback_performed = False
|
45 |
+
|
46 |
+
def before_train_epoch(self, runner):
|
47 |
+
"""Check if we should switch the loss function."""
|
48 |
+
current_epoch = runner.epoch
|
49 |
+
|
50 |
+
# Switch at the specified epoch
|
51 |
+
if current_epoch >= self.switch_epoch and not self.switched:
|
52 |
+
self._switch_stage2_loss(runner)
|
53 |
+
self.switched = True
|
54 |
+
runner.logger.info(
|
55 |
+
f"Epoch {current_epoch}: Switched Stage 3 loss to {self.target_loss_type}")
|
56 |
+
|
57 |
+
# Monitor during warmup period
|
58 |
+
elif current_epoch >= self.switch_epoch and current_epoch < self.switch_epoch + self.warmup_epochs:
|
59 |
+
if self.monitor_stage_weights:
|
60 |
+
self._log_loss_info(runner, current_epoch)
|
61 |
+
|
62 |
+
def _switch_stage2_loss(self, runner):
|
63 |
+
"""Switch stage 3 bbox loss from SmoothL1 to target loss."""
|
64 |
+
model = runner.model
|
65 |
+
|
66 |
+
# Navigate to stage 3 bbox head (index 2) - final refinement stage
|
67 |
+
try:
|
68 |
+
# Handle DDP wrapper
|
69 |
+
if hasattr(model, 'module'):
|
70 |
+
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
|
71 |
+
else:
|
72 |
+
bbox_head_stage2 = model.roi_head.bbox_head[2]
|
73 |
+
|
74 |
+
# Store original loss for comparison
|
75 |
+
self.original_loss = bbox_head_stage2.loss_bbox
|
76 |
+
|
77 |
+
# Create new loss function
|
78 |
+
if self.target_loss_type == 'GIoULoss':
|
79 |
+
new_loss = GIoULoss(loss_weight=self.loss_weight)
|
80 |
+
# Enable decoded bbox regression for IoU losses
|
81 |
+
bbox_head_stage2.reg_decoded_bbox = True
|
82 |
+
elif self.target_loss_type == 'CIoULoss':
|
83 |
+
new_loss = CIoULoss(loss_weight=self.loss_weight)
|
84 |
+
# Enable decoded bbox regression for IoU losses
|
85 |
+
bbox_head_stage2.reg_decoded_bbox = True
|
86 |
+
elif self.target_loss_type == 'DIoULoss':
|
87 |
+
new_loss = DIoULoss(loss_weight=self.loss_weight)
|
88 |
+
# Enable decoded bbox regression for IoU losses
|
89 |
+
bbox_head_stage2.reg_decoded_bbox = True
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
|
92 |
+
|
93 |
+
# Store the switch information with loss-specific benefits
|
94 |
+
if self.target_loss_type == 'CIoULoss':
|
95 |
+
runner.logger.info(f"π― CIoU Loss Benefits for Data Points:")
|
96 |
+
runner.logger.info(f" β’ Directly optimizes center point distance")
|
97 |
+
runner.logger.info(f" β’ Enforces aspect ratio consistency (square-ish data points)")
|
98 |
+
runner.logger.info(f" β’ Better convergence for small objects")
|
99 |
+
runner.logger.info(f" β’ Most complete bounding box quality metric")
|
100 |
+
elif self.target_loss_type == 'DIoULoss':
|
101 |
+
runner.logger.info(f"π― DIoU Loss Benefits for Data Points:")
|
102 |
+
runner.logger.info(f" β’ Directly optimizes center point distance")
|
103 |
+
runner.logger.info(f" β’ Better convergence for small objects")
|
104 |
+
runner.logger.info(f" β’ More precise localization for data points")
|
105 |
+
elif self.target_loss_type == 'GIoULoss':
|
106 |
+
runner.logger.info(f"π― GIoU Loss Benefits:")
|
107 |
+
runner.logger.info(f" β’ Improved IoU-based optimization")
|
108 |
+
runner.logger.info(f" β’ Better than standard IoU loss")
|
109 |
+
|
110 |
+
# Replace the loss function
|
111 |
+
bbox_head_stage2.loss_bbox = new_loss
|
112 |
+
|
113 |
+
runner.logger.info(
|
114 |
+
f"Progressive Loss Switch: Stage 3 changed from "
|
115 |
+
f"{type(self.original_loss).__name__} to {self.target_loss_type}")
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
runner.logger.error(f"Failed to switch loss function: {e}")
|
119 |
+
|
120 |
+
def _log_loss_info(self, runner, epoch):
|
121 |
+
"""Log information about current loss configuration."""
|
122 |
+
try:
|
123 |
+
model = runner.model
|
124 |
+
if hasattr(model, 'module'):
|
125 |
+
bbox_heads = model.module.roi_head.bbox_head
|
126 |
+
else:
|
127 |
+
bbox_heads = model.roi_head.bbox_head
|
128 |
+
|
129 |
+
loss_info = {}
|
130 |
+
for i, head in enumerate(bbox_heads):
|
131 |
+
loss_type = type(head.loss_bbox).__name__
|
132 |
+
loss_weight = head.loss_bbox.loss_weight
|
133 |
+
loss_info[f'stage_{i+1}'] = f"{loss_type}(w={loss_weight})"
|
134 |
+
|
135 |
+
runner.logger.info(f"Epoch {epoch} Loss Configuration: {loss_info}")
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
runner.logger.warning(f"Could not log loss info: {e}")
|
139 |
+
|
140 |
+
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
|
141 |
+
"""Monitor loss values during training and detect NaN."""
|
142 |
+
if self.switched and outputs is not None and isinstance(outputs, dict):
|
143 |
+
# NaN detection and rollback logic
|
144 |
+
if self.nan_detection and not self.rollback_performed:
|
145 |
+
total_loss = outputs.get('loss', None)
|
146 |
+
if total_loss is not None and torch.isnan(total_loss):
|
147 |
+
self.consecutive_nans += 1
|
148 |
+
runner.logger.warning(f"π¨ NaN detected in total loss! Consecutive: {self.consecutive_nans}/{self.max_nan_tolerance}")
|
149 |
+
|
150 |
+
if self.consecutive_nans >= self.max_nan_tolerance:
|
151 |
+
self._rollback_loss(runner)
|
152 |
+
self.consecutive_nans = 0
|
153 |
+
self.rollback_performed = True
|
154 |
+
runner.logger.error(f"π EMERGENCY ROLLBACK: Switched back to SmoothL1Loss due to {self.max_nan_tolerance} consecutive NaN losses")
|
155 |
+
return
|
156 |
+
elif total_loss is not None and torch.isfinite(total_loss):
|
157 |
+
# Reset NaN counter on successful iteration
|
158 |
+
self.consecutive_nans = 0
|
159 |
+
|
160 |
+
# Log individual stage losses if available
|
161 |
+
log_vars = outputs.get('log_vars', {})
|
162 |
+
stage_losses = {}
|
163 |
+
|
164 |
+
for key, value in log_vars.items():
|
165 |
+
if 'loss_bbox' in key and isinstance(value, (int, float)):
|
166 |
+
stage_losses[key] = value
|
167 |
+
|
168 |
+
if stage_losses and self.monitor_stage_weights:
|
169 |
+
# Log every 100 iterations to avoid spam
|
170 |
+
if runner.iter % 100 == 0:
|
171 |
+
loss_summary = ", ".join([f"{k}: {v:.4f}" for k, v in stage_losses.items()])
|
172 |
+
runner.logger.info(f"Stage Losses - {loss_summary}")
|
173 |
+
|
174 |
+
def after_train_epoch(self, runner):
|
175 |
+
"""Check epoch completion and reset NaN counters."""
|
176 |
+
if self.nan_detection and self.switched:
|
177 |
+
# Log current status
|
178 |
+
if self.consecutive_nans > 0:
|
179 |
+
runner.logger.warning(f"Epoch {runner.epoch} completed with {self.consecutive_nans} NaN occurrences")
|
180 |
+
else:
|
181 |
+
runner.logger.info(f"Epoch {runner.epoch} completed successfully with {self.target_loss_type}")
|
182 |
+
|
183 |
+
def _rollback_loss(self, runner):
|
184 |
+
"""Rollback stage 3 to SmoothL1Loss."""
|
185 |
+
try:
|
186 |
+
model = runner.model
|
187 |
+
if hasattr(model, 'module'):
|
188 |
+
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
|
189 |
+
else:
|
190 |
+
bbox_head_stage2 = model.roi_head.bbox_head[2]
|
191 |
+
|
192 |
+
# Create new SmoothL1Loss
|
193 |
+
rollback_loss = SmoothL1Loss(beta=1.0, loss_weight=1.0)
|
194 |
+
bbox_head_stage2.loss_bbox = rollback_loss
|
195 |
+
bbox_head_stage2.reg_decoded_bbox = False # Disable decoded bbox for SmoothL1
|
196 |
+
|
197 |
+
runner.logger.info(f"β
Successfully rolled back Stage 3 from {self.target_loss_type} to SmoothL1Loss")
|
198 |
+
|
199 |
+
except Exception as e:
|
200 |
+
runner.logger.error(f"β Failed to rollback loss function: {e}")
|
201 |
+
|
202 |
+
|
203 |
+
@HOOKS.register_module()
|
204 |
+
class AdaptiveLossHook(Hook):
|
205 |
+
"""
|
206 |
+
Adaptive version that switches based on training stability metrics.
|
207 |
+
|
208 |
+
Monitors IoU overlap quality and switches when model is stable.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self,
|
212 |
+
min_epoch=3,
|
213 |
+
min_avg_iou=0.4,
|
214 |
+
target_loss_type='GIoULoss',
|
215 |
+
loss_weight=1.0,
|
216 |
+
check_interval=100):
|
217 |
+
super().__init__()
|
218 |
+
self.min_epoch = min_epoch
|
219 |
+
self.min_avg_iou = min_avg_iou
|
220 |
+
self.target_loss_type = target_loss_type
|
221 |
+
self.loss_weight = loss_weight
|
222 |
+
self.check_interval = check_interval
|
223 |
+
self.switched = False
|
224 |
+
self.iou_history = []
|
225 |
+
|
226 |
+
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
|
227 |
+
"""Monitor training stability through IoU metrics."""
|
228 |
+
if (not self.switched and
|
229 |
+
runner.epoch >= self.min_epoch and
|
230 |
+
runner.iter % self.check_interval == 0):
|
231 |
+
|
232 |
+
# Extract IoU information from outputs if available
|
233 |
+
if outputs and isinstance(outputs, dict):
|
234 |
+
log_vars = outputs.get('log_vars', {})
|
235 |
+
|
236 |
+
# Look for any IoU-related metrics
|
237 |
+
iou_metrics = [v for k, v in log_vars.items()
|
238 |
+
if 'iou' in k.lower() and isinstance(v, (int, float))]
|
239 |
+
|
240 |
+
if iou_metrics:
|
241 |
+
avg_iou = sum(iou_metrics) / len(iou_metrics)
|
242 |
+
self.iou_history.append(avg_iou)
|
243 |
+
|
244 |
+
# Keep only recent history
|
245 |
+
if len(self.iou_history) > 10:
|
246 |
+
self.iou_history.pop(0)
|
247 |
+
|
248 |
+
# Check if we should switch
|
249 |
+
if (len(self.iou_history) >= 5 and
|
250 |
+
sum(self.iou_history[-5:]) / 5 >= self.min_avg_iou):
|
251 |
+
|
252 |
+
self._switch_stage2_loss(runner)
|
253 |
+
self.switched = True
|
254 |
+
|
255 |
+
recent_iou = sum(self.iou_history[-5:]) / 5
|
256 |
+
runner.logger.info(
|
257 |
+
f"Adaptive switch at epoch {runner.epoch}, iter {runner.iter}: "
|
258 |
+
f"avg IoU {recent_iou:.3f} >= {self.min_avg_iou}")
|
259 |
+
|
260 |
+
def _switch_stage2_loss(self, runner):
|
261 |
+
"""Same switching logic as ProgressiveLossHook."""
|
262 |
+
model = runner.model
|
263 |
+
try:
|
264 |
+
if hasattr(model, 'module'):
|
265 |
+
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
|
266 |
+
else:
|
267 |
+
bbox_head_stage2 = model.roi_head.bbox_head[2]
|
268 |
+
|
269 |
+
if self.target_loss_type == 'GIoULoss':
|
270 |
+
new_loss = GIoULoss(loss_weight=self.loss_weight)
|
271 |
+
bbox_head_stage2.reg_decoded_bbox = True
|
272 |
+
elif self.target_loss_type == 'CIoULoss':
|
273 |
+
new_loss = CIoULoss(loss_weight=self.loss_weight)
|
274 |
+
bbox_head_stage2.reg_decoded_bbox = True
|
275 |
+
elif self.target_loss_type == 'DIoULoss':
|
276 |
+
new_loss = DIoULoss(loss_weight=self.loss_weight)
|
277 |
+
bbox_head_stage2.reg_decoded_bbox = True
|
278 |
+
else:
|
279 |
+
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
|
280 |
+
|
281 |
+
bbox_head_stage2.loss_bbox = new_loss
|
282 |
+
|
283 |
+
runner.logger.info(f"Adaptive Loss Switch: Stage 3 β {self.target_loss_type}")
|
284 |
+
|
285 |
+
except Exception as e:
|
286 |
+
runner.logger.error(f"Failed to switch loss function: {e}")
|
custom_models/register.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmdet.registry import MODELS, DATASETS, TRANSFORMS, HOOKS
|
2 |
+
from .custom_heads import FCHead, RegHead, DataSeriesHead
|
3 |
+
from .custom_cascade_with_meta import CustomCascadeWithMeta
|
4 |
+
from .custom_dataset import (
|
5 |
+
ChartDataset, RobustLoadImageFromFile, CreateDummyImg,
|
6 |
+
ClampBBoxes, SetScaleFactor, EnsureScaleFactor, SetInputs, CustomPackDetInputs
|
7 |
+
)
|
8 |
+
from .flexible_load_annotations import FlexibleLoadAnnotations
|
9 |
+
from .custom_hooks import (
|
10 |
+
ChartTypeDistributionHook, SkipInvalidLossHook, RuntimeErrorHook, MissingImageReportHook, SkipBadSamplesHook, CompatibleCheckpointHook
|
11 |
+
)
|
12 |
+
from .nan_recovery_hook import NanRecoveryHook
|
13 |
+
from .progressive_loss_hook import ProgressiveLossHook, AdaptiveLossHook
|
14 |
+
from .square_fcn_mask_head import SquareFCNMaskHead
|
15 |
+
|
16 |
+
def register_all_modules():
|
17 |
+
"""Register all enhanced modules for comprehensive chart detection."""
|
18 |
+
|
19 |
+
# Note: Most modules are already registered via decorators
|
20 |
+
# Only register modules that don't use decorators
|
21 |
+
|
22 |
+
print("β
Enhanced chart detection modules registered via decorators:")
|
23 |
+
print(" π Models: FCHead, RegHead, DataSeriesHead, CustomCascadeWithMeta, SquareFCNMaskHead")
|
24 |
+
print(" π Datasets: ChartDataset (21 categories)")
|
25 |
+
print(" π Transforms: RobustLoadImageFromFile, CreateDummyImg, ClampBBoxes, SetScaleFactor, EnsureScaleFactor, SetInputs, CustomPackDetInputs, FlexibleLoadAnnotations")
|
26 |
+
print(" π― Hooks: ChartTypeDistributionHook, SkipInvalidLossHook, RuntimeErrorHook, MissingImageReportHook, SkipBadSamplesHook, NanRecoveryHook, CompatibleCheckpointHook, ProgressiveLossHook, AdaptiveLossHook")
|
27 |
+
|
28 |
+
# Just print info, don't re-register since decorators handle it
|
29 |
+
register_all_modules()
|
custom_models/square_fcn_mask_head.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead
|
8 |
+
from mmdet.registry import MODELS
|
9 |
+
from mmdet.structures.mask.mask_target import mask_target
|
10 |
+
from typing import Union, Dict, Any
|
11 |
+
from mmengine.config import ConfigDict
|
12 |
+
from mmengine.structures import InstanceData
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
from .square_mask_target import square_mask_target
|
16 |
+
|
17 |
+
|
18 |
+
@MODELS.register_module()
|
19 |
+
class SquareFCNMaskHead(FCNMaskHead):
|
20 |
+
"""FCN mask head that forces square mask targets.
|
21 |
+
|
22 |
+
This head ensures that all mask targets are square regardless of the original
|
23 |
+
aspect ratio to avoid tensor size mismatches during training.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, *args, **kwargs):
|
27 |
+
print(f"π SQUARE_FCN_MASK_HEAD: Initializing SquareFCNMaskHead")
|
28 |
+
print(f"π SQUARE_FCN_MASK_HEAD: args: {args}")
|
29 |
+
print(f"π SQUARE_FCN_MASK_HEAD: kwargs: {kwargs}")
|
30 |
+
super().__init__(*args, **kwargs)
|
31 |
+
print(f"π SQUARE_FCN_MASK_HEAD: SquareFCNMaskHead initialized successfully")
|
32 |
+
|
33 |
+
def forward(self, x: Tensor) -> Tensor:
|
34 |
+
"""Forward features from the upstream network.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (Tensor): Extract mask RoI features.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Tensor: Predicted foreground masks.
|
41 |
+
"""
|
42 |
+
print(f"π SQUARE_FCN_MASK_HEAD: Input shape: {x.shape}")
|
43 |
+
|
44 |
+
for i, conv in enumerate(self.convs):
|
45 |
+
x = conv(x)
|
46 |
+
print(f"π SQUARE_FCN_MASK_HEAD: After conv {i} shape: {x.shape}")
|
47 |
+
|
48 |
+
if self.upsample is not None:
|
49 |
+
print(f"π SQUARE_FCN_MASK_HEAD: Upsampling from {x.shape}")
|
50 |
+
x = self.upsample(x)
|
51 |
+
if self.upsample_method == 'deconv':
|
52 |
+
x = self.relu(x)
|
53 |
+
print(f"π SQUARE_FCN_MASK_HEAD: After upsample shape: {x.shape}")
|
54 |
+
else:
|
55 |
+
print(f"π SQUARE_FCN_MASK_HEAD: No upsampling, shape: {x.shape}")
|
56 |
+
|
57 |
+
mask_preds = self.conv_logits(x)
|
58 |
+
print(f"π SQUARE_FCN_MASK_HEAD: Final mask_preds shape: {mask_preds.shape}")
|
59 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds device: {mask_preds.device}")
|
60 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds dtype: {mask_preds.dtype}")
|
61 |
+
|
62 |
+
return mask_preds
|
63 |
+
|
64 |
+
def loss_and_target(self,
|
65 |
+
mask_preds: Tensor,
|
66 |
+
sampling_results: List[Any],
|
67 |
+
batch_gt_instances: List[InstanceData],
|
68 |
+
rcnn_train_cfg: Union[Dict[str, Any], ConfigDict]) -> dict:
|
69 |
+
"""Calculate the loss based on the features extracted by the mask head.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
mask_preds (Tensor): Predicted foreground masks, has shape
|
73 |
+
(num_pos, num_classes, mask_h, mask_w).
|
74 |
+
sampling_results (List[:obj:`SamplingResult`]): Assign results of
|
75 |
+
all images in a batch after sampling.
|
76 |
+
batch_gt_instances (List[:obj:`InstanceData`]): Batch of
|
77 |
+
gt_instance. It usually includes ``bboxes``, ``labels``,
|
78 |
+
and ``masks`` attributes.
|
79 |
+
rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
dict: A dictionary of loss components.
|
83 |
+
"""
|
84 |
+
print(f"π SQUARE_FCN_MASK_HEAD: loss_and_target called")
|
85 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
|
86 |
+
|
87 |
+
# Get mask targets
|
88 |
+
mask_targets = self.get_targets(sampling_results, batch_gt_instances,
|
89 |
+
rcnn_train_cfg)
|
90 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
|
91 |
+
|
92 |
+
# Get labels for positive proposals
|
93 |
+
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
|
94 |
+
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels shape: {pos_labels.shape}")
|
95 |
+
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels: {pos_labels}")
|
96 |
+
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels min: {pos_labels.min()}")
|
97 |
+
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels max: {pos_labels.max()}")
|
98 |
+
print(f"π SQUARE_FCN_MASK_HEAD: num_classes: {self.num_classes}")
|
99 |
+
|
100 |
+
# Check for out-of-bounds labels
|
101 |
+
if pos_labels.max() >= self.num_classes:
|
102 |
+
print(f"π SQUARE_FCN_MASK_HEAD: ERROR! Found label {pos_labels.max()} >= num_classes {self.num_classes}")
|
103 |
+
# Clamp labels to valid range
|
104 |
+
pos_labels = torch.clamp(pos_labels, 0, self.num_classes - 1)
|
105 |
+
print(f"π SQUARE_FCN_MASK_HEAD: Clamped pos_labels max: {pos_labels.max()}")
|
106 |
+
|
107 |
+
# Check for size mismatch between predictions and targets
|
108 |
+
if mask_preds.shape[-2:] != mask_targets.shape[-2:]:
|
109 |
+
print(f"π SQUARE_FCN_MASK_HEAD: SIZE MISMATCH!")
|
110 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
|
111 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
|
112 |
+
|
113 |
+
# Calculate loss - use the original approach like FCNMaskHead
|
114 |
+
print(f"π SQUARE_FCN_MASK_HEAD: About to call loss_mask")
|
115 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
|
116 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
|
117 |
+
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels shape: {pos_labels.shape}")
|
118 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds device: {mask_preds.device}")
|
119 |
+
print(f"π SQUARE_FCN_MASK_HEAD: mask_targets device: {mask_targets.device}")
|
120 |
+
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels device: {pos_labels.device}")
|
121 |
+
|
122 |
+
# Call loss function with full mask_preds and pos_labels like the original FCN mask head
|
123 |
+
loss_mask = self.loss_mask(mask_preds, mask_targets, pos_labels)
|
124 |
+
|
125 |
+
print(f"π SQUARE_FCN_MASK_HEAD: Loss calculated successfully: {loss_mask}")
|
126 |
+
|
127 |
+
# only return the *nested* loss dict that StandardRoIHead.update() expects
|
128 |
+
return dict(
|
129 |
+
loss_mask={'loss_mask': loss_mask},
|
130 |
+
# if you really need mask_targets downstream you can still return it under a
|
131 |
+
# different key, but it will be ignored by the standard loss updater
|
132 |
+
mask_targets=mask_targets
|
133 |
+
)
|
134 |
+
|
135 |
+
def get_targets(self,
|
136 |
+
sampling_results: List[Any],
|
137 |
+
batch_gt_instances: List[InstanceData],
|
138 |
+
rcnn_train_cfg: Union[Dict[str, Any], ConfigDict]) -> Tensor:
|
139 |
+
"""Calculate the ground truth for all samples in a batch according to
|
140 |
+
the sampling_results.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
sampling_results (List[:obj:`SamplingResult`]): Assign results of
|
144 |
+
all images in a batch after sampling.
|
145 |
+
batch_gt_instances (List[:obj:`InstanceData`]): Batch of
|
146 |
+
gt_instance. It usually includes ``bboxes``, ``labels``,
|
147 |
+
and ``masks`` attributes.
|
148 |
+
rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Tensor: Mask targets of each positive proposals in the image,
|
152 |
+
has shape (num_pos, mask_h, mask_w).
|
153 |
+
"""
|
154 |
+
print(f"π SQUARE_FCN_MASK_HEAD: get_targets called")
|
155 |
+
|
156 |
+
pos_proposals_list = [res.pos_priors for res in sampling_results]
|
157 |
+
pos_assigned_gt_inds_list = [
|
158 |
+
res.pos_assigned_gt_inds for res in sampling_results
|
159 |
+
]
|
160 |
+
gt_masks_list = [res.masks for res in batch_gt_instances]
|
161 |
+
|
162 |
+
print(f"π SQUARE_FCN_MASK_HEAD: Number of sampling results: {len(sampling_results)}")
|
163 |
+
print(f"π SQUARE_FCN_MASK_HEAD: rcnn_train_cfg: {rcnn_train_cfg}")
|
164 |
+
|
165 |
+
# Use our custom square mask target function
|
166 |
+
mask_targets = square_mask_target(pos_proposals_list, pos_assigned_gt_inds_list,
|
167 |
+
gt_masks_list, rcnn_train_cfg)
|
168 |
+
|
169 |
+
print(f"π SQUARE_FCN_MASK_HEAD: Final mask_targets shape: {mask_targets.shape}")
|
170 |
+
return mask_targets
|
custom_models/square_mask_target.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.nn.modules.utils import _pair
|
5 |
+
|
6 |
+
from mmdet.registry import MODELS
|
7 |
+
from mmdet.structures.mask.mask_target import mask_target as original_mask_target
|
8 |
+
|
9 |
+
|
10 |
+
def square_mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, cfg):
|
11 |
+
"""Compute square mask target for positive proposals in multiple images.
|
12 |
+
|
13 |
+
This function forces all mask targets to be square regardless of the original
|
14 |
+
aspect ratio to avoid tensor size mismatches.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
pos_proposals_list (list[Tensor]): Positive proposals in multiple images.
|
18 |
+
pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each positive proposals.
|
19 |
+
gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of each image.
|
20 |
+
cfg (dict): Config dict that specifies the mask size.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tensor: Square mask target of each image, has shape (num_pos, size, size).
|
24 |
+
"""
|
25 |
+
# Get the target size (should be a tuple like (14, 14))
|
26 |
+
mask_size = _pair(cfg.mask_size)
|
27 |
+
|
28 |
+
# Force square size by using the minimum dimension
|
29 |
+
square_size = min(mask_size)
|
30 |
+
|
31 |
+
# Create a proper ConfigDict object
|
32 |
+
from mmengine.config import ConfigDict
|
33 |
+
square_cfg = ConfigDict({'mask_size': (square_size, square_size)})
|
34 |
+
|
35 |
+
# Call the original mask target function with square size
|
36 |
+
mask_targets = original_mask_target(pos_proposals_list, pos_assigned_gt_inds_list,
|
37 |
+
gt_masks_list, square_cfg)
|
38 |
+
|
39 |
+
print(f"π SQUARE_MASK_TARGET: Original mask_targets shape: {mask_targets.shape}")
|
40 |
+
print(f"π SQUARE_MASK_TARGET: Expected square_size: {square_size}")
|
41 |
+
|
42 |
+
# Force square shape by padding or cropping if necessary
|
43 |
+
if mask_targets.size(1) != square_size or mask_targets.size(2) != square_size:
|
44 |
+
print(f"π SQUARE_MASK_TARGET: Forcing square shape from {mask_targets.shape} to ({mask_targets.size(0)}, {square_size}, {square_size})")
|
45 |
+
|
46 |
+
# Create new tensor with square shape
|
47 |
+
num_masks = mask_targets.size(0)
|
48 |
+
square_targets = torch.zeros(num_masks, square_size, square_size,
|
49 |
+
device=mask_targets.device, dtype=mask_targets.dtype)
|
50 |
+
|
51 |
+
# Copy the mask data, padding with zeros if necessary
|
52 |
+
h, w = mask_targets.size(1), mask_targets.size(2)
|
53 |
+
h_copy = min(h, square_size)
|
54 |
+
w_copy = min(w, square_size)
|
55 |
+
|
56 |
+
square_targets[:, :h_copy, :w_copy] = mask_targets[:, :h_copy, :w_copy]
|
57 |
+
mask_targets = square_targets
|
58 |
+
|
59 |
+
print(f"π SQUARE_MASK_TARGET: Final mask_targets shape: {mask_targets.shape}")
|
60 |
+
else:
|
61 |
+
print(f"π SQUARE_MASK_TARGET: Masks already square: {mask_targets.shape}")
|
62 |
+
|
63 |
+
return mask_targets
|
64 |
+
|
65 |
+
|
66 |
+
# Register the custom function
|
67 |
+
MODELS.register_module(name='square_mask_target', module=square_mask_target)
|
debug_api.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Debug script to check API status and endpoints
|
4 |
+
"""
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import json
|
8 |
+
|
9 |
+
def check_space_status():
|
10 |
+
"""Check if the space is running"""
|
11 |
+
print("π Checking space status...")
|
12 |
+
|
13 |
+
try:
|
14 |
+
# Check main page
|
15 |
+
response = requests.get("https://hanszhu-dense-captioning-platform.hf.space/")
|
16 |
+
print(f"Main page status: {response.status_code}")
|
17 |
+
|
18 |
+
if response.status_code == 200:
|
19 |
+
print("β
Space is accessible")
|
20 |
+
else:
|
21 |
+
print("β Space is not accessible")
|
22 |
+
|
23 |
+
except Exception as e:
|
24 |
+
print(f"β Error checking space: {e}")
|
25 |
+
|
26 |
+
def check_api_endpoints():
|
27 |
+
"""Check various API endpoints"""
|
28 |
+
print("\nπ Checking API endpoints...")
|
29 |
+
|
30 |
+
base_url = "https://hanszhu-dense-captioning-platform.hf.space"
|
31 |
+
|
32 |
+
endpoints = [
|
33 |
+
"/",
|
34 |
+
"/api",
|
35 |
+
"/api/predict",
|
36 |
+
"/predict",
|
37 |
+
"/api/predict/",
|
38 |
+
"/predict/"
|
39 |
+
]
|
40 |
+
|
41 |
+
for endpoint in endpoints:
|
42 |
+
try:
|
43 |
+
response = requests.get(f"{base_url}{endpoint}")
|
44 |
+
print(f"{endpoint}: {response.status_code} - {response.headers.get('content-type', 'unknown')}")
|
45 |
+
|
46 |
+
if response.status_code == 200:
|
47 |
+
print(f" Content preview: {response.text[:100]}...")
|
48 |
+
|
49 |
+
except Exception as e:
|
50 |
+
print(f"{endpoint}: Error - {e}")
|
51 |
+
|
52 |
+
def test_post_request():
|
53 |
+
"""Test POST request to predict endpoint"""
|
54 |
+
print("\nπ Testing POST request...")
|
55 |
+
|
56 |
+
try:
|
57 |
+
# Test URL
|
58 |
+
test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
59 |
+
|
60 |
+
# Try different POST formats
|
61 |
+
test_data = [
|
62 |
+
{"data": [test_url]},
|
63 |
+
{"fn_index": 0, "data": [test_url]},
|
64 |
+
{"data": test_url},
|
65 |
+
test_url
|
66 |
+
]
|
67 |
+
|
68 |
+
for i, data in enumerate(test_data):
|
69 |
+
print(f"\nTest {i+1}: {type(data)}")
|
70 |
+
|
71 |
+
try:
|
72 |
+
response = requests.post(
|
73 |
+
"https://hanszhu-dense-captioning-platform.hf.space/predict",
|
74 |
+
json=data,
|
75 |
+
headers={"Content-Type": "application/json"}
|
76 |
+
)
|
77 |
+
|
78 |
+
print(f" Status: {response.status_code}")
|
79 |
+
print(f" Content-Type: {response.headers.get('content-type', 'unknown')}")
|
80 |
+
print(f" Response: {response.text[:200]}...")
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
print(f" Error: {e}")
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
print(f"β Error in POST test: {e}")
|
87 |
+
|
88 |
+
def test_gradio_client_detailed():
|
89 |
+
"""Test gradio_client with detailed error handling"""
|
90 |
+
print("\nπ Testing gradio_client with detailed error handling...")
|
91 |
+
|
92 |
+
try:
|
93 |
+
from gradio_client import Client
|
94 |
+
|
95 |
+
print("Creating client...")
|
96 |
+
client = Client("hanszhu/Dense-Captioning-Platform")
|
97 |
+
|
98 |
+
print("Getting space info...")
|
99 |
+
info = client.view_api()
|
100 |
+
print(f"API info: {info}")
|
101 |
+
|
102 |
+
print("Making prediction...")
|
103 |
+
test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
104 |
+
result = client.predict(test_url, api_name="/predict")
|
105 |
+
|
106 |
+
print(f"β
Success! Result: {result}")
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
print(f"β gradio_client error: {e}")
|
110 |
+
import traceback
|
111 |
+
traceback.print_exc()
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
print("π Debugging Dense Captioning Platform API")
|
115 |
+
print("=" * 60)
|
116 |
+
|
117 |
+
check_space_status()
|
118 |
+
check_api_endpoints()
|
119 |
+
test_post_request()
|
120 |
+
test_gradio_client_detailed()
|
121 |
+
|
122 |
+
print("\n" + "=" * 60)
|
123 |
+
print("οΏ½οΏ½ Debug completed!")
|
find_api_endpoint.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Script to find the correct API endpoint
|
4 |
+
"""
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import json
|
8 |
+
|
9 |
+
def try_different_endpoints():
|
10 |
+
"""Try different possible API endpoints"""
|
11 |
+
print("π Trying different API endpoints...")
|
12 |
+
|
13 |
+
base_url = "https://hanszhu-dense-captioning-platform.hf.space"
|
14 |
+
|
15 |
+
# Different possible endpoints
|
16 |
+
endpoints = [
|
17 |
+
"/api/predict",
|
18 |
+
"/predict",
|
19 |
+
"/api/run/predict",
|
20 |
+
"/run/predict",
|
21 |
+
"/api/0",
|
22 |
+
"/0",
|
23 |
+
"/api/1",
|
24 |
+
"/1",
|
25 |
+
"/api/2",
|
26 |
+
"/2",
|
27 |
+
"/api/3",
|
28 |
+
"/3",
|
29 |
+
"/api/4",
|
30 |
+
"/4",
|
31 |
+
"/api/5",
|
32 |
+
"/5"
|
33 |
+
]
|
34 |
+
|
35 |
+
test_data = {
|
36 |
+
"data": ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"]
|
37 |
+
}
|
38 |
+
|
39 |
+
for endpoint in endpoints:
|
40 |
+
print(f"\nTrying POST to: {endpoint}")
|
41 |
+
|
42 |
+
try:
|
43 |
+
response = requests.post(
|
44 |
+
f"{base_url}{endpoint}",
|
45 |
+
json=test_data,
|
46 |
+
headers={"Content-Type": "application/json"}
|
47 |
+
)
|
48 |
+
|
49 |
+
print(f" Status: {response.status_code}")
|
50 |
+
print(f" Content-Type: {response.headers.get('content-type', 'unknown')}")
|
51 |
+
|
52 |
+
if response.status_code == 200:
|
53 |
+
print(" β
SUCCESS! Found working endpoint!")
|
54 |
+
print(f" Response: {response.text[:200]}...")
|
55 |
+
return endpoint
|
56 |
+
elif response.status_code == 405:
|
57 |
+
print(" β οΈ Method not allowed (endpoint exists but wrong method)")
|
58 |
+
elif response.status_code == 404:
|
59 |
+
print(" β Not found")
|
60 |
+
else:
|
61 |
+
print(f" β Unexpected status: {response.text[:100]}...")
|
62 |
+
|
63 |
+
except Exception as e:
|
64 |
+
print(f" β Error: {e}")
|
65 |
+
|
66 |
+
return None
|
67 |
+
|
68 |
+
def try_get_endpoints():
|
69 |
+
"""Try GET requests to find API info"""
|
70 |
+
print("\nπ Trying GET requests to find API info...")
|
71 |
+
|
72 |
+
base_url = "https://hanszhu-dense-captioning-platform.hf.space"
|
73 |
+
|
74 |
+
get_endpoints = [
|
75 |
+
"/api",
|
76 |
+
"/api/",
|
77 |
+
"/api/predict",
|
78 |
+
"/api/predict/",
|
79 |
+
"/api/run/predict",
|
80 |
+
"/api/run/predict/",
|
81 |
+
"/api/0",
|
82 |
+
"/api/1",
|
83 |
+
"/api/2",
|
84 |
+
"/api/3",
|
85 |
+
"/api/4",
|
86 |
+
"/api/5"
|
87 |
+
]
|
88 |
+
|
89 |
+
for endpoint in get_endpoints:
|
90 |
+
print(f"\nTrying GET: {endpoint}")
|
91 |
+
|
92 |
+
try:
|
93 |
+
response = requests.get(f"{base_url}{endpoint}")
|
94 |
+
|
95 |
+
print(f" Status: {response.status_code}")
|
96 |
+
print(f" Content-Type: {response.headers.get('content-type', 'unknown')}")
|
97 |
+
|
98 |
+
if response.status_code == 200:
|
99 |
+
content = response.text[:200]
|
100 |
+
print(f" Content: {content}...")
|
101 |
+
|
102 |
+
# Check if it's JSON
|
103 |
+
if response.headers.get('content-type', '').startswith('application/json'):
|
104 |
+
print(" β
JSON response - this might be the API info!")
|
105 |
+
try:
|
106 |
+
data = response.json()
|
107 |
+
print(f" API Info: {json.dumps(data, indent=2)}")
|
108 |
+
except:
|
109 |
+
pass
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
print(f" β Error: {e}")
|
113 |
+
|
114 |
+
def try_gradio_client_different_ways():
|
115 |
+
"""Try gradio_client with different approaches"""
|
116 |
+
print("\nπ Trying gradio_client with different approaches...")
|
117 |
+
|
118 |
+
try:
|
119 |
+
from gradio_client import Client
|
120 |
+
|
121 |
+
print("Creating client...")
|
122 |
+
client = Client("hanszhu/Dense-Captioning-Platform")
|
123 |
+
|
124 |
+
print("Trying different API names...")
|
125 |
+
|
126 |
+
api_names = ["/predict", "/run/predict", "0", "1", "2", "3", "4", "5"]
|
127 |
+
|
128 |
+
for api_name in api_names:
|
129 |
+
print(f"\nTrying api_name: {api_name}")
|
130 |
+
|
131 |
+
try:
|
132 |
+
test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
133 |
+
result = client.predict(test_url, api_name=api_name)
|
134 |
+
print(f" β
SUCCESS with api_name={api_name}!")
|
135 |
+
print(f" Result: {result}")
|
136 |
+
return api_name
|
137 |
+
|
138 |
+
except Exception as e:
|
139 |
+
print(f" β Failed: {e}")
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
print(f"β gradio_client error: {e}")
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
print("π Finding the correct API endpoint")
|
146 |
+
print("=" * 60)
|
147 |
+
|
148 |
+
# Try different POST endpoints
|
149 |
+
working_endpoint = try_different_endpoints()
|
150 |
+
|
151 |
+
# Try GET endpoints for API info
|
152 |
+
try_get_endpoints()
|
153 |
+
|
154 |
+
# Try gradio_client with different approaches
|
155 |
+
working_api_name = try_gradio_client_different_ways()
|
156 |
+
|
157 |
+
print("\n" + "=" * 60)
|
158 |
+
print("π Endpoint discovery completed!")
|
159 |
+
|
160 |
+
if working_endpoint:
|
161 |
+
print(f"β
Found working POST endpoint: {working_endpoint}")
|
162 |
+
if working_api_name:
|
163 |
+
print(f"β
Found working gradio_client api_name: {working_api_name}")
|
164 |
+
|
165 |
+
if not working_endpoint and not working_api_name:
|
166 |
+
print("β No working endpoints found")
|
167 |
+
print("The space might still be loading or need different configuration")
|
models/chart_elementnet_swin.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cascade_rcnn_r50_fpn_meta.py - Enhanced config with Swin Transformer backbone
|
2 |
+
#
|
3 |
+
# PROGRESSIVE LOSS STRATEGY:
|
4 |
+
# - All 3 Cascade stages start with SmoothL1Loss for stable initial training
|
5 |
+
# - At epoch 5, Stage 3 (final stage) switches to GIoULoss via ProgressiveLossHook
|
6 |
+
# - Stage 1 & 2 remain SmoothL1Loss throughout training
|
7 |
+
# - This ensures model stability before introducing more complex IoU-based losses
|
8 |
+
|
9 |
+
# Custom imports - this registers our modules without polluting config namespace
|
10 |
+
custom_imports = dict(
|
11 |
+
imports=[
|
12 |
+
'custom_models.custom_dataset',
|
13 |
+
'custom_models.register',
|
14 |
+
'custom_models.custom_hooks',
|
15 |
+
'custom_models.progressive_loss_hook',
|
16 |
+
],
|
17 |
+
allow_failed_imports=False
|
18 |
+
)
|
19 |
+
|
20 |
+
# Add to Python path
|
21 |
+
import sys
|
22 |
+
import os
|
23 |
+
# Use a simpler path approach that doesn't rely on __file__
|
24 |
+
sys.path.insert(0, os.path.join(os.getcwd(), '..', '..'))
|
25 |
+
|
26 |
+
# Custom Cascade model with coordinate handling for chart data
|
27 |
+
model = dict(
|
28 |
+
type='CustomCascadeWithMeta', # Use custom model with coordinate handling
|
29 |
+
coordinate_standardization=dict(
|
30 |
+
enabled=True,
|
31 |
+
origin='bottom_left', # Match annotation creation coordinate system
|
32 |
+
normalize=True,
|
33 |
+
relative_to_plot=False, # Keep simple for now
|
34 |
+
scale_to_axis=False # Keep simple for now
|
35 |
+
),
|
36 |
+
data_preprocessor=dict(
|
37 |
+
type='DetDataPreprocessor',
|
38 |
+
mean=[123.675, 116.28, 103.53],
|
39 |
+
std=[58.395, 57.12, 57.375],
|
40 |
+
bgr_to_rgb=True,
|
41 |
+
pad_size_divisor=32),
|
42 |
+
# ----- Swin Transformer Base (22K) Backbone + FPN -----
|
43 |
+
backbone=dict(
|
44 |
+
type='SwinTransformer',
|
45 |
+
embed_dims=128, # Swin Base embedding dimensions
|
46 |
+
depths=[2, 2, 18, 2], # Swin Base depths
|
47 |
+
num_heads=[4, 8, 16, 32], # Swin Base attention heads
|
48 |
+
window_size=7,
|
49 |
+
mlp_ratio=4,
|
50 |
+
qkv_bias=True,
|
51 |
+
qk_scale=None,
|
52 |
+
drop_rate=0.0,
|
53 |
+
attn_drop_rate=0.0,
|
54 |
+
drop_path_rate=0.3, # Slightly higher for more complex model
|
55 |
+
patch_norm=True,
|
56 |
+
out_indices=(0, 1, 2, 3),
|
57 |
+
with_cp=False,
|
58 |
+
convert_weights=True,
|
59 |
+
init_cfg=dict(
|
60 |
+
type='Pretrained',
|
61 |
+
checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth'
|
62 |
+
)
|
63 |
+
),
|
64 |
+
neck=dict(
|
65 |
+
type='FPN',
|
66 |
+
in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage)
|
67 |
+
out_channels=256,
|
68 |
+
num_outs=6,
|
69 |
+
start_level=0,
|
70 |
+
add_extra_convs='on_input'
|
71 |
+
),
|
72 |
+
# Enhanced RPN with smaller anchors for tiny objects + improved losses
|
73 |
+
rpn_head=dict(
|
74 |
+
type='RPNHead',
|
75 |
+
in_channels=256,
|
76 |
+
feat_channels=256,
|
77 |
+
anchor_generator=dict(
|
78 |
+
type='AnchorGenerator',
|
79 |
+
scales=[1, 2, 4, 8], # Even smaller scales for tiny objects
|
80 |
+
ratios=[0.5, 1.0, 2.0], # Multiple aspect ratios
|
81 |
+
strides=[4, 8, 16, 32, 64, 128]), # Extended FPN strides
|
82 |
+
bbox_coder=dict(
|
83 |
+
type='DeltaXYWHBBoxCoder',
|
84 |
+
target_means=[.0, .0, .0, .0],
|
85 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
86 |
+
loss_cls=dict(
|
87 |
+
type='CrossEntropyLoss',
|
88 |
+
use_sigmoid=True,
|
89 |
+
loss_weight=1.0),
|
90 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
|
91 |
+
# Progressive Loss Strategy: Start with SmoothL1 for all 3 stages
|
92 |
+
# Stage 3 (final stage) will switch to GIoU at epoch 5 via ProgressiveLossHook
|
93 |
+
roi_head=dict(
|
94 |
+
type='CascadeRoIHead',
|
95 |
+
num_stages=3,
|
96 |
+
stage_loss_weights=[1, 0.5, 0.25],
|
97 |
+
bbox_roi_extractor=dict(
|
98 |
+
type='SingleRoIExtractor',
|
99 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
100 |
+
out_channels=256,
|
101 |
+
featmap_strides=[4, 8, 16, 32]),
|
102 |
+
bbox_head=[
|
103 |
+
# Stage 1: Always SmoothL1Loss (coarse detection)
|
104 |
+
dict(
|
105 |
+
type='Shared2FCBBoxHead',
|
106 |
+
in_channels=256,
|
107 |
+
fc_out_channels=1024,
|
108 |
+
roi_feat_size=7,
|
109 |
+
num_classes=21, # 21 enhanced categories
|
110 |
+
bbox_coder=dict(
|
111 |
+
type='DeltaXYWHBBoxCoder',
|
112 |
+
target_means=[0., 0., 0., 0.],
|
113 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
114 |
+
reg_class_agnostic=True,
|
115 |
+
loss_cls=dict(
|
116 |
+
type='CrossEntropyLoss',
|
117 |
+
use_sigmoid=False,
|
118 |
+
loss_weight=1.0),
|
119 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
|
120 |
+
# Stage 2: Always SmoothL1Loss (intermediate refinement)
|
121 |
+
dict(
|
122 |
+
type='Shared2FCBBoxHead',
|
123 |
+
in_channels=256,
|
124 |
+
fc_out_channels=1024,
|
125 |
+
roi_feat_size=7,
|
126 |
+
num_classes=21, # 21 enhanced categories
|
127 |
+
bbox_coder=dict(
|
128 |
+
type='DeltaXYWHBBoxCoder',
|
129 |
+
target_means=[0., 0., 0., 0.],
|
130 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
131 |
+
reg_class_agnostic=True,
|
132 |
+
loss_cls=dict(
|
133 |
+
type='CrossEntropyLoss',
|
134 |
+
use_sigmoid=False,
|
135 |
+
loss_weight=1.0),
|
136 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
|
137 |
+
# Stage 3: SmoothL1 β GIoU at epoch 5 (progressive switching)
|
138 |
+
dict(
|
139 |
+
type='Shared2FCBBoxHead',
|
140 |
+
in_channels=256,
|
141 |
+
fc_out_channels=1024,
|
142 |
+
roi_feat_size=7,
|
143 |
+
num_classes=21, # 21 enhanced categories
|
144 |
+
bbox_coder=dict(
|
145 |
+
type='DeltaXYWHBBoxCoder',
|
146 |
+
target_means=[0., 0., 0., 0.],
|
147 |
+
target_stds=[0.02, 0.02, 0.05, 0.05]),
|
148 |
+
reg_class_agnostic=True,
|
149 |
+
loss_cls=dict(
|
150 |
+
type='CrossEntropyLoss',
|
151 |
+
use_sigmoid=False,
|
152 |
+
loss_weight=1.0),
|
153 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
|
154 |
+
]),
|
155 |
+
train_cfg=dict(
|
156 |
+
rpn=dict(
|
157 |
+
assigner=dict(
|
158 |
+
type='MaxIoUAssigner',
|
159 |
+
pos_iou_thr=0.7,
|
160 |
+
neg_iou_thr=0.3,
|
161 |
+
min_pos_iou=0.3,
|
162 |
+
match_low_quality=True,
|
163 |
+
ignore_iof_thr=-1),
|
164 |
+
sampler=dict(
|
165 |
+
type='RandomSampler',
|
166 |
+
num=256,
|
167 |
+
pos_fraction=0.5,
|
168 |
+
neg_pos_ub=-1,
|
169 |
+
add_gt_as_proposals=False),
|
170 |
+
allowed_border=0,
|
171 |
+
pos_weight=-1,
|
172 |
+
debug=False),
|
173 |
+
rpn_proposal=dict(
|
174 |
+
nms_pre=2000,
|
175 |
+
max_per_img=2000,
|
176 |
+
nms=dict(type='nms', iou_threshold=0.8),
|
177 |
+
min_bbox_size=0),
|
178 |
+
rcnn=[
|
179 |
+
dict(
|
180 |
+
assigner=dict(
|
181 |
+
type='MaxIoUAssigner',
|
182 |
+
pos_iou_thr=0.4,
|
183 |
+
neg_iou_thr=0.4,
|
184 |
+
min_pos_iou=0.4,
|
185 |
+
match_low_quality=False,
|
186 |
+
ignore_iof_thr=-1),
|
187 |
+
sampler=dict(
|
188 |
+
type='RandomSampler',
|
189 |
+
num=512,
|
190 |
+
pos_fraction=0.25,
|
191 |
+
neg_pos_ub=-1,
|
192 |
+
add_gt_as_proposals=True),
|
193 |
+
pos_weight=-1,
|
194 |
+
debug=False),
|
195 |
+
dict(
|
196 |
+
assigner=dict(
|
197 |
+
type='MaxIoUAssigner',
|
198 |
+
pos_iou_thr=0.6,
|
199 |
+
neg_iou_thr=0.6,
|
200 |
+
min_pos_iou=0.6,
|
201 |
+
match_low_quality=False,
|
202 |
+
ignore_iof_thr=-1),
|
203 |
+
sampler=dict(
|
204 |
+
type='RandomSampler',
|
205 |
+
num=512,
|
206 |
+
pos_fraction=0.25,
|
207 |
+
neg_pos_ub=-1,
|
208 |
+
add_gt_as_proposals=True),
|
209 |
+
pos_weight=-1,
|
210 |
+
debug=False),
|
211 |
+
dict(
|
212 |
+
assigner=dict(
|
213 |
+
type='MaxIoUAssigner',
|
214 |
+
pos_iou_thr=0.7,
|
215 |
+
neg_iou_thr=0.7,
|
216 |
+
min_pos_iou=0.7,
|
217 |
+
match_low_quality=False,
|
218 |
+
ignore_iof_thr=-1),
|
219 |
+
sampler=dict(
|
220 |
+
type='RandomSampler',
|
221 |
+
num=512,
|
222 |
+
pos_fraction=0.25,
|
223 |
+
neg_pos_ub=-1,
|
224 |
+
add_gt_as_proposals=True),
|
225 |
+
pos_weight=-1,
|
226 |
+
debug=False)
|
227 |
+
]),
|
228 |
+
# Enhanced test configuration with soft-NMS and multi-scale support
|
229 |
+
test_cfg=dict(
|
230 |
+
rpn=dict(
|
231 |
+
nms_pre=1000,
|
232 |
+
max_per_img=1000,
|
233 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
234 |
+
min_bbox_size=0),
|
235 |
+
rcnn=dict(
|
236 |
+
score_thr=0.005, # Even lower threshold to catch more classes
|
237 |
+
nms=dict(
|
238 |
+
type='soft_nms', # Soft-NMS for better small object detection
|
239 |
+
iou_threshold=0.5,
|
240 |
+
min_score=0.005,
|
241 |
+
method='gaussian',
|
242 |
+
sigma=0.5),
|
243 |
+
max_per_img=500))) # Allow more detections
|
244 |
+
|
245 |
+
# Dataset settings - using cleaned annotations
|
246 |
+
dataset_type = 'ChartDataset'
|
247 |
+
data_root = '' # Remove data_root duplication
|
248 |
+
|
249 |
+
# Define the 21 chart element classes that match the annotations
|
250 |
+
CLASSES = (
|
251 |
+
'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label',
|
252 |
+
'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item',
|
253 |
+
'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line',
|
254 |
+
'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area'
|
255 |
+
)
|
256 |
+
|
257 |
+
# Updated to use cleaned annotation files
|
258 |
+
train_dataloader = dict(
|
259 |
+
batch_size=2, # Increased back to 2
|
260 |
+
num_workers=2,
|
261 |
+
persistent_workers=True,
|
262 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
263 |
+
dataset=dict(
|
264 |
+
type=dataset_type,
|
265 |
+
data_root=data_root,
|
266 |
+
ann_file='legend_data/annotations_JSON_cleaned/train_enriched.json', # Full path
|
267 |
+
data_prefix=dict(img='legend_data/train/images/'), # Full path
|
268 |
+
metainfo=dict(classes=CLASSES), # Tell dataset what classes to expect
|
269 |
+
filter_cfg=dict(filter_empty_gt=True, min_size=0, class_specific_min_sizes={
|
270 |
+
'data-point': 16, # Back to 16x16 from 32x32
|
271 |
+
'data-bar': 16, # Back to 16x16 from 32x32
|
272 |
+
'tick-label': 16, # Back to 16x16 from 32x32
|
273 |
+
'x-tick-label': 16, # Back to 16x16 from 32x32
|
274 |
+
'y-tick-label': 16 # Back to 16x16 from 32x32
|
275 |
+
}),
|
276 |
+
pipeline=[
|
277 |
+
dict(type='LoadImageFromFile'),
|
278 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
279 |
+
dict(type='Resize', scale=(1600, 1000), keep_ratio=True), # Higher resolution for tiny objects
|
280 |
+
dict(type='RandomFlip', prob=0.5),
|
281 |
+
dict(type='ClampBBoxes'), # Ensure bboxes stay within image bounds
|
282 |
+
dict(type='PackDetInputs')
|
283 |
+
]
|
284 |
+
)
|
285 |
+
)
|
286 |
+
|
287 |
+
val_dataloader = dict(
|
288 |
+
batch_size=1,
|
289 |
+
num_workers=2,
|
290 |
+
persistent_workers=True,
|
291 |
+
drop_last=False,
|
292 |
+
sampler=dict(type='DefaultSampler', shuffle=False),
|
293 |
+
dataset=dict(
|
294 |
+
type=dataset_type,
|
295 |
+
data_root=data_root,
|
296 |
+
ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json', # Full path
|
297 |
+
data_prefix=dict(img='legend_data/train/images/'), # All images are in train/images
|
298 |
+
metainfo=dict(classes=CLASSES), # Tell dataset what classes to expect
|
299 |
+
test_mode=True,
|
300 |
+
pipeline=[
|
301 |
+
dict(type='LoadImageFromFile'),
|
302 |
+
dict(type='Resize', scale=(1600, 1000), keep_ratio=True), # Base resolution for validation
|
303 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
304 |
+
dict(type='ClampBBoxes'), # Ensure bboxes stay within image bounds
|
305 |
+
dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
306 |
+
]
|
307 |
+
)
|
308 |
+
)
|
309 |
+
|
310 |
+
test_dataloader = val_dataloader
|
311 |
+
|
312 |
+
# Enhanced evaluators with debugging
|
313 |
+
val_evaluator = dict(
|
314 |
+
type='CocoMetric',
|
315 |
+
ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json', # Using cleaned annotations
|
316 |
+
metric='bbox',
|
317 |
+
format_only=False,
|
318 |
+
classwise=True, # Enable detailed per-class metrics table
|
319 |
+
proposal_nums=(100, 300, 1000)) # More detailed AR metrics
|
320 |
+
|
321 |
+
test_evaluator = val_evaluator
|
322 |
+
|
323 |
+
# Add custom hooks for debugging empty results
|
324 |
+
default_hooks = dict(
|
325 |
+
timer=dict(type='IterTimerHook'),
|
326 |
+
logger=dict(type='LoggerHook', interval=50),
|
327 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
328 |
+
checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3),
|
329 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
330 |
+
visualization=dict(type='DetVisualizationHook'))
|
331 |
+
|
332 |
+
# Add NaN recovery hook for graceful handling like Faster R-CNN
|
333 |
+
custom_hooks = [
|
334 |
+
dict(type='SkipBadSamplesHook', interval=1), # Skip samples with bad GT data
|
335 |
+
dict(type='ChartTypeDistributionHook', interval=500), # Monitor class distribution
|
336 |
+
dict(type='MissingImageReportHook', interval=1000), # Track missing images
|
337 |
+
dict(type='NanRecoveryHook', # For logging & monitoring
|
338 |
+
fallback_loss=1.0,
|
339 |
+
max_consecutive_nans=100,
|
340 |
+
log_interval=50),
|
341 |
+
dict(type='ProgressiveLossHook', # Progressive loss switching
|
342 |
+
switch_epoch=5, # Switch stage 3 to GIoU at epoch 5
|
343 |
+
target_loss_type='GIoULoss', # Use GIoU for stage 3 (final stage)
|
344 |
+
loss_weight=1.0, # Keep same loss weight
|
345 |
+
warmup_epochs=2, # Monitor for 2 epochs after switch
|
346 |
+
monitor_stage_weights=True), # Log stage loss details
|
347 |
+
]
|
348 |
+
|
349 |
+
# Training configuration - extended to 40 epochs for Swin Base on small objects
|
350 |
+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=40, val_interval=1)
|
351 |
+
val_cfg = dict(type='ValLoop')
|
352 |
+
test_cfg = dict(type='TestLoop')
|
353 |
+
|
354 |
+
# Optimizer with standard stable settings
|
355 |
+
optim_wrapper = dict(
|
356 |
+
type='OptimWrapper',
|
357 |
+
optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001),
|
358 |
+
clip_grad=dict(max_norm=35.0, norm_type=2)
|
359 |
+
)
|
360 |
+
|
361 |
+
# Extended learning rate schedule with cosine annealing for Swin Base
|
362 |
+
param_scheduler = [
|
363 |
+
dict(
|
364 |
+
type='LinearLR',
|
365 |
+
start_factor=0.05, # 1e-4 / 2e-2 = 0.05 (warmup from 1e-4 to 2e-2)
|
366 |
+
by_epoch=False,
|
367 |
+
begin=0,
|
368 |
+
end=1000), # 1k iteration warmup
|
369 |
+
dict(
|
370 |
+
type='CosineAnnealingLR',
|
371 |
+
begin=0,
|
372 |
+
end=40, # Match max_epochs
|
373 |
+
by_epoch=True,
|
374 |
+
T_max=40,
|
375 |
+
eta_min=1e-6, # Minimum learning rate
|
376 |
+
convert_to_iter_based=True)
|
377 |
+
]
|
378 |
+
|
379 |
+
# Work directory
|
380 |
+
work_dir = './work_dirs/cascade_rcnn_swin_base_40ep_cosine_fpn_meta'
|
381 |
+
|
382 |
+
# Multi-scale test configuration (uncomment to enable)
|
383 |
+
# img_scales = [(800, 500), (1600, 1000), (2400, 1500)] # 0.5x, 1.0x, 1.5x scales
|
384 |
+
# tta_model = dict(
|
385 |
+
# type='DetTTAModel',
|
386 |
+
# tta_cfg=dict(
|
387 |
+
# nms=dict(type='nms', iou_threshold=0.5),
|
388 |
+
# max_per_img=100)
|
389 |
+
# )
|
390 |
+
|
391 |
+
# Fresh start
|
392 |
+
resume = False
|
393 |
+
load_from = None
|
394 |
+
|
models/chart_pointnet_swin.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mask_rcnn_swin_meta.py - Mask R-CNN with Swin Transformer for data point segmentation
|
2 |
+
#
|
3 |
+
# ADAPTED FROM CASCADE R-CNN CONFIG:
|
4 |
+
# - Uses same Swin Transformer Base backbone with optimizations
|
5 |
+
# - Maintains data-point class weighting (10x) and IoU strategies
|
6 |
+
# - Adds mask head for instance segmentation of data points
|
7 |
+
# - Uses enhanced annotation files with segmentation masks
|
8 |
+
# - Keeps custom hooks and progressive loss strategies
|
9 |
+
#
|
10 |
+
# MASK-SPECIFIC OPTIMIZATIONS:
|
11 |
+
# - RoI size 14x14 for mask extraction (matches data point size)
|
12 |
+
# - FCN mask head with 4 convolution layers
|
13 |
+
# - Mask loss weight balanced with bbox and classification losses
|
14 |
+
# - Enhanced test-time augmentation for better mask quality
|
15 |
+
#
|
16 |
+
# DATA POINT FOCUS:
|
17 |
+
# - Primary target: data-point class (ID 11) with 10x weight
|
18 |
+
# - Generates both bounding boxes AND instance masks
|
19 |
+
# - Optimized for 16x16 pixel data points in scientific charts
|
20 |
+
# Removed _base_ inheritance to avoid path issues - all configs are inlined below
|
21 |
+
|
22 |
+
# Custom imports - same as Cascade R-CNN setup
|
23 |
+
custom_imports = dict(
|
24 |
+
imports=[
|
25 |
+
'custom_models.register',
|
26 |
+
'custom_models.custom_hooks',
|
27 |
+
'custom_models.progressive_loss_hook',
|
28 |
+
'custom_models.flexible_load_annotations',
|
29 |
+
],
|
30 |
+
allow_failed_imports=False
|
31 |
+
)
|
32 |
+
|
33 |
+
# Add to Python path
|
34 |
+
import sys
|
35 |
+
sys.path.insert(0, '.')
|
36 |
+
|
37 |
+
# Mask R-CNN model with Swin Transformer backbone
|
38 |
+
model = dict(
|
39 |
+
type='MaskRCNN',
|
40 |
+
data_preprocessor=dict(
|
41 |
+
type='DetDataPreprocessor',
|
42 |
+
mean=[123.675, 116.28, 103.53],
|
43 |
+
std=[58.395, 57.12, 57.375],
|
44 |
+
bgr_to_rgb=True,
|
45 |
+
pad_size_divisor=32,
|
46 |
+
pad_mask=True, # Important for mask training
|
47 |
+
mask_pad_value=0,
|
48 |
+
),
|
49 |
+
# Same Swin Transformer Base backbone as Cascade R-CNN
|
50 |
+
backbone=dict(
|
51 |
+
type='SwinTransformer',
|
52 |
+
embed_dims=128, # Swin Base embedding dimensions
|
53 |
+
depths=[2, 2, 18, 2], # Swin Base depths
|
54 |
+
num_heads=[4, 8, 16, 32], # Swin Base attention heads
|
55 |
+
window_size=7,
|
56 |
+
mlp_ratio=4,
|
57 |
+
qkv_bias=True,
|
58 |
+
qk_scale=None,
|
59 |
+
drop_rate=0.0,
|
60 |
+
attn_drop_rate=0.0,
|
61 |
+
drop_path_rate=0.3, # Same as Cascade config
|
62 |
+
patch_norm=True,
|
63 |
+
out_indices=(0, 1, 2, 3),
|
64 |
+
with_cp=False,
|
65 |
+
convert_weights=True,
|
66 |
+
init_cfg=dict(
|
67 |
+
type='Pretrained',
|
68 |
+
checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth'
|
69 |
+
)
|
70 |
+
),
|
71 |
+
# Same FPN as Cascade R-CNN
|
72 |
+
neck=dict(
|
73 |
+
type='FPN',
|
74 |
+
in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage)
|
75 |
+
out_channels=256,
|
76 |
+
num_outs=5, # Standard for Mask R-CNN (was 6 in Cascade)
|
77 |
+
start_level=0,
|
78 |
+
add_extra_convs='on_input'
|
79 |
+
),
|
80 |
+
# Same RPN configuration as Cascade R-CNN
|
81 |
+
rpn_head=dict(
|
82 |
+
type='RPNHead',
|
83 |
+
in_channels=256,
|
84 |
+
feat_channels=256,
|
85 |
+
anchor_generator=dict(
|
86 |
+
type='AnchorGenerator',
|
87 |
+
scales=[1, 2, 4, 8], # Same small scales for tiny objects
|
88 |
+
ratios=[0.5, 1.0, 2.0],
|
89 |
+
strides=[4, 8, 16, 32, 64]), # Standard FPN strides for Mask R-CNN
|
90 |
+
bbox_coder=dict(
|
91 |
+
type='DeltaXYWHBBoxCoder',
|
92 |
+
target_means=[.0, .0, .0, .0],
|
93 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
94 |
+
loss_cls=dict(
|
95 |
+
type='CrossEntropyLoss',
|
96 |
+
use_sigmoid=True,
|
97 |
+
loss_weight=1.0),
|
98 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)
|
99 |
+
),
|
100 |
+
# Mask R-CNN ROI head with bbox + mask branches
|
101 |
+
roi_head=dict(
|
102 |
+
type='StandardRoIHead',
|
103 |
+
# Bbox ROI extractor (same as Cascade R-CNN final stage)
|
104 |
+
bbox_roi_extractor=dict(
|
105 |
+
type='SingleRoIExtractor',
|
106 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
107 |
+
out_channels=256,
|
108 |
+
featmap_strides=[4, 8, 16, 32]
|
109 |
+
),
|
110 |
+
# Bbox head with data-point class weighting
|
111 |
+
bbox_head=dict(
|
112 |
+
type='Shared2FCBBoxHead',
|
113 |
+
in_channels=256,
|
114 |
+
fc_out_channels=1024,
|
115 |
+
roi_feat_size=7,
|
116 |
+
num_classes=22, # 22 enhanced categories including boxplot
|
117 |
+
bbox_coder=dict(
|
118 |
+
type='DeltaXYWHBBoxCoder',
|
119 |
+
target_means=[0., 0., 0., 0.],
|
120 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]
|
121 |
+
),
|
122 |
+
reg_class_agnostic=False,
|
123 |
+
loss_cls=dict(
|
124 |
+
type='CrossEntropyLoss',
|
125 |
+
use_sigmoid=False,
|
126 |
+
loss_weight=1.0,
|
127 |
+
class_weight=[1.0, # background class (index 0)
|
128 |
+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
|
129 |
+
10.0, # data-point at index 12 gets 10x weight (11+1 for background)
|
130 |
+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] # Added boxplot class
|
131 |
+
),
|
132 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)
|
133 |
+
),
|
134 |
+
# Mask ROI extractor (optimized for 16x16 data points)
|
135 |
+
mask_roi_extractor=dict(
|
136 |
+
type='SingleRoIExtractor',
|
137 |
+
roi_layer=dict(type='RoIAlign', output_size=(14, 14), sampling_ratio=0, aligned=True), # Force exact 14x14 with legacy alignment
|
138 |
+
out_channels=256,
|
139 |
+
featmap_strides=[4, 8, 16, 32]
|
140 |
+
),
|
141 |
+
# Mask head optimized for data points with square mask targets
|
142 |
+
mask_head=dict(
|
143 |
+
type='SquareFCNMaskHead',
|
144 |
+
num_convs=4, # 4 conv layers for good feature extraction
|
145 |
+
in_channels=256,
|
146 |
+
roi_feat_size=14, # Explicitly set ROI feature size
|
147 |
+
conv_out_channels=256,
|
148 |
+
num_classes=22, # 22 enhanced categories including boxplot
|
149 |
+
upsample_cfg=dict(type=None), # No upsampling - keep 14x14
|
150 |
+
loss_mask=dict(
|
151 |
+
type='CrossEntropyLoss',
|
152 |
+
use_mask=True,
|
153 |
+
loss_weight=1.0 # Balanced with bbox loss
|
154 |
+
)
|
155 |
+
)
|
156 |
+
),
|
157 |
+
# Training configuration adapted from Cascade R-CNN
|
158 |
+
train_cfg=dict(
|
159 |
+
rpn=dict(
|
160 |
+
assigner=dict(
|
161 |
+
type='MaxIoUAssigner',
|
162 |
+
pos_iou_thr=0.7,
|
163 |
+
neg_iou_thr=0.3,
|
164 |
+
min_pos_iou=0.3,
|
165 |
+
match_low_quality=True,
|
166 |
+
ignore_iof_thr=-1),
|
167 |
+
sampler=dict(
|
168 |
+
type='RandomSampler',
|
169 |
+
num=256,
|
170 |
+
pos_fraction=0.5,
|
171 |
+
neg_pos_ub=-1,
|
172 |
+
add_gt_as_proposals=False),
|
173 |
+
allowed_border=0,
|
174 |
+
pos_weight=-1,
|
175 |
+
debug=False),
|
176 |
+
rpn_proposal=dict(
|
177 |
+
nms_pre=2000,
|
178 |
+
max_per_img=1000,
|
179 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
180 |
+
min_bbox_size=0),
|
181 |
+
# RCNN training (using Cascade stage 2 settings - balanced for mask training)
|
182 |
+
rcnn=dict(
|
183 |
+
assigner=dict(
|
184 |
+
type='MaxIoUAssigner',
|
185 |
+
pos_iou_thr=0.5, # Balanced IoU for bbox + mask training
|
186 |
+
neg_iou_thr=0.5,
|
187 |
+
min_pos_iou=0.5,
|
188 |
+
match_low_quality=True, # Important for small data points
|
189 |
+
ignore_iof_thr=-1),
|
190 |
+
sampler=dict(
|
191 |
+
type='RandomSampler',
|
192 |
+
num=512,
|
193 |
+
pos_fraction=0.25,
|
194 |
+
neg_pos_ub=-1,
|
195 |
+
add_gt_as_proposals=True),
|
196 |
+
mask_size=(14, 14), # Force exact 14x14 size for data points
|
197 |
+
pos_weight=-1,
|
198 |
+
debug=False)
|
199 |
+
),
|
200 |
+
# Test configuration with soft NMS
|
201 |
+
test_cfg=dict(
|
202 |
+
rpn=dict(
|
203 |
+
nms_pre=1000,
|
204 |
+
max_per_img=1000,
|
205 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
206 |
+
min_bbox_size=0),
|
207 |
+
rcnn=dict(
|
208 |
+
score_thr=0.005, # Low threshold to catch data points
|
209 |
+
nms=dict(
|
210 |
+
type='soft_nms', # Soft NMS for better small object detection
|
211 |
+
iou_threshold=0.3, # Low for data points
|
212 |
+
min_score=0.005,
|
213 |
+
method='gaussian',
|
214 |
+
sigma=0.5),
|
215 |
+
max_per_img=100,
|
216 |
+
mask_thr_binary=0.5 # Binary mask threshold
|
217 |
+
)
|
218 |
+
)
|
219 |
+
)
|
220 |
+
|
221 |
+
# Dataset settings - using standard COCO dataset for mask support
|
222 |
+
dataset_type = 'CocoDataset'
|
223 |
+
data_root = ''
|
224 |
+
|
225 |
+
# 22 enhanced categories including boxplot
|
226 |
+
CLASSES = (
|
227 |
+
'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label', # 0-5
|
228 |
+
'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item', # 6-10
|
229 |
+
'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line', # 11-15 (data-point at index 11)
|
230 |
+
'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area', # 16-20
|
231 |
+
'boxplot' # 21
|
232 |
+
)
|
233 |
+
|
234 |
+
# Verify data-point class index
|
235 |
+
assert CLASSES[11] == 'data-point', f"Expected 'data-point' at index 11 in CLASSES tuple, got '{CLASSES[11]}'"
|
236 |
+
|
237 |
+
# Training dataloader with mask annotations
|
238 |
+
train_dataloader = dict(
|
239 |
+
batch_size=2, # Same as Cascade R-CNN
|
240 |
+
num_workers=2,
|
241 |
+
persistent_workers=True,
|
242 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
243 |
+
dataset=dict(
|
244 |
+
type=dataset_type,
|
245 |
+
data_root=data_root,
|
246 |
+
ann_file='legend_match_swin/mask_generation/enhanced_datasets/train_filtered_with_masks_only.json',
|
247 |
+
data_prefix=dict(img='legend_data/train/images/'),
|
248 |
+
metainfo=dict(classes=CLASSES),
|
249 |
+
filter_cfg=dict(filter_empty_gt=False, min_size=12), # Don't filter out images with masks
|
250 |
+
# Disable any built-in filtering that might remove annotations
|
251 |
+
test_mode=False,
|
252 |
+
pipeline=[
|
253 |
+
dict(type='LoadImageFromFile'),
|
254 |
+
dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True),
|
255 |
+
dict(type='Resize', scale=(1120, 672), keep_ratio=True),
|
256 |
+
dict(type='RandomFlip', prob=0.5),
|
257 |
+
dict(type='ClampBBoxes'),
|
258 |
+
dict(type='PackDetInputs')
|
259 |
+
]
|
260 |
+
)
|
261 |
+
)
|
262 |
+
|
263 |
+
# Validation dataloader with mask annotations
|
264 |
+
val_dataloader = dict(
|
265 |
+
batch_size=1,
|
266 |
+
num_workers=2,
|
267 |
+
persistent_workers=True,
|
268 |
+
drop_last=False,
|
269 |
+
sampler=dict(type='DefaultSampler', shuffle=False),
|
270 |
+
dataset=dict(
|
271 |
+
type=dataset_type,
|
272 |
+
data_root=data_root,
|
273 |
+
ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json',
|
274 |
+
data_prefix=dict(img='legend_data/train/images/'),
|
275 |
+
metainfo=dict(classes=CLASSES),
|
276 |
+
test_mode=True,
|
277 |
+
pipeline=[
|
278 |
+
dict(type='LoadImageFromFile'),
|
279 |
+
dict(type='Resize', scale=(1120, 672), keep_ratio=True),
|
280 |
+
dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True),
|
281 |
+
dict(type='ClampBBoxes'),
|
282 |
+
dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
283 |
+
]
|
284 |
+
)
|
285 |
+
)
|
286 |
+
|
287 |
+
test_dataloader = val_dataloader
|
288 |
+
|
289 |
+
# Enhanced evaluators for both bbox and mask metrics
|
290 |
+
val_evaluator = dict(
|
291 |
+
type='CocoMetric',
|
292 |
+
ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json',
|
293 |
+
metric=['bbox', 'segm'],
|
294 |
+
format_only=False,
|
295 |
+
classwise=True,
|
296 |
+
proposal_nums=(100, 300, 1000)
|
297 |
+
)
|
298 |
+
|
299 |
+
test_evaluator = val_evaluator
|
300 |
+
|
301 |
+
# Same custom hooks as Cascade R-CNN
|
302 |
+
default_hooks = dict(
|
303 |
+
timer=dict(type='IterTimerHook'),
|
304 |
+
logger=dict(type='LoggerHook', interval=50),
|
305 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
306 |
+
checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3),
|
307 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
308 |
+
visualization=dict(type='DetVisualizationHook')
|
309 |
+
)
|
310 |
+
|
311 |
+
# Same custom hooks as Cascade R-CNN (adapted for Mask R-CNN)
|
312 |
+
custom_hooks = [
|
313 |
+
dict(type='SkipBadSamplesHook', interval=1),
|
314 |
+
dict(type='ChartTypeDistributionHook', interval=500),
|
315 |
+
dict(type='MissingImageReportHook', interval=1000),
|
316 |
+
dict(type='NanRecoveryHook',
|
317 |
+
fallback_loss=1.0,
|
318 |
+
max_consecutive_nans=50,
|
319 |
+
log_interval=25),
|
320 |
+
# Note: Progressive loss hook not used in standard Mask R-CNN
|
321 |
+
# but could be adapted if needed for bbox loss only
|
322 |
+
]
|
323 |
+
|
324 |
+
# Training configuration - reduced to 20 epochs
|
325 |
+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=20, val_interval=1)
|
326 |
+
val_cfg = dict(type='ValLoop')
|
327 |
+
test_cfg = dict(type='TestLoop')
|
328 |
+
|
329 |
+
# Same optimizer settings as Cascade R-CNN
|
330 |
+
optim_wrapper = dict(
|
331 |
+
type='OptimWrapper',
|
332 |
+
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
|
333 |
+
clip_grad=dict(max_norm=10.0, norm_type=2)
|
334 |
+
)
|
335 |
+
|
336 |
+
# Same learning rate schedule as Cascade R-CNN
|
337 |
+
param_scheduler = [
|
338 |
+
dict(
|
339 |
+
type='LinearLR',
|
340 |
+
start_factor=0.1,
|
341 |
+
by_epoch=False,
|
342 |
+
begin=0,
|
343 |
+
end=1000),
|
344 |
+
dict(
|
345 |
+
type='CosineAnnealingLR',
|
346 |
+
begin=0,
|
347 |
+
end=20,
|
348 |
+
by_epoch=True,
|
349 |
+
T_max=20,
|
350 |
+
eta_min=1e-5,
|
351 |
+
convert_to_iter_based=True)
|
352 |
+
]
|
353 |
+
|
354 |
+
# Work directory
|
355 |
+
work_dir = '/content/drive/MyDrive/Research Summer 2025/Dense Captioning Toolkit/CHART-DeMatch/work_dirs/mask_rcnn_swin_base_20ep_meta'
|
356 |
+
|
357 |
+
# Fresh start
|
358 |
+
resume = False
|
359 |
+
load_from = None
|
360 |
+
|
361 |
+
# Default runtime settings (normally inherited from _base_)
|
362 |
+
default_scope = 'mmdet'
|
363 |
+
env_cfg = dict(
|
364 |
+
cudnn_benchmark=False,
|
365 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
366 |
+
dist_cfg=dict(backend='nccl'),
|
367 |
+
)
|
368 |
+
|
369 |
+
vis_backends = [dict(type='LocalVisBackend')]
|
370 |
+
visualizer = dict(
|
371 |
+
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
372 |
+
|
373 |
+
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
|
374 |
+
log_level = 'INFO'
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.39.0
|
2 |
+
torch>=2.0.0
|
3 |
+
torchvision>=0.15.0
|
4 |
+
transformers>=4.30.0
|
5 |
+
Pillow>=9.0.0
|
6 |
+
numpy>=1.21.0
|
7 |
+
opencv-python>=4.8.0
|
8 |
+
huggingface-hub>=0.16.0
|
9 |
+
openmim
|
10 |
+
mmdet
|
11 |
+
mmengine
|
12 |
+
scikit-image>=0.21.0
|
simple_test.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Simple test script for the Dense Captioning Platform API
|
4 |
+
"""
|
5 |
+
|
6 |
+
def test_gradio_client():
|
7 |
+
"""Test using gradio_client"""
|
8 |
+
print("π§ͺ Testing with gradio_client...")
|
9 |
+
|
10 |
+
try:
|
11 |
+
from gradio_client import Client, handle_file
|
12 |
+
|
13 |
+
# Initialize client with direct URL (working approach)
|
14 |
+
client = Client("https://hanszhu-dense-captioning-platform.hf.space")
|
15 |
+
|
16 |
+
# Test with a simple image URL
|
17 |
+
test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
18 |
+
|
19 |
+
print(f"Testing with URL: {test_url}")
|
20 |
+
|
21 |
+
# Make prediction using handle_file and fn_index (working approach)
|
22 |
+
result = client.predict(
|
23 |
+
image=handle_file(test_url),
|
24 |
+
fn_index=0 # Use fn_index instead of api_name
|
25 |
+
)
|
26 |
+
|
27 |
+
print("β
gradio_client test successful!")
|
28 |
+
print(f"Result: {result}")
|
29 |
+
|
30 |
+
return True
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
print(f"β gradio_client test failed: {e}")
|
34 |
+
return False
|
35 |
+
|
36 |
+
def test_direct_api():
|
37 |
+
"""Test direct API call (Blocks don't support RESTful APIs)"""
|
38 |
+
print("\nπ§ͺ Testing direct API call...")
|
39 |
+
|
40 |
+
print("β οΈ Direct API calls not supported for Blocks-based Spaces")
|
41 |
+
print(" Use gradio_client instead for API access")
|
42 |
+
return False
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
print("π Testing Dense Captioning Platform API")
|
46 |
+
print("=" * 50)
|
47 |
+
|
48 |
+
# Test both methods
|
49 |
+
gradio_success = test_gradio_client()
|
50 |
+
direct_success = test_direct_api()
|
51 |
+
|
52 |
+
print("\n" + "=" * 50)
|
53 |
+
print("π Test Results:")
|
54 |
+
print(f"gradio_client: {'β
PASS' if gradio_success else 'β FAIL'}")
|
55 |
+
print(f"Direct API: {'β
PASS' if direct_success else 'β FAIL'}")
|
56 |
+
|
57 |
+
if gradio_success or direct_success:
|
58 |
+
print("\nπ API is working!")
|
59 |
+
else:
|
60 |
+
print("\nβ οΈ API needs more configuration")
|
test_api.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script for the Dense Captioning Platform API
|
4 |
+
"""
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import json
|
8 |
+
from PIL import Image
|
9 |
+
import io
|
10 |
+
import base64
|
11 |
+
|
12 |
+
def test_api_with_url():
|
13 |
+
"""Test the API using a URL"""
|
14 |
+
print("π§ͺ Testing API with URL...")
|
15 |
+
|
16 |
+
# Test URL (a simple chart image)
|
17 |
+
test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
18 |
+
|
19 |
+
# API endpoint
|
20 |
+
api_url = "https://hanszhu-dense-captioning-platform.hf.space/predict"
|
21 |
+
|
22 |
+
try:
|
23 |
+
# Make the request
|
24 |
+
response = requests.post(
|
25 |
+
api_url,
|
26 |
+
json={"data": [test_url]},
|
27 |
+
headers={"Content-Type": "application/json"}
|
28 |
+
)
|
29 |
+
|
30 |
+
print(f"Status Code: {response.status_code}")
|
31 |
+
print(f"Response: {response.text[:500]}...")
|
32 |
+
|
33 |
+
if response.status_code == 200:
|
34 |
+
result = response.json()
|
35 |
+
print("β
API test successful!")
|
36 |
+
print(f"Chart Type: {result.get('data', [{}])[0].get('chart_type_label', 'Unknown')}")
|
37 |
+
else:
|
38 |
+
print("β API test failed!")
|
39 |
+
|
40 |
+
except Exception as e:
|
41 |
+
print(f"β Error testing API: {e}")
|
42 |
+
|
43 |
+
def test_api_with_gradio_client():
|
44 |
+
"""Test the API using gradio_client"""
|
45 |
+
print("\nπ§ͺ Testing API with gradio_client...")
|
46 |
+
|
47 |
+
try:
|
48 |
+
from gradio_client import Client
|
49 |
+
|
50 |
+
# Initialize client
|
51 |
+
client = Client("hanszhu/Dense-Captioning-Platform")
|
52 |
+
|
53 |
+
# Test with a URL
|
54 |
+
result = client.predict(
|
55 |
+
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
|
56 |
+
api_name="/predict"
|
57 |
+
)
|
58 |
+
|
59 |
+
print("β
gradio_client test successful!")
|
60 |
+
print(f"Result: {result}")
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
print(f"β Error with gradio_client: {e}")
|
64 |
+
|
65 |
+
def test_api_endpoints():
|
66 |
+
"""Test available API endpoints"""
|
67 |
+
print("\nπ§ͺ Testing API endpoints...")
|
68 |
+
|
69 |
+
base_url = "https://hanszhu-dense-captioning-platform.hf.space"
|
70 |
+
|
71 |
+
endpoints = [
|
72 |
+
"/",
|
73 |
+
"/api",
|
74 |
+
"/api/predict",
|
75 |
+
"/predict"
|
76 |
+
]
|
77 |
+
|
78 |
+
for endpoint in endpoints:
|
79 |
+
try:
|
80 |
+
response = requests.get(f"{base_url}{endpoint}")
|
81 |
+
print(f"{endpoint}: {response.status_code}")
|
82 |
+
except Exception as e:
|
83 |
+
print(f"{endpoint}: Error - {e}")
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
print("π Testing Dense Captioning Platform API")
|
87 |
+
print("=" * 50)
|
88 |
+
|
89 |
+
# Test different approaches
|
90 |
+
test_api_endpoints()
|
91 |
+
test_api_with_url()
|
92 |
+
test_api_with_gradio_client()
|
93 |
+
|
94 |
+
print("\n" + "=" * 50)
|
95 |
+
print("π API testing completed!")
|
test_api_endpoints.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Comprehensive API endpoint testing
|
4 |
+
"""
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import json
|
8 |
+
|
9 |
+
def test_all_possible_endpoints():
|
10 |
+
"""Test all possible API endpoint combinations"""
|
11 |
+
print("π Testing all possible API endpoints...")
|
12 |
+
|
13 |
+
base_url = "https://hanszhu-dense-captioning-platform.hf.space"
|
14 |
+
test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
15 |
+
|
16 |
+
# Different endpoint patterns
|
17 |
+
endpoints = [
|
18 |
+
"/api/predict",
|
19 |
+
"/predict",
|
20 |
+
"/api/run/predict",
|
21 |
+
"/run/predict",
|
22 |
+
"/api/0",
|
23 |
+
"/0",
|
24 |
+
"/api/1",
|
25 |
+
"/1",
|
26 |
+
"/api/2",
|
27 |
+
"/2"
|
28 |
+
]
|
29 |
+
|
30 |
+
# Different request formats
|
31 |
+
request_formats = [
|
32 |
+
{"data": [test_url]},
|
33 |
+
{"data": [test_url], "fn_index": 0},
|
34 |
+
{"data": [test_url], "fn_index": 1},
|
35 |
+
{"data": test_url},
|
36 |
+
{"data": [test_url], "session_hash": "test123"}
|
37 |
+
]
|
38 |
+
|
39 |
+
for endpoint in endpoints:
|
40 |
+
print(f"\nπ Testing endpoint: {endpoint}")
|
41 |
+
|
42 |
+
for i, format_data in enumerate(request_formats):
|
43 |
+
print(f" Format {i+1}: {format_data}")
|
44 |
+
|
45 |
+
try:
|
46 |
+
response = requests.post(
|
47 |
+
f"{base_url}{endpoint}",
|
48 |
+
json=format_data,
|
49 |
+
headers={"Content-Type": "application/json"},
|
50 |
+
timeout=10
|
51 |
+
)
|
52 |
+
|
53 |
+
print(f" Status: {response.status_code}")
|
54 |
+
|
55 |
+
if response.status_code == 200:
|
56 |
+
print(" β
SUCCESS!")
|
57 |
+
print(f" Response: {response.text[:200]}...")
|
58 |
+
return endpoint, format_data
|
59 |
+
elif response.status_code == 405:
|
60 |
+
print(" β οΈ Method not allowed (endpoint exists)")
|
61 |
+
elif response.status_code == 404:
|
62 |
+
print(" β Not found")
|
63 |
+
else:
|
64 |
+
print(f" β Unexpected: {response.text[:100]}...")
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
print(f" β Error: {e}")
|
68 |
+
|
69 |
+
return None, None
|
70 |
+
|
71 |
+
def test_gradio_client_different_ways():
|
72 |
+
"""Test gradio_client with different approaches"""
|
73 |
+
print("\nπ Testing gradio_client with different approaches...")
|
74 |
+
|
75 |
+
try:
|
76 |
+
from gradio_client import Client
|
77 |
+
|
78 |
+
print("Creating client...")
|
79 |
+
client = Client("hanszhu/Dense-Captioning-Platform")
|
80 |
+
|
81 |
+
print("Trying different API names...")
|
82 |
+
|
83 |
+
api_names = ["/predict", "/run/predict", "0", "1", "2", "3", "4", "5"]
|
84 |
+
|
85 |
+
for api_name in api_names:
|
86 |
+
print(f"\nTrying api_name: {api_name}")
|
87 |
+
|
88 |
+
try:
|
89 |
+
test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
90 |
+
result = client.predict(test_url, api_name=api_name)
|
91 |
+
print(f" β
SUCCESS with api_name={api_name}!")
|
92 |
+
print(f" Result: {result}")
|
93 |
+
return api_name
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
print(f" β Failed: {e}")
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
print(f"β gradio_client error: {e}")
|
100 |
+
|
101 |
+
def check_space_status():
|
102 |
+
"""Check if the space is running and accessible"""
|
103 |
+
print("\nπ Checking space status...")
|
104 |
+
|
105 |
+
try:
|
106 |
+
response = requests.get("https://hanszhu-dense-captioning-platform.hf.space/", timeout=10)
|
107 |
+
print(f"Space status: {response.status_code}")
|
108 |
+
|
109 |
+
if response.status_code == 200:
|
110 |
+
print("β
Space is running")
|
111 |
+
|
112 |
+
# Check for API-related content
|
113 |
+
content = response.text.lower()
|
114 |
+
if "api" in content:
|
115 |
+
print("β
API-related content found")
|
116 |
+
if "predict" in content:
|
117 |
+
print("β
Predict-related content found")
|
118 |
+
if "gradio" in content:
|
119 |
+
print("β
Gradio content found")
|
120 |
+
|
121 |
+
else:
|
122 |
+
print("β Space is not accessible")
|
123 |
+
|
124 |
+
except Exception as e:
|
125 |
+
print(f"β Error checking space: {e}")
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
print("π Comprehensive API Endpoint Testing")
|
129 |
+
print("=" * 60)
|
130 |
+
|
131 |
+
# Check space status
|
132 |
+
check_space_status()
|
133 |
+
|
134 |
+
# Test all endpoints
|
135 |
+
working_endpoint, working_format = test_all_possible_endpoints()
|
136 |
+
|
137 |
+
# Test gradio_client
|
138 |
+
working_api_name = test_gradio_client_different_ways()
|
139 |
+
|
140 |
+
print("\n" + "=" * 60)
|
141 |
+
print("π Testing completed!")
|
142 |
+
|
143 |
+
if working_endpoint and working_format:
|
144 |
+
print(f"β
Found working combination:")
|
145 |
+
print(f" Endpoint: {working_endpoint}")
|
146 |
+
print(f" Format: {working_format}")
|
147 |
+
if working_api_name:
|
148 |
+
print(f"β
Found working gradio_client api_name: {working_api_name}")
|
149 |
+
|
150 |
+
if not working_endpoint and not working_api_name:
|
151 |
+
print("β No working endpoints found")
|
152 |
+
print("The space might need different configuration or the API is not properly exposed")
|
web_test.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script to check web interface and understand API issue
|
4 |
+
"""
|
5 |
+
|
6 |
+
import requests
|
7 |
+
import time
|
8 |
+
|
9 |
+
def check_web_interface():
|
10 |
+
"""Check if the web interface is working"""
|
11 |
+
print("π Checking web interface...")
|
12 |
+
|
13 |
+
try:
|
14 |
+
response = requests.get("https://hanszhu-dense-captioning-platform.hf.space/")
|
15 |
+
|
16 |
+
if response.status_code == 200:
|
17 |
+
print("β
Web interface is accessible")
|
18 |
+
|
19 |
+
# Check if it contains our app content
|
20 |
+
if "Dense Captioning Platform" in response.text:
|
21 |
+
print("β
App is loaded correctly")
|
22 |
+
else:
|
23 |
+
print("β App content not found")
|
24 |
+
|
25 |
+
# Check if it contains Gradio elements
|
26 |
+
if "gradio" in response.text.lower():
|
27 |
+
print("β
Gradio is loaded")
|
28 |
+
else:
|
29 |
+
print("β Gradio not found")
|
30 |
+
|
31 |
+
else:
|
32 |
+
print(f"β Web interface not accessible: {response.status_code}")
|
33 |
+
|
34 |
+
except Exception as e:
|
35 |
+
print(f"β Error checking web interface: {e}")
|
36 |
+
|
37 |
+
def check_api_info():
|
38 |
+
"""Check API info endpoint"""
|
39 |
+
print("\nπ Checking API info...")
|
40 |
+
|
41 |
+
try:
|
42 |
+
# Try different API info endpoints
|
43 |
+
endpoints = [
|
44 |
+
"https://hanszhu-dense-captioning-platform.hf.space/api",
|
45 |
+
"https://hanszhu-dense-captioning-platform.hf.space/api/",
|
46 |
+
"https://hanszhu-dense-captioning-platform.hf.space/api/predict",
|
47 |
+
"https://hanszhu-dense-captioning-platform.hf.space/api/predict/"
|
48 |
+
]
|
49 |
+
|
50 |
+
for endpoint in endpoints:
|
51 |
+
print(f"\nTrying: {endpoint}")
|
52 |
+
|
53 |
+
try:
|
54 |
+
response = requests.get(endpoint)
|
55 |
+
print(f" Status: {response.status_code}")
|
56 |
+
print(f" Content-Type: {response.headers.get('content-type', 'unknown')}")
|
57 |
+
|
58 |
+
if response.status_code == 200:
|
59 |
+
content = response.text[:200]
|
60 |
+
print(f" Content: {content}...")
|
61 |
+
|
62 |
+
# Check if it's JSON
|
63 |
+
if response.headers.get('content-type', '').startswith('application/json'):
|
64 |
+
print(" β
JSON response")
|
65 |
+
else:
|
66 |
+
print(" β Not JSON response")
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
print(f" Error: {e}")
|
70 |
+
|
71 |
+
except Exception as e:
|
72 |
+
print(f"β Error checking API info: {e}")
|
73 |
+
|
74 |
+
def wait_and_retry():
|
75 |
+
"""Wait and retry to see if the API becomes available"""
|
76 |
+
print("\nβ³ Waiting for API to become available...")
|
77 |
+
|
78 |
+
for i in range(5):
|
79 |
+
print(f"\nAttempt {i+1}/5:")
|
80 |
+
|
81 |
+
try:
|
82 |
+
response = requests.get("https://hanszhu-dense-captioning-platform.hf.space/api")
|
83 |
+
|
84 |
+
if response.status_code == 200 and response.headers.get('content-type', '').startswith('application/json'):
|
85 |
+
print("β
API is now available!")
|
86 |
+
return True
|
87 |
+
else:
|
88 |
+
print(f"β API not ready yet: {response.status_code}")
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
print(f"β Error: {e}")
|
92 |
+
|
93 |
+
if i < 4: # Don't sleep after the last attempt
|
94 |
+
print("Waiting 30 seconds...")
|
95 |
+
time.sleep(30)
|
96 |
+
|
97 |
+
return False
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
print("π Testing Dense Captioning Platform Web Interface")
|
101 |
+
print("=" * 60)
|
102 |
+
|
103 |
+
check_web_interface()
|
104 |
+
check_api_info()
|
105 |
+
|
106 |
+
# Wait and retry
|
107 |
+
if not wait_and_retry():
|
108 |
+
print("\nβ οΈ API is still not available after waiting")
|
109 |
+
print("This might indicate:")
|
110 |
+
print("1. The space is still loading models")
|
111 |
+
print("2. There's a configuration issue")
|
112 |
+
print("3. The API endpoints need different configuration")
|
113 |
+
|
114 |
+
print("\n" + "=" * 60)
|
115 |
+
print("π Web interface test completed!")
|