hanszhu commited on
Commit
eb4d305
Β·
0 Parent(s):

build(space): initial Docker Space with Gradio app, MMDet, SAM integration

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+
7
+ # Model files (downloaded automatically)
8
+ models/models--*/
9
+ models/.locks/
10
+ *.pth
11
+ *.pkl
12
+ *.h5
13
+ *.onnx
14
+
15
+ # Environment files
16
+ .env
17
+ .venv
18
+ env/
19
+ venv/
20
+ ENV/
21
+ env.bak/
22
+ venv.bak/
23
+
24
+ # IDE files
25
+ .vscode/
26
+ .idea/
27
+ *.swp
28
+ *.swo
29
+ *~
30
+
31
+ # OS files
32
+ .DS_Store
33
+ .DS_Store?
34
+ ._*
35
+ .Spotlight-V100
36
+ .Trashes
37
+ ehthumbs.db
38
+ Thumbs.db
39
+
40
+ # Logs
41
+ *.log
42
+ logs/
43
+
44
+ # Temporary files
45
+ *.tmp
46
+ *.temp
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive \
4
+ PIP_NO_CACHE_DIR=1 \
5
+ MPLBACKEND=Agg \
6
+ MIM_IGNORE_INSTALL_PYTORCH=1
7
+
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ libgl1 libglib2.0-0 git && \
10
+ rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /app
13
+
14
+ COPY requirements.txt /app/requirements.txt
15
+
16
+ # Install pip deps and the mm stack with openmim
17
+ RUN python -m pip install -U pip openmim && \
18
+ pip install -r requirements.txt && \
19
+ mim install "mmengine==0.10.4" && \
20
+ mim install "mmcv==2.1.0" && \
21
+ mim install "mmdet==3.3.0" && \
22
+ pip install git+https://github.com/facebookresearch/segment-anything.git
23
+
24
+ # Copy the rest of the application
25
+ COPY . /app
26
+
27
+ EXPOSE 7860
28
+
29
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Dense Captioning Platform
3
+ emoji: 🐒
4
+ colorFrom: purple
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
README_API.md ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # πŸ“Š Dense Captioning Platform API Documentation
2
+
3
+ ## Overview
4
+
5
+ The Dense Captioning Platform provides comprehensive chart analysis through a Gradio-based API. It can classify chart types, detect chart elements, and segment data points from uploaded images.
6
+
7
+ ## API Access
8
+
9
+ **Base URL:** `https://hanszhu-dense-captioning-platform.hf.space`
10
+
11
+ **API Type:** Gradio Client API (not RESTful)
12
+
13
+ ## Installation
14
+
15
+ ### Prerequisites
16
+
17
+ ```bash
18
+ pip install gradio-client
19
+ ```
20
+
21
+ ### Quick Start
22
+
23
+ ### Python Client (Recommended)
24
+
25
+ ```python
26
+ from gradio_client import Client, handle_file
27
+
28
+ # Initialize client with direct URL
29
+ client = Client("https://hanszhu-dense-captioning-platform.hf.space")
30
+
31
+ # Analyze a chart image using file path
32
+ result = client.predict(
33
+ image=handle_file('path/to/your/chart.png'),
34
+ fn_index=0
35
+ )
36
+
37
+ print(result)
38
+ ```
39
+
40
+ ### Using a URL
41
+
42
+ ```python
43
+ from gradio_client import Client, handle_file
44
+
45
+ client = Client("https://hanszhu-dense-captioning-platform.hf.space")
46
+
47
+ # Use a publicly accessible image URL
48
+ result = client.predict(
49
+ image=handle_file("https://example.com/chart.png"),
50
+ fn_index=0
51
+ )
52
+
53
+ print(result)
54
+ ```
55
+
56
+ ## Input Parameters
57
+
58
+ | Parameter | Type | Required | Description |
59
+ |-----------|------|----------|-------------|
60
+ | `image` | File/URL | Yes | Chart image to analyze (PNG, JPG, JPEG supported) |
61
+
62
+ ## Important Notes
63
+
64
+ ### βœ… Working Approach
65
+ - **Use `fn_index=0`** instead of `api_name="/predict"`
66
+ - **Use direct URL** `"https://hanszhu-dense-captioning-platform.hf.space"`
67
+ - **Always use `handle_file()`** for both local files and URLs
68
+ - **This is a Gradio Client API**, not a RESTful API
69
+
70
+ ### ❌ What Doesn't Work
71
+ - Direct HTTP POST requests to `/predict`
72
+ - Using `api_name="/predict"` with this setup
73
+ - Using `Client("hanszhu/Dense-Captioning-Platform")` (use direct URL instead)
74
+
75
+ ## Output Format
76
+
77
+ The API returns a JSON object with the following structure:
78
+
79
+ ```json
80
+ {
81
+ "chart_type_id": 4,
82
+ "chart_type_label": "Bar plot",
83
+ "element_result": {
84
+ "bboxes": [...],
85
+ "segments": [...]
86
+ },
87
+ "datapoint_result": {
88
+ "bboxes": [...],
89
+ "segments": [...]
90
+ },
91
+ "status": "Full analysis completed",
92
+ "processing_time": 2.345
93
+ }
94
+ ```
95
+
96
+ ### Output Fields
97
+
98
+ | Field | Type | Description |
99
+ |-------|------|-------------|
100
+ | `chart_type_id` | int | Numeric identifier for chart type (0-27) |
101
+ | `chart_type_label` | string | Human-readable chart type name |
102
+ | `element_result` | object/string | Detected chart elements (titles, axes, legends, etc.) |
103
+ | `datapoint_result` | object/string | Segmented data points and regions |
104
+ | `status` | string | Processing status message |
105
+ | `processing_time` | float | Time taken for analysis in seconds |
106
+
107
+ ## Supported Chart Types
108
+
109
+ The platform can classify 28 different chart types:
110
+
111
+ | ID | Chart Type | ID | Chart Type |
112
+ |----|------------|----|------------|
113
+ | 0 | Line graph | 14 | Histogram |
114
+ | 1 | Natural image | 15 | Box plot |
115
+ | 2 | Table | 16 | Vector plot |
116
+ | 3 | 3D object | 17 | Pie chart |
117
+ | 4 | Bar plot | 18 | Surface plot |
118
+ | 5 | Scatter plot | 19 | Algorithm |
119
+ | 6 | Medical image | 20 | Contour plot |
120
+ | 7 | Sketch | 21 | Tree diagram |
121
+ | 8 | Geographic map | 22 | Bubble chart |
122
+ | 9 | Flow chart | 23 | Polar plot |
123
+ | 10 | Heat map | 24 | Area chart |
124
+ | 11 | Mask | 25 | Pareto chart |
125
+ | 12 | Block diagram | 26 | Radar chart |
126
+ | 13 | Venn diagram | 27 | Confusion matrix |
127
+
128
+ ## Chart Elements Detected
129
+
130
+ The element detection model identifies:
131
+
132
+ - **Titles & Labels**: Chart title, subtitle, axis labels
133
+ - **Axes**: X-axis, Y-axis, tick labels
134
+ - **Legend**: Legend title, legend items, legend text
135
+ - **Data Elements**: Data points, data lines, data bars, data areas
136
+ - **Structural Elements**: Grid lines, plot areas
137
+
138
+ ## Error Handling
139
+
140
+ The API returns error messages in the response fields when issues occur:
141
+
142
+ ```json
143
+ {
144
+ "chart_type_id": "Error: Model not available",
145
+ "chart_type_label": "Error: Model not available",
146
+ "element_result": "Error: Invalid image format",
147
+ "datapoint_result": "Error: Processing failed",
148
+ "status": "Error in chart classification",
149
+ "processing_time": 0.0
150
+ }
151
+ ```
152
+
153
+ ## Rate Limits
154
+
155
+ - **Free Tier**: Limited requests per hour
156
+ - **Processing Time**: Typically 2-5 seconds per image
157
+ - **Image Size**: Recommended max 10MB
158
+
159
+ ## Complete Working Example
160
+
161
+ Here's a complete example that demonstrates all the working patterns:
162
+
163
+ ```python
164
+ from gradio_client import Client, handle_file
165
+ import json
166
+
167
+ def analyze_chart(image_path_or_url):
168
+ """
169
+ Analyze a chart image using the Dense Captioning Platform API
170
+
171
+ Args:
172
+ image_path_or_url (str): Path to local image file or URL to image
173
+
174
+ Returns:
175
+ dict: Analysis results with chart type, elements, and data points
176
+ """
177
+ try:
178
+ # Initialize client with direct URL
179
+ client = Client("https://hanszhu-dense-captioning-platform.hf.space")
180
+
181
+ # Make prediction using the working approach
182
+ result = client.predict(
183
+ image=handle_file(image_path_or_url),
184
+ fn_index=0
185
+ )
186
+
187
+ return result
188
+
189
+ except Exception as e:
190
+ return {
191
+ "error": f"API call failed: {str(e)}",
192
+ "status": "Error",
193
+ "processing_time": 0.0
194
+ }
195
+
196
+ # Example usage
197
+ if __name__ == "__main__":
198
+ # Test with a local file
199
+ local_result = analyze_chart("path/to/your/chart.png")
200
+ print("Local file result:", json.dumps(local_result, indent=2))
201
+
202
+ # Test with a URL
203
+ url_result = analyze_chart("https://example.com/chart.png")
204
+ print("URL result:", json.dumps(url_result, indent=2))
205
+ ```
206
+
207
+ ## Examples
208
+
209
+ ### Example 1: Bar Chart Analysis
210
+
211
+ ```python
212
+ from gradio_client import Client, handle_file
213
+
214
+ client = Client("https://hanszhu-dense-captioning-platform.hf.space")
215
+
216
+ # Analyze a bar chart
217
+ result = client.predict(
218
+ image=handle_file('bar_chart.png'),
219
+ fn_index=0
220
+ )
221
+
222
+ print(f"Chart Type: {result['chart_type_label']}")
223
+ print(f"Processing Time: {result['processing_time']}s")
224
+ ```
225
+
226
+ ### Example 2: Batch Processing
227
+
228
+ ```python
229
+ from gradio_client import Client, handle_file
230
+ import os
231
+
232
+ client = Client("https://hanszhu-dense-captioning-platform.hf.space")
233
+
234
+ # Process multiple charts
235
+ chart_files = ['chart1.png', 'chart2.png', 'chart3.png']
236
+ results = []
237
+
238
+ for chart_file in chart_files:
239
+ if os.path.exists(chart_file):
240
+ result = client.predict(
241
+ image=handle_file(chart_file),
242
+ fn_index=0
243
+ )
244
+ results.append(result)
245
+ print(f"Processed {chart_file}: {result['chart_type_label']}")
246
+ ```
247
+
248
+ ### Example 3: Test with Public Image
249
+
250
+ ```python
251
+ from gradio_client import Client, handle_file
252
+
253
+ client = Client("https://hanszhu-dense-captioning-platform.hf.space")
254
+
255
+ # Test with a public image URL
256
+ result = client.predict(
257
+ image=handle_file("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"),
258
+ fn_index=0
259
+ )
260
+
261
+ print("βœ… API Test Successful!")
262
+ print(f"Chart Type: {result['chart_type_label']}")
263
+ print(f"Status: {result['status']}")
264
+ ```
265
+
266
+ ## Troubleshooting
267
+
268
+ ### Common Issues
269
+
270
+ 1. **"Model not available"**: The models are still loading, wait a moment and retry
271
+ 2. **"Invalid image format"**: Ensure the image is in PNG, JPG, or JPEG format
272
+ 3. **"Processing failed"**: The image might be corrupted or too large
273
+ 4. **"Expecting value: line 1 column 1"**: Use `fn_index=0` instead of `api_name="/predict"`
274
+ 5. **"Cannot find a function with api_name"**: Use direct URL and `fn_index=0`
275
+
276
+ ### Best Practices
277
+
278
+ 1. **Image Quality**: Use clear, high-resolution images for best results
279
+ 2. **Format**: PNG or JPG formats work best
280
+ 3. **Size**: Keep images under 10MB for faster processing
281
+ 4. **Client Setup**: Always use direct URL and `fn_index=0`
282
+ 5. **File Handling**: Always use `handle_file()` for both local files and URLs
283
+ 6. **Retry Logic**: Implement retry logic for failed requests
284
+
285
+ ### Quick Test
286
+
287
+ To verify the API is working, run this test:
288
+
289
+ ```python
290
+ from gradio_client import Client, handle_file
291
+
292
+ try:
293
+ client = Client("https://hanszhu-dense-captioning-platform.hf.space")
294
+ result = client.predict(
295
+ image=handle_file("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"),
296
+ fn_index=0
297
+ )
298
+ print("βœ… API is working!")
299
+ print(f"Chart Type: {result['chart_type_label']}")
300
+ except Exception as e:
301
+ print(f"❌ API test failed: {e}")
302
+ ```
303
+
304
+ ## Support
305
+
306
+ For issues or questions:
307
+ - Check the [Hugging Face Space](https://huggingface.co/spaces/hanszhu/Dense-Captioning-Platform)
308
+ - Review the error messages in the API response
309
+ - Ensure your image format and size are within limits
app.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
+ import cv2
8
+
9
+ # Add custom modules to path - try multiple possible locations
10
+ possible_paths = [
11
+ "./custom_models",
12
+ "../custom_models",
13
+ "./Dense-Captioning-Platform/custom_models"
14
+ ]
15
+
16
+ for path in possible_paths:
17
+ if os.path.exists(path):
18
+ sys.path.insert(0, os.path.abspath(path))
19
+ break
20
+
21
+ # Add mmcv to path if it exists
22
+ if os.path.exists('./mmcv'):
23
+ sys.path.insert(0, os.path.abspath('./mmcv'))
24
+ print("βœ… Added local mmcv to path")
25
+
26
+ # Import and register custom modules
27
+ try:
28
+ from custom_models import register
29
+ print("βœ… Custom modules registered successfully")
30
+ except Exception as e:
31
+ print(f"⚠️ Warning: Could not register custom modules: {e}")
32
+
33
+ # ----------------------
34
+ # Optional MedSAM integration
35
+ # ----------------------
36
+ class MedSAMIntegrator:
37
+ def __init__(self):
38
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ self.medsam_model = None
40
+ self.current_image = None
41
+ self.current_image_path = None
42
+ self.embedding = None
43
+ self._load_medsam_model()
44
+
45
+ def _ensure_segment_anything(self):
46
+ try:
47
+ import segment_anything # noqa: F401
48
+ return True
49
+ except Exception as e:
50
+ print(f"⚠ segment_anything not available: {e}. It must be installed at build time (Dockerfile).")
51
+ return False
52
+
53
+ def _load_medsam_model(self):
54
+ try:
55
+ # Ensure library is present
56
+ if not self._ensure_segment_anything():
57
+ print("MedSAM features disabled (segment_anything not available)")
58
+ return
59
+
60
+ from segment_anything import sam_model_registry as _reg
61
+ import torch as _torch
62
+
63
+ # Preferred local path
64
+ medsam_ckpt_path = "models/medsam_vit_b.pth"
65
+
66
+ # If not present, fetch from HF Hub using provided repo or default
67
+ if not os.path.exists(medsam_ckpt_path):
68
+ try:
69
+ from huggingface_hub import hf_hub_download, list_repo_files
70
+ repo_id = os.environ.get("HF_MEDSAM_REPO", "Aniketg6/Fine-Tuned-MedSAM")
71
+ # Try to find a .pth/.pt in the repo
72
+ print(f"πŸ”„ Trying to download MedSAM checkpoint from {repo_id} ...")
73
+ files = list_repo_files(repo_id)
74
+ candidate = None
75
+ for f in files:
76
+ lf = f.lower()
77
+ if lf.endswith(".pth") or lf.endswith(".pt"):
78
+ candidate = f
79
+ break
80
+ if candidate is None:
81
+ # Fallback to a common name
82
+ candidate = "medsam_vit_b.pth"
83
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=candidate, cache_dir="./models")
84
+ medsam_ckpt_path = ckpt_path
85
+ print(f"βœ… Downloaded MedSAM checkpoint: {medsam_ckpt_path}")
86
+ except Exception as dl_err:
87
+ print(f"⚠ Could not fetch MedSAM checkpoint from HF Hub: {dl_err}")
88
+ print("MedSAM features disabled (no checkpoint)")
89
+ return
90
+
91
+ # Load checkpoint
92
+ checkpoint = _torch.load(medsam_ckpt_path, map_location='cpu')
93
+ self.medsam_model = _reg["vit_b"](checkpoint=None)
94
+ self.medsam_model.load_state_dict(checkpoint)
95
+ self.medsam_model.to(self.device)
96
+ self.medsam_model.eval()
97
+ print("βœ“ MedSAM model loaded successfully")
98
+ except Exception as e:
99
+ print(f"⚠ MedSAM model not available: {e}. MedSAM features disabled.")
100
+
101
+ def is_available(self):
102
+ return self.medsam_model is not None
103
+
104
+ def load_image(self, image_path, precomputed_embedding=None):
105
+ try:
106
+ from skimage import transform, io # local import to avoid hard dep if unused
107
+ img_np = io.imread(image_path)
108
+ if len(img_np.shape) == 2:
109
+ img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
110
+ else:
111
+ img_3c = img_np
112
+ self.current_image = img_3c
113
+ self.current_image_path = image_path
114
+ if precomputed_embedding is not None:
115
+ if not self.set_precomputed_embedding(precomputed_embedding):
116
+ self.get_embeddings()
117
+ else:
118
+ self.get_embeddings()
119
+ return True
120
+ except Exception as e:
121
+ print(f"Error loading image for MedSAM: {e}")
122
+ return False
123
+
124
+ @torch.no_grad()
125
+ def get_embeddings(self):
126
+ if self.current_image is None or self.medsam_model is None:
127
+ return None
128
+ from skimage import transform
129
+ img_1024 = transform.resize(
130
+ self.current_image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
131
+ ).astype(np.uint8)
132
+ img_1024 = (img_1024 - img_1024.min()) / np.clip(img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None)
133
+ img_1024_tensor = (
134
+ torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(self.device)
135
+ )
136
+ self.embedding = self.medsam_model.image_encoder(img_1024_tensor)
137
+ return self.embedding
138
+
139
+ def set_precomputed_embedding(self, embedding_array):
140
+ try:
141
+ if isinstance(embedding_array, np.ndarray):
142
+ embedding_tensor = torch.tensor(embedding_array).to(self.device)
143
+ self.embedding = embedding_tensor
144
+ return True
145
+ return False
146
+ except Exception as e:
147
+ print(f"Error setting precomputed embedding: {e}")
148
+ return False
149
+
150
+ @torch.no_grad()
151
+ def medsam_inference(self, box_1024, height, width):
152
+ if self.embedding is None or self.medsam_model is None:
153
+ return None
154
+ box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=self.embedding.device)
155
+ if len(box_torch.shape) == 2:
156
+ box_torch = box_torch[:, None, :]
157
+ sparse_embeddings, dense_embeddings = self.medsam_model.prompt_encoder(
158
+ points=None, boxes=box_torch, masks=None,
159
+ )
160
+ low_res_logits, _ = self.medsam_model.mask_decoder(
161
+ image_embeddings=self.embedding,
162
+ image_pe=self.medsam_model.prompt_encoder.get_dense_pe(),
163
+ sparse_prompt_embeddings=sparse_embeddings,
164
+ dense_prompt_embeddings=dense_embeddings,
165
+ multimask_output=False,
166
+ )
167
+ low_res_pred = torch.sigmoid(low_res_logits)
168
+ low_res_pred = torch.nn.functional.interpolate(
169
+ low_res_pred, size=(height, width), mode="bilinear", align_corners=False,
170
+ )
171
+ low_res_pred = low_res_pred.squeeze().cpu().numpy()
172
+ medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
173
+ return medsam_seg
174
+
175
+ def segment_with_box(self, bbox):
176
+ if self.embedding is None or self.current_image is None:
177
+ return None
178
+ try:
179
+ H, W, _ = self.current_image.shape
180
+ x1, y1, x2, y2 = bbox
181
+ x1 = max(0, min(int(x1), W - 1))
182
+ y1 = max(0, min(int(y1), H - 1))
183
+ x2 = max(0, min(int(x2), W - 1))
184
+ y2 = max(0, min(int(y2), H - 1))
185
+ if x2 <= x1:
186
+ x2 = min(x1 + 10, W - 1)
187
+ if y2 <= y1:
188
+ y2 = min(y1 + 10, H - 1)
189
+ box_np = np.array([[x1, y1, x2, y2]], dtype=float)
190
+ box_1024 = box_np / np.array([W, H, W, H]) * 1024.0
191
+ medsam_mask = self.medsam_inference(box_1024, H, W)
192
+ if medsam_mask is not None:
193
+ return {"mask": medsam_mask, "confidence": 1.0, "method": "medsam_box"}
194
+ return None
195
+ except Exception as e:
196
+ print(f"Error in MedSAM box-based segmentation: {e}")
197
+ return None
198
+
199
+ # Single global instance
200
+ _medsam = MedSAMIntegrator()
201
+
202
+ # Cache for SAM automatic mask generator
203
+ _sam_auto_generator = None
204
+ _sam_auto_ckpt_path = None
205
+
206
+
207
+ def _get_sam_generator():
208
+ """Load and cache SAM ViT-H automatic mask generator with faster params if checkpoint exists."""
209
+ global _sam_auto_generator, _sam_auto_ckpt_path
210
+ if _sam_auto_generator is not None:
211
+ return _sam_auto_generator
212
+ try:
213
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
214
+ ckpt = "models/sam_vit_h_4b8939.pth"
215
+ if not os.path.exists(ckpt):
216
+ try:
217
+ from huggingface_hub import hf_hub_download
218
+ ckpt = hf_hub_download(
219
+ repo_id="Aniketg6/SAM",
220
+ filename="sam_vit_h_4b8939.pth",
221
+ cache_dir="./models"
222
+ )
223
+ print(f"βœ… Downloaded SAM ViT-H checkpoint to: {ckpt}")
224
+ except Exception as e:
225
+ print(f"⚠ Failed to download SAM ViT-H checkpoint: {e}")
226
+ return None
227
+ _sam_auto_ckpt_path = ckpt
228
+ sam = sam_model_registry["vit_h"](checkpoint=ckpt)
229
+ # Speed-tuned generator params
230
+ _sam_auto_generator = SamAutomaticMaskGenerator(
231
+ sam,
232
+ points_per_side=16,
233
+ pred_iou_thresh=0.88,
234
+ stability_score_thresh=0.9,
235
+ crop_n_layers=0,
236
+ box_nms_thresh=0.7,
237
+ min_mask_region_area=512 # filter tiny masks
238
+ )
239
+ return _sam_auto_generator
240
+ except Exception as e:
241
+ print(f"_get_sam_generator failed: {e}")
242
+ return None
243
+
244
+
245
+ def _extract_bboxes_from_mmdet_result(det_result):
246
+ """Extract Nx4 xyxy bboxes from various MMDet result formats."""
247
+ boxes = []
248
+ try:
249
+ # MMDet 3.x: list of DetDataSample
250
+ if isinstance(det_result, list) and len(det_result) > 0:
251
+ sample = det_result[0]
252
+ if hasattr(sample, 'pred_instances'):
253
+ inst = sample.pred_instances
254
+ if hasattr(inst, 'bboxes'):
255
+ b = inst.bboxes
256
+ # mmengine structures may use .tensor for boxes
257
+ if hasattr(b, 'tensor'):
258
+ b = b.tensor
259
+ boxes = b.detach().cpu().numpy().tolist()
260
+ # Single DetDataSample
261
+ elif hasattr(det_result, 'pred_instances'):
262
+ inst = det_result.pred_instances
263
+ if hasattr(inst, 'bboxes'):
264
+ b = inst.bboxes
265
+ if hasattr(b, 'tensor'):
266
+ b = b.tensor
267
+ boxes = b.detach().cpu().numpy().tolist()
268
+ # MMDet 2.x: tuple of (bbox_result, segm_result)
269
+ elif isinstance(det_result, tuple) and len(det_result) >= 1:
270
+ bbox_result = det_result[0]
271
+ # bbox_result is list per class, each Nx5 [x1,y1,x2,y2,score]
272
+ if isinstance(bbox_result, (list, tuple)):
273
+ for arr in bbox_result:
274
+ try:
275
+ arr_np = np.array(arr)
276
+ if arr_np.ndim == 2 and arr_np.shape[1] >= 4:
277
+ boxes.extend(arr_np[:, :4].tolist())
278
+ except Exception:
279
+ continue
280
+ except Exception as e:
281
+ print(f"Failed to parse MMDet result for boxes: {e}")
282
+ return boxes
283
+
284
+
285
+ def _overlay_masks_on_image(image_pil, mask_list, alpha=0.4):
286
+ """Overlay binary masks on an image with random colors."""
287
+ if image_pil is None or not mask_list:
288
+ return image_pil
289
+ img = np.array(image_pil.convert('RGB'))
290
+ overlay = img.copy()
291
+ for idx, m in enumerate(mask_list):
292
+ if m is None or 'mask' not in m or m['mask'] is None:
293
+ continue
294
+ mask = m['mask'].astype(bool)
295
+ color = np.random.RandomState(seed=idx + 1234).randint(0, 255, size=3)
296
+ overlay[mask] = (0.5 * overlay[mask] + 0.5 * color).astype(np.uint8)
297
+ blended = (alpha * overlay + (1 - alpha) * img).astype(np.uint8)
298
+ return Image.fromarray(blended)
299
+
300
+
301
+ def _mask_to_polygons(mask: np.ndarray):
302
+ """Convert a binary mask (H,W) to a list of polygons ([[x,y], ...]) using OpenCV contours."""
303
+ try:
304
+ mask_u8 = (mask.astype(np.uint8) * 255)
305
+ contours, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
306
+ polygons = []
307
+ for cnt in contours:
308
+ if cnt is None or len(cnt) < 3:
309
+ continue
310
+ # Simplify contour slightly
311
+ epsilon = 0.002 * cv2.arcLength(cnt, True)
312
+ approx = cv2.approxPolyDP(cnt, epsilon, True)
313
+ poly = approx.reshape(-1, 2).tolist()
314
+ polygons.append(poly)
315
+ return polygons
316
+ except Exception as e:
317
+ print(f"_mask_to_polygons failed: {e}")
318
+ return []
319
+
320
+
321
+ def _find_largest_foreground_bbox(pil_img: Image.Image):
322
+ """Heuristic: find largest foreground region bbox via Otsu threshold on grayscale.
323
+ Returns [x1, y1, x2, y2] or full-image bbox if none found."""
324
+ try:
325
+ img = np.array(pil_img.convert('RGB'))
326
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
327
+ # Otsu threshold (invert if needed by checking mean)
328
+ _, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
329
+ # Assume foreground is darker; invert if threshold yields background as white majority
330
+ if th.mean() > 127:
331
+ th = 255 - th
332
+ # Morph close to connect regions
333
+ kernel = np.ones((5, 5), np.uint8)
334
+ th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
335
+ contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
336
+ if not contours:
337
+ W, H = pil_img.size
338
+ return [0, 0, W - 1, H - 1]
339
+ # Largest contour by area
340
+ cnt = max(contours, key=cv2.contourArea)
341
+ x, y, w, h = cv2.boundingRect(cnt)
342
+ # Pad a little
343
+ pad = int(0.02 * max(w, h))
344
+ x1 = max(0, x - pad)
345
+ y1 = max(0, y - pad)
346
+ x2 = min(img.shape[1] - 1, x + w + pad)
347
+ y2 = min(img.shape[0] - 1, y + h + pad)
348
+ return [x1, y1, x2, y2]
349
+ except Exception as e:
350
+ print(f"_find_largest_foreground_bbox failed: {e}")
351
+ W, H = pil_img.size
352
+ return [0, 0, W - 1, H - 1]
353
+
354
+
355
+ def _find_topk_foreground_bboxes(pil_img: Image.Image, max_regions: int = 20, min_area: int = 100):
356
+ """Find top-K foreground bboxes via Otsu threshold + morphology. Returns list of [x1,y1,x2,y2]."""
357
+ try:
358
+ img = np.array(pil_img.convert('RGB'))
359
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
360
+ _, th = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
361
+ if th.mean() > 127:
362
+ th = 255 - th
363
+ kernel = np.ones((3, 3), np.uint8)
364
+ th = cv2.morphologyEx(th, cv2.MORPH_OPEN, kernel, iterations=1)
365
+ th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
366
+ contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
367
+ if not contours:
368
+ return []
369
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
370
+ bboxes = []
371
+ H, W = img.shape[:2]
372
+ for cnt in contours:
373
+ area = cv2.contourArea(cnt)
374
+ if area < min_area:
375
+ continue
376
+ x, y, w, h = cv2.boundingRect(cnt)
377
+ # Filter very thin shapes
378
+ if w < 5 or h < 5:
379
+ continue
380
+ pad = int(0.01 * max(w, h))
381
+ x1 = max(0, x - pad)
382
+ y1 = max(0, y - pad)
383
+ x2 = min(W - 1, x + w + pad)
384
+ y2 = min(H - 1, y + h + pad)
385
+ bboxes.append([x1, y1, x2, y2])
386
+ if len(bboxes) >= max_regions:
387
+ break
388
+ return bboxes
389
+ except Exception as e:
390
+ print(f"_find_topk_foreground_bboxes failed: {e}")
391
+ return []
392
+
393
+ # Try to import mmdet for inference
394
+ try:
395
+ from mmdet.apis import init_detector, inference_detector
396
+ MM_DET_AVAILABLE = True
397
+ print("βœ… MMDetection available for inference")
398
+ except ImportError as e:
399
+ print(f"⚠️ MMDetection import failed: {e}")
400
+ print("πŸ”„ Attempting to install MMDetection dependencies...")
401
+ try:
402
+ import subprocess
403
+ import sys
404
+
405
+ # Use the working solution with mim install
406
+ print("πŸ”„ Installing MMDetection dependencies with mim...")
407
+
408
+ # Install openmim if not already installed
409
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "openmim"])
410
+
411
+ # Install mmengine
412
+ subprocess.check_call([sys.executable, "-m", "mim", "install", "mmengine"])
413
+
414
+ # Install mmcv with mim (this handles compilation properly)
415
+ subprocess.check_call([sys.executable, "-m", "mim", "install", "mmcv==2.1.0"])
416
+
417
+ # Install mmdet
418
+ subprocess.check_call([sys.executable, "-m", "mim", "install", "mmdet"])
419
+
420
+ # Try importing again
421
+ from mmdet.apis import init_detector, inference_detector
422
+ MM_DET_AVAILABLE = True
423
+ print("βœ… MMDetection installed and available for inference")
424
+ except Exception as install_error:
425
+ print(f"❌ Failed to install MMDetection: {install_error}")
426
+ MM_DET_AVAILABLE = False
427
+
428
+ # === Chart Type Classification (DocFigure) ===
429
+ print("πŸ”„ Loading Chart Classification Model...")
430
+
431
+ # Chart type labels from DocFigure dataset (28 classes)
432
+ CHART_TYPE_LABELS = [
433
+ 'Line graph', 'Natural image', 'Table', '3D object', 'Bar plot', 'Scatter plot',
434
+ 'Medical image', 'Sketch', 'Geographic map', 'Flow chart', 'Heat map', 'Mask',
435
+ 'Block diagram', 'Venn diagram', 'Confusion matrix', 'Histogram', 'Box plot',
436
+ 'Vector plot', 'Pie chart', 'Surface plot', 'Algorithm', 'Contour plot',
437
+ 'Tree diagram', 'Bubble chart', 'Polar plot', 'Area chart', 'Pareto chart', 'Radar chart'
438
+ ]
439
+
440
+ try:
441
+ # Load the chart_type.pth model file from Hugging Face Hub
442
+ from huggingface_hub import hf_hub_download
443
+ import torch
444
+ from torchvision import transforms
445
+
446
+ print("πŸ”„ Downloading chart_type.pth from Hugging Face Hub...")
447
+ chart_type_path = hf_hub_download(
448
+ repo_id="hanszhu/ChartTypeNet-DocFigure",
449
+ filename="chart_type.pth",
450
+ cache_dir="./models"
451
+ )
452
+ print(f"βœ… Downloaded to: {chart_type_path}")
453
+
454
+ # Load the PyTorch model
455
+ loaded_data = torch.load(chart_type_path, map_location='cpu')
456
+
457
+ # Check if it's a state dict or a complete model
458
+ if isinstance(loaded_data, dict):
459
+ # Check if it's a checkpoint with model_state_dict
460
+ if "model_state_dict" in loaded_data:
461
+ print("πŸ”„ Loading checkpoint, extracting model_state_dict...")
462
+ state_dict = loaded_data["model_state_dict"]
463
+ else:
464
+ # It's a direct state dict
465
+ print("πŸ”„ Loading state dict, creating model architecture...")
466
+ state_dict = loaded_data
467
+
468
+ # Strip "backbone." prefix from state dict keys if present
469
+ cleaned_state_dict = {}
470
+ for key, value in state_dict.items():
471
+ if key.startswith("backbone."):
472
+ # Remove "backbone." prefix
473
+ new_key = key[9:]
474
+ cleaned_state_dict[new_key] = value
475
+ else:
476
+ cleaned_state_dict[key] = value
477
+
478
+ print(f"πŸ”„ Cleaned state dict: {len(cleaned_state_dict)} keys")
479
+
480
+ # Create the model architecture
481
+ from torchvision.models import resnet50
482
+ chart_type_model = resnet50(pretrained=False)
483
+
484
+ # Create the correct classifier structure to match the state dict
485
+ import torch.nn as nn
486
+ in_features = chart_type_model.fc.in_features
487
+ dropout = nn.Dropout(0.5)
488
+
489
+ chart_type_model.fc = nn.Sequential(
490
+ nn.Linear(in_features, 512),
491
+ nn.ReLU(inplace=True),
492
+ dropout,
493
+ nn.Linear(512, 28)
494
+ )
495
+
496
+ # Load the cleaned state dict
497
+ chart_type_model.load_state_dict(cleaned_state_dict)
498
+ else:
499
+ # It's a complete model
500
+ chart_type_model = loaded_data
501
+
502
+ chart_type_model.eval()
503
+
504
+ # Create a simple processor for the model
505
+ chart_type_processor = transforms.Compose([
506
+ transforms.Resize((224, 224)),
507
+ transforms.ToTensor(),
508
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
509
+ ])
510
+
511
+ CHART_TYPE_AVAILABLE = True
512
+ print("βœ… Chart classification model loaded")
513
+ except Exception as e:
514
+ print(f"⚠️ Failed to load chart classification model: {e}")
515
+ import traceback
516
+ print("πŸ” Full traceback:")
517
+ traceback.print_exc()
518
+ CHART_TYPE_AVAILABLE = False
519
+
520
+ # === Chart Element Detection (Cascade R-CNN) ===
521
+ element_model = None
522
+ datapoint_model = None
523
+
524
+ print(f"πŸ” MM_DET_AVAILABLE: {MM_DET_AVAILABLE}")
525
+
526
+ if MM_DET_AVAILABLE:
527
+ # Check if config files exist
528
+ element_config = "models/chart_elementnet_swin.py"
529
+ point_config = "models/chart_pointnet_swin.py"
530
+
531
+ print(f"πŸ” Checking config files...")
532
+ print(f"πŸ” Element config exists: {os.path.exists(element_config)}")
533
+ print(f"πŸ” Point config exists: {os.path.exists(point_config)}")
534
+ print(f"πŸ” Current working directory: {os.getcwd()}")
535
+ print(f"πŸ” Files in models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}")
536
+
537
+ try:
538
+ print("πŸ”„ Loading ChartElementNet-MultiClass (Cascade R-CNN)...")
539
+ print(f"πŸ” Config path: {element_config}")
540
+ print(f"πŸ” Weights path: hanszhu/ChartElementNet-MultiClass")
541
+ print(f"πŸ” About to call init_detector...")
542
+
543
+ # Download model from Hugging Face Hub
544
+ from huggingface_hub import hf_hub_download
545
+ print("πŸ”„ Downloading ChartElementNet weights from Hugging Face Hub...")
546
+ element_checkpoint = hf_hub_download(
547
+ repo_id="hanszhu/ChartElementNet-MultiClass",
548
+ filename="chart_label+.pth",
549
+ cache_dir="./models"
550
+ )
551
+ print(f"βœ… Downloaded to: {element_checkpoint}")
552
+
553
+ # Use local config with downloaded weights
554
+ element_model = init_detector(element_config, element_checkpoint, device="cpu")
555
+ print("βœ… ChartElementNet loaded successfully")
556
+ except Exception as e:
557
+ print(f"❌ Failed to load ChartElementNet: {e}")
558
+ print(f"πŸ” Error type: {type(e).__name__}")
559
+ print(f"πŸ” Error details: {str(e)}")
560
+ import traceback
561
+ print("πŸ” Full traceback:")
562
+ traceback.print_exc()
563
+
564
+ try:
565
+ print("πŸ”„ Loading ChartPointNet-InstanceSeg (Mask R-CNN)...")
566
+ print(f"πŸ” Config path: {point_config}")
567
+ print(f"πŸ” Weights path: hanszhu/ChartPointNet-InstanceSeg")
568
+ print(f"πŸ” About to call init_detector...")
569
+
570
+ # Download model from Hugging Face Hub
571
+ print("πŸ”„ Downloading ChartPointNet weights from Hugging Face Hub...")
572
+ datapoint_checkpoint = hf_hub_download(
573
+ repo_id="hanszhu/ChartPointNet-InstanceSeg",
574
+ filename="chart_datapoint.pth",
575
+ cache_dir="./models"
576
+ )
577
+ print(f"βœ… Downloaded to: {datapoint_checkpoint}")
578
+
579
+ # Use local config with downloaded weights
580
+ datapoint_model = init_detector(point_config, datapoint_checkpoint, device="cpu")
581
+ print("βœ… ChartPointNet loaded successfully")
582
+ except Exception as e:
583
+ print(f"❌ Failed to load ChartPointNet: {e}")
584
+ print(f"πŸ” Error type: {type(e).__name__}")
585
+ print(f"πŸ” Error details: {str(e)}")
586
+ import traceback
587
+ print("πŸ” Full traceback:")
588
+ traceback.print_exc()
589
+ else:
590
+ print("❌ MMDetection not available - cannot load custom models")
591
+ print(f"πŸ” MM_DET_AVAILABLE was False")
592
+
593
+ print(f"πŸ” Final model status:")
594
+ print(f"πŸ” element_model: {element_model is not None}")
595
+ print(f"πŸ” datapoint_model: {datapoint_model is not None}")
596
+
597
+ # === Main prediction function ===
598
+ def analyze(image):
599
+ """
600
+ Analyze a chart image and return comprehensive results.
601
+
602
+ Args:
603
+ image: Input chart image (filepath string or PIL.Image)
604
+
605
+ Returns:
606
+ dict: Analysis results containing:
607
+ - chart_type_id (int): Numeric chart type identifier (0-27)
608
+ - chart_type_label (str): Human-readable chart type name
609
+ - element_result (str): Detected chart elements (titles, axes, legends, etc.)
610
+ - datapoint_result (str): Segmented data points and regions
611
+ - status (str): Processing status message
612
+ - processing_time (float): Time taken for analysis in seconds
613
+ """
614
+ import time
615
+ from PIL import Image
616
+
617
+ start_time = time.time()
618
+
619
+ # Handle filepath input (convert to PIL Image)
620
+ if isinstance(image, str):
621
+ # It's a filepath, load the image
622
+ image = Image.open(image).convert("RGB")
623
+ elif image is None:
624
+ return {"error": "No image provided"}
625
+
626
+ # Ensure we have a PIL Image
627
+ if not isinstance(image, Image.Image):
628
+ return {"error": "Invalid image format"}
629
+
630
+ result = {
631
+ "chart_type_id": "Model not available",
632
+ "chart_type_label": "Model not available",
633
+ "element_result": "MMDetection models not available",
634
+ "datapoint_result": "MMDetection models not available",
635
+ "status": "Basic chart classification only",
636
+ "processing_time": 0.0,
637
+ "medsam": {"available": False}
638
+ }
639
+
640
+ # Chart Type Classification
641
+ if CHART_TYPE_AVAILABLE:
642
+ try:
643
+ # Preprocess image for PyTorch model
644
+ processed_image = chart_type_processor(image).unsqueeze(0) # Add batch dimension
645
+
646
+ # Get prediction
647
+ with torch.no_grad():
648
+ outputs = chart_type_model(processed_image)
649
+ # Handle different output formats
650
+ if isinstance(outputs, torch.Tensor):
651
+ logits = outputs
652
+ elif hasattr(outputs, 'logits'):
653
+ logits = outputs.logits
654
+ else:
655
+ logits = outputs
656
+
657
+ predicted_class = logits.argmax(dim=-1).item()
658
+
659
+ result["chart_type_id"] = predicted_class
660
+ result["chart_type_label"] = CHART_TYPE_LABELS[predicted_class] if 0 <= predicted_class < len(CHART_TYPE_LABELS) else f"Unknown ({predicted_class})"
661
+ result["status"] = "Chart classification completed"
662
+
663
+ except Exception as e:
664
+ result["chart_type_id"] = f"Error: {str(e)}"
665
+ result["chart_type_label"] = f"Error: {str(e)}"
666
+ result["status"] = "Error in chart classification"
667
+
668
+ # Chart Element Detection (Cascade R-CNN)
669
+ if element_model is not None:
670
+ try:
671
+ # If medical image, skip heavy MMDet to speed up
672
+ if isinstance(result.get("chart_type_label"), str) and result["chart_type_label"].lower() == "medical image":
673
+ result["element_result"] = "skipped_for_medical"
674
+ else:
675
+ # Convert PIL image to numpy array for MMDetection
676
+ np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL β†’ BGR
677
+
678
+ element_result = inference_detector(element_model, np_img)
679
+
680
+ # Convert result to more API-friendly format
681
+ if isinstance(element_result, tuple):
682
+ bbox_result, segm_result = element_result
683
+ element_data = {
684
+ "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
685
+ "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
686
+ }
687
+ else:
688
+ element_data = str(element_result)
689
+
690
+ result["element_result"] = element_data
691
+ result["status"] = "Chart classification + element detection completed"
692
+ except Exception as e:
693
+ result["element_result"] = f"Error: {str(e)}"
694
+
695
+ # Chart Data Point Segmentation (Mask R-CNN)
696
+ if datapoint_model is not None:
697
+ try:
698
+ # If medical image, skip heavy MMDet to speed up
699
+ if isinstance(result.get("chart_type_label"), str) and result["chart_type_label"].lower() == "medical image":
700
+ result["datapoint_result"] = "skipped_for_medical"
701
+ else:
702
+ # Convert PIL image to numpy array for MMDetection
703
+ np_img = np.array(image.convert("RGB"))[:, :, ::-1] # PIL β†’ BGR
704
+
705
+ datapoint_result = inference_detector(datapoint_model, np_img)
706
+
707
+ # Convert result to more API-friendly format
708
+ if isinstance(datapoint_result, tuple):
709
+ bbox_result, segm_result = datapoint_result
710
+ datapoint_data = {
711
+ "bboxes": bbox_result.tolist() if hasattr(bbox_result, 'tolist') else str(bbox_result),
712
+ "segments": segm_result.tolist() if hasattr(segm_result, 'tolist') else str(segm_result)
713
+ }
714
+ else:
715
+ datapoint_data = str(datapoint_result)
716
+
717
+ result["datapoint_result"] = datapoint_data
718
+ result["status"] = "Full analysis completed"
719
+ except Exception as e:
720
+ result["datapoint_result"] = f"Error: {str(e)}"
721
+
722
+ # If predicted as medical image and MedSAM is available, include mask data (polygons)
723
+ try:
724
+ label_lower = str(result.get("chart_type_label", "")).strip().lower()
725
+ if label_lower == "medical image":
726
+ if _medsam.is_available():
727
+ # Do not run heuristics here. Prompts are required and handled in the UI then-chain.
728
+ # Indicate availability and that prompts are needed for segmentation.
729
+ result["medsam"] = {"available": True, "reason": "provide bbox/points prompts to generate segmentations"}
730
+ else:
731
+ # Not available; include reason
732
+ result["medsam"] = {"available": False, "reason": "segment_anything or checkpoint missing"}
733
+ except Exception as e:
734
+ print(f"MedSAM JSON augmentation failed: {e}")
735
+
736
+ result["processing_time"] = round(time.time() - start_time, 3)
737
+ return result
738
+
739
+
740
+ def analyze_with_medsam(base_result, image):
741
+ """Auto-generate segmentations for medical images using SAM ViT-H if available,
742
+ otherwise fallback to MedSAM over top-K foreground boxes. Returns updated JSON and overlay image."""
743
+ try:
744
+ if not isinstance(base_result, dict):
745
+ return base_result, None
746
+ label = str(base_result.get("chart_type_label", "")).strip().lower()
747
+ if label != "medical image":
748
+ return base_result, None
749
+
750
+ pil_img = Image.open(image).convert("RGB") if isinstance(image, str) else image
751
+ if pil_img is None:
752
+ return base_result, None
753
+
754
+ segmentations = []
755
+ masks_for_overlay = []
756
+
757
+ # Try fast SAM generator first; avoid MedSAM embedding when SAM is available
758
+ gen = _get_sam_generator()
759
+ if gen is not None and _sam_auto_ckpt_path is not None and os.path.exists(_sam_auto_ckpt_path):
760
+ try:
761
+ import cv2 as _cv2
762
+ img_path = image if isinstance(image, str) else None
763
+ if img_path is None:
764
+ tmp_path = "./_tmp_input_image.png"
765
+ pil_img.save(tmp_path)
766
+ img_path = tmp_path
767
+ img_bgr = _cv2.imread(img_path)
768
+ masks = gen.generate(img_bgr)
769
+ # Keep top-K by stability_score or area
770
+ def _score(m):
771
+ s = float(m.get('stability_score', 0.0))
772
+ seg = m.get('segmentation', None)
773
+ area = int(seg.sum()) if isinstance(seg, np.ndarray) else 0
774
+ return (s, area)
775
+ masks = sorted(masks, key=_score, reverse=True)[:8]
776
+ for m in masks:
777
+ seg = m.get('segmentation', None)
778
+ if seg is None:
779
+ continue
780
+ seg_u8 = seg.astype(np.uint8)
781
+ segmentations.append({
782
+ "mask": seg_u8.tolist(),
783
+ "confidence": float(m.get('stability_score', 1.0)),
784
+ "method": "sam_auto"
785
+ })
786
+ masks_for_overlay.append({"mask": seg_u8})
787
+ except Exception as e:
788
+ print(f"SAM generator segmentation failed: {e}")
789
+
790
+ # Fallback to MedSAM boxes only if nothing produced
791
+ if not segmentations and _medsam.is_available():
792
+ try:
793
+ # Prepare embedding once
794
+ img_path = image if isinstance(image, str) else None
795
+ if img_path is None:
796
+ tmp_path = "./_tmp_input_image.png"
797
+ pil_img.save(tmp_path)
798
+ img_path = tmp_path
799
+ _medsam.load_image(img_path)
800
+ cand_bboxes = _find_topk_foreground_bboxes(pil_img, max_regions=5, min_area=400)
801
+ for bbox in cand_bboxes:
802
+ m = _medsam.segment_with_box(bbox)
803
+ if m is None or not isinstance(m.get('mask'), np.ndarray):
804
+ continue
805
+ segmentations.append({
806
+ "mask": m['mask'].astype(np.uint8).tolist(),
807
+ "confidence": float(m.get('confidence', 1.0)),
808
+ "method": m.get("method", "medsam_box_auto")
809
+ })
810
+ masks_for_overlay.append(m)
811
+ except Exception as auto_e:
812
+ print(f"MedSAM fallback segmentation failed: {auto_e}")
813
+
814
+ W, H = pil_img.size
815
+ base_result["medsam"] = {
816
+ "available": True,
817
+ "height": H,
818
+ "width": W,
819
+ "segmentations": segmentations,
820
+ "num_segments": len(segmentations)
821
+ }
822
+
823
+ overlay_img = _overlay_masks_on_image(pil_img, masks_for_overlay) if masks_for_overlay else None
824
+ return base_result, overlay_img
825
+ except Exception as e:
826
+ print(f"analyze_with_medsam failed: {e}")
827
+ return base_result, None
828
+
829
+ # === Gradio UI with API enhancements ===
830
+ # Create Blocks interface with explicit API name for stable API surface
831
+ with gr.Blocks(
832
+ title="πŸ“Š Dense Captioning Platform"
833
+ ) as demo:
834
+
835
+ gr.Markdown("# πŸ“Š Dense Captioning Platform")
836
+ gr.Markdown("""
837
+ **Comprehensive Chart Analysis API**
838
+
839
+ Upload a chart image to get:
840
+ - **Chart Type Classification**: Identifies the type of chart (line, bar, scatter, etc.)
841
+ - **Element Detection**: Detects chart elements like titles, axes, legends, data points
842
+ - **Data Point Segmentation**: Segments individual data points and regions
843
+
844
+ Masks will be automatically generated for medical images when supported.
845
+
846
+ **API Usage:**
847
+ ```python
848
+ from gradio_client import Client, handle_file
849
+
850
+ client = Client("hanszhu/Dense-Captioning-Platform")
851
+ result = client.predict(
852
+ image=handle_file('path/to/your/chart.png'),
853
+ api_name="/predict"
854
+ )
855
+ print(result)
856
+ ```
857
+
858
+ **Supported Chart Types:** Line graphs, Bar plots, Scatter plots, Pie charts, Heat maps, and 23+ more
859
+ """)
860
+
861
+ with gr.Row():
862
+ with gr.Column():
863
+ # Input
864
+ image_input = gr.Image(
865
+ type="filepath", # βœ… REQUIRED for gradio_client
866
+ label="Upload Chart Image",
867
+ height=400
868
+ )
869
+
870
+ # Analyze button (single)
871
+ analyze_btn = gr.Button(
872
+ "πŸ” Analyze",
873
+ variant="primary",
874
+ size="lg"
875
+ )
876
+
877
+ with gr.Column():
878
+ # Output JSON
879
+ result_output = gr.JSON(
880
+ label="Analysis Results",
881
+ height=400
882
+ )
883
+ # Overlay image output (populated only for medical images)
884
+ overlay_output = gr.Image(
885
+ label="MedSAM Overlay (Medical images)",
886
+ height=400
887
+ )
888
+
889
+ # Single API endpoint for JSON
890
+ analyze_event = analyze_btn.click(
891
+ fn=analyze,
892
+ inputs=image_input,
893
+ outputs=result_output,
894
+ api_name="/predict" # βœ… Standard API name that gradio_client expects
895
+ )
896
+
897
+ # Automatic overlay generation step for medical images
898
+ analyze_event.then(
899
+ fn=analyze_with_medsam,
900
+ inputs=[result_output, image_input],
901
+ outputs=[result_output, overlay_output],
902
+ )
903
+
904
+ # Add some examples
905
+ gr.Examples(
906
+ examples=[
907
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"]
908
+ ],
909
+ inputs=image_input,
910
+ label="Try with this example"
911
+ )
912
+
913
+ # Launch with API-friendly settings
914
+ if __name__ == "__main__":
915
+ launch_kwargs = {
916
+ "server_name": "0.0.0.0", # Allow external connections
917
+ "server_port": 7860,
918
+ "share": False, # Set to True if you want a public link
919
+ "show_error": True, # Show detailed errors for debugging
920
+ "quiet": False, # Show startup messages
921
+ "show_api": True # Enable API documentation
922
+ }
923
+
924
+ # Enable queue for gradio_client compatibility
925
+ demo.queue().launch(**launch_kwargs) # βœ… required for gradio_client to work
926
+
custom_models/custom_cascade_with_meta.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmdet.models.detectors import CascadeRCNN
2
+ from mmdet.registry import MODELS
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ @MODELS.register_module()
7
+ class CustomCascadeWithMeta(CascadeRCNN):
8
+ """Custom Cascade R-CNN with metadata prediction heads."""
9
+
10
+ def __init__(self,
11
+ *args,
12
+ chart_cls_head=None,
13
+ plot_reg_head=None,
14
+ axes_info_head=None,
15
+ data_series_head=None,
16
+ data_points_count_head=None,
17
+ coordinate_standardization=None,
18
+ data_series_config=None,
19
+ axis_aware_feature=None,
20
+ **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+
23
+ # Initialize metadata prediction heads
24
+ if chart_cls_head is not None:
25
+ self.chart_cls_head = MODELS.build(chart_cls_head)
26
+ if plot_reg_head is not None:
27
+ self.plot_reg_head = MODELS.build(plot_reg_head)
28
+ if axes_info_head is not None:
29
+ self.axes_info_head = MODELS.build(axes_info_head)
30
+ if data_series_head is not None:
31
+ self.data_series_head = MODELS.build(data_series_head)
32
+ if data_points_count_head is not None:
33
+ self.data_points_count_head = MODELS.build(data_points_count_head)
34
+ else:
35
+ # Default simple regression head for data point count
36
+ self.data_points_count_head = nn.Sequential(
37
+ nn.Linear(2048, 512), # Assuming ResNet-50 backbone features
38
+ nn.ReLU(),
39
+ nn.Dropout(0.1),
40
+ nn.Linear(512, 1) # Single output for count
41
+ )
42
+
43
+ # Store configurations
44
+ self.coordinate_standardization = coordinate_standardization
45
+ self.data_series_config = data_series_config
46
+ self.axis_aware_feature = axis_aware_feature
47
+
48
+ def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
49
+ """Forward function during training."""
50
+ # Get base detector predictions
51
+ x = self.extract_feat(img)
52
+ losses = dict()
53
+
54
+ # RPN forward and loss
55
+ if self.with_rpn:
56
+ proposal_cfg = self.train_cfg.get('rpn_proposal',
57
+ self.test_cfg.rpn)
58
+ rpn_losses, proposal_list = self.rpn_head.forward_train(
59
+ x,
60
+ img_metas,
61
+ gt_bboxes,
62
+ gt_labels=None,
63
+ ann_weight=None,
64
+ proposal_cfg=proposal_cfg)
65
+ losses.update(rpn_losses)
66
+ else:
67
+ proposal_list = kwargs.get('proposals', None)
68
+
69
+ # ROI forward and loss
70
+ roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
71
+ gt_bboxes, gt_labels, **kwargs)
72
+ losses.update(roi_losses)
73
+
74
+ # Get global features for metadata prediction
75
+ global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
76
+
77
+ # Extract ground truth data point counts from img_metas
78
+ gt_data_point_counts = []
79
+ for img_meta in img_metas:
80
+ count = img_meta.get('img_info', {}).get('num_data_points', 0)
81
+ gt_data_point_counts.append(count)
82
+ gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
83
+
84
+ # Predict data point counts and compute loss
85
+ pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
86
+ data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
87
+ losses['data_points_count_loss'] = data_points_count_loss
88
+
89
+ # Use predicted data point count as additional feature for ROI head
90
+ # Expand the global feature with data point count information
91
+ normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range
92
+ enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
93
+
94
+ # Metadata prediction losses
95
+ if hasattr(self, 'chart_cls_head'):
96
+ chart_cls_loss = self.chart_cls_head(enhanced_global_feat)
97
+ losses['chart_cls_loss'] = chart_cls_loss
98
+
99
+ if hasattr(self, 'plot_reg_head'):
100
+ plot_reg_loss = self.plot_reg_head(enhanced_global_feat)
101
+ losses['plot_reg_loss'] = plot_reg_loss
102
+
103
+ if hasattr(self, 'axes_info_head'):
104
+ axes_info_loss = self.axes_info_head(enhanced_global_feat)
105
+ losses['axes_info_loss'] = axes_info_loss
106
+
107
+ if hasattr(self, 'data_series_head'):
108
+ data_series_loss = self.data_series_head(enhanced_global_feat)
109
+ losses['data_series_loss'] = data_series_loss
110
+
111
+ return losses
112
+
113
+ def simple_test(self, img, img_metas, **kwargs):
114
+ """Test without augmentation."""
115
+ x = self.extract_feat(img)
116
+ proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
117
+ det_bboxes, det_labels = self.roi_head.simple_test_bboxes(
118
+ x, img_metas, proposal_list, self.test_cfg.rcnn, **kwargs)
119
+
120
+ # Get global features for metadata prediction
121
+ global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
122
+
123
+ # Predict data point counts
124
+ pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
125
+
126
+ # Use predicted data point count as additional feature
127
+ normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range
128
+ enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
129
+
130
+ # Get metadata predictions
131
+ results = []
132
+ for i, (bboxes, labels) in enumerate(zip(det_bboxes, det_labels)):
133
+ result = DetDataSample()
134
+ result.bboxes = bboxes
135
+ result.labels = labels
136
+
137
+ # Add data point count prediction
138
+ result.predicted_data_points = pred_data_point_counts[i].item()
139
+
140
+ # Add metadata predictions using enhanced features
141
+ if hasattr(self, 'chart_cls_head'):
142
+ result.chart_type = self.chart_cls_head(enhanced_global_feat[i:i+1])
143
+ if hasattr(self, 'plot_reg_head'):
144
+ result.plot_bb = self.plot_reg_head(enhanced_global_feat[i:i+1])
145
+ if hasattr(self, 'axes_info_head'):
146
+ result.axes_info = self.axes_info_head(enhanced_global_feat[i:i+1])
147
+ if hasattr(self, 'data_series_head'):
148
+ result.data_series = self.data_series_head(enhanced_global_feat[i:i+1])
149
+
150
+ results.append(result)
151
+
152
+ return results
custom_models/custom_dataset.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path as osp
3
+ import numpy as np
4
+ from mmcv.transforms import BaseTransform
5
+ from mmcv.transforms import LoadImageFromFile
6
+ from mmdet.registry import DATASETS, TRANSFORMS
7
+ from mmdet.datasets.transforms import PackDetInputs
8
+ from mmdet.datasets.base_det_dataset import BaseDetDataset
9
+ import warnings
10
+
11
+ # ─── Enhanced robust image loader for real images ───
12
+ @TRANSFORMS.register_module()
13
+ class RobustLoadImageFromFile(LoadImageFromFile):
14
+ """Enhanced image loader: tries real images first, falls back to dummy if needed."""
15
+
16
+ # Class variable to track missing images
17
+ missing_count = 0
18
+
19
+ def __init__(self, try_real_images=True, fallback_to_dummy=True, **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.try_real_images = try_real_images
22
+ self.fallback_to_dummy = fallback_to_dummy
23
+
24
+ def transform(self, results):
25
+ """Try to load real image first, fall back to dummy if not found."""
26
+ if self.try_real_images:
27
+ try:
28
+ # Try standard MMDet image loading first
29
+ results = super().transform(results)
30
+ return results
31
+
32
+ except (FileNotFoundError, OSError, Exception) as e:
33
+ # Count missing image
34
+ RobustLoadImageFromFile.missing_count += 1
35
+
36
+ # Log warning every 10 missing images to avoid spam
37
+ if RobustLoadImageFromFile.missing_count % 10 == 1:
38
+ warnings.warn(f"Missing image #{RobustLoadImageFromFile.missing_count}: {results.get('img_path', 'unknown')}. "
39
+ f"Total missing so far: {RobustLoadImageFromFile.missing_count}",
40
+ UserWarning)
41
+
42
+ if not self.fallback_to_dummy:
43
+ raise e
44
+ # Fall through to create dummy image
45
+
46
+ # Create dummy image (either by choice or because real image loading failed)
47
+ if 'img_shape' in results:
48
+ h, w = results['img_shape'][:2]
49
+ else:
50
+ h = results.get('height', 800)
51
+ w = results.get('width', 600)
52
+
53
+ results['img'] = np.zeros((h, w, 3), dtype=np.uint8)
54
+ results['img_shape'] = (h, w, 3)
55
+ results['ori_shape'] = (h, w, 3)
56
+ return results
57
+
58
+ @classmethod
59
+ def get_missing_count(cls):
60
+ """Get the total count of missing images."""
61
+ return cls.missing_count
62
+
63
+ @classmethod
64
+ def reset_missing_count(cls):
65
+ """Reset the missing image counter."""
66
+ cls.missing_count = 0
67
+
68
+ # ─── Legacy support for old transform name ───
69
+ @TRANSFORMS.register_module()
70
+ class CreateDummyImg(RobustLoadImageFromFile):
71
+ """Legacy alias for RobustLoadImageFromFile."""
72
+ pass
73
+
74
+ @TRANSFORMS.register_module()
75
+ class ClampBBoxes(BaseTransform):
76
+ """Simple bbox clamping transform - only clamps coordinates, doesn't filter."""
77
+ def __init__(self, min_size=1):
78
+ self.min_size = min_size
79
+
80
+ def transform(self, results):
81
+ """Clamp bboxes to image bounds without removing any boxes."""
82
+ if 'gt_bboxes' not in results:
83
+ return results
84
+
85
+ h, w = results['img_shape'][:2]
86
+
87
+ # Handle both numpy arrays and MMDet's HorizontalBoxes objects
88
+ gt_bboxes = results['gt_bboxes']
89
+ if hasattr(gt_bboxes, 'tensor'):
90
+ # MMDet HorizontalBoxes object - clamp in place
91
+ gt_bboxes.tensor[:, 0].clamp_(0, w) # x1
92
+ gt_bboxes.tensor[:, 1].clamp_(0, h) # y1
93
+ gt_bboxes.tensor[:, 2].clamp_(0, w) # x2
94
+ gt_bboxes.tensor[:, 3].clamp_(0, h) # y2
95
+ else:
96
+ # Regular numpy array - clamp in place
97
+ if len(gt_bboxes) > 0:
98
+ gt_bboxes[:, 0] = np.clip(gt_bboxes[:, 0], 0, w) # x1
99
+ gt_bboxes[:, 1] = np.clip(gt_bboxes[:, 1], 0, h) # y1
100
+ gt_bboxes[:, 2] = np.clip(gt_bboxes[:, 2], 0, w) # x2
101
+ gt_bboxes[:, 3] = np.clip(gt_bboxes[:, 3], 0, h) # y2
102
+
103
+ # Don't drop anything here - let filter_cfg handle empty GT filtering
104
+ results['gt_bboxes'] = gt_bboxes
105
+ return results
106
+
107
+ @TRANSFORMS.register_module()
108
+ class SetScaleFactor(BaseTransform):
109
+ """Compute scale_factor from data_series & plot_bb before any Resize."""
110
+ def __init__(self, default_scale=(1.0, 1.0)):
111
+ self.default_scale = default_scale
112
+
113
+ def calculate_scale_factor(self, results):
114
+ bb = results.get('plot_bb', {})
115
+ w, h = bb.get('width', 0), bb.get('height', 0)
116
+ xs, ys = [], []
117
+ for series in results.get('data_series', []):
118
+ for pt in series.get('data', []):
119
+ x, y = pt.get('x'), pt.get('y')
120
+ if isinstance(x, (int, float)): xs.append(x)
121
+ if isinstance(y, (int, float)): ys.append(y)
122
+ if xs and max(xs) != min(xs):
123
+ x_scale = w / (max(xs) - min(xs))
124
+ else:
125
+ x_scale = self.default_scale[0]
126
+ if ys and max(ys) != min(ys):
127
+ y_scale = -h / (max(ys) - min(ys))
128
+ else:
129
+ y_scale = self.default_scale[1]
130
+ return (x_scale, y_scale)
131
+
132
+ def transform(self, results):
133
+ try:
134
+ sf = self.calculate_scale_factor(results)
135
+ results['scale_factor'] = np.array(sf, dtype=np.float32)
136
+ except Exception:
137
+ results['scale_factor'] = np.array(self.default_scale, dtype=np.float32)
138
+ H, W = results.get('height', 0), results.get('width', 0)
139
+ results['img_shape'] = (H, W, 3)
140
+ return results
141
+
142
+ @TRANSFORMS.register_module()
143
+ class EnsureScaleFactor(BaseTransform):
144
+ """Fallback if no scale_factor set yet."""
145
+ def transform(self, results):
146
+ results['scale_factor'] = np.array([1.0, 1.0], dtype=np.float32)
147
+ return results
148
+
149
+ @TRANSFORMS.register_module()
150
+ class SetInputs(BaseTransform):
151
+ """Copy dummy img into inputs for DetDataPreprocessor."""
152
+ def transform(self, results):
153
+ if 'img' in results:
154
+ results['inputs'] = results['img'].copy()
155
+ return results
156
+
157
+ @TRANSFORMS.register_module()
158
+ class CustomPackDetInputs(PackDetInputs):
159
+ """Final packing into DetDataSample, ensure inputs present."""
160
+ def transform(self, results):
161
+ if 'img' in results:
162
+ results['inputs'] = results['img'].copy()
163
+ return super().transform(results)
164
+
165
+ @DATASETS.register_module()
166
+ class ChartDataset(BaseDetDataset):
167
+ """Enhanced dataset for comprehensive chart element detection and analysis."""
168
+
169
+ # Updated METAINFO with 21 enhanced categories
170
+ METAINFO = {
171
+ 'classes': [
172
+ 'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label',
173
+ 'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item',
174
+ 'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line',
175
+ 'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area'
176
+ ]
177
+ }
178
+
179
+ # Chart-type specific element filtering based on actual dataset distribution
180
+ # Data from analyze_chart_types.py:
181
+ # β€’ line (41.9%): 1710 images β†’ data-line only
182
+ # β€’ scatter (18.2%): 742 images β†’ data-point only
183
+ # β€’ vertical_bar (30.5%): 1246 images β†’ data-bar only
184
+ # β€’ dot (9.2%): 374 images β†’ data-point only
185
+ # β€’ horizontal_bar (0.2%): 9 images β†’ data-bar only
186
+ CHART_TYPE_ELEMENT_MAPPING = {
187
+ # Line charts (41.9% - 1710 images): ONLY data-line
188
+ 'line': {
189
+ 'allowed_data_elements': {'data-line'},
190
+ 'forbidden_data_elements': {'data-point', 'data-bar', 'data-area'}
191
+ },
192
+ # Scatter charts (18.2% - 742 images): ONLY data-point
193
+ 'scatter': {
194
+ 'allowed_data_elements': {'data-point'},
195
+ 'forbidden_data_elements': {'data-line', 'data-bar', 'data-area'}
196
+ },
197
+ # Vertical bar charts (30.5% - 1246 images): ONLY data-bar
198
+ 'vertical_bar': {
199
+ 'allowed_data_elements': {'data-bar'},
200
+ 'forbidden_data_elements': {'data-point', 'data-line', 'data-area'}
201
+ },
202
+ # Dot charts (9.2% - 374 images): ONLY data-point
203
+ 'dot': {
204
+ 'allowed_data_elements': {'data-point'},
205
+ 'forbidden_data_elements': {'data-line', 'data-bar', 'data-area'}
206
+ },
207
+ # Horizontal bar charts (0.2% - 9 images): ONLY data-bar
208
+ 'horizontal_bar': {
209
+ 'allowed_data_elements': {'data-bar'},
210
+ 'forbidden_data_elements': {'data-point', 'data-line', 'data-area'}
211
+ }
212
+ }
213
+
214
+ def __init__(self, *args, **kwargs):
215
+ super().__init__(*args, **kwargs)
216
+ self.metainfo.update(self.METAINFO)
217
+
218
+ # Print configuration info
219
+ print(f"πŸ“Š ChartDataset initialized with {len(self.METAINFO['classes'])} categories:")
220
+ for i, cls_name in enumerate(self.METAINFO['classes']):
221
+ print(f" {i}: {cls_name}")
222
+
223
+ # Print chart-type filtering info
224
+ print(f"🎯 Chart-type specific filtering enabled:")
225
+ for chart_type, mapping in self.CHART_TYPE_ELEMENT_MAPPING.items():
226
+ allowed = mapping.get('allowed_data_elements', set())
227
+ forbidden = mapping.get('forbidden_data_elements', set())
228
+ print(f" β€’ {chart_type}: βœ… {allowed} | 🚫 {forbidden}")
229
+
230
+ # Debug print the data configuration
231
+ print(f"πŸ“ Dataset configuration:")
232
+ print(f" β€’ data_root: {getattr(self, 'data_root', 'None')}")
233
+ print(f" β€’ data_prefix: {getattr(self, 'data_prefix', 'None')}")
234
+ print(f" β€’ ann_file: {getattr(self, 'ann_file', 'None')}")
235
+
236
+ def load_data_list(self):
237
+ """Load enhanced annotation files with priority order."""
238
+
239
+ # Auto-detect best annotation file (same logic as config)
240
+ def get_best_ann_file(split):
241
+ ann_dir = osp.join(self.data_root, 'annotations_JSON')
242
+
243
+ # Priority order with flexible naming
244
+ candidates = [
245
+ f'{split}_enriched_with_info.json',
246
+ f'{split}_enriched.json',
247
+ f'{split}_with_info.json', # Added: Handles val_with_info.json
248
+ f'{split}.json',
249
+ f'{split}_cleaned.json'
250
+ ]
251
+
252
+ for candidate in candidates:
253
+ full_path = osp.join(ann_dir, candidate)
254
+ if osp.exists(full_path):
255
+ print(f"πŸ“ ChartDataset using {candidate}")
256
+ return full_path
257
+
258
+ # Fallback to ann_file if specified
259
+ if hasattr(self, 'ann_file') and self.ann_file:
260
+ fallback_path = osp.join(self.data_root, self.ann_file)
261
+ if osp.exists(fallback_path):
262
+ print(f"πŸ“ Using fallback annotation file: {self.ann_file}")
263
+ return fallback_path
264
+
265
+ raise FileNotFoundError(f"No annotation files found in {ann_dir}")
266
+
267
+ # Determine file path
268
+ if hasattr(self, 'ann_file') and self.ann_file:
269
+ ann_file_path = osp.join(self.data_root, self.ann_file)
270
+ else:
271
+ # Try to auto-detect based on common patterns
272
+ for split in ['train', 'val']:
273
+ try:
274
+ ann_file_path = get_best_ann_file(split)
275
+ break
276
+ except FileNotFoundError:
277
+ continue
278
+ else:
279
+ raise FileNotFoundError("Could not find any annotation files")
280
+
281
+ # Load annotation file
282
+ with open(ann_file_path, 'r') as f:
283
+ ann = json.load(f)
284
+
285
+ print(f"πŸ“Š Loading from {ann_file_path}")
286
+ print(f" β€’ Images: {len(ann.get('images', []))}")
287
+ print(f" β€’ Annotations: {len(ann.get('annotations', []))}")
288
+
289
+ # Build image lookup
290
+ img_id_to_info = {img['id']: img for img in ann['images']}
291
+
292
+ # Group annotations by image
293
+ img_id_to_anns = {}
294
+ for ann_data in ann.get('annotations', []):
295
+ img_id = ann_data['image_id']
296
+ if img_id not in img_id_to_anns:
297
+ img_id_to_anns[img_id] = []
298
+ img_id_to_anns[img_id].append(ann_data)
299
+
300
+ # Create data list with enhanced metadata
301
+ data_list = []
302
+ for img_id, img_info in img_id_to_info.items():
303
+ annotations = img_id_to_anns.get(img_id, [])
304
+
305
+ # Skip images without annotations if filter_empty_gt is enabled
306
+ if not annotations and self.filter_cfg.get('filter_empty_gt', False):
307
+ continue
308
+
309
+ # Convert annotations to instances format
310
+ instances = []
311
+ for ann in annotations:
312
+ bbox = ann['bbox'] # [x, y, width, height]
313
+ # Convert to [x1, y1, x2, y2] format for MMDet
314
+ bbox_xyxy = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
315
+
316
+ instance = {
317
+ 'bbox': bbox_xyxy,
318
+ 'bbox_label': ann['category_id'],
319
+ 'ignore_flag': 0,
320
+ 'annotation_id': ann.get('id', -1),
321
+ 'area': ann.get('area', bbox[2] * bbox[3]),
322
+ 'element_type': ann.get('element_type', 'unknown')
323
+ }
324
+
325
+ # Add additional annotation metadata if available
326
+ for key in ['text', 'role', 'data_point', 'chart_type', 'total_data_points']:
327
+ if key in ann:
328
+ instance[key] = ann[key]
329
+
330
+ instances.append(instance)
331
+
332
+ # Create data info with enhanced metadata
333
+ # Fix: Construct full image path using data_prefix (like standard MMDet datasets)
334
+ filename = img_info['file_name']
335
+ if self.data_prefix.get('img'):
336
+ img_path = osp.join(self.data_prefix['img'], filename)
337
+ else:
338
+ img_path = filename # Fallback to original filename
339
+
340
+ data_info = {
341
+ 'img_id': img_info['id'],
342
+ 'img_path': img_path, # Use constructed path
343
+ 'height': img_info['height'],
344
+ 'width': img_info['width'],
345
+ 'instances': instances,
346
+ # Enhanced metadata from enriched annotations
347
+ 'chart_type': img_info.get('chart_type', ''),
348
+ 'plot_bb': img_info.get('plot_bb', {}),
349
+ 'data_series': img_info.get('data_series', []),
350
+ 'data_series_stats': img_info.get('data_series_stats', {}),
351
+ 'axes_info': img_info.get('axes_info', {}),
352
+ 'element_counts': img_info.get('element_counts', {}),
353
+ 'source': img_info.get('source', 'unknown')
354
+ }
355
+
356
+ data_list.append(data_info)
357
+
358
+ print(f"βœ… Loaded {len(data_list)} images with enhanced metadata")
359
+ return data_list
360
+
361
+ def parse_data_info(self, raw_data_info):
362
+ """Parse data info with enhanced metadata support."""
363
+ d = raw_data_info.copy()
364
+
365
+ # Debug logging for first few images to verify path construction
366
+ if hasattr(self, '_debug_count'):
367
+ self._debug_count += 1
368
+ else:
369
+ self._debug_count = 1
370
+
371
+ if self._debug_count <= 3:
372
+ print(f"πŸ” Path verification debug #{self._debug_count}:")
373
+ print(f" β€’ img_path from load_data_list: {d['img_path']}")
374
+ print(f" β€’ data_root: {getattr(self, 'data_root', 'None')}")
375
+ full_path = osp.join(self.data_root, d['img_path']) if hasattr(self, 'data_root') else d['img_path']
376
+ print(f" β€’ Full absolute path: {full_path}")
377
+ print(f" β€’ Path exists: {osp.exists(full_path)}")
378
+
379
+ # Create or get image information
380
+ img_h, img_w = d['height'], d['width']
381
+
382
+ # Get class names for class-specific filtering
383
+ class_names = self.METAINFO['classes']
384
+
385
+ # Get filter configuration
386
+ min_size = self.filter_cfg.get('min_size', 1)
387
+ class_specific_min_sizes = self.filter_cfg.get('class_specific_min_sizes', {})
388
+
389
+ # Handle bboxes and labels from instances with enhanced filtering
390
+ bboxes, labels = [], []
391
+ filtered_count = 0
392
+ enlarged_count = 0
393
+ chart_type_filtered_count = 0
394
+
395
+ # Get chart type for filtering
396
+ chart_type = d.get('chart_type', '').lower()
397
+ chart_mapping = self.CHART_TYPE_ELEMENT_MAPPING.get(chart_type, {})
398
+ allowed_data_elements = chart_mapping.get('allowed_data_elements', set())
399
+ forbidden_data_elements = chart_mapping.get('forbidden_data_elements', set())
400
+
401
+ for inst in d.get('instances', []):
402
+ bbox = inst['bbox']
403
+ label_id = inst['bbox_label']
404
+
405
+ # Get class name for this label
406
+ class_name = class_names[label_id] if 0 <= label_id < len(class_names) else 'unknown'
407
+
408
+ # Chart-type specific filtering: Skip forbidden data elements
409
+ if chart_type and class_name in forbidden_data_elements:
410
+ chart_type_filtered_count += 1
411
+ if self._debug_count <= 3 and chart_type_filtered_count <= 3:
412
+ print(f" 🚫 Filtered {class_name} from {chart_type} chart (inappropriate data element)")
413
+ continue
414
+
415
+ # Chart-type specific validation: Log allowed data elements
416
+ if chart_type and class_name in allowed_data_elements:
417
+ if self._debug_count <= 3:
418
+ print(f" βœ… Keeping {class_name} for {chart_type} chart (appropriate data element)")
419
+
420
+ # Validate and clamp bbox
421
+ x1, y1, x2, y2 = bbox
422
+ x1 = max(0, min(x1, img_w))
423
+ y1 = max(0, min(y1, img_h))
424
+ x2 = max(x1, min(x2, img_w))
425
+ y2 = max(y1, min(y2, img_h))
426
+
427
+ # Skip invalid bboxes
428
+ if x2 <= x1 or y2 <= y1:
429
+ filtered_count += 1
430
+ continue
431
+
432
+ # Calculate current bbox dimensions
433
+ bbox_w = x2 - x1
434
+ bbox_h = y2 - y1
435
+ bbox_min_dim = min(bbox_w, bbox_h)
436
+
437
+ # Check class-specific minimum size
438
+ required_min_size = class_specific_min_sizes.get(class_name, min_size)
439
+
440
+ # If bbox is smaller than required, enlarge it to meet the minimum size
441
+ if bbox_min_dim < required_min_size:
442
+ # Calculate expansion needed
443
+ expand_w = max(0, required_min_size - bbox_w) / 2
444
+ expand_h = max(0, required_min_size - bbox_h) / 2
445
+
446
+ # Expand bbox while keeping it within image bounds
447
+ new_x1 = max(0, x1 - expand_w)
448
+ new_y1 = max(0, y1 - expand_h)
449
+ new_x2 = min(img_w, x2 + expand_w)
450
+ new_y2 = min(img_h, y2 + expand_h)
451
+
452
+ # Update bbox coordinates
453
+ x1, y1, x2, y2 = new_x1, new_y1, new_x2, new_y2
454
+ enlarged_count += 1
455
+
456
+ if self._debug_count <= 3 and enlarged_count <= 3:
457
+ print(f" πŸ“ Enlarged {class_name} bbox: {bbox_w:.1f}x{bbox_h:.1f} β†’ {(x2-x1):.1f}x{(y2-y1):.1f}")
458
+
459
+ bboxes.append([x1, y1, x2, y2])
460
+ labels.append(label_id)
461
+
462
+ # Log filtering and enlargement statistics for first few images
463
+ if self._debug_count <= 3:
464
+ print(f" πŸ“Š Bbox processing: {len(bboxes)} kept, {filtered_count} filtered (invalid), {chart_type_filtered_count} filtered (chart-type), {enlarged_count} enlarged")
465
+ if chart_type:
466
+ print(f" πŸ“ˆ Chart type: {chart_type} | Allowed data elements: {allowed_data_elements}")
467
+ if forbidden_data_elements:
468
+ print(f" 🚫 Forbidden data elements for {chart_type}: {forbidden_data_elements}")
469
+
470
+ # Convert to arrays
471
+ d['gt_bboxes'] = np.array(bboxes, dtype=np.float32) if bboxes else np.zeros((0, 4), dtype=np.float32)
472
+ d['gt_bboxes_labels'] = np.array(labels, dtype=np.int64) if labels else np.zeros((0,), dtype=np.int64)
473
+
474
+ # Enhanced scale factor calculation using data_series_stats
475
+ d['scale_factor'] = np.array([1.0, 1.0], dtype=np.float32)
476
+
477
+ # Use enhanced metadata for better scale factor calculation
478
+ data_series_stats = d.get('data_series_stats', {})
479
+ plot_bb = d.get('plot_bb', {})
480
+
481
+ if data_series_stats and plot_bb and all(k in plot_bb for k in ['width', 'height']):
482
+ x_range = data_series_stats.get('x_range')
483
+ y_range = data_series_stats.get('y_range')
484
+
485
+ if x_range and len(x_range) == 2 and x_range[1] != x_range[0]:
486
+ d['scale_factor'][0] = plot_bb['width'] / (x_range[1] - x_range[0])
487
+ if y_range and len(y_range) == 2 and y_range[1] != y_range[0]:
488
+ d['scale_factor'][1] = -plot_bb['height'] / (y_range[1] - y_range[0])
489
+
490
+ # Required MMDet fields
491
+ d.update({
492
+ 'img_shape': (img_h, img_w, 3),
493
+ 'ori_shape': (img_h, img_w, 3),
494
+ 'pad_shape': (img_h, img_w, 3),
495
+ 'flip': False,
496
+ 'flip_direction': None,
497
+ 'img_fields': ['img'],
498
+ 'bbox_fields': ['bbox'],
499
+ })
500
+
501
+ # Additional metadata for training
502
+ d['img_info'] = {
503
+ 'height': img_h,
504
+ 'width': img_w,
505
+ 'img_shape': d['img_shape'],
506
+ 'ori_shape': d['ori_shape'],
507
+ 'pad_shape': d['pad_shape'],
508
+ 'scale_factor': d['scale_factor'].copy(),
509
+ 'flip': d['flip'],
510
+ 'flip_direction': d['flip_direction'],
511
+ # Enhanced metadata
512
+ 'chart_type': d.get('chart_type', ''),
513
+ 'num_data_points': data_series_stats.get('num_data_points', 0),
514
+ 'element_counts': d.get('element_counts', {})
515
+ }
516
+
517
+ return d
518
+
519
+ def print_missing_image_summary():
520
+ """Print summary of missing images."""
521
+ count = RobustLoadImageFromFile.get_missing_count()
522
+ if count > 0:
523
+ print(f"πŸ“Š MISSING IMAGES SUMMARY: {count} images were not found and replaced with dummy images")
524
+ else:
525
+ print("βœ… All images loaded successfully!")
526
+
527
+ def print_dataset_summary():
528
+ """Print summary of dataset configuration."""
529
+ print("πŸ“Š ENHANCED CHART DATASET SUMMARY:")
530
+ print(f" β€’ 21 categories supported for comprehensive chart element detection")
531
+ print(f" β€’ Auto-detects best annotation files (enriched_with_info > enriched > regular)")
532
+ print(f" β€’ Enhanced metadata: chart_type, data_series_stats, element_counts, axes_info")
533
+ print(f" β€’ Robust image loading with fallback to dummy images")
534
+ print(f" β€’ Multiple annotations per image (not just plot areas)")
535
+
536
+ print("βœ… [PLUGIN] Enhanced ChartDataset + transforms registered!")
537
+ print_dataset_summary()
custom_models/custom_faster_rcnn_with_meta.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom_faster_rcnn_with_meta.py - Faster R-CNN with coordinate handling for chart data
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmdet.models.detectors.faster_rcnn import FasterRCNN
5
+ from mmdet.registry import MODELS
6
+
7
+
8
+ @MODELS.register_module()
9
+ class CustomFasterRCNNWithMeta(FasterRCNN):
10
+ """Faster R-CNN with coordinate standardization for chart detection."""
11
+
12
+ def __init__(self,
13
+ *args,
14
+ coordinate_standardization=None,
15
+ data_points_count_head=None,
16
+ **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ # Store coordinate standardization settings
20
+ self.coord_std = coordinate_standardization or {}
21
+
22
+ # Initialize data points count head
23
+ if data_points_count_head is not None:
24
+ self.data_points_count_head = MODELS.build(data_points_count_head)
25
+ else:
26
+ # Default simple regression head for data point count
27
+ self.data_points_count_head = nn.Sequential(
28
+ nn.Linear(2048, 512), # Assuming ResNet-50 backbone features
29
+ nn.ReLU(),
30
+ nn.Dropout(0.1),
31
+ nn.Linear(512, 1) # Single output for count
32
+ )
33
+
34
+ print(f"🎯 CustomFasterRCNNWithMeta initialized with coordinate handling:")
35
+ print(f" β€’ Enabled: {self.coord_std.get('enabled', False)}")
36
+ print(f" β€’ Origin: {self.coord_std.get('origin', 'top_left')}")
37
+ print(f" β€’ Normalize: {self.coord_std.get('normalize', False)}")
38
+ print(f" β€’ Data points count prediction: Enabled")
39
+
40
+ def transform_coordinates(self, coords, img_shape, plot_bb=None, axes_info=None):
41
+ """Transform coordinates based on standardization settings."""
42
+ if not self.coord_std.get('enabled', False):
43
+ return coords
44
+
45
+ # Get image dimensions
46
+ img_height, img_width = img_shape[-2:]
47
+
48
+ # Convert to tensor if not already
49
+ if not isinstance(coords, torch.Tensor):
50
+ coords = torch.tensor(coords, device=img_shape.device if hasattr(img_shape, 'device') else 'cpu')
51
+
52
+ # Ensure coords is 2D
53
+ if coords.dim() == 1:
54
+ coords = coords.view(-1, 2)
55
+
56
+ # Normalize coordinates if needed
57
+ if self.coord_std.get('normalize', True):
58
+ coords = coords / torch.tensor([img_width, img_height], device=coords.device)
59
+
60
+ # Handle bottom-left to top-left origin conversion
61
+ if self.coord_std.get('origin', 'bottom_left') == 'bottom_left':
62
+ # Flip y-coordinates to convert from bottom-left to top-left origin
63
+ coords[:, 1] = 1.0 - coords[:, 1]
64
+
65
+ # Convert back to pixel coordinates
66
+ if self.coord_std.get('normalize', True):
67
+ coords = coords * torch.tensor([img_width, img_height], device=coords.device)
68
+
69
+ return coords
70
+
71
+ def forward_train(self,
72
+ img,
73
+ img_metas,
74
+ gt_bboxes,
75
+ gt_labels,
76
+ gt_bboxes_ignore=None,
77
+ **kwargs):
78
+ """Forward function during training with coordinate transformation."""
79
+
80
+ # Transform ground truth bboxes if coordinate standardization is enabled
81
+ if self.coord_std.get('enabled', False) and gt_bboxes is not None:
82
+ transformed_gt_bboxes = []
83
+ for i, bboxes in enumerate(gt_bboxes):
84
+ if len(bboxes) > 0:
85
+ # Convert bbox format for transformation
86
+ # MMDet uses [x1, y1, x2, y2] format
87
+ centers = torch.stack([
88
+ (bboxes[:, 0] + bboxes[:, 2]) / 2, # center_x
89
+ (bboxes[:, 1] + bboxes[:, 3]) / 2 # center_y
90
+ ], dim=1)
91
+
92
+ # Transform centers
93
+ img_shape = img.shape if hasattr(img, 'shape') else (img_metas[i]['img_shape'][0], img_metas[i]['img_shape'][1])
94
+ transformed_centers = self.transform_coordinates(
95
+ centers, img_shape,
96
+ plot_bb=img_metas[i].get('plot_bb'),
97
+ axes_info=img_metas[i].get('axes_info')
98
+ )
99
+
100
+ # Reconstruct bboxes with transformed centers
101
+ widths = bboxes[:, 2] - bboxes[:, 0]
102
+ heights = bboxes[:, 3] - bboxes[:, 1]
103
+
104
+ transformed_bboxes = torch.stack([
105
+ transformed_centers[:, 0] - widths / 2, # x1
106
+ transformed_centers[:, 1] - heights / 2, # y1
107
+ transformed_centers[:, 0] + widths / 2, # x2
108
+ transformed_centers[:, 1] + heights / 2 # y2
109
+ ], dim=1)
110
+
111
+ transformed_gt_bboxes.append(transformed_bboxes)
112
+ else:
113
+ transformed_gt_bboxes.append(bboxes)
114
+
115
+ gt_bboxes = transformed_gt_bboxes
116
+
117
+ # Call parent forward_train with transformed coordinates to get losses
118
+ losses = super().forward_train(
119
+ img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore, **kwargs)
120
+
121
+ # Extract features for data point count prediction
122
+ x = self.extract_feat(img)
123
+ global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
124
+
125
+ # Extract ground truth data point counts from img_metas
126
+ gt_data_point_counts = []
127
+ for img_meta in img_metas:
128
+ count = img_meta.get('img_info', {}).get('num_data_points', 0)
129
+ gt_data_point_counts.append(count)
130
+ gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
131
+
132
+ # Predict data point counts and compute loss
133
+ pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
134
+ data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
135
+ losses['data_points_count_loss'] = data_points_count_loss
136
+
137
+ return losses
138
+
139
+ def simple_test(self, img, img_metas, proposals=None, rescale=False):
140
+ """Simple test function with coordinate inverse transformation."""
141
+ # Get predictions from parent
142
+ results = super().simple_test(img, img_metas, proposals, rescale)
143
+
144
+ # Extract features for data point count prediction
145
+ x = self.extract_feat(img)
146
+ global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
147
+
148
+ # Predict data point counts
149
+ pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
150
+
151
+ # Add data point count predictions to results
152
+ if results is not None:
153
+ for i, result in enumerate(results):
154
+ if hasattr(result, 'pred_instances'):
155
+ result.pred_instances.predicted_data_points = pred_data_point_counts[i].item()
156
+ elif hasattr(result, 'bboxes'):
157
+ # For older MMDet versions, add as additional attribute
158
+ result.predicted_data_points = pred_data_point_counts[i].item()
159
+
160
+ # Inverse transform predictions if coordinate standardization is enabled
161
+ if self.coord_std.get('enabled', False) and results is not None:
162
+ # Note: For simplicity, we're not doing inverse transform in test
163
+ # The coordinate system should be consistent during training
164
+ pass
165
+
166
+ return results
custom_models/custom_heads.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from mmdet.registry import MODELS
5
+
6
+ @MODELS.register_module()
7
+ class FCHead(nn.Module):
8
+ """Enhanced fully connected head for classification tasks with attention."""
9
+
10
+ def __init__(self, in_channels, num_classes, loss=None):
11
+ super().__init__()
12
+ self.attention = nn.MultiheadAttention(in_channels, num_heads=8)
13
+ self.fc1 = nn.Linear(in_channels, in_channels // 2)
14
+ self.fc2 = nn.Linear(in_channels // 2, num_classes)
15
+ self.loss = loss
16
+
17
+ def forward(self, x):
18
+ # Apply self-attention
19
+ x = self.attention(x, x, x)[0]
20
+ # Apply MLP
21
+ x = F.relu(self.fc1(x))
22
+ return self.fc2(x)
23
+
24
+ @MODELS.register_module()
25
+ class RegHead(nn.Module):
26
+ """Enhanced regression head for coordinate prediction with distance-based loss."""
27
+
28
+ def __init__(self, in_channels, out_dims, max_points=None, loss=None, attention=False, use_axis_info=False):
29
+ super().__init__()
30
+ self.fc = nn.Linear(in_channels, out_dims)
31
+ self.max_points = max_points
32
+ self.loss = loss
33
+ self.attention = attention
34
+ self.use_axis_info = use_axis_info
35
+
36
+ if attention:
37
+ self.attention_layer = nn.MultiheadAttention(in_channels, num_heads=8)
38
+
39
+ # Add axis orientation detection
40
+ if use_axis_info:
41
+ self.axis_orientation = nn.Linear(in_channels, 2) # 2 for x/y axis orientation
42
+
43
+ def compute_distance_loss(self, pred_points, gt_points):
44
+ """Compute distance-based loss between predicted and ground truth points."""
45
+ # Ensure points are in the same format
46
+ if pred_points.dim() == 2:
47
+ pred_points = pred_points.unsqueeze(0)
48
+ if gt_points.dim() == 2:
49
+ gt_points = gt_points.unsqueeze(0)
50
+
51
+ # Compute pairwise distances
52
+ dist = torch.cdist(pred_points, gt_points)
53
+
54
+ # Get minimum distance for each point
55
+ min_dist, _ = torch.min(dist, dim=2)
56
+
57
+ # Compute loss (using smooth L1 loss for robustness)
58
+ return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist))
59
+
60
+ def forward(self, x):
61
+ if self.attention:
62
+ x = self.attention_layer(x, x, x)[0]
63
+
64
+ # Get base predictions
65
+ pred = self.fc(x)
66
+
67
+ # If using axis info, also predict axis orientation
68
+ if self.use_axis_info:
69
+ axis_orientation = self.axis_orientation(x)
70
+ return pred, axis_orientation
71
+
72
+ return pred
73
+
74
+ class CoordinateTransformer:
75
+ """Helper class to transform coordinates between different spaces."""
76
+
77
+ @staticmethod
78
+ def to_axis_relative(points, axis_info):
79
+ """Transform points to be relative to axis coordinates.
80
+
81
+ Args:
82
+ points (torch.Tensor): Points in image coordinates (N, 2)
83
+ axis_info (torch.Tensor): Axis information [x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale]
84
+ """
85
+ # Extract axis information
86
+ x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1)
87
+
88
+ # Normalize to [0, 1] range
89
+ x_norm = (points[..., 0] - x_min) / (x_max - x_min)
90
+ y_norm = (points[..., 1] - y_min) / (y_max - y_min)
91
+
92
+ # Scale to axis units
93
+ x_axis = x_norm * x_scale + x_origin
94
+ y_axis = y_norm * y_scale + y_origin
95
+
96
+ return torch.stack([x_axis, y_axis], dim=-1)
97
+
98
+ @staticmethod
99
+ def to_image_coordinates(points, axis_info):
100
+ """Transform points from axis coordinates to image coordinates."""
101
+ # Extract axis information
102
+ x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1)
103
+
104
+ # Convert from axis units to normalized coordinates
105
+ x_norm = (points[..., 0] - x_origin) / x_scale
106
+ y_norm = (points[..., 1] - y_origin) / y_scale
107
+
108
+ # Convert to image coordinates
109
+ x_img = x_norm * (x_max - x_min) + x_min
110
+ y_img = y_norm * (y_max - y_min) + y_min
111
+
112
+ return torch.stack([x_img, y_img], dim=-1)
113
+
114
+ @MODELS.register_module()
115
+ class DataSeriesHead(nn.Module):
116
+ """Specialized head for data series prediction with dual attention to coordinates and axis-relative positions."""
117
+
118
+ def __init__(self, in_channels, max_points=50, loss=None):
119
+ super().__init__()
120
+ self.max_points = max_points
121
+ self.loss = loss
122
+
123
+ # Feature extraction
124
+ self.fc1 = nn.Linear(in_channels, in_channels // 2)
125
+
126
+ # Separate branches for absolute and relative coordinates
127
+ self.absolute_branch = nn.Sequential(
128
+ nn.Linear(in_channels // 2, in_channels // 4),
129
+ nn.ReLU(),
130
+ nn.Linear(in_channels // 4, max_points * 2) # 2 coordinates per point
131
+ )
132
+
133
+ self.relative_branch = nn.Sequential(
134
+ nn.Linear(in_channels // 2, in_channels // 4),
135
+ nn.ReLU(),
136
+ nn.Linear(in_channels // 4, max_points * 2) # 2 coordinates per point
137
+ )
138
+
139
+ # Attention mechanisms
140
+ self.coord_attention = nn.MultiheadAttention(in_channels, num_heads=8)
141
+ self.axis_attention = nn.MultiheadAttention(in_channels, num_heads=8)
142
+ self.sequence_attention = nn.MultiheadAttention(in_channels, num_heads=8)
143
+
144
+ # Sequence-aware processing
145
+ self.sequence_encoder = nn.TransformerEncoder(
146
+ nn.TransformerEncoderLayer(
147
+ d_model=in_channels,
148
+ nhead=8,
149
+ dim_feedforward=in_channels * 4,
150
+ dropout=0.1
151
+ ),
152
+ num_layers=2
153
+ )
154
+
155
+ # Pattern recognition
156
+ self.pattern_recognizer = nn.Sequential(
157
+ nn.Linear(in_channels, in_channels // 2),
158
+ nn.ReLU(),
159
+ nn.Linear(in_channels // 2, 5) # 5 for different chart patterns
160
+ )
161
+
162
+ # Coordinate transformer
163
+ self.coord_transformer = CoordinateTransformer()
164
+
165
+ def check_monotonicity(self, points, chart_type):
166
+ """Check if points follow expected monotonicity based on chart type."""
167
+ if chart_type in ['line', 'scatter']:
168
+ # For line/scatter, check if points are generally increasing or decreasing
169
+ diffs = points[..., 1].diff()
170
+ return torch.all(diffs >= 0) or torch.all(diffs <= 0)
171
+ return True
172
+
173
+ def forward(self, x, axis_info=None, chart_type=None):
174
+ # Apply coordinate attention
175
+ coord_feat = self.coord_attention(x, x, x)[0]
176
+
177
+ # Apply axis attention if axis info is available
178
+ if axis_info is not None:
179
+ axis_feat = self.axis_attention(x, x, x)[0]
180
+ # Combine features
181
+ x = coord_feat + axis_feat
182
+ else:
183
+ x = coord_feat
184
+
185
+ # Apply sequence attention
186
+ seq_feat = self.sequence_attention(x, x, x)[0]
187
+ x = x + seq_feat
188
+
189
+ # Process through sequence encoder
190
+ x = self.sequence_encoder(x.unsqueeze(0)).squeeze(0)
191
+
192
+ # Extract base features
193
+ x = F.relu(self.fc1(x))
194
+
195
+ # Get predictions from both branches
196
+ absolute_points = self.absolute_branch(x)
197
+ relative_points = self.relative_branch(x)
198
+
199
+ # Reshape to (batch_size, max_points, 2)
200
+ absolute_points = absolute_points.view(-1, self.max_points, 2)
201
+ relative_points = relative_points.view(-1, self.max_points, 2)
202
+
203
+ # If axis information is provided, transform relative points
204
+ if axis_info is not None:
205
+ relative_points = self.coord_transformer.to_axis_relative(relative_points, axis_info)
206
+
207
+ # Get pattern prediction
208
+ pattern_logits = self.pattern_recognizer(x)
209
+
210
+ # Check monotonicity if chart type is provided
211
+ if chart_type is not None:
212
+ monotonicity = self.check_monotonicity(absolute_points, chart_type)
213
+ else:
214
+ monotonicity = None
215
+
216
+ return absolute_points, relative_points, pattern_logits, monotonicity
217
+
218
+ def compute_loss(self, pred_absolute, pred_relative, gt_absolute, gt_relative,
219
+ pattern_logits, gt_pattern, axis_info=None, chart_type=None):
220
+ """Compute combined loss for both absolute and relative coordinates."""
221
+ # Ensure points are in the same format
222
+ if pred_absolute.dim() == 2:
223
+ pred_absolute = pred_absolute.unsqueeze(0)
224
+ if pred_relative.dim() == 2:
225
+ pred_relative = pred_relative.unsqueeze(0)
226
+ if gt_absolute.dim() == 2:
227
+ gt_absolute = gt_absolute.unsqueeze(0)
228
+ if gt_relative.dim() == 2:
229
+ gt_relative = gt_relative.unsqueeze(0)
230
+
231
+ # Compute absolute coordinate loss
232
+ absolute_loss = self.compute_distance_loss(pred_absolute, gt_absolute)
233
+
234
+ # Compute relative coordinate loss
235
+ if axis_info is not None:
236
+ # Transform predicted absolute points to relative coordinates
237
+ pred_absolute_relative = self.coord_transformer.to_axis_relative(pred_absolute, axis_info)
238
+ relative_loss = self.compute_distance_loss(pred_absolute_relative, gt_relative)
239
+ else:
240
+ relative_loss = torch.tensor(0.0, device=pred_absolute.device)
241
+
242
+ # Compute pattern recognition loss
243
+ pattern_loss = F.cross_entropy(pattern_logits, gt_pattern)
244
+
245
+ # Add monotonicity penalty if applicable
246
+ if chart_type is not None:
247
+ monotonicity = self.check_monotonicity(pred_absolute, chart_type)
248
+ monotonicity_loss = F.binary_cross_entropy(monotonicity.float(), torch.ones_like(monotonicity.float()))
249
+ else:
250
+ monotonicity_loss = torch.tensor(0.0, device=pred_absolute.device)
251
+
252
+ # Combine losses with weights
253
+ total_loss = (absolute_loss + relative_loss +
254
+ 0.5 * pattern_loss + 0.3 * monotonicity_loss)
255
+
256
+ return total_loss
257
+
258
+ def compute_distance_loss(self, pred_points, gt_points):
259
+ """Compute distance-based loss between predicted and ground truth points."""
260
+ # Compute pairwise distances
261
+ dist = torch.cdist(pred_points, gt_points)
262
+
263
+ # Get minimum distance for each point
264
+ min_dist, _ = torch.min(dist, dim=2)
265
+
266
+ # Compute loss (using smooth L1 loss for robustness)
267
+ return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist))
custom_models/flexible_load_annotations.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict, Optional
3
+ from mmcv.transforms.base import BaseTransform
4
+ from mmdet.registry import TRANSFORMS
5
+ from mmdet.datasets.transforms.loading import LoadAnnotations
6
+ import logging
7
+ from mmdet.structures.mask import BitmapMasks
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @TRANSFORMS.register_module()
12
+ class FlexibleLoadAnnotations(LoadAnnotations):
13
+ """
14
+ Flexible annotation loader that handles mixed mask/bbox datasets.
15
+ """
16
+
17
+ def __init__(self,
18
+ with_bbox: bool = True,
19
+ with_mask: bool = True,
20
+ with_seg: bool = False,
21
+ poly2mask: bool = True,
22
+ **kwargs):
23
+ super().__init__(
24
+ with_bbox=with_bbox,
25
+ with_mask=with_mask,
26
+ with_seg=with_seg,
27
+ poly2mask=poly2mask,
28
+ **kwargs
29
+ )
30
+ self.mask_stats = {'total': 0, 'with_masks': 0, 'without_masks': 0}
31
+
32
+ def _load_masks(self, results: dict) -> dict:
33
+ """Load mask annotations from COCO format instances."""
34
+ if not self.with_mask or not isinstance(results, dict):
35
+ return results
36
+
37
+ # Check for ann_info format (what COCO dataset actually provides)
38
+ ann_info = results.get('ann_info')
39
+ if isinstance(ann_info, dict):
40
+ # Check if segmentation is in ann_info
41
+ if 'segmentation' in ann_info:
42
+ segmentation = ann_info['segmentation']
43
+ if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
44
+ # Convert to mask format
45
+ ann_info['masks'] = segmentation
46
+ return super()._load_masks(results)
47
+
48
+ # Check for polygon data in ann_info
49
+ if 'polygon' in ann_info:
50
+ polygon = ann_info['polygon']
51
+ if polygon and isinstance(polygon, dict):
52
+ try:
53
+ # Convert polygon to COCO segmentation format
54
+ coords = []
55
+ for j in range(4): # Assuming 4-point polygons
56
+ x_key = f'x{j}'
57
+ y_key = f'y{j}'
58
+ if x_key in polygon and y_key in polygon:
59
+ coords.extend([polygon[x_key], polygon[y_key]])
60
+
61
+ if len(coords) >= 6: # Need at least 3 points (6 coordinates)
62
+ # Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
63
+ segmentation = [coords]
64
+ ann_info['segmentation'] = segmentation
65
+ ann_info['masks'] = segmentation
66
+ return super()._load_masks(results)
67
+ except Exception as e:
68
+ logger.debug(f"Polygon conversion failed: {e}")
69
+
70
+ # Handle COCO format: instances with segmentation
71
+ instances = results.get('instances')
72
+ if isinstance(instances, list):
73
+ # Process ALL instances - keep both with and without masks
74
+ valid_instances = []
75
+
76
+ for i, instance in enumerate(instances):
77
+ self.mask_stats['total'] += 1
78
+
79
+ # Check for segmentation in COCO format (COCO dataset stores it in 'mask' field)
80
+ segmentation = instance.get('mask') or instance.get('segmentation')
81
+ if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
82
+ # Handle nested list format: [[x1, y1, x2, y2, ...]]
83
+ if isinstance(segmentation[0], list):
84
+ # Nested format - check if inner list has enough coordinates
85
+ inner_seg = segmentation[0]
86
+ if len(inner_seg) >= 6: # Need at least 3 points (6 coordinates)
87
+ instance['mask'] = segmentation # Keep original nested format for parent
88
+ valid_instances.append(instance)
89
+ self.mask_stats['with_masks'] += 1
90
+ else:
91
+ # Keep instance for bbox training even without valid mask
92
+ instance['mask'] = []
93
+ valid_instances.append(instance)
94
+ self.mask_stats['without_masks'] += 1
95
+ else:
96
+ # Flat format - already correct
97
+ instance['mask'] = segmentation
98
+ valid_instances.append(instance)
99
+ self.mask_stats['with_masks'] += 1
100
+ else:
101
+ # Check for polygon data and convert to segmentation
102
+ polygon = instance.get('polygon')
103
+ if polygon and isinstance(polygon, dict):
104
+ # Convert polygon to COCO segmentation format
105
+ try:
106
+ # Extract polygon coordinates
107
+ coords = []
108
+ for j in range(4): # Assuming 4-point polygons
109
+ x_key = f'x{j}'
110
+ y_key = f'y{j}'
111
+ if x_key in polygon and y_key in polygon:
112
+ coords.extend([polygon[x_key], polygon[y_key]])
113
+
114
+ if len(coords) >= 6: # Need at least 3 points (6 coordinates)
115
+ # Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
116
+ segmentation = [coords]
117
+ instance['segmentation'] = segmentation
118
+ instance['mask'] = segmentation
119
+ valid_instances.append(instance)
120
+ self.mask_stats['with_masks'] += 1
121
+ else:
122
+ # Keep instance for bbox training even without mask
123
+ # Add empty mask field to prevent KeyError in parent class
124
+ instance['mask'] = []
125
+ valid_instances.append(instance)
126
+ self.mask_stats['without_masks'] += 1
127
+ except Exception as e:
128
+ # Keep instance for bbox training even if polygon conversion fails
129
+ # Add empty mask field to prevent KeyError in parent class
130
+ instance['mask'] = []
131
+ valid_instances.append(instance)
132
+ self.mask_stats['without_masks'] += 1
133
+ else:
134
+ # Keep instance for bbox training even without segmentation
135
+ # Add empty mask field to prevent KeyError in parent class
136
+ instance['mask'] = []
137
+ valid_instances.append(instance)
138
+ self.mask_stats['without_masks'] += 1
139
+
140
+ # Update results with valid instances only
141
+ results['instances'] = valid_instances
142
+
143
+ # Call parent method to process the filtered instances
144
+ if valid_instances:
145
+ super()._load_masks(results) # Parent modifies results in place
146
+ return results
147
+ else:
148
+ # No valid masks, create empty mask structure
149
+ h, w = results.get('img_shape', (0, 0))
150
+ results['gt_masks'] = BitmapMasks([], h, w)
151
+ results['gt_ignore_flags'] = np.array([], dtype=bool)
152
+ return results
153
+
154
+ # Check for direct segmentation in results
155
+ if 'segmentation' in results:
156
+ segmentation = results['segmentation']
157
+ if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
158
+ results['masks'] = segmentation
159
+ return super()._load_masks(results)
160
+
161
+ return results
162
+
163
+ def transform(self, results: dict) -> dict:
164
+ """Transform function to load annotations."""
165
+ # ensure we always return a dict
166
+ if not isinstance(results, dict):
167
+ logger.error(f"Expected dict, got {type(results)}")
168
+ return {}
169
+
170
+ # Call parent transform to handle bbox loading
171
+ results = super().transform(results)
172
+
173
+ # Handle mask loading with our custom logic
174
+ results = self._load_masks(results)
175
+
176
+ # periodic logging
177
+ if self.mask_stats['total'] % 1000 == 0:
178
+ t = self.mask_stats['total']
179
+ w = self.mask_stats['with_masks']
180
+ wo = self.mask_stats['without_masks']
181
+ logger.info(f"Mask stats - total: {t}, with_masks: {w}, without_masks: {wo}")
182
+
183
+ return results
184
+
185
+ def __repr__(self) -> str:
186
+ """String representation."""
187
+ return (f'{self.__class__.__name__}('
188
+ f'with_bbox={self.with_bbox}, '
189
+ f'with_mask={self.with_mask}, '
190
+ f'with_seg={self.with_seg}, '
191
+ f'poly2mask={self.poly2mask})')
custom_models/mask_filter.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from mmcv.transforms.base import BaseTransform
3
+ from mmdet.registry import TRANSFORMS
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ @TRANSFORMS.register_module()
9
+ class MaskFilter(BaseTransform):
10
+ """Filter out images with no valid masks during training.
11
+
12
+ This transform checks if there are any valid masks in the image and
13
+ returns None if no masks are found, which will cause the image to be skipped.
14
+ """
15
+
16
+ def __init__(self, min_masks=1):
17
+ self.min_masks = min_masks
18
+
19
+ def transform(self, results):
20
+ """Filter results based on mask availability.
21
+
22
+ Args:
23
+ results (dict): Result dict from dataset.
24
+
25
+ Returns:
26
+ dict or None: Returns results if valid masks found, None otherwise.
27
+ """
28
+ # Check if we have valid masks
29
+ gt_masks = results.get('gt_masks')
30
+
31
+ if gt_masks is None:
32
+ logger.warning("MaskFilter: No gt_masks found, skipping image")
33
+ return None
34
+
35
+ # Count valid masks
36
+ if hasattr(gt_masks, 'masks'):
37
+ num_masks = len(gt_masks.masks)
38
+ elif hasattr(gt_masks, 'polygons'):
39
+ num_masks = len(gt_masks.polygons)
40
+ else:
41
+ num_masks = 0
42
+
43
+ if num_masks < self.min_masks:
44
+ logger.info(f"MaskFilter: Only {num_masks} masks found (min: {self.min_masks}), skipping image")
45
+ return None
46
+
47
+ logger.info(f"MaskFilter: {num_masks} masks found, keeping image for training")
48
+ return results
custom_models/nan_recovery_hook.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nan_recovery_hook.py - Graceful NaN loss recovery for Cascade R-CNN
2
+ import torch
3
+ from mmengine.hooks import Hook
4
+ from mmengine.runner import Runner
5
+ from mmdet.registry import HOOKS
6
+ from typing import Optional, Dict, Any
7
+
8
+
9
+ @HOOKS.register_module()
10
+ class NanRecoveryHook(Hook):
11
+ """Hook to handle NaN losses gracefully without crashing training.
12
+
13
+ This hook detects NaN losses and handles them by:
14
+ 1. Replacing NaN losses with the last valid loss value
15
+ 2. Skipping gradient updates for that iteration
16
+ 3. Logging the recovery for monitoring
17
+ 4. Allowing training to continue normally
18
+ """
19
+
20
+ def __init__(self,
21
+ fallback_loss: float = 0.5,
22
+ max_consecutive_nans: int = 100, # Increased from 10
23
+ log_interval: int = 50): # Log less frequently
24
+ self.fallback_loss = fallback_loss
25
+ self.max_consecutive_nans = max_consecutive_nans
26
+ self.log_interval = log_interval
27
+
28
+ # State tracking
29
+ self.last_valid_loss = fallback_loss
30
+ self.consecutive_nans = 0
31
+ self.total_nans = 0
32
+ self.nan_iterations = []
33
+
34
+ def before_train_iter(self,
35
+ runner: Runner,
36
+ batch_idx: int,
37
+ data_batch: Optional[dict] = None) -> None:
38
+ """Reset any state before training iteration."""
39
+ pass
40
+
41
+ def after_train_iter(self,
42
+ runner: Runner,
43
+ batch_idx: int,
44
+ data_batch: Optional[dict] = None,
45
+ outputs: Optional[Dict[str, Any]] = None) -> None:
46
+ """Handle NaN losses after training iteration."""
47
+ if outputs is None:
48
+ return
49
+
50
+ # Check ALL loss components for NaN, not just the main loss
51
+ has_nan = False
52
+
53
+ # Check main loss
54
+ total_loss = outputs.get('loss')
55
+ if total_loss is not None and (torch.isnan(total_loss) or torch.isinf(total_loss)):
56
+ has_nan = True
57
+
58
+ # Check all individual loss components
59
+ for key, value in outputs.items():
60
+ if isinstance(value, torch.Tensor) and 'loss' in key.lower():
61
+ if torch.isnan(value) or torch.isinf(value):
62
+ has_nan = True
63
+ break
64
+
65
+ if has_nan:
66
+ self._handle_nan_loss(runner, batch_idx, outputs)
67
+ else:
68
+ # Valid loss - update tracking
69
+ if total_loss is not None:
70
+ self.last_valid_loss = float(total_loss.item())
71
+ if self.consecutive_nans > 0:
72
+ runner.logger.info(f"πŸŽ‰ Loss recovered after {self.consecutive_nans} NaN iterations")
73
+ self.consecutive_nans = 0
74
+
75
+ def _handle_nan_loss(self, runner: Runner, batch_idx: int, outputs: Dict[str, Any]) -> None:
76
+ """Handle NaN loss by replacing with detached fallback and managing state."""
77
+ self.consecutive_nans += 1
78
+ self.total_nans += 1
79
+ self.nan_iterations.append(batch_idx)
80
+
81
+ # Try to get last good state from SkipBadSamplesHook if available
82
+ last_good_iteration = batch_idx
83
+ last_good_loss = self.last_valid_loss
84
+
85
+ for hook in runner.hooks:
86
+ if hasattr(hook, 'last_good_iteration') and hasattr(hook, 'last_good_loss'):
87
+ if hook.last_good_loss is not None:
88
+ last_good_iteration = hook.last_good_iteration
89
+ last_good_loss = hook.last_good_loss
90
+ break
91
+
92
+ # Replace NaN loss with detached fallback (no gradients = true no-op)
93
+ if 'loss' in outputs and outputs['loss'] is not None:
94
+ fallback_tensor = torch.tensor(
95
+ last_good_loss,
96
+ device=outputs['loss'].device,
97
+ dtype=outputs['loss'].dtype
98
+ # NOTE: No requires_grad=True - this makes it detached
99
+ )
100
+ outputs['loss'] = fallback_tensor
101
+
102
+ # Also fix individual loss components with detached tensors
103
+ self._fix_loss_components(outputs, last_good_loss)
104
+
105
+ # Log recovery with state info
106
+ if self.consecutive_nans <= 5 or self.consecutive_nans % self.log_interval == 0:
107
+ runner.logger.warning(
108
+ f"πŸ”„ NaN Recovery at iteration {batch_idx}: "
109
+ f"Using last good loss {last_good_loss:.4f} from iteration {last_good_iteration}. "
110
+ f"Consecutive NaNs: {self.consecutive_nans}, Total: {self.total_nans}"
111
+ )
112
+
113
+ # Reset training state if too many consecutive NaNs
114
+ if self.consecutive_nans >= self.max_consecutive_nans:
115
+ self._reset_nan_state(runner, last_good_iteration)
116
+
117
+ def _reset_nan_state(self, runner: Runner, last_good_iteration: int) -> None:
118
+ """Reset training state when too many consecutive NaNs."""
119
+ runner.logger.error(
120
+ f"πŸ”„ Too many consecutive NaN losses ({self.consecutive_nans}). "
121
+ f"Resetting to last good state from iteration {last_good_iteration}"
122
+ )
123
+
124
+ try:
125
+ # Clear model gradients
126
+ if hasattr(runner.model, 'zero_grad'):
127
+ runner.model.zero_grad()
128
+
129
+ # Clear CUDA cache
130
+ if torch.cuda.is_available():
131
+ torch.cuda.empty_cache()
132
+
133
+ # Reset consecutive counter
134
+ self.consecutive_nans = 0
135
+
136
+ runner.logger.info(f"βœ… NaN state reset. Resuming training...")
137
+
138
+ except Exception as e:
139
+ runner.logger.error(f"❌ Failed to reset NaN state: {e}")
140
+
141
+ def _fix_loss_components(self, outputs: Dict[str, Any], fallback_loss: float = None) -> None:
142
+ """Fix ALL loss components with detached tensors (no gradients)."""
143
+ if fallback_loss is None:
144
+ fallback_loss = self.last_valid_loss
145
+
146
+ fallback_small = max(0.01, fallback_loss * 0.1) # Ensure non-zero minimum
147
+
148
+ # Fix ALL tensors with 'loss' in the key name using detached tensors
149
+ for key, value in outputs.items():
150
+ if isinstance(value, torch.Tensor) and 'loss' in key.lower():
151
+ if torch.isnan(value) or torch.isinf(value):
152
+ # Create detached replacement tensor (no gradients)
153
+ replacement = torch.tensor(
154
+ fallback_small,
155
+ device=value.device,
156
+ dtype=value.dtype
157
+ # NOTE: No requires_grad=True - detached for true no-op
158
+ )
159
+ outputs[key] = replacement
160
+ print(f" πŸ”§ Fixed {key}: {value.item():.4f} -> detached {fallback_small:.4f}")
161
+
162
+ # Also fix any scalar values that might be NaN
163
+ for key, value in list(outputs.items()):
164
+ if isinstance(value, (int, float)) and 'loss' in key.lower():
165
+ if not torch.isfinite(torch.tensor(value)):
166
+ outputs[key] = fallback_small
167
+ print(f" πŸ”§ Fixed scalar {key}: {value} -> {fallback_small:.4f}")
168
+
169
+ def after_train_epoch(self, runner: Runner) -> None:
170
+ """Summary statistics after each epoch."""
171
+ if self.total_nans > 0:
172
+ runner.logger.info(
173
+ f"πŸ“Š NaN Recovery Summary for Epoch: "
174
+ f"{self.total_nans} NaN losses recovered. "
175
+ f"Training continued successfully."
176
+ )
177
+
178
+ # Reset for next epoch
179
+ self.consecutive_nans = 0
180
+ self.total_nans = 0
181
+ self.nan_iterations.clear()
custom_models/progressive_loss_hook.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # progressive_loss_hook.py - Progressive Loss Switching Hook for Cascade R-CNN
2
+ import torch
3
+ from mmengine.hooks import Hook
4
+ from mmdet.registry import HOOKS
5
+ from mmdet.models.losses import SmoothL1Loss, GIoULoss, CIoULoss, DIoULoss
6
+
7
+ @HOOKS.register_module()
8
+ class ProgressiveLossHook(Hook):
9
+ """
10
+ Progressive Loss Switching Hook for Cascade R-CNN.
11
+
12
+ Starts with SmoothL1Loss for all stages, then progressively switches
13
+ stage 3 (final stage) to GIoU/CIoU/DIoU after the model stabilizes.
14
+
15
+ Args:
16
+ switch_epoch (int): Epoch to switch stage 3 from SmoothL1 to target loss
17
+ target_loss_type (str): Target loss type for stage 3 ('GIoULoss', 'CIoULoss', or 'DIoULoss')
18
+ loss_weight (float): Loss weight for the new loss function
19
+ warmup_epochs (int): Number of epochs to gradually blend the losses
20
+ monitor_stage_weights (bool): Whether to log stage loss weights
21
+ nan_detection (bool): Whether to enable NaN detection and rollback
22
+ max_nan_tolerance (int): Maximum consecutive NaN losses before rollback
23
+ """
24
+
25
+ def __init__(self,
26
+ switch_epoch=5,
27
+ target_loss_type='GIoULoss',
28
+ loss_weight=1.0,
29
+ warmup_epochs=2,
30
+ monitor_stage_weights=True,
31
+ nan_detection=False,
32
+ max_nan_tolerance=5):
33
+ super().__init__()
34
+ self.switch_epoch = switch_epoch
35
+ self.target_loss_type = target_loss_type
36
+ self.loss_weight = loss_weight
37
+ self.warmup_epochs = warmup_epochs
38
+ self.monitor_stage_weights = monitor_stage_weights
39
+ self.nan_detection = nan_detection
40
+ self.max_nan_tolerance = max_nan_tolerance
41
+ self.switched = False
42
+ self.original_loss = None
43
+ self.consecutive_nans = 0
44
+ self.rollback_performed = False
45
+
46
+ def before_train_epoch(self, runner):
47
+ """Check if we should switch the loss function."""
48
+ current_epoch = runner.epoch
49
+
50
+ # Switch at the specified epoch
51
+ if current_epoch >= self.switch_epoch and not self.switched:
52
+ self._switch_stage2_loss(runner)
53
+ self.switched = True
54
+ runner.logger.info(
55
+ f"Epoch {current_epoch}: Switched Stage 3 loss to {self.target_loss_type}")
56
+
57
+ # Monitor during warmup period
58
+ elif current_epoch >= self.switch_epoch and current_epoch < self.switch_epoch + self.warmup_epochs:
59
+ if self.monitor_stage_weights:
60
+ self._log_loss_info(runner, current_epoch)
61
+
62
+ def _switch_stage2_loss(self, runner):
63
+ """Switch stage 3 bbox loss from SmoothL1 to target loss."""
64
+ model = runner.model
65
+
66
+ # Navigate to stage 3 bbox head (index 2) - final refinement stage
67
+ try:
68
+ # Handle DDP wrapper
69
+ if hasattr(model, 'module'):
70
+ bbox_head_stage2 = model.module.roi_head.bbox_head[2]
71
+ else:
72
+ bbox_head_stage2 = model.roi_head.bbox_head[2]
73
+
74
+ # Store original loss for comparison
75
+ self.original_loss = bbox_head_stage2.loss_bbox
76
+
77
+ # Create new loss function
78
+ if self.target_loss_type == 'GIoULoss':
79
+ new_loss = GIoULoss(loss_weight=self.loss_weight)
80
+ # Enable decoded bbox regression for IoU losses
81
+ bbox_head_stage2.reg_decoded_bbox = True
82
+ elif self.target_loss_type == 'CIoULoss':
83
+ new_loss = CIoULoss(loss_weight=self.loss_weight)
84
+ # Enable decoded bbox regression for IoU losses
85
+ bbox_head_stage2.reg_decoded_bbox = True
86
+ elif self.target_loss_type == 'DIoULoss':
87
+ new_loss = DIoULoss(loss_weight=self.loss_weight)
88
+ # Enable decoded bbox regression for IoU losses
89
+ bbox_head_stage2.reg_decoded_bbox = True
90
+ else:
91
+ raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
92
+
93
+ # Store the switch information with loss-specific benefits
94
+ if self.target_loss_type == 'CIoULoss':
95
+ runner.logger.info(f"🎯 CIoU Loss Benefits for Data Points:")
96
+ runner.logger.info(f" β€’ Directly optimizes center point distance")
97
+ runner.logger.info(f" β€’ Enforces aspect ratio consistency (square-ish data points)")
98
+ runner.logger.info(f" β€’ Better convergence for small objects")
99
+ runner.logger.info(f" β€’ Most complete bounding box quality metric")
100
+ elif self.target_loss_type == 'DIoULoss':
101
+ runner.logger.info(f"🎯 DIoU Loss Benefits for Data Points:")
102
+ runner.logger.info(f" β€’ Directly optimizes center point distance")
103
+ runner.logger.info(f" β€’ Better convergence for small objects")
104
+ runner.logger.info(f" β€’ More precise localization for data points")
105
+ elif self.target_loss_type == 'GIoULoss':
106
+ runner.logger.info(f"🎯 GIoU Loss Benefits:")
107
+ runner.logger.info(f" β€’ Improved IoU-based optimization")
108
+ runner.logger.info(f" β€’ Better than standard IoU loss")
109
+
110
+ # Replace the loss function
111
+ bbox_head_stage2.loss_bbox = new_loss
112
+
113
+ runner.logger.info(
114
+ f"Progressive Loss Switch: Stage 3 changed from "
115
+ f"{type(self.original_loss).__name__} to {self.target_loss_type}")
116
+
117
+ except Exception as e:
118
+ runner.logger.error(f"Failed to switch loss function: {e}")
119
+
120
+ def _log_loss_info(self, runner, epoch):
121
+ """Log information about current loss configuration."""
122
+ try:
123
+ model = runner.model
124
+ if hasattr(model, 'module'):
125
+ bbox_heads = model.module.roi_head.bbox_head
126
+ else:
127
+ bbox_heads = model.roi_head.bbox_head
128
+
129
+ loss_info = {}
130
+ for i, head in enumerate(bbox_heads):
131
+ loss_type = type(head.loss_bbox).__name__
132
+ loss_weight = head.loss_bbox.loss_weight
133
+ loss_info[f'stage_{i+1}'] = f"{loss_type}(w={loss_weight})"
134
+
135
+ runner.logger.info(f"Epoch {epoch} Loss Configuration: {loss_info}")
136
+
137
+ except Exception as e:
138
+ runner.logger.warning(f"Could not log loss info: {e}")
139
+
140
+ def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
141
+ """Monitor loss values during training and detect NaN."""
142
+ if self.switched and outputs is not None and isinstance(outputs, dict):
143
+ # NaN detection and rollback logic
144
+ if self.nan_detection and not self.rollback_performed:
145
+ total_loss = outputs.get('loss', None)
146
+ if total_loss is not None and torch.isnan(total_loss):
147
+ self.consecutive_nans += 1
148
+ runner.logger.warning(f"🚨 NaN detected in total loss! Consecutive: {self.consecutive_nans}/{self.max_nan_tolerance}")
149
+
150
+ if self.consecutive_nans >= self.max_nan_tolerance:
151
+ self._rollback_loss(runner)
152
+ self.consecutive_nans = 0
153
+ self.rollback_performed = True
154
+ runner.logger.error(f"πŸ”„ EMERGENCY ROLLBACK: Switched back to SmoothL1Loss due to {self.max_nan_tolerance} consecutive NaN losses")
155
+ return
156
+ elif total_loss is not None and torch.isfinite(total_loss):
157
+ # Reset NaN counter on successful iteration
158
+ self.consecutive_nans = 0
159
+
160
+ # Log individual stage losses if available
161
+ log_vars = outputs.get('log_vars', {})
162
+ stage_losses = {}
163
+
164
+ for key, value in log_vars.items():
165
+ if 'loss_bbox' in key and isinstance(value, (int, float)):
166
+ stage_losses[key] = value
167
+
168
+ if stage_losses and self.monitor_stage_weights:
169
+ # Log every 100 iterations to avoid spam
170
+ if runner.iter % 100 == 0:
171
+ loss_summary = ", ".join([f"{k}: {v:.4f}" for k, v in stage_losses.items()])
172
+ runner.logger.info(f"Stage Losses - {loss_summary}")
173
+
174
+ def after_train_epoch(self, runner):
175
+ """Check epoch completion and reset NaN counters."""
176
+ if self.nan_detection and self.switched:
177
+ # Log current status
178
+ if self.consecutive_nans > 0:
179
+ runner.logger.warning(f"Epoch {runner.epoch} completed with {self.consecutive_nans} NaN occurrences")
180
+ else:
181
+ runner.logger.info(f"Epoch {runner.epoch} completed successfully with {self.target_loss_type}")
182
+
183
+ def _rollback_loss(self, runner):
184
+ """Rollback stage 3 to SmoothL1Loss."""
185
+ try:
186
+ model = runner.model
187
+ if hasattr(model, 'module'):
188
+ bbox_head_stage2 = model.module.roi_head.bbox_head[2]
189
+ else:
190
+ bbox_head_stage2 = model.roi_head.bbox_head[2]
191
+
192
+ # Create new SmoothL1Loss
193
+ rollback_loss = SmoothL1Loss(beta=1.0, loss_weight=1.0)
194
+ bbox_head_stage2.loss_bbox = rollback_loss
195
+ bbox_head_stage2.reg_decoded_bbox = False # Disable decoded bbox for SmoothL1
196
+
197
+ runner.logger.info(f"βœ… Successfully rolled back Stage 3 from {self.target_loss_type} to SmoothL1Loss")
198
+
199
+ except Exception as e:
200
+ runner.logger.error(f"❌ Failed to rollback loss function: {e}")
201
+
202
+
203
+ @HOOKS.register_module()
204
+ class AdaptiveLossHook(Hook):
205
+ """
206
+ Adaptive version that switches based on training stability metrics.
207
+
208
+ Monitors IoU overlap quality and switches when model is stable.
209
+ """
210
+
211
+ def __init__(self,
212
+ min_epoch=3,
213
+ min_avg_iou=0.4,
214
+ target_loss_type='GIoULoss',
215
+ loss_weight=1.0,
216
+ check_interval=100):
217
+ super().__init__()
218
+ self.min_epoch = min_epoch
219
+ self.min_avg_iou = min_avg_iou
220
+ self.target_loss_type = target_loss_type
221
+ self.loss_weight = loss_weight
222
+ self.check_interval = check_interval
223
+ self.switched = False
224
+ self.iou_history = []
225
+
226
+ def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
227
+ """Monitor training stability through IoU metrics."""
228
+ if (not self.switched and
229
+ runner.epoch >= self.min_epoch and
230
+ runner.iter % self.check_interval == 0):
231
+
232
+ # Extract IoU information from outputs if available
233
+ if outputs and isinstance(outputs, dict):
234
+ log_vars = outputs.get('log_vars', {})
235
+
236
+ # Look for any IoU-related metrics
237
+ iou_metrics = [v for k, v in log_vars.items()
238
+ if 'iou' in k.lower() and isinstance(v, (int, float))]
239
+
240
+ if iou_metrics:
241
+ avg_iou = sum(iou_metrics) / len(iou_metrics)
242
+ self.iou_history.append(avg_iou)
243
+
244
+ # Keep only recent history
245
+ if len(self.iou_history) > 10:
246
+ self.iou_history.pop(0)
247
+
248
+ # Check if we should switch
249
+ if (len(self.iou_history) >= 5 and
250
+ sum(self.iou_history[-5:]) / 5 >= self.min_avg_iou):
251
+
252
+ self._switch_stage2_loss(runner)
253
+ self.switched = True
254
+
255
+ recent_iou = sum(self.iou_history[-5:]) / 5
256
+ runner.logger.info(
257
+ f"Adaptive switch at epoch {runner.epoch}, iter {runner.iter}: "
258
+ f"avg IoU {recent_iou:.3f} >= {self.min_avg_iou}")
259
+
260
+ def _switch_stage2_loss(self, runner):
261
+ """Same switching logic as ProgressiveLossHook."""
262
+ model = runner.model
263
+ try:
264
+ if hasattr(model, 'module'):
265
+ bbox_head_stage2 = model.module.roi_head.bbox_head[2]
266
+ else:
267
+ bbox_head_stage2 = model.roi_head.bbox_head[2]
268
+
269
+ if self.target_loss_type == 'GIoULoss':
270
+ new_loss = GIoULoss(loss_weight=self.loss_weight)
271
+ bbox_head_stage2.reg_decoded_bbox = True
272
+ elif self.target_loss_type == 'CIoULoss':
273
+ new_loss = CIoULoss(loss_weight=self.loss_weight)
274
+ bbox_head_stage2.reg_decoded_bbox = True
275
+ elif self.target_loss_type == 'DIoULoss':
276
+ new_loss = DIoULoss(loss_weight=self.loss_weight)
277
+ bbox_head_stage2.reg_decoded_bbox = True
278
+ else:
279
+ raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
280
+
281
+ bbox_head_stage2.loss_bbox = new_loss
282
+
283
+ runner.logger.info(f"Adaptive Loss Switch: Stage 3 β†’ {self.target_loss_type}")
284
+
285
+ except Exception as e:
286
+ runner.logger.error(f"Failed to switch loss function: {e}")
custom_models/register.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmdet.registry import MODELS, DATASETS, TRANSFORMS, HOOKS
2
+ from .custom_heads import FCHead, RegHead, DataSeriesHead
3
+ from .custom_cascade_with_meta import CustomCascadeWithMeta
4
+ from .custom_dataset import (
5
+ ChartDataset, RobustLoadImageFromFile, CreateDummyImg,
6
+ ClampBBoxes, SetScaleFactor, EnsureScaleFactor, SetInputs, CustomPackDetInputs
7
+ )
8
+ from .flexible_load_annotations import FlexibleLoadAnnotations
9
+ from .custom_hooks import (
10
+ ChartTypeDistributionHook, SkipInvalidLossHook, RuntimeErrorHook, MissingImageReportHook, SkipBadSamplesHook, CompatibleCheckpointHook
11
+ )
12
+ from .nan_recovery_hook import NanRecoveryHook
13
+ from .progressive_loss_hook import ProgressiveLossHook, AdaptiveLossHook
14
+ from .square_fcn_mask_head import SquareFCNMaskHead
15
+
16
+ def register_all_modules():
17
+ """Register all enhanced modules for comprehensive chart detection."""
18
+
19
+ # Note: Most modules are already registered via decorators
20
+ # Only register modules that don't use decorators
21
+
22
+ print("βœ… Enhanced chart detection modules registered via decorators:")
23
+ print(" πŸ“Š Models: FCHead, RegHead, DataSeriesHead, CustomCascadeWithMeta, SquareFCNMaskHead")
24
+ print(" πŸ“ Datasets: ChartDataset (21 categories)")
25
+ print(" πŸ”„ Transforms: RobustLoadImageFromFile, CreateDummyImg, ClampBBoxes, SetScaleFactor, EnsureScaleFactor, SetInputs, CustomPackDetInputs, FlexibleLoadAnnotations")
26
+ print(" 🎯 Hooks: ChartTypeDistributionHook, SkipInvalidLossHook, RuntimeErrorHook, MissingImageReportHook, SkipBadSamplesHook, NanRecoveryHook, CompatibleCheckpointHook, ProgressiveLossHook, AdaptiveLossHook")
27
+
28
+ # Just print info, don't re-register since decorators handle it
29
+ register_all_modules()
custom_models/square_fcn_mask_head.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead
8
+ from mmdet.registry import MODELS
9
+ from mmdet.structures.mask.mask_target import mask_target
10
+ from typing import Union, Dict, Any
11
+ from mmengine.config import ConfigDict
12
+ from mmengine.structures import InstanceData
13
+ from typing import List
14
+
15
+ from .square_mask_target import square_mask_target
16
+
17
+
18
+ @MODELS.register_module()
19
+ class SquareFCNMaskHead(FCNMaskHead):
20
+ """FCN mask head that forces square mask targets.
21
+
22
+ This head ensures that all mask targets are square regardless of the original
23
+ aspect ratio to avoid tensor size mismatches during training.
24
+ """
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: Initializing SquareFCNMaskHead")
28
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: args: {args}")
29
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: kwargs: {kwargs}")
30
+ super().__init__(*args, **kwargs)
31
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: SquareFCNMaskHead initialized successfully")
32
+
33
+ def forward(self, x: Tensor) -> Tensor:
34
+ """Forward features from the upstream network.
35
+
36
+ Args:
37
+ x (Tensor): Extract mask RoI features.
38
+
39
+ Returns:
40
+ Tensor: Predicted foreground masks.
41
+ """
42
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: Input shape: {x.shape}")
43
+
44
+ for i, conv in enumerate(self.convs):
45
+ x = conv(x)
46
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: After conv {i} shape: {x.shape}")
47
+
48
+ if self.upsample is not None:
49
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: Upsampling from {x.shape}")
50
+ x = self.upsample(x)
51
+ if self.upsample_method == 'deconv':
52
+ x = self.relu(x)
53
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: After upsample shape: {x.shape}")
54
+ else:
55
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: No upsampling, shape: {x.shape}")
56
+
57
+ mask_preds = self.conv_logits(x)
58
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: Final mask_preds shape: {mask_preds.shape}")
59
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds device: {mask_preds.device}")
60
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds dtype: {mask_preds.dtype}")
61
+
62
+ return mask_preds
63
+
64
+ def loss_and_target(self,
65
+ mask_preds: Tensor,
66
+ sampling_results: List[Any],
67
+ batch_gt_instances: List[InstanceData],
68
+ rcnn_train_cfg: Union[Dict[str, Any], ConfigDict]) -> dict:
69
+ """Calculate the loss based on the features extracted by the mask head.
70
+
71
+ Args:
72
+ mask_preds (Tensor): Predicted foreground masks, has shape
73
+ (num_pos, num_classes, mask_h, mask_w).
74
+ sampling_results (List[:obj:`SamplingResult`]): Assign results of
75
+ all images in a batch after sampling.
76
+ batch_gt_instances (List[:obj:`InstanceData`]): Batch of
77
+ gt_instance. It usually includes ``bboxes``, ``labels``,
78
+ and ``masks`` attributes.
79
+ rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
80
+
81
+ Returns:
82
+ dict: A dictionary of loss components.
83
+ """
84
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: loss_and_target called")
85
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
86
+
87
+ # Get mask targets
88
+ mask_targets = self.get_targets(sampling_results, batch_gt_instances,
89
+ rcnn_train_cfg)
90
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
91
+
92
+ # Get labels for positive proposals
93
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
94
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels shape: {pos_labels.shape}")
95
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels: {pos_labels}")
96
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels min: {pos_labels.min()}")
97
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels max: {pos_labels.max()}")
98
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: num_classes: {self.num_classes}")
99
+
100
+ # Check for out-of-bounds labels
101
+ if pos_labels.max() >= self.num_classes:
102
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: ERROR! Found label {pos_labels.max()} >= num_classes {self.num_classes}")
103
+ # Clamp labels to valid range
104
+ pos_labels = torch.clamp(pos_labels, 0, self.num_classes - 1)
105
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: Clamped pos_labels max: {pos_labels.max()}")
106
+
107
+ # Check for size mismatch between predictions and targets
108
+ if mask_preds.shape[-2:] != mask_targets.shape[-2:]:
109
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: SIZE MISMATCH!")
110
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
111
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
112
+
113
+ # Calculate loss - use the original approach like FCNMaskHead
114
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: About to call loss_mask")
115
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
116
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
117
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels shape: {pos_labels.shape}")
118
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds device: {mask_preds.device}")
119
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_targets device: {mask_targets.device}")
120
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels device: {pos_labels.device}")
121
+
122
+ # Call loss function with full mask_preds and pos_labels like the original FCN mask head
123
+ loss_mask = self.loss_mask(mask_preds, mask_targets, pos_labels)
124
+
125
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: Loss calculated successfully: {loss_mask}")
126
+
127
+ # only return the *nested* loss dict that StandardRoIHead.update() expects
128
+ return dict(
129
+ loss_mask={'loss_mask': loss_mask},
130
+ # if you really need mask_targets downstream you can still return it under a
131
+ # different key, but it will be ignored by the standard loss updater
132
+ mask_targets=mask_targets
133
+ )
134
+
135
+ def get_targets(self,
136
+ sampling_results: List[Any],
137
+ batch_gt_instances: List[InstanceData],
138
+ rcnn_train_cfg: Union[Dict[str, Any], ConfigDict]) -> Tensor:
139
+ """Calculate the ground truth for all samples in a batch according to
140
+ the sampling_results.
141
+
142
+ Args:
143
+ sampling_results (List[:obj:`SamplingResult`]): Assign results of
144
+ all images in a batch after sampling.
145
+ batch_gt_instances (List[:obj:`InstanceData`]): Batch of
146
+ gt_instance. It usually includes ``bboxes``, ``labels``,
147
+ and ``masks`` attributes.
148
+ rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
149
+
150
+ Returns:
151
+ Tensor: Mask targets of each positive proposals in the image,
152
+ has shape (num_pos, mask_h, mask_w).
153
+ """
154
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: get_targets called")
155
+
156
+ pos_proposals_list = [res.pos_priors for res in sampling_results]
157
+ pos_assigned_gt_inds_list = [
158
+ res.pos_assigned_gt_inds for res in sampling_results
159
+ ]
160
+ gt_masks_list = [res.masks for res in batch_gt_instances]
161
+
162
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: Number of sampling results: {len(sampling_results)}")
163
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: rcnn_train_cfg: {rcnn_train_cfg}")
164
+
165
+ # Use our custom square mask target function
166
+ mask_targets = square_mask_target(pos_proposals_list, pos_assigned_gt_inds_list,
167
+ gt_masks_list, rcnn_train_cfg)
168
+
169
+ print(f"πŸ” SQUARE_FCN_MASK_HEAD: Final mask_targets shape: {mask_targets.shape}")
170
+ return mask_targets
custom_models/square_mask_target.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn.modules.utils import _pair
5
+
6
+ from mmdet.registry import MODELS
7
+ from mmdet.structures.mask.mask_target import mask_target as original_mask_target
8
+
9
+
10
+ def square_mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, cfg):
11
+ """Compute square mask target for positive proposals in multiple images.
12
+
13
+ This function forces all mask targets to be square regardless of the original
14
+ aspect ratio to avoid tensor size mismatches.
15
+
16
+ Args:
17
+ pos_proposals_list (list[Tensor]): Positive proposals in multiple images.
18
+ pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each positive proposals.
19
+ gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of each image.
20
+ cfg (dict): Config dict that specifies the mask size.
21
+
22
+ Returns:
23
+ Tensor: Square mask target of each image, has shape (num_pos, size, size).
24
+ """
25
+ # Get the target size (should be a tuple like (14, 14))
26
+ mask_size = _pair(cfg.mask_size)
27
+
28
+ # Force square size by using the minimum dimension
29
+ square_size = min(mask_size)
30
+
31
+ # Create a proper ConfigDict object
32
+ from mmengine.config import ConfigDict
33
+ square_cfg = ConfigDict({'mask_size': (square_size, square_size)})
34
+
35
+ # Call the original mask target function with square size
36
+ mask_targets = original_mask_target(pos_proposals_list, pos_assigned_gt_inds_list,
37
+ gt_masks_list, square_cfg)
38
+
39
+ print(f"πŸ” SQUARE_MASK_TARGET: Original mask_targets shape: {mask_targets.shape}")
40
+ print(f"πŸ” SQUARE_MASK_TARGET: Expected square_size: {square_size}")
41
+
42
+ # Force square shape by padding or cropping if necessary
43
+ if mask_targets.size(1) != square_size or mask_targets.size(2) != square_size:
44
+ print(f"πŸ” SQUARE_MASK_TARGET: Forcing square shape from {mask_targets.shape} to ({mask_targets.size(0)}, {square_size}, {square_size})")
45
+
46
+ # Create new tensor with square shape
47
+ num_masks = mask_targets.size(0)
48
+ square_targets = torch.zeros(num_masks, square_size, square_size,
49
+ device=mask_targets.device, dtype=mask_targets.dtype)
50
+
51
+ # Copy the mask data, padding with zeros if necessary
52
+ h, w = mask_targets.size(1), mask_targets.size(2)
53
+ h_copy = min(h, square_size)
54
+ w_copy = min(w, square_size)
55
+
56
+ square_targets[:, :h_copy, :w_copy] = mask_targets[:, :h_copy, :w_copy]
57
+ mask_targets = square_targets
58
+
59
+ print(f"πŸ” SQUARE_MASK_TARGET: Final mask_targets shape: {mask_targets.shape}")
60
+ else:
61
+ print(f"πŸ” SQUARE_MASK_TARGET: Masks already square: {mask_targets.shape}")
62
+
63
+ return mask_targets
64
+
65
+
66
+ # Register the custom function
67
+ MODELS.register_module(name='square_mask_target', module=square_mask_target)
debug_api.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Debug script to check API status and endpoints
4
+ """
5
+
6
+ import requests
7
+ import json
8
+
9
+ def check_space_status():
10
+ """Check if the space is running"""
11
+ print("πŸ” Checking space status...")
12
+
13
+ try:
14
+ # Check main page
15
+ response = requests.get("https://hanszhu-dense-captioning-platform.hf.space/")
16
+ print(f"Main page status: {response.status_code}")
17
+
18
+ if response.status_code == 200:
19
+ print("βœ… Space is accessible")
20
+ else:
21
+ print("❌ Space is not accessible")
22
+
23
+ except Exception as e:
24
+ print(f"❌ Error checking space: {e}")
25
+
26
+ def check_api_endpoints():
27
+ """Check various API endpoints"""
28
+ print("\nπŸ” Checking API endpoints...")
29
+
30
+ base_url = "https://hanszhu-dense-captioning-platform.hf.space"
31
+
32
+ endpoints = [
33
+ "/",
34
+ "/api",
35
+ "/api/predict",
36
+ "/predict",
37
+ "/api/predict/",
38
+ "/predict/"
39
+ ]
40
+
41
+ for endpoint in endpoints:
42
+ try:
43
+ response = requests.get(f"{base_url}{endpoint}")
44
+ print(f"{endpoint}: {response.status_code} - {response.headers.get('content-type', 'unknown')}")
45
+
46
+ if response.status_code == 200:
47
+ print(f" Content preview: {response.text[:100]}...")
48
+
49
+ except Exception as e:
50
+ print(f"{endpoint}: Error - {e}")
51
+
52
+ def test_post_request():
53
+ """Test POST request to predict endpoint"""
54
+ print("\nπŸ” Testing POST request...")
55
+
56
+ try:
57
+ # Test URL
58
+ test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
59
+
60
+ # Try different POST formats
61
+ test_data = [
62
+ {"data": [test_url]},
63
+ {"fn_index": 0, "data": [test_url]},
64
+ {"data": test_url},
65
+ test_url
66
+ ]
67
+
68
+ for i, data in enumerate(test_data):
69
+ print(f"\nTest {i+1}: {type(data)}")
70
+
71
+ try:
72
+ response = requests.post(
73
+ "https://hanszhu-dense-captioning-platform.hf.space/predict",
74
+ json=data,
75
+ headers={"Content-Type": "application/json"}
76
+ )
77
+
78
+ print(f" Status: {response.status_code}")
79
+ print(f" Content-Type: {response.headers.get('content-type', 'unknown')}")
80
+ print(f" Response: {response.text[:200]}...")
81
+
82
+ except Exception as e:
83
+ print(f" Error: {e}")
84
+
85
+ except Exception as e:
86
+ print(f"❌ Error in POST test: {e}")
87
+
88
+ def test_gradio_client_detailed():
89
+ """Test gradio_client with detailed error handling"""
90
+ print("\nπŸ” Testing gradio_client with detailed error handling...")
91
+
92
+ try:
93
+ from gradio_client import Client
94
+
95
+ print("Creating client...")
96
+ client = Client("hanszhu/Dense-Captioning-Platform")
97
+
98
+ print("Getting space info...")
99
+ info = client.view_api()
100
+ print(f"API info: {info}")
101
+
102
+ print("Making prediction...")
103
+ test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
104
+ result = client.predict(test_url, api_name="/predict")
105
+
106
+ print(f"βœ… Success! Result: {result}")
107
+
108
+ except Exception as e:
109
+ print(f"❌ gradio_client error: {e}")
110
+ import traceback
111
+ traceback.print_exc()
112
+
113
+ if __name__ == "__main__":
114
+ print("πŸš€ Debugging Dense Captioning Platform API")
115
+ print("=" * 60)
116
+
117
+ check_space_status()
118
+ check_api_endpoints()
119
+ test_post_request()
120
+ test_gradio_client_detailed()
121
+
122
+ print("\n" + "=" * 60)
123
+ print("οΏ½οΏ½ Debug completed!")
find_api_endpoint.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to find the correct API endpoint
4
+ """
5
+
6
+ import requests
7
+ import json
8
+
9
+ def try_different_endpoints():
10
+ """Try different possible API endpoints"""
11
+ print("πŸ” Trying different API endpoints...")
12
+
13
+ base_url = "https://hanszhu-dense-captioning-platform.hf.space"
14
+
15
+ # Different possible endpoints
16
+ endpoints = [
17
+ "/api/predict",
18
+ "/predict",
19
+ "/api/run/predict",
20
+ "/run/predict",
21
+ "/api/0",
22
+ "/0",
23
+ "/api/1",
24
+ "/1",
25
+ "/api/2",
26
+ "/2",
27
+ "/api/3",
28
+ "/3",
29
+ "/api/4",
30
+ "/4",
31
+ "/api/5",
32
+ "/5"
33
+ ]
34
+
35
+ test_data = {
36
+ "data": ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"]
37
+ }
38
+
39
+ for endpoint in endpoints:
40
+ print(f"\nTrying POST to: {endpoint}")
41
+
42
+ try:
43
+ response = requests.post(
44
+ f"{base_url}{endpoint}",
45
+ json=test_data,
46
+ headers={"Content-Type": "application/json"}
47
+ )
48
+
49
+ print(f" Status: {response.status_code}")
50
+ print(f" Content-Type: {response.headers.get('content-type', 'unknown')}")
51
+
52
+ if response.status_code == 200:
53
+ print(" βœ… SUCCESS! Found working endpoint!")
54
+ print(f" Response: {response.text[:200]}...")
55
+ return endpoint
56
+ elif response.status_code == 405:
57
+ print(" ⚠️ Method not allowed (endpoint exists but wrong method)")
58
+ elif response.status_code == 404:
59
+ print(" ❌ Not found")
60
+ else:
61
+ print(f" ❌ Unexpected status: {response.text[:100]}...")
62
+
63
+ except Exception as e:
64
+ print(f" ❌ Error: {e}")
65
+
66
+ return None
67
+
68
+ def try_get_endpoints():
69
+ """Try GET requests to find API info"""
70
+ print("\nπŸ” Trying GET requests to find API info...")
71
+
72
+ base_url = "https://hanszhu-dense-captioning-platform.hf.space"
73
+
74
+ get_endpoints = [
75
+ "/api",
76
+ "/api/",
77
+ "/api/predict",
78
+ "/api/predict/",
79
+ "/api/run/predict",
80
+ "/api/run/predict/",
81
+ "/api/0",
82
+ "/api/1",
83
+ "/api/2",
84
+ "/api/3",
85
+ "/api/4",
86
+ "/api/5"
87
+ ]
88
+
89
+ for endpoint in get_endpoints:
90
+ print(f"\nTrying GET: {endpoint}")
91
+
92
+ try:
93
+ response = requests.get(f"{base_url}{endpoint}")
94
+
95
+ print(f" Status: {response.status_code}")
96
+ print(f" Content-Type: {response.headers.get('content-type', 'unknown')}")
97
+
98
+ if response.status_code == 200:
99
+ content = response.text[:200]
100
+ print(f" Content: {content}...")
101
+
102
+ # Check if it's JSON
103
+ if response.headers.get('content-type', '').startswith('application/json'):
104
+ print(" βœ… JSON response - this might be the API info!")
105
+ try:
106
+ data = response.json()
107
+ print(f" API Info: {json.dumps(data, indent=2)}")
108
+ except:
109
+ pass
110
+
111
+ except Exception as e:
112
+ print(f" ❌ Error: {e}")
113
+
114
+ def try_gradio_client_different_ways():
115
+ """Try gradio_client with different approaches"""
116
+ print("\nπŸ” Trying gradio_client with different approaches...")
117
+
118
+ try:
119
+ from gradio_client import Client
120
+
121
+ print("Creating client...")
122
+ client = Client("hanszhu/Dense-Captioning-Platform")
123
+
124
+ print("Trying different API names...")
125
+
126
+ api_names = ["/predict", "/run/predict", "0", "1", "2", "3", "4", "5"]
127
+
128
+ for api_name in api_names:
129
+ print(f"\nTrying api_name: {api_name}")
130
+
131
+ try:
132
+ test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
133
+ result = client.predict(test_url, api_name=api_name)
134
+ print(f" βœ… SUCCESS with api_name={api_name}!")
135
+ print(f" Result: {result}")
136
+ return api_name
137
+
138
+ except Exception as e:
139
+ print(f" ❌ Failed: {e}")
140
+
141
+ except Exception as e:
142
+ print(f"❌ gradio_client error: {e}")
143
+
144
+ if __name__ == "__main__":
145
+ print("πŸš€ Finding the correct API endpoint")
146
+ print("=" * 60)
147
+
148
+ # Try different POST endpoints
149
+ working_endpoint = try_different_endpoints()
150
+
151
+ # Try GET endpoints for API info
152
+ try_get_endpoints()
153
+
154
+ # Try gradio_client with different approaches
155
+ working_api_name = try_gradio_client_different_ways()
156
+
157
+ print("\n" + "=" * 60)
158
+ print("🏁 Endpoint discovery completed!")
159
+
160
+ if working_endpoint:
161
+ print(f"βœ… Found working POST endpoint: {working_endpoint}")
162
+ if working_api_name:
163
+ print(f"βœ… Found working gradio_client api_name: {working_api_name}")
164
+
165
+ if not working_endpoint and not working_api_name:
166
+ print("❌ No working endpoints found")
167
+ print("The space might still be loading or need different configuration")
models/chart_elementnet_swin.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cascade_rcnn_r50_fpn_meta.py - Enhanced config with Swin Transformer backbone
2
+ #
3
+ # PROGRESSIVE LOSS STRATEGY:
4
+ # - All 3 Cascade stages start with SmoothL1Loss for stable initial training
5
+ # - At epoch 5, Stage 3 (final stage) switches to GIoULoss via ProgressiveLossHook
6
+ # - Stage 1 & 2 remain SmoothL1Loss throughout training
7
+ # - This ensures model stability before introducing more complex IoU-based losses
8
+
9
+ # Custom imports - this registers our modules without polluting config namespace
10
+ custom_imports = dict(
11
+ imports=[
12
+ 'custom_models.custom_dataset',
13
+ 'custom_models.register',
14
+ 'custom_models.custom_hooks',
15
+ 'custom_models.progressive_loss_hook',
16
+ ],
17
+ allow_failed_imports=False
18
+ )
19
+
20
+ # Add to Python path
21
+ import sys
22
+ import os
23
+ # Use a simpler path approach that doesn't rely on __file__
24
+ sys.path.insert(0, os.path.join(os.getcwd(), '..', '..'))
25
+
26
+ # Custom Cascade model with coordinate handling for chart data
27
+ model = dict(
28
+ type='CustomCascadeWithMeta', # Use custom model with coordinate handling
29
+ coordinate_standardization=dict(
30
+ enabled=True,
31
+ origin='bottom_left', # Match annotation creation coordinate system
32
+ normalize=True,
33
+ relative_to_plot=False, # Keep simple for now
34
+ scale_to_axis=False # Keep simple for now
35
+ ),
36
+ data_preprocessor=dict(
37
+ type='DetDataPreprocessor',
38
+ mean=[123.675, 116.28, 103.53],
39
+ std=[58.395, 57.12, 57.375],
40
+ bgr_to_rgb=True,
41
+ pad_size_divisor=32),
42
+ # ----- Swin Transformer Base (22K) Backbone + FPN -----
43
+ backbone=dict(
44
+ type='SwinTransformer',
45
+ embed_dims=128, # Swin Base embedding dimensions
46
+ depths=[2, 2, 18, 2], # Swin Base depths
47
+ num_heads=[4, 8, 16, 32], # Swin Base attention heads
48
+ window_size=7,
49
+ mlp_ratio=4,
50
+ qkv_bias=True,
51
+ qk_scale=None,
52
+ drop_rate=0.0,
53
+ attn_drop_rate=0.0,
54
+ drop_path_rate=0.3, # Slightly higher for more complex model
55
+ patch_norm=True,
56
+ out_indices=(0, 1, 2, 3),
57
+ with_cp=False,
58
+ convert_weights=True,
59
+ init_cfg=dict(
60
+ type='Pretrained',
61
+ checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth'
62
+ )
63
+ ),
64
+ neck=dict(
65
+ type='FPN',
66
+ in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage)
67
+ out_channels=256,
68
+ num_outs=6,
69
+ start_level=0,
70
+ add_extra_convs='on_input'
71
+ ),
72
+ # Enhanced RPN with smaller anchors for tiny objects + improved losses
73
+ rpn_head=dict(
74
+ type='RPNHead',
75
+ in_channels=256,
76
+ feat_channels=256,
77
+ anchor_generator=dict(
78
+ type='AnchorGenerator',
79
+ scales=[1, 2, 4, 8], # Even smaller scales for tiny objects
80
+ ratios=[0.5, 1.0, 2.0], # Multiple aspect ratios
81
+ strides=[4, 8, 16, 32, 64, 128]), # Extended FPN strides
82
+ bbox_coder=dict(
83
+ type='DeltaXYWHBBoxCoder',
84
+ target_means=[.0, .0, .0, .0],
85
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
86
+ loss_cls=dict(
87
+ type='CrossEntropyLoss',
88
+ use_sigmoid=True,
89
+ loss_weight=1.0),
90
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
91
+ # Progressive Loss Strategy: Start with SmoothL1 for all 3 stages
92
+ # Stage 3 (final stage) will switch to GIoU at epoch 5 via ProgressiveLossHook
93
+ roi_head=dict(
94
+ type='CascadeRoIHead',
95
+ num_stages=3,
96
+ stage_loss_weights=[1, 0.5, 0.25],
97
+ bbox_roi_extractor=dict(
98
+ type='SingleRoIExtractor',
99
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
100
+ out_channels=256,
101
+ featmap_strides=[4, 8, 16, 32]),
102
+ bbox_head=[
103
+ # Stage 1: Always SmoothL1Loss (coarse detection)
104
+ dict(
105
+ type='Shared2FCBBoxHead',
106
+ in_channels=256,
107
+ fc_out_channels=1024,
108
+ roi_feat_size=7,
109
+ num_classes=21, # 21 enhanced categories
110
+ bbox_coder=dict(
111
+ type='DeltaXYWHBBoxCoder',
112
+ target_means=[0., 0., 0., 0.],
113
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
114
+ reg_class_agnostic=True,
115
+ loss_cls=dict(
116
+ type='CrossEntropyLoss',
117
+ use_sigmoid=False,
118
+ loss_weight=1.0),
119
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
120
+ # Stage 2: Always SmoothL1Loss (intermediate refinement)
121
+ dict(
122
+ type='Shared2FCBBoxHead',
123
+ in_channels=256,
124
+ fc_out_channels=1024,
125
+ roi_feat_size=7,
126
+ num_classes=21, # 21 enhanced categories
127
+ bbox_coder=dict(
128
+ type='DeltaXYWHBBoxCoder',
129
+ target_means=[0., 0., 0., 0.],
130
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
131
+ reg_class_agnostic=True,
132
+ loss_cls=dict(
133
+ type='CrossEntropyLoss',
134
+ use_sigmoid=False,
135
+ loss_weight=1.0),
136
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
137
+ # Stage 3: SmoothL1 β†’ GIoU at epoch 5 (progressive switching)
138
+ dict(
139
+ type='Shared2FCBBoxHead',
140
+ in_channels=256,
141
+ fc_out_channels=1024,
142
+ roi_feat_size=7,
143
+ num_classes=21, # 21 enhanced categories
144
+ bbox_coder=dict(
145
+ type='DeltaXYWHBBoxCoder',
146
+ target_means=[0., 0., 0., 0.],
147
+ target_stds=[0.02, 0.02, 0.05, 0.05]),
148
+ reg_class_agnostic=True,
149
+ loss_cls=dict(
150
+ type='CrossEntropyLoss',
151
+ use_sigmoid=False,
152
+ loss_weight=1.0),
153
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
154
+ ]),
155
+ train_cfg=dict(
156
+ rpn=dict(
157
+ assigner=dict(
158
+ type='MaxIoUAssigner',
159
+ pos_iou_thr=0.7,
160
+ neg_iou_thr=0.3,
161
+ min_pos_iou=0.3,
162
+ match_low_quality=True,
163
+ ignore_iof_thr=-1),
164
+ sampler=dict(
165
+ type='RandomSampler',
166
+ num=256,
167
+ pos_fraction=0.5,
168
+ neg_pos_ub=-1,
169
+ add_gt_as_proposals=False),
170
+ allowed_border=0,
171
+ pos_weight=-1,
172
+ debug=False),
173
+ rpn_proposal=dict(
174
+ nms_pre=2000,
175
+ max_per_img=2000,
176
+ nms=dict(type='nms', iou_threshold=0.8),
177
+ min_bbox_size=0),
178
+ rcnn=[
179
+ dict(
180
+ assigner=dict(
181
+ type='MaxIoUAssigner',
182
+ pos_iou_thr=0.4,
183
+ neg_iou_thr=0.4,
184
+ min_pos_iou=0.4,
185
+ match_low_quality=False,
186
+ ignore_iof_thr=-1),
187
+ sampler=dict(
188
+ type='RandomSampler',
189
+ num=512,
190
+ pos_fraction=0.25,
191
+ neg_pos_ub=-1,
192
+ add_gt_as_proposals=True),
193
+ pos_weight=-1,
194
+ debug=False),
195
+ dict(
196
+ assigner=dict(
197
+ type='MaxIoUAssigner',
198
+ pos_iou_thr=0.6,
199
+ neg_iou_thr=0.6,
200
+ min_pos_iou=0.6,
201
+ match_low_quality=False,
202
+ ignore_iof_thr=-1),
203
+ sampler=dict(
204
+ type='RandomSampler',
205
+ num=512,
206
+ pos_fraction=0.25,
207
+ neg_pos_ub=-1,
208
+ add_gt_as_proposals=True),
209
+ pos_weight=-1,
210
+ debug=False),
211
+ dict(
212
+ assigner=dict(
213
+ type='MaxIoUAssigner',
214
+ pos_iou_thr=0.7,
215
+ neg_iou_thr=0.7,
216
+ min_pos_iou=0.7,
217
+ match_low_quality=False,
218
+ ignore_iof_thr=-1),
219
+ sampler=dict(
220
+ type='RandomSampler',
221
+ num=512,
222
+ pos_fraction=0.25,
223
+ neg_pos_ub=-1,
224
+ add_gt_as_proposals=True),
225
+ pos_weight=-1,
226
+ debug=False)
227
+ ]),
228
+ # Enhanced test configuration with soft-NMS and multi-scale support
229
+ test_cfg=dict(
230
+ rpn=dict(
231
+ nms_pre=1000,
232
+ max_per_img=1000,
233
+ nms=dict(type='nms', iou_threshold=0.7),
234
+ min_bbox_size=0),
235
+ rcnn=dict(
236
+ score_thr=0.005, # Even lower threshold to catch more classes
237
+ nms=dict(
238
+ type='soft_nms', # Soft-NMS for better small object detection
239
+ iou_threshold=0.5,
240
+ min_score=0.005,
241
+ method='gaussian',
242
+ sigma=0.5),
243
+ max_per_img=500))) # Allow more detections
244
+
245
+ # Dataset settings - using cleaned annotations
246
+ dataset_type = 'ChartDataset'
247
+ data_root = '' # Remove data_root duplication
248
+
249
+ # Define the 21 chart element classes that match the annotations
250
+ CLASSES = (
251
+ 'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label',
252
+ 'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item',
253
+ 'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line',
254
+ 'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area'
255
+ )
256
+
257
+ # Updated to use cleaned annotation files
258
+ train_dataloader = dict(
259
+ batch_size=2, # Increased back to 2
260
+ num_workers=2,
261
+ persistent_workers=True,
262
+ sampler=dict(type='DefaultSampler', shuffle=True),
263
+ dataset=dict(
264
+ type=dataset_type,
265
+ data_root=data_root,
266
+ ann_file='legend_data/annotations_JSON_cleaned/train_enriched.json', # Full path
267
+ data_prefix=dict(img='legend_data/train/images/'), # Full path
268
+ metainfo=dict(classes=CLASSES), # Tell dataset what classes to expect
269
+ filter_cfg=dict(filter_empty_gt=True, min_size=0, class_specific_min_sizes={
270
+ 'data-point': 16, # Back to 16x16 from 32x32
271
+ 'data-bar': 16, # Back to 16x16 from 32x32
272
+ 'tick-label': 16, # Back to 16x16 from 32x32
273
+ 'x-tick-label': 16, # Back to 16x16 from 32x32
274
+ 'y-tick-label': 16 # Back to 16x16 from 32x32
275
+ }),
276
+ pipeline=[
277
+ dict(type='LoadImageFromFile'),
278
+ dict(type='LoadAnnotations', with_bbox=True),
279
+ dict(type='Resize', scale=(1600, 1000), keep_ratio=True), # Higher resolution for tiny objects
280
+ dict(type='RandomFlip', prob=0.5),
281
+ dict(type='ClampBBoxes'), # Ensure bboxes stay within image bounds
282
+ dict(type='PackDetInputs')
283
+ ]
284
+ )
285
+ )
286
+
287
+ val_dataloader = dict(
288
+ batch_size=1,
289
+ num_workers=2,
290
+ persistent_workers=True,
291
+ drop_last=False,
292
+ sampler=dict(type='DefaultSampler', shuffle=False),
293
+ dataset=dict(
294
+ type=dataset_type,
295
+ data_root=data_root,
296
+ ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json', # Full path
297
+ data_prefix=dict(img='legend_data/train/images/'), # All images are in train/images
298
+ metainfo=dict(classes=CLASSES), # Tell dataset what classes to expect
299
+ test_mode=True,
300
+ pipeline=[
301
+ dict(type='LoadImageFromFile'),
302
+ dict(type='Resize', scale=(1600, 1000), keep_ratio=True), # Base resolution for validation
303
+ dict(type='LoadAnnotations', with_bbox=True),
304
+ dict(type='ClampBBoxes'), # Ensure bboxes stay within image bounds
305
+ dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'))
306
+ ]
307
+ )
308
+ )
309
+
310
+ test_dataloader = val_dataloader
311
+
312
+ # Enhanced evaluators with debugging
313
+ val_evaluator = dict(
314
+ type='CocoMetric',
315
+ ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json', # Using cleaned annotations
316
+ metric='bbox',
317
+ format_only=False,
318
+ classwise=True, # Enable detailed per-class metrics table
319
+ proposal_nums=(100, 300, 1000)) # More detailed AR metrics
320
+
321
+ test_evaluator = val_evaluator
322
+
323
+ # Add custom hooks for debugging empty results
324
+ default_hooks = dict(
325
+ timer=dict(type='IterTimerHook'),
326
+ logger=dict(type='LoggerHook', interval=50),
327
+ param_scheduler=dict(type='ParamSchedulerHook'),
328
+ checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3),
329
+ sampler_seed=dict(type='DistSamplerSeedHook'),
330
+ visualization=dict(type='DetVisualizationHook'))
331
+
332
+ # Add NaN recovery hook for graceful handling like Faster R-CNN
333
+ custom_hooks = [
334
+ dict(type='SkipBadSamplesHook', interval=1), # Skip samples with bad GT data
335
+ dict(type='ChartTypeDistributionHook', interval=500), # Monitor class distribution
336
+ dict(type='MissingImageReportHook', interval=1000), # Track missing images
337
+ dict(type='NanRecoveryHook', # For logging & monitoring
338
+ fallback_loss=1.0,
339
+ max_consecutive_nans=100,
340
+ log_interval=50),
341
+ dict(type='ProgressiveLossHook', # Progressive loss switching
342
+ switch_epoch=5, # Switch stage 3 to GIoU at epoch 5
343
+ target_loss_type='GIoULoss', # Use GIoU for stage 3 (final stage)
344
+ loss_weight=1.0, # Keep same loss weight
345
+ warmup_epochs=2, # Monitor for 2 epochs after switch
346
+ monitor_stage_weights=True), # Log stage loss details
347
+ ]
348
+
349
+ # Training configuration - extended to 40 epochs for Swin Base on small objects
350
+ train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=40, val_interval=1)
351
+ val_cfg = dict(type='ValLoop')
352
+ test_cfg = dict(type='TestLoop')
353
+
354
+ # Optimizer with standard stable settings
355
+ optim_wrapper = dict(
356
+ type='OptimWrapper',
357
+ optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001),
358
+ clip_grad=dict(max_norm=35.0, norm_type=2)
359
+ )
360
+
361
+ # Extended learning rate schedule with cosine annealing for Swin Base
362
+ param_scheduler = [
363
+ dict(
364
+ type='LinearLR',
365
+ start_factor=0.05, # 1e-4 / 2e-2 = 0.05 (warmup from 1e-4 to 2e-2)
366
+ by_epoch=False,
367
+ begin=0,
368
+ end=1000), # 1k iteration warmup
369
+ dict(
370
+ type='CosineAnnealingLR',
371
+ begin=0,
372
+ end=40, # Match max_epochs
373
+ by_epoch=True,
374
+ T_max=40,
375
+ eta_min=1e-6, # Minimum learning rate
376
+ convert_to_iter_based=True)
377
+ ]
378
+
379
+ # Work directory
380
+ work_dir = './work_dirs/cascade_rcnn_swin_base_40ep_cosine_fpn_meta'
381
+
382
+ # Multi-scale test configuration (uncomment to enable)
383
+ # img_scales = [(800, 500), (1600, 1000), (2400, 1500)] # 0.5x, 1.0x, 1.5x scales
384
+ # tta_model = dict(
385
+ # type='DetTTAModel',
386
+ # tta_cfg=dict(
387
+ # nms=dict(type='nms', iou_threshold=0.5),
388
+ # max_per_img=100)
389
+ # )
390
+
391
+ # Fresh start
392
+ resume = False
393
+ load_from = None
394
+
models/chart_pointnet_swin.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mask_rcnn_swin_meta.py - Mask R-CNN with Swin Transformer for data point segmentation
2
+ #
3
+ # ADAPTED FROM CASCADE R-CNN CONFIG:
4
+ # - Uses same Swin Transformer Base backbone with optimizations
5
+ # - Maintains data-point class weighting (10x) and IoU strategies
6
+ # - Adds mask head for instance segmentation of data points
7
+ # - Uses enhanced annotation files with segmentation masks
8
+ # - Keeps custom hooks and progressive loss strategies
9
+ #
10
+ # MASK-SPECIFIC OPTIMIZATIONS:
11
+ # - RoI size 14x14 for mask extraction (matches data point size)
12
+ # - FCN mask head with 4 convolution layers
13
+ # - Mask loss weight balanced with bbox and classification losses
14
+ # - Enhanced test-time augmentation for better mask quality
15
+ #
16
+ # DATA POINT FOCUS:
17
+ # - Primary target: data-point class (ID 11) with 10x weight
18
+ # - Generates both bounding boxes AND instance masks
19
+ # - Optimized for 16x16 pixel data points in scientific charts
20
+ # Removed _base_ inheritance to avoid path issues - all configs are inlined below
21
+
22
+ # Custom imports - same as Cascade R-CNN setup
23
+ custom_imports = dict(
24
+ imports=[
25
+ 'custom_models.register',
26
+ 'custom_models.custom_hooks',
27
+ 'custom_models.progressive_loss_hook',
28
+ 'custom_models.flexible_load_annotations',
29
+ ],
30
+ allow_failed_imports=False
31
+ )
32
+
33
+ # Add to Python path
34
+ import sys
35
+ sys.path.insert(0, '.')
36
+
37
+ # Mask R-CNN model with Swin Transformer backbone
38
+ model = dict(
39
+ type='MaskRCNN',
40
+ data_preprocessor=dict(
41
+ type='DetDataPreprocessor',
42
+ mean=[123.675, 116.28, 103.53],
43
+ std=[58.395, 57.12, 57.375],
44
+ bgr_to_rgb=True,
45
+ pad_size_divisor=32,
46
+ pad_mask=True, # Important for mask training
47
+ mask_pad_value=0,
48
+ ),
49
+ # Same Swin Transformer Base backbone as Cascade R-CNN
50
+ backbone=dict(
51
+ type='SwinTransformer',
52
+ embed_dims=128, # Swin Base embedding dimensions
53
+ depths=[2, 2, 18, 2], # Swin Base depths
54
+ num_heads=[4, 8, 16, 32], # Swin Base attention heads
55
+ window_size=7,
56
+ mlp_ratio=4,
57
+ qkv_bias=True,
58
+ qk_scale=None,
59
+ drop_rate=0.0,
60
+ attn_drop_rate=0.0,
61
+ drop_path_rate=0.3, # Same as Cascade config
62
+ patch_norm=True,
63
+ out_indices=(0, 1, 2, 3),
64
+ with_cp=False,
65
+ convert_weights=True,
66
+ init_cfg=dict(
67
+ type='Pretrained',
68
+ checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth'
69
+ )
70
+ ),
71
+ # Same FPN as Cascade R-CNN
72
+ neck=dict(
73
+ type='FPN',
74
+ in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage)
75
+ out_channels=256,
76
+ num_outs=5, # Standard for Mask R-CNN (was 6 in Cascade)
77
+ start_level=0,
78
+ add_extra_convs='on_input'
79
+ ),
80
+ # Same RPN configuration as Cascade R-CNN
81
+ rpn_head=dict(
82
+ type='RPNHead',
83
+ in_channels=256,
84
+ feat_channels=256,
85
+ anchor_generator=dict(
86
+ type='AnchorGenerator',
87
+ scales=[1, 2, 4, 8], # Same small scales for tiny objects
88
+ ratios=[0.5, 1.0, 2.0],
89
+ strides=[4, 8, 16, 32, 64]), # Standard FPN strides for Mask R-CNN
90
+ bbox_coder=dict(
91
+ type='DeltaXYWHBBoxCoder',
92
+ target_means=[.0, .0, .0, .0],
93
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
94
+ loss_cls=dict(
95
+ type='CrossEntropyLoss',
96
+ use_sigmoid=True,
97
+ loss_weight=1.0),
98
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)
99
+ ),
100
+ # Mask R-CNN ROI head with bbox + mask branches
101
+ roi_head=dict(
102
+ type='StandardRoIHead',
103
+ # Bbox ROI extractor (same as Cascade R-CNN final stage)
104
+ bbox_roi_extractor=dict(
105
+ type='SingleRoIExtractor',
106
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
107
+ out_channels=256,
108
+ featmap_strides=[4, 8, 16, 32]
109
+ ),
110
+ # Bbox head with data-point class weighting
111
+ bbox_head=dict(
112
+ type='Shared2FCBBoxHead',
113
+ in_channels=256,
114
+ fc_out_channels=1024,
115
+ roi_feat_size=7,
116
+ num_classes=22, # 22 enhanced categories including boxplot
117
+ bbox_coder=dict(
118
+ type='DeltaXYWHBBoxCoder',
119
+ target_means=[0., 0., 0., 0.],
120
+ target_stds=[0.1, 0.1, 0.2, 0.2]
121
+ ),
122
+ reg_class_agnostic=False,
123
+ loss_cls=dict(
124
+ type='CrossEntropyLoss',
125
+ use_sigmoid=False,
126
+ loss_weight=1.0,
127
+ class_weight=[1.0, # background class (index 0)
128
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
129
+ 10.0, # data-point at index 12 gets 10x weight (11+1 for background)
130
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] # Added boxplot class
131
+ ),
132
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)
133
+ ),
134
+ # Mask ROI extractor (optimized for 16x16 data points)
135
+ mask_roi_extractor=dict(
136
+ type='SingleRoIExtractor',
137
+ roi_layer=dict(type='RoIAlign', output_size=(14, 14), sampling_ratio=0, aligned=True), # Force exact 14x14 with legacy alignment
138
+ out_channels=256,
139
+ featmap_strides=[4, 8, 16, 32]
140
+ ),
141
+ # Mask head optimized for data points with square mask targets
142
+ mask_head=dict(
143
+ type='SquareFCNMaskHead',
144
+ num_convs=4, # 4 conv layers for good feature extraction
145
+ in_channels=256,
146
+ roi_feat_size=14, # Explicitly set ROI feature size
147
+ conv_out_channels=256,
148
+ num_classes=22, # 22 enhanced categories including boxplot
149
+ upsample_cfg=dict(type=None), # No upsampling - keep 14x14
150
+ loss_mask=dict(
151
+ type='CrossEntropyLoss',
152
+ use_mask=True,
153
+ loss_weight=1.0 # Balanced with bbox loss
154
+ )
155
+ )
156
+ ),
157
+ # Training configuration adapted from Cascade R-CNN
158
+ train_cfg=dict(
159
+ rpn=dict(
160
+ assigner=dict(
161
+ type='MaxIoUAssigner',
162
+ pos_iou_thr=0.7,
163
+ neg_iou_thr=0.3,
164
+ min_pos_iou=0.3,
165
+ match_low_quality=True,
166
+ ignore_iof_thr=-1),
167
+ sampler=dict(
168
+ type='RandomSampler',
169
+ num=256,
170
+ pos_fraction=0.5,
171
+ neg_pos_ub=-1,
172
+ add_gt_as_proposals=False),
173
+ allowed_border=0,
174
+ pos_weight=-1,
175
+ debug=False),
176
+ rpn_proposal=dict(
177
+ nms_pre=2000,
178
+ max_per_img=1000,
179
+ nms=dict(type='nms', iou_threshold=0.7),
180
+ min_bbox_size=0),
181
+ # RCNN training (using Cascade stage 2 settings - balanced for mask training)
182
+ rcnn=dict(
183
+ assigner=dict(
184
+ type='MaxIoUAssigner',
185
+ pos_iou_thr=0.5, # Balanced IoU for bbox + mask training
186
+ neg_iou_thr=0.5,
187
+ min_pos_iou=0.5,
188
+ match_low_quality=True, # Important for small data points
189
+ ignore_iof_thr=-1),
190
+ sampler=dict(
191
+ type='RandomSampler',
192
+ num=512,
193
+ pos_fraction=0.25,
194
+ neg_pos_ub=-1,
195
+ add_gt_as_proposals=True),
196
+ mask_size=(14, 14), # Force exact 14x14 size for data points
197
+ pos_weight=-1,
198
+ debug=False)
199
+ ),
200
+ # Test configuration with soft NMS
201
+ test_cfg=dict(
202
+ rpn=dict(
203
+ nms_pre=1000,
204
+ max_per_img=1000,
205
+ nms=dict(type='nms', iou_threshold=0.7),
206
+ min_bbox_size=0),
207
+ rcnn=dict(
208
+ score_thr=0.005, # Low threshold to catch data points
209
+ nms=dict(
210
+ type='soft_nms', # Soft NMS for better small object detection
211
+ iou_threshold=0.3, # Low for data points
212
+ min_score=0.005,
213
+ method='gaussian',
214
+ sigma=0.5),
215
+ max_per_img=100,
216
+ mask_thr_binary=0.5 # Binary mask threshold
217
+ )
218
+ )
219
+ )
220
+
221
+ # Dataset settings - using standard COCO dataset for mask support
222
+ dataset_type = 'CocoDataset'
223
+ data_root = ''
224
+
225
+ # 22 enhanced categories including boxplot
226
+ CLASSES = (
227
+ 'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label', # 0-5
228
+ 'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item', # 6-10
229
+ 'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line', # 11-15 (data-point at index 11)
230
+ 'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area', # 16-20
231
+ 'boxplot' # 21
232
+ )
233
+
234
+ # Verify data-point class index
235
+ assert CLASSES[11] == 'data-point', f"Expected 'data-point' at index 11 in CLASSES tuple, got '{CLASSES[11]}'"
236
+
237
+ # Training dataloader with mask annotations
238
+ train_dataloader = dict(
239
+ batch_size=2, # Same as Cascade R-CNN
240
+ num_workers=2,
241
+ persistent_workers=True,
242
+ sampler=dict(type='DefaultSampler', shuffle=True),
243
+ dataset=dict(
244
+ type=dataset_type,
245
+ data_root=data_root,
246
+ ann_file='legend_match_swin/mask_generation/enhanced_datasets/train_filtered_with_masks_only.json',
247
+ data_prefix=dict(img='legend_data/train/images/'),
248
+ metainfo=dict(classes=CLASSES),
249
+ filter_cfg=dict(filter_empty_gt=False, min_size=12), # Don't filter out images with masks
250
+ # Disable any built-in filtering that might remove annotations
251
+ test_mode=False,
252
+ pipeline=[
253
+ dict(type='LoadImageFromFile'),
254
+ dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True),
255
+ dict(type='Resize', scale=(1120, 672), keep_ratio=True),
256
+ dict(type='RandomFlip', prob=0.5),
257
+ dict(type='ClampBBoxes'),
258
+ dict(type='PackDetInputs')
259
+ ]
260
+ )
261
+ )
262
+
263
+ # Validation dataloader with mask annotations
264
+ val_dataloader = dict(
265
+ batch_size=1,
266
+ num_workers=2,
267
+ persistent_workers=True,
268
+ drop_last=False,
269
+ sampler=dict(type='DefaultSampler', shuffle=False),
270
+ dataset=dict(
271
+ type=dataset_type,
272
+ data_root=data_root,
273
+ ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json',
274
+ data_prefix=dict(img='legend_data/train/images/'),
275
+ metainfo=dict(classes=CLASSES),
276
+ test_mode=True,
277
+ pipeline=[
278
+ dict(type='LoadImageFromFile'),
279
+ dict(type='Resize', scale=(1120, 672), keep_ratio=True),
280
+ dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True),
281
+ dict(type='ClampBBoxes'),
282
+ dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'))
283
+ ]
284
+ )
285
+ )
286
+
287
+ test_dataloader = val_dataloader
288
+
289
+ # Enhanced evaluators for both bbox and mask metrics
290
+ val_evaluator = dict(
291
+ type='CocoMetric',
292
+ ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json',
293
+ metric=['bbox', 'segm'],
294
+ format_only=False,
295
+ classwise=True,
296
+ proposal_nums=(100, 300, 1000)
297
+ )
298
+
299
+ test_evaluator = val_evaluator
300
+
301
+ # Same custom hooks as Cascade R-CNN
302
+ default_hooks = dict(
303
+ timer=dict(type='IterTimerHook'),
304
+ logger=dict(type='LoggerHook', interval=50),
305
+ param_scheduler=dict(type='ParamSchedulerHook'),
306
+ checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3),
307
+ sampler_seed=dict(type='DistSamplerSeedHook'),
308
+ visualization=dict(type='DetVisualizationHook')
309
+ )
310
+
311
+ # Same custom hooks as Cascade R-CNN (adapted for Mask R-CNN)
312
+ custom_hooks = [
313
+ dict(type='SkipBadSamplesHook', interval=1),
314
+ dict(type='ChartTypeDistributionHook', interval=500),
315
+ dict(type='MissingImageReportHook', interval=1000),
316
+ dict(type='NanRecoveryHook',
317
+ fallback_loss=1.0,
318
+ max_consecutive_nans=50,
319
+ log_interval=25),
320
+ # Note: Progressive loss hook not used in standard Mask R-CNN
321
+ # but could be adapted if needed for bbox loss only
322
+ ]
323
+
324
+ # Training configuration - reduced to 20 epochs
325
+ train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=20, val_interval=1)
326
+ val_cfg = dict(type='ValLoop')
327
+ test_cfg = dict(type='TestLoop')
328
+
329
+ # Same optimizer settings as Cascade R-CNN
330
+ optim_wrapper = dict(
331
+ type='OptimWrapper',
332
+ optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
333
+ clip_grad=dict(max_norm=10.0, norm_type=2)
334
+ )
335
+
336
+ # Same learning rate schedule as Cascade R-CNN
337
+ param_scheduler = [
338
+ dict(
339
+ type='LinearLR',
340
+ start_factor=0.1,
341
+ by_epoch=False,
342
+ begin=0,
343
+ end=1000),
344
+ dict(
345
+ type='CosineAnnealingLR',
346
+ begin=0,
347
+ end=20,
348
+ by_epoch=True,
349
+ T_max=20,
350
+ eta_min=1e-5,
351
+ convert_to_iter_based=True)
352
+ ]
353
+
354
+ # Work directory
355
+ work_dir = '/content/drive/MyDrive/Research Summer 2025/Dense Captioning Toolkit/CHART-DeMatch/work_dirs/mask_rcnn_swin_base_20ep_meta'
356
+
357
+ # Fresh start
358
+ resume = False
359
+ load_from = None
360
+
361
+ # Default runtime settings (normally inherited from _base_)
362
+ default_scope = 'mmdet'
363
+ env_cfg = dict(
364
+ cudnn_benchmark=False,
365
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
366
+ dist_cfg=dict(backend='nccl'),
367
+ )
368
+
369
+ vis_backends = [dict(type='LocalVisBackend')]
370
+ visualizer = dict(
371
+ type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
372
+
373
+ log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
374
+ log_level = 'INFO'
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.39.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ transformers>=4.30.0
5
+ Pillow>=9.0.0
6
+ numpy>=1.21.0
7
+ opencv-python>=4.8.0
8
+ huggingface-hub>=0.16.0
9
+ openmim
10
+ mmdet
11
+ mmengine
12
+ scikit-image>=0.21.0
simple_test.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple test script for the Dense Captioning Platform API
4
+ """
5
+
6
+ def test_gradio_client():
7
+ """Test using gradio_client"""
8
+ print("πŸ§ͺ Testing with gradio_client...")
9
+
10
+ try:
11
+ from gradio_client import Client, handle_file
12
+
13
+ # Initialize client with direct URL (working approach)
14
+ client = Client("https://hanszhu-dense-captioning-platform.hf.space")
15
+
16
+ # Test with a simple image URL
17
+ test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
18
+
19
+ print(f"Testing with URL: {test_url}")
20
+
21
+ # Make prediction using handle_file and fn_index (working approach)
22
+ result = client.predict(
23
+ image=handle_file(test_url),
24
+ fn_index=0 # Use fn_index instead of api_name
25
+ )
26
+
27
+ print("βœ… gradio_client test successful!")
28
+ print(f"Result: {result}")
29
+
30
+ return True
31
+
32
+ except Exception as e:
33
+ print(f"❌ gradio_client test failed: {e}")
34
+ return False
35
+
36
+ def test_direct_api():
37
+ """Test direct API call (Blocks don't support RESTful APIs)"""
38
+ print("\nπŸ§ͺ Testing direct API call...")
39
+
40
+ print("⚠️ Direct API calls not supported for Blocks-based Spaces")
41
+ print(" Use gradio_client instead for API access")
42
+ return False
43
+
44
+ if __name__ == "__main__":
45
+ print("πŸš€ Testing Dense Captioning Platform API")
46
+ print("=" * 50)
47
+
48
+ # Test both methods
49
+ gradio_success = test_gradio_client()
50
+ direct_success = test_direct_api()
51
+
52
+ print("\n" + "=" * 50)
53
+ print("🏁 Test Results:")
54
+ print(f"gradio_client: {'βœ… PASS' if gradio_success else '❌ FAIL'}")
55
+ print(f"Direct API: {'βœ… PASS' if direct_success else '❌ FAIL'}")
56
+
57
+ if gradio_success or direct_success:
58
+ print("\nπŸŽ‰ API is working!")
59
+ else:
60
+ print("\n⚠️ API needs more configuration")
test_api.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for the Dense Captioning Platform API
4
+ """
5
+
6
+ import requests
7
+ import json
8
+ from PIL import Image
9
+ import io
10
+ import base64
11
+
12
+ def test_api_with_url():
13
+ """Test the API using a URL"""
14
+ print("πŸ§ͺ Testing API with URL...")
15
+
16
+ # Test URL (a simple chart image)
17
+ test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
18
+
19
+ # API endpoint
20
+ api_url = "https://hanszhu-dense-captioning-platform.hf.space/predict"
21
+
22
+ try:
23
+ # Make the request
24
+ response = requests.post(
25
+ api_url,
26
+ json={"data": [test_url]},
27
+ headers={"Content-Type": "application/json"}
28
+ )
29
+
30
+ print(f"Status Code: {response.status_code}")
31
+ print(f"Response: {response.text[:500]}...")
32
+
33
+ if response.status_code == 200:
34
+ result = response.json()
35
+ print("βœ… API test successful!")
36
+ print(f"Chart Type: {result.get('data', [{}])[0].get('chart_type_label', 'Unknown')}")
37
+ else:
38
+ print("❌ API test failed!")
39
+
40
+ except Exception as e:
41
+ print(f"❌ Error testing API: {e}")
42
+
43
+ def test_api_with_gradio_client():
44
+ """Test the API using gradio_client"""
45
+ print("\nπŸ§ͺ Testing API with gradio_client...")
46
+
47
+ try:
48
+ from gradio_client import Client
49
+
50
+ # Initialize client
51
+ client = Client("hanszhu/Dense-Captioning-Platform")
52
+
53
+ # Test with a URL
54
+ result = client.predict(
55
+ "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
56
+ api_name="/predict"
57
+ )
58
+
59
+ print("βœ… gradio_client test successful!")
60
+ print(f"Result: {result}")
61
+
62
+ except Exception as e:
63
+ print(f"❌ Error with gradio_client: {e}")
64
+
65
+ def test_api_endpoints():
66
+ """Test available API endpoints"""
67
+ print("\nπŸ§ͺ Testing API endpoints...")
68
+
69
+ base_url = "https://hanszhu-dense-captioning-platform.hf.space"
70
+
71
+ endpoints = [
72
+ "/",
73
+ "/api",
74
+ "/api/predict",
75
+ "/predict"
76
+ ]
77
+
78
+ for endpoint in endpoints:
79
+ try:
80
+ response = requests.get(f"{base_url}{endpoint}")
81
+ print(f"{endpoint}: {response.status_code}")
82
+ except Exception as e:
83
+ print(f"{endpoint}: Error - {e}")
84
+
85
+ if __name__ == "__main__":
86
+ print("πŸš€ Testing Dense Captioning Platform API")
87
+ print("=" * 50)
88
+
89
+ # Test different approaches
90
+ test_api_endpoints()
91
+ test_api_with_url()
92
+ test_api_with_gradio_client()
93
+
94
+ print("\n" + "=" * 50)
95
+ print("🏁 API testing completed!")
test_api_endpoints.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive API endpoint testing
4
+ """
5
+
6
+ import requests
7
+ import json
8
+
9
+ def test_all_possible_endpoints():
10
+ """Test all possible API endpoint combinations"""
11
+ print("πŸ” Testing all possible API endpoints...")
12
+
13
+ base_url = "https://hanszhu-dense-captioning-platform.hf.space"
14
+ test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
15
+
16
+ # Different endpoint patterns
17
+ endpoints = [
18
+ "/api/predict",
19
+ "/predict",
20
+ "/api/run/predict",
21
+ "/run/predict",
22
+ "/api/0",
23
+ "/0",
24
+ "/api/1",
25
+ "/1",
26
+ "/api/2",
27
+ "/2"
28
+ ]
29
+
30
+ # Different request formats
31
+ request_formats = [
32
+ {"data": [test_url]},
33
+ {"data": [test_url], "fn_index": 0},
34
+ {"data": [test_url], "fn_index": 1},
35
+ {"data": test_url},
36
+ {"data": [test_url], "session_hash": "test123"}
37
+ ]
38
+
39
+ for endpoint in endpoints:
40
+ print(f"\nπŸ” Testing endpoint: {endpoint}")
41
+
42
+ for i, format_data in enumerate(request_formats):
43
+ print(f" Format {i+1}: {format_data}")
44
+
45
+ try:
46
+ response = requests.post(
47
+ f"{base_url}{endpoint}",
48
+ json=format_data,
49
+ headers={"Content-Type": "application/json"},
50
+ timeout=10
51
+ )
52
+
53
+ print(f" Status: {response.status_code}")
54
+
55
+ if response.status_code == 200:
56
+ print(" βœ… SUCCESS!")
57
+ print(f" Response: {response.text[:200]}...")
58
+ return endpoint, format_data
59
+ elif response.status_code == 405:
60
+ print(" ⚠️ Method not allowed (endpoint exists)")
61
+ elif response.status_code == 404:
62
+ print(" ❌ Not found")
63
+ else:
64
+ print(f" ❌ Unexpected: {response.text[:100]}...")
65
+
66
+ except Exception as e:
67
+ print(f" ❌ Error: {e}")
68
+
69
+ return None, None
70
+
71
+ def test_gradio_client_different_ways():
72
+ """Test gradio_client with different approaches"""
73
+ print("\nπŸ” Testing gradio_client with different approaches...")
74
+
75
+ try:
76
+ from gradio_client import Client
77
+
78
+ print("Creating client...")
79
+ client = Client("hanszhu/Dense-Captioning-Platform")
80
+
81
+ print("Trying different API names...")
82
+
83
+ api_names = ["/predict", "/run/predict", "0", "1", "2", "3", "4", "5"]
84
+
85
+ for api_name in api_names:
86
+ print(f"\nTrying api_name: {api_name}")
87
+
88
+ try:
89
+ test_url = "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
90
+ result = client.predict(test_url, api_name=api_name)
91
+ print(f" βœ… SUCCESS with api_name={api_name}!")
92
+ print(f" Result: {result}")
93
+ return api_name
94
+
95
+ except Exception as e:
96
+ print(f" ❌ Failed: {e}")
97
+
98
+ except Exception as e:
99
+ print(f"❌ gradio_client error: {e}")
100
+
101
+ def check_space_status():
102
+ """Check if the space is running and accessible"""
103
+ print("\nπŸ” Checking space status...")
104
+
105
+ try:
106
+ response = requests.get("https://hanszhu-dense-captioning-platform.hf.space/", timeout=10)
107
+ print(f"Space status: {response.status_code}")
108
+
109
+ if response.status_code == 200:
110
+ print("βœ… Space is running")
111
+
112
+ # Check for API-related content
113
+ content = response.text.lower()
114
+ if "api" in content:
115
+ print("βœ… API-related content found")
116
+ if "predict" in content:
117
+ print("βœ… Predict-related content found")
118
+ if "gradio" in content:
119
+ print("βœ… Gradio content found")
120
+
121
+ else:
122
+ print("❌ Space is not accessible")
123
+
124
+ except Exception as e:
125
+ print(f"❌ Error checking space: {e}")
126
+
127
+ if __name__ == "__main__":
128
+ print("πŸš€ Comprehensive API Endpoint Testing")
129
+ print("=" * 60)
130
+
131
+ # Check space status
132
+ check_space_status()
133
+
134
+ # Test all endpoints
135
+ working_endpoint, working_format = test_all_possible_endpoints()
136
+
137
+ # Test gradio_client
138
+ working_api_name = test_gradio_client_different_ways()
139
+
140
+ print("\n" + "=" * 60)
141
+ print("🏁 Testing completed!")
142
+
143
+ if working_endpoint and working_format:
144
+ print(f"βœ… Found working combination:")
145
+ print(f" Endpoint: {working_endpoint}")
146
+ print(f" Format: {working_format}")
147
+ if working_api_name:
148
+ print(f"βœ… Found working gradio_client api_name: {working_api_name}")
149
+
150
+ if not working_endpoint and not working_api_name:
151
+ print("❌ No working endpoints found")
152
+ print("The space might need different configuration or the API is not properly exposed")
web_test.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to check web interface and understand API issue
4
+ """
5
+
6
+ import requests
7
+ import time
8
+
9
+ def check_web_interface():
10
+ """Check if the web interface is working"""
11
+ print("πŸ” Checking web interface...")
12
+
13
+ try:
14
+ response = requests.get("https://hanszhu-dense-captioning-platform.hf.space/")
15
+
16
+ if response.status_code == 200:
17
+ print("βœ… Web interface is accessible")
18
+
19
+ # Check if it contains our app content
20
+ if "Dense Captioning Platform" in response.text:
21
+ print("βœ… App is loaded correctly")
22
+ else:
23
+ print("❌ App content not found")
24
+
25
+ # Check if it contains Gradio elements
26
+ if "gradio" in response.text.lower():
27
+ print("βœ… Gradio is loaded")
28
+ else:
29
+ print("❌ Gradio not found")
30
+
31
+ else:
32
+ print(f"❌ Web interface not accessible: {response.status_code}")
33
+
34
+ except Exception as e:
35
+ print(f"❌ Error checking web interface: {e}")
36
+
37
+ def check_api_info():
38
+ """Check API info endpoint"""
39
+ print("\nπŸ” Checking API info...")
40
+
41
+ try:
42
+ # Try different API info endpoints
43
+ endpoints = [
44
+ "https://hanszhu-dense-captioning-platform.hf.space/api",
45
+ "https://hanszhu-dense-captioning-platform.hf.space/api/",
46
+ "https://hanszhu-dense-captioning-platform.hf.space/api/predict",
47
+ "https://hanszhu-dense-captioning-platform.hf.space/api/predict/"
48
+ ]
49
+
50
+ for endpoint in endpoints:
51
+ print(f"\nTrying: {endpoint}")
52
+
53
+ try:
54
+ response = requests.get(endpoint)
55
+ print(f" Status: {response.status_code}")
56
+ print(f" Content-Type: {response.headers.get('content-type', 'unknown')}")
57
+
58
+ if response.status_code == 200:
59
+ content = response.text[:200]
60
+ print(f" Content: {content}...")
61
+
62
+ # Check if it's JSON
63
+ if response.headers.get('content-type', '').startswith('application/json'):
64
+ print(" βœ… JSON response")
65
+ else:
66
+ print(" ❌ Not JSON response")
67
+
68
+ except Exception as e:
69
+ print(f" Error: {e}")
70
+
71
+ except Exception as e:
72
+ print(f"❌ Error checking API info: {e}")
73
+
74
+ def wait_and_retry():
75
+ """Wait and retry to see if the API becomes available"""
76
+ print("\n⏳ Waiting for API to become available...")
77
+
78
+ for i in range(5):
79
+ print(f"\nAttempt {i+1}/5:")
80
+
81
+ try:
82
+ response = requests.get("https://hanszhu-dense-captioning-platform.hf.space/api")
83
+
84
+ if response.status_code == 200 and response.headers.get('content-type', '').startswith('application/json'):
85
+ print("βœ… API is now available!")
86
+ return True
87
+ else:
88
+ print(f"❌ API not ready yet: {response.status_code}")
89
+
90
+ except Exception as e:
91
+ print(f"❌ Error: {e}")
92
+
93
+ if i < 4: # Don't sleep after the last attempt
94
+ print("Waiting 30 seconds...")
95
+ time.sleep(30)
96
+
97
+ return False
98
+
99
+ if __name__ == "__main__":
100
+ print("πŸš€ Testing Dense Captioning Platform Web Interface")
101
+ print("=" * 60)
102
+
103
+ check_web_interface()
104
+ check_api_info()
105
+
106
+ # Wait and retry
107
+ if not wait_and_retry():
108
+ print("\n⚠️ API is still not available after waiting")
109
+ print("This might indicate:")
110
+ print("1. The space is still loading models")
111
+ print("2. There's a configuration issue")
112
+ print("3. The API endpoints need different configuration")
113
+
114
+ print("\n" + "=" * 60)
115
+ print("🏁 Web interface test completed!")