File size: 15,345 Bytes
eb4d305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
# cascade_rcnn_r50_fpn_meta.py - Enhanced config with Swin Transformer backbone
# 
# PROGRESSIVE LOSS STRATEGY:
# - All 3 Cascade stages start with SmoothL1Loss for stable initial training
# - At epoch 5, Stage 3 (final stage) switches to GIoULoss via ProgressiveLossHook  
# - Stage 1 & 2 remain SmoothL1Loss throughout training
# - This ensures model stability before introducing more complex IoU-based losses

# Custom imports - this registers our modules without polluting config namespace
custom_imports = dict(
    imports=[
        'custom_models.custom_dataset',
        'custom_models.register',
        'custom_models.custom_hooks',
        'custom_models.progressive_loss_hook',
    ],
    allow_failed_imports=False
)

# Add to Python path
import sys
import os
# Use a simpler path approach that doesn't rely on __file__
sys.path.insert(0, os.path.join(os.getcwd(), '..', '..'))

# Custom Cascade model with coordinate handling for chart data
model = dict(
    type='CustomCascadeWithMeta',  # Use custom model with coordinate handling
    coordinate_standardization=dict(
        enabled=True,
        origin='bottom_left',      # Match annotation creation coordinate system
        normalize=True,
        relative_to_plot=False,    # Keep simple for now
        scale_to_axis=False        # Keep simple for now
    ),
    data_preprocessor=dict(
        type='DetDataPreprocessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        bgr_to_rgb=True,
        pad_size_divisor=32),
    # ----- Swin Transformer Base (22K) Backbone + FPN -----
    backbone=dict(
        type='SwinTransformer',
        embed_dims=128,  # Swin Base embedding dimensions
        depths=[2, 2, 18, 2],  # Swin Base depths
        num_heads=[4, 8, 16, 32],  # Swin Base attention heads
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.3,  # Slightly higher for more complex model
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=False,
        convert_weights=True,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth'
        )
    ),
    neck=dict(
        type='FPN',
        in_channels=[128, 256, 512, 1024],  # Swin Base: embed_dims * 2^(stage)
        out_channels=256,
        num_outs=6,
        start_level=0,
        add_extra_convs='on_input'
    ),
    # Enhanced RPN with smaller anchors for tiny objects + improved losses
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[1, 2, 4, 8],  # Even smaller scales for tiny objects
            ratios=[0.5, 1.0, 2.0],  # Multiple aspect ratios
            strides=[4, 8, 16, 32, 64, 128]),  # Extended FPN strides
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            loss_weight=1.0),
        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
    # Progressive Loss Strategy: Start with SmoothL1 for all 3 stages
    # Stage 3 (final stage) will switch to GIoU at epoch 5 via ProgressiveLossHook
    roi_head=dict(
        type='CascadeRoIHead',
        num_stages=3,
        stage_loss_weights=[1, 0.5, 0.25],
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=[
            # Stage 1: Always SmoothL1Loss (coarse detection)
            dict(
                type='Shared2FCBBoxHead',
                in_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                num_classes=21,  # 21 enhanced categories
                bbox_coder=dict(
                    type='DeltaXYWHBBoxCoder',
                    target_means=[0., 0., 0., 0.],
                    target_stds=[0.05, 0.05, 0.1, 0.1]),
                reg_class_agnostic=True,
                loss_cls=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=False,
                    loss_weight=1.0),
                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
            # Stage 2: Always SmoothL1Loss (intermediate refinement)
            dict(
                type='Shared2FCBBoxHead',
                in_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                num_classes=21,  # 21 enhanced categories
                bbox_coder=dict(
                    type='DeltaXYWHBBoxCoder',
                    target_means=[0., 0., 0., 0.],
                    target_stds=[0.033, 0.033, 0.067, 0.067]),
                reg_class_agnostic=True,
                loss_cls=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=False,
                    loss_weight=1.0),
                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
            # Stage 3: SmoothL1 → GIoU at epoch 5 (progressive switching)
            dict(
                type='Shared2FCBBoxHead',
                in_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                num_classes=21,  # 21 enhanced categories
                bbox_coder=dict(
                    type='DeltaXYWHBBoxCoder',
                    target_means=[0., 0., 0., 0.],
                    target_stds=[0.02, 0.02, 0.05, 0.05]),
                reg_class_agnostic=True,
                loss_cls=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=False,
                    loss_weight=1.0),
                loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
        ]),
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=0,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=2000,
            nms=dict(type='nms', iou_threshold=0.8),
            min_bbox_size=0),
        rcnn=[
            dict(
                assigner=dict(
                    type='MaxIoUAssigner',
                    pos_iou_thr=0.4,
                    neg_iou_thr=0.4,
                    min_pos_iou=0.4,
                    match_low_quality=False,
                    ignore_iof_thr=-1),
                sampler=dict(
                    type='RandomSampler',
                    num=512,
                    pos_fraction=0.25,
                    neg_pos_ub=-1,
                    add_gt_as_proposals=True),
                pos_weight=-1,
                debug=False),
            dict(
                assigner=dict(
                    type='MaxIoUAssigner',
                    pos_iou_thr=0.6,
                    neg_iou_thr=0.6,
                    min_pos_iou=0.6,
                    match_low_quality=False,
                    ignore_iof_thr=-1),
                sampler=dict(
                    type='RandomSampler',
                    num=512,
                    pos_fraction=0.25,
                    neg_pos_ub=-1,
                    add_gt_as_proposals=True),
                pos_weight=-1,
                debug=False),
            dict(
                assigner=dict(
                    type='MaxIoUAssigner',
                    pos_iou_thr=0.7,
                    neg_iou_thr=0.7,
                    min_pos_iou=0.7,
                    match_low_quality=False,
                    ignore_iof_thr=-1),
                sampler=dict(
                    type='RandomSampler',
                    num=512,
                    pos_fraction=0.25,
                    neg_pos_ub=-1,
                    add_gt_as_proposals=True),
                pos_weight=-1,
                debug=False)
        ]),
    # Enhanced test configuration with soft-NMS and multi-scale support
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.005,  # Even lower threshold to catch more classes
            nms=dict(
                type='soft_nms',  # Soft-NMS for better small object detection
                iou_threshold=0.5,
                min_score=0.005,
                method='gaussian',
                sigma=0.5),
            max_per_img=500)))  # Allow more detections

# Dataset settings - using cleaned annotations
dataset_type = 'ChartDataset'
data_root = ''  # Remove data_root duplication

# Define the 21 chart element classes that match the annotations
CLASSES = (
    'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label',
    'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item',
    'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line',
    'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area'
)

# Updated to use cleaned annotation files
train_dataloader = dict(
    batch_size=2,  # Increased back to 2
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='legend_data/annotations_JSON_cleaned/train_enriched.json',  # Full path
        data_prefix=dict(img='legend_data/train/images/'),  # Full path
        metainfo=dict(classes=CLASSES),  # Tell dataset what classes to expect
        filter_cfg=dict(filter_empty_gt=True, min_size=0, class_specific_min_sizes={
            'data-point': 16,    # Back to 16x16 from 32x32 
            'data-bar': 16,      # Back to 16x16 from 32x32
            'tick-label': 16,    # Back to 16x16 from 32x32
            'x-tick-label': 16,  # Back to 16x16 from 32x32  
            'y-tick-label': 16   # Back to 16x16 from 32x32
        }),
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', with_bbox=True),
            dict(type='Resize', scale=(1600, 1000), keep_ratio=True),  # Higher resolution for tiny objects
            dict(type='RandomFlip', prob=0.5),
            dict(type='ClampBBoxes'),  # Ensure bboxes stay within image bounds
            dict(type='PackDetInputs')
        ]
    )
)

val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json',  # Full path
        data_prefix=dict(img='legend_data/train/images/'),  # All images are in train/images
        metainfo=dict(classes=CLASSES),  # Tell dataset what classes to expect
        test_mode=True,
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='Resize', scale=(1600, 1000), keep_ratio=True),  # Base resolution for validation
            dict(type='LoadAnnotations', with_bbox=True),
            dict(type='ClampBBoxes'),  # Ensure bboxes stay within image bounds
            dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'))
        ]
    )
)

test_dataloader = val_dataloader

# Enhanced evaluators with debugging
val_evaluator = dict(
    type='CocoMetric',
    ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json',  # Using cleaned annotations
    metric='bbox',
    format_only=False,
    classwise=True,  # Enable detailed per-class metrics table
    proposal_nums=(100, 300, 1000))  # More detailed AR metrics

test_evaluator = val_evaluator

# Add custom hooks for debugging empty results
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=50),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='DetVisualizationHook'))

# Add NaN recovery hook for graceful handling like Faster R-CNN
custom_hooks = [
    dict(type='SkipBadSamplesHook', interval=1),           # Skip samples with bad GT data
    dict(type='ChartTypeDistributionHook', interval=500),  # Monitor class distribution
    dict(type='MissingImageReportHook', interval=1000),    # Track missing images
    dict(type='NanRecoveryHook',                           # For logging & monitoring
         fallback_loss=1.0,
         max_consecutive_nans=100,
         log_interval=50),
    dict(type='ProgressiveLossHook',                       # Progressive loss switching
         switch_epoch=5,                                   # Switch stage 3 to GIoU at epoch 5
         target_loss_type='GIoULoss',                      # Use GIoU for stage 3 (final stage)
         loss_weight=1.0,                                  # Keep same loss weight
         warmup_epochs=2,                                  # Monitor for 2 epochs after switch
         monitor_stage_weights=True),                      # Log stage loss details
]

# Training configuration - extended to 40 epochs for Swin Base on small objects
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=40, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# Optimizer with standard stable settings
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001),
    clip_grad=dict(max_norm=35.0, norm_type=2)
)

# Extended learning rate schedule with cosine annealing for Swin Base
param_scheduler = [
    dict(
        type='LinearLR', 
        start_factor=0.05,  # 1e-4 / 2e-2 = 0.05 (warmup from 1e-4 to 2e-2)
        by_epoch=False, 
        begin=0, 
        end=1000),  # 1k iteration warmup
    dict(
        type='CosineAnnealingLR',
        begin=0,
        end=40,  # Match max_epochs
        by_epoch=True,
        T_max=40,
        eta_min=1e-6,  # Minimum learning rate
        convert_to_iter_based=True)
]

# Work directory 
work_dir = './work_dirs/cascade_rcnn_swin_base_40ep_cosine_fpn_meta'

# Multi-scale test configuration (uncomment to enable)
# img_scales = [(800, 500), (1600, 1000), (2400, 1500)]  # 0.5x, 1.0x, 1.5x scales
# tta_model = dict(
#     type='DetTTAModel',
#     tta_cfg=dict(
#         nms=dict(type='nms', iou_threshold=0.5),
#         max_per_img=100)
# )

# Fresh start
resume = False
load_from = None