hanszhu commited on
Commit
c26f22f
·
verified ·
1 Parent(s): c598cec

Upload chart_pointnet_swin.py

Browse files
Files changed (1) hide show
  1. chart_pointnet_swin.py +374 -0
chart_pointnet_swin.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mask_rcnn_swin_meta.py - Mask R-CNN with Swin Transformer for data point segmentation
2
+ #
3
+ # ADAPTED FROM CASCADE R-CNN CONFIG:
4
+ # - Uses same Swin Transformer Base backbone with optimizations
5
+ # - Maintains data-point class weighting (10x) and IoU strategies
6
+ # - Adds mask head for instance segmentation of data points
7
+ # - Uses enhanced annotation files with segmentation masks
8
+ # - Keeps custom hooks and progressive loss strategies
9
+ #
10
+ # MASK-SPECIFIC OPTIMIZATIONS:
11
+ # - RoI size 14x14 for mask extraction (matches data point size)
12
+ # - FCN mask head with 4 convolution layers
13
+ # - Mask loss weight balanced with bbox and classification losses
14
+ # - Enhanced test-time augmentation for better mask quality
15
+ #
16
+ # DATA POINT FOCUS:
17
+ # - Primary target: data-point class (ID 11) with 10x weight
18
+ # - Generates both bounding boxes AND instance masks
19
+ # - Optimized for 16x16 pixel data points in scientific charts
20
+ # Removed _base_ inheritance to avoid path issues - all configs are inlined below
21
+
22
+ # Custom imports - same as Cascade R-CNN setup
23
+ custom_imports = dict(
24
+ imports=[
25
+ 'legend_match_swin.custom_models.register',
26
+ 'legend_match_swin.custom_models.custom_hooks',
27
+ 'legend_match_swin.custom_models.progressive_loss_hook',
28
+ 'legend_match_swin.custom_models.flexible_load_annotations',
29
+ ],
30
+ allow_failed_imports=False
31
+ )
32
+
33
+ # Add to Python path
34
+ import sys
35
+ sys.path.insert(0, '.')
36
+
37
+ # Mask R-CNN model with Swin Transformer backbone
38
+ model = dict(
39
+ type='MaskRCNN',
40
+ data_preprocessor=dict(
41
+ type='DetDataPreprocessor',
42
+ mean=[123.675, 116.28, 103.53],
43
+ std=[58.395, 57.12, 57.375],
44
+ bgr_to_rgb=True,
45
+ pad_size_divisor=32,
46
+ pad_mask=True, # Important for mask training
47
+ mask_pad_value=0,
48
+ ),
49
+ # Same Swin Transformer Base backbone as Cascade R-CNN
50
+ backbone=dict(
51
+ type='SwinTransformer',
52
+ embed_dims=128, # Swin Base embedding dimensions
53
+ depths=[2, 2, 18, 2], # Swin Base depths
54
+ num_heads=[4, 8, 16, 32], # Swin Base attention heads
55
+ window_size=7,
56
+ mlp_ratio=4,
57
+ qkv_bias=True,
58
+ qk_scale=None,
59
+ drop_rate=0.0,
60
+ attn_drop_rate=0.0,
61
+ drop_path_rate=0.3, # Same as Cascade config
62
+ patch_norm=True,
63
+ out_indices=(0, 1, 2, 3),
64
+ with_cp=False,
65
+ convert_weights=True,
66
+ init_cfg=dict(
67
+ type='Pretrained',
68
+ checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth'
69
+ )
70
+ ),
71
+ # Same FPN as Cascade R-CNN
72
+ neck=dict(
73
+ type='FPN',
74
+ in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage)
75
+ out_channels=256,
76
+ num_outs=5, # Standard for Mask R-CNN (was 6 in Cascade)
77
+ start_level=0,
78
+ add_extra_convs='on_input'
79
+ ),
80
+ # Same RPN configuration as Cascade R-CNN
81
+ rpn_head=dict(
82
+ type='RPNHead',
83
+ in_channels=256,
84
+ feat_channels=256,
85
+ anchor_generator=dict(
86
+ type='AnchorGenerator',
87
+ scales=[1, 2, 4, 8], # Same small scales for tiny objects
88
+ ratios=[0.5, 1.0, 2.0],
89
+ strides=[4, 8, 16, 32, 64]), # Standard FPN strides for Mask R-CNN
90
+ bbox_coder=dict(
91
+ type='DeltaXYWHBBoxCoder',
92
+ target_means=[.0, .0, .0, .0],
93
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
94
+ loss_cls=dict(
95
+ type='CrossEntropyLoss',
96
+ use_sigmoid=True,
97
+ loss_weight=1.0),
98
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)
99
+ ),
100
+ # Mask R-CNN ROI head with bbox + mask branches
101
+ roi_head=dict(
102
+ type='StandardRoIHead',
103
+ # Bbox ROI extractor (same as Cascade R-CNN final stage)
104
+ bbox_roi_extractor=dict(
105
+ type='SingleRoIExtractor',
106
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
107
+ out_channels=256,
108
+ featmap_strides=[4, 8, 16, 32]
109
+ ),
110
+ # Bbox head with data-point class weighting
111
+ bbox_head=dict(
112
+ type='Shared2FCBBoxHead',
113
+ in_channels=256,
114
+ fc_out_channels=1024,
115
+ roi_feat_size=7,
116
+ num_classes=22, # 22 enhanced categories including boxplot
117
+ bbox_coder=dict(
118
+ type='DeltaXYWHBBoxCoder',
119
+ target_means=[0., 0., 0., 0.],
120
+ target_stds=[0.1, 0.1, 0.2, 0.2]
121
+ ),
122
+ reg_class_agnostic=False,
123
+ loss_cls=dict(
124
+ type='CrossEntropyLoss',
125
+ use_sigmoid=False,
126
+ loss_weight=1.0,
127
+ class_weight=[1.0, # background class (index 0)
128
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
129
+ 10.0, # data-point at index 12 gets 10x weight (11+1 for background)
130
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] # Added boxplot class
131
+ ),
132
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)
133
+ ),
134
+ # Mask ROI extractor (optimized for 16x16 data points)
135
+ mask_roi_extractor=dict(
136
+ type='SingleRoIExtractor',
137
+ roi_layer=dict(type='RoIAlign', output_size=(14, 14), sampling_ratio=0, aligned=True), # Force exact 14x14 with legacy alignment
138
+ out_channels=256,
139
+ featmap_strides=[4, 8, 16, 32]
140
+ ),
141
+ # Mask head optimized for data points with square mask targets
142
+ mask_head=dict(
143
+ type='SquareFCNMaskHead',
144
+ num_convs=4, # 4 conv layers for good feature extraction
145
+ in_channels=256,
146
+ roi_feat_size=14, # Explicitly set ROI feature size
147
+ conv_out_channels=256,
148
+ num_classes=22, # 22 enhanced categories including boxplot
149
+ upsample_cfg=dict(type=None), # No upsampling - keep 14x14
150
+ loss_mask=dict(
151
+ type='CrossEntropyLoss',
152
+ use_mask=True,
153
+ loss_weight=1.0 # Balanced with bbox loss
154
+ )
155
+ )
156
+ ),
157
+ # Training configuration adapted from Cascade R-CNN
158
+ train_cfg=dict(
159
+ rpn=dict(
160
+ assigner=dict(
161
+ type='MaxIoUAssigner',
162
+ pos_iou_thr=0.7,
163
+ neg_iou_thr=0.3,
164
+ min_pos_iou=0.3,
165
+ match_low_quality=True,
166
+ ignore_iof_thr=-1),
167
+ sampler=dict(
168
+ type='RandomSampler',
169
+ num=256,
170
+ pos_fraction=0.5,
171
+ neg_pos_ub=-1,
172
+ add_gt_as_proposals=False),
173
+ allowed_border=0,
174
+ pos_weight=-1,
175
+ debug=False),
176
+ rpn_proposal=dict(
177
+ nms_pre=2000,
178
+ max_per_img=1000,
179
+ nms=dict(type='nms', iou_threshold=0.7),
180
+ min_bbox_size=0),
181
+ # RCNN training (using Cascade stage 2 settings - balanced for mask training)
182
+ rcnn=dict(
183
+ assigner=dict(
184
+ type='MaxIoUAssigner',
185
+ pos_iou_thr=0.5, # Balanced IoU for bbox + mask training
186
+ neg_iou_thr=0.5,
187
+ min_pos_iou=0.5,
188
+ match_low_quality=True, # Important for small data points
189
+ ignore_iof_thr=-1),
190
+ sampler=dict(
191
+ type='RandomSampler',
192
+ num=512,
193
+ pos_fraction=0.25,
194
+ neg_pos_ub=-1,
195
+ add_gt_as_proposals=True),
196
+ mask_size=(14, 14), # Force exact 14x14 size for data points
197
+ pos_weight=-1,
198
+ debug=False)
199
+ ),
200
+ # Test configuration with soft NMS
201
+ test_cfg=dict(
202
+ rpn=dict(
203
+ nms_pre=1000,
204
+ max_per_img=1000,
205
+ nms=dict(type='nms', iou_threshold=0.7),
206
+ min_bbox_size=0),
207
+ rcnn=dict(
208
+ score_thr=0.005, # Low threshold to catch data points
209
+ nms=dict(
210
+ type='soft_nms', # Soft NMS for better small object detection
211
+ iou_threshold=0.3, # Low for data points
212
+ min_score=0.005,
213
+ method='gaussian',
214
+ sigma=0.5),
215
+ max_per_img=100,
216
+ mask_thr_binary=0.5 # Binary mask threshold
217
+ )
218
+ )
219
+ )
220
+
221
+ # Dataset settings - using standard COCO dataset for mask support
222
+ dataset_type = 'CocoDataset'
223
+ data_root = ''
224
+
225
+ # 22 enhanced categories including boxplot
226
+ CLASSES = (
227
+ 'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label', # 0-5
228
+ 'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item', # 6-10
229
+ 'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line', # 11-15 (data-point at index 11)
230
+ 'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area', # 16-20
231
+ 'boxplot' # 21
232
+ )
233
+
234
+ # Verify data-point class index
235
+ assert CLASSES[11] == 'data-point', f"Expected 'data-point' at index 11 in CLASSES tuple, got '{CLASSES[11]}'"
236
+
237
+ # Training dataloader with mask annotations
238
+ train_dataloader = dict(
239
+ batch_size=2, # Same as Cascade R-CNN
240
+ num_workers=2,
241
+ persistent_workers=True,
242
+ sampler=dict(type='DefaultSampler', shuffle=True),
243
+ dataset=dict(
244
+ type=dataset_type,
245
+ data_root=data_root,
246
+ ann_file='legend_match_swin/mask_generation/enhanced_datasets/train_filtered_with_masks_only.json',
247
+ data_prefix=dict(img='legend_data/train/images/'),
248
+ metainfo=dict(classes=CLASSES),
249
+ filter_cfg=dict(filter_empty_gt=False, min_size=12), # Don't filter out images with masks
250
+ # Disable any built-in filtering that might remove annotations
251
+ test_mode=False,
252
+ pipeline=[
253
+ dict(type='LoadImageFromFile'),
254
+ dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True),
255
+ dict(type='Resize', scale=(1120, 672), keep_ratio=True),
256
+ dict(type='RandomFlip', prob=0.5),
257
+ dict(type='ClampBBoxes'),
258
+ dict(type='PackDetInputs')
259
+ ]
260
+ )
261
+ )
262
+
263
+ # Validation dataloader with mask annotations
264
+ val_dataloader = dict(
265
+ batch_size=1,
266
+ num_workers=2,
267
+ persistent_workers=True,
268
+ drop_last=False,
269
+ sampler=dict(type='DefaultSampler', shuffle=False),
270
+ dataset=dict(
271
+ type=dataset_type,
272
+ data_root=data_root,
273
+ ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json',
274
+ data_prefix=dict(img='legend_data/train/images/'),
275
+ metainfo=dict(classes=CLASSES),
276
+ test_mode=True,
277
+ pipeline=[
278
+ dict(type='LoadImageFromFile'),
279
+ dict(type='Resize', scale=(1120, 672), keep_ratio=True),
280
+ dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True),
281
+ dict(type='ClampBBoxes'),
282
+ dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'))
283
+ ]
284
+ )
285
+ )
286
+
287
+ test_dataloader = val_dataloader
288
+
289
+ # Enhanced evaluators for both bbox and mask metrics
290
+ val_evaluator = dict(
291
+ type='CocoMetric',
292
+ ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json',
293
+ metric=['bbox', 'segm'],
294
+ format_only=False,
295
+ classwise=True,
296
+ proposal_nums=(100, 300, 1000)
297
+ )
298
+
299
+ test_evaluator = val_evaluator
300
+
301
+ # Same custom hooks as Cascade R-CNN
302
+ default_hooks = dict(
303
+ timer=dict(type='IterTimerHook'),
304
+ logger=dict(type='LoggerHook', interval=50),
305
+ param_scheduler=dict(type='ParamSchedulerHook'),
306
+ checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3),
307
+ sampler_seed=dict(type='DistSamplerSeedHook'),
308
+ visualization=dict(type='DetVisualizationHook')
309
+ )
310
+
311
+ # Same custom hooks as Cascade R-CNN (adapted for Mask R-CNN)
312
+ custom_hooks = [
313
+ dict(type='SkipBadSamplesHook', interval=1),
314
+ dict(type='ChartTypeDistributionHook', interval=500),
315
+ dict(type='MissingImageReportHook', interval=1000),
316
+ dict(type='NanRecoveryHook',
317
+ fallback_loss=1.0,
318
+ max_consecutive_nans=50,
319
+ log_interval=25),
320
+ # Note: Progressive loss hook not used in standard Mask R-CNN
321
+ # but could be adapted if needed for bbox loss only
322
+ ]
323
+
324
+ # Training configuration - reduced to 20 epochs
325
+ train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=20, val_interval=1)
326
+ val_cfg = dict(type='ValLoop')
327
+ test_cfg = dict(type='TestLoop')
328
+
329
+ # Same optimizer settings as Cascade R-CNN
330
+ optim_wrapper = dict(
331
+ type='OptimWrapper',
332
+ optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
333
+ clip_grad=dict(max_norm=10.0, norm_type=2)
334
+ )
335
+
336
+ # Same learning rate schedule as Cascade R-CNN
337
+ param_scheduler = [
338
+ dict(
339
+ type='LinearLR',
340
+ start_factor=0.1,
341
+ by_epoch=False,
342
+ begin=0,
343
+ end=1000),
344
+ dict(
345
+ type='CosineAnnealingLR',
346
+ begin=0,
347
+ end=20,
348
+ by_epoch=True,
349
+ T_max=20,
350
+ eta_min=1e-5,
351
+ convert_to_iter_based=True)
352
+ ]
353
+
354
+ # Work directory
355
+ work_dir = '/content/drive/MyDrive/Research Summer 2025/Dense Captioning Toolkit/CHART-DeMatch/work_dirs/mask_rcnn_swin_base_20ep_meta'
356
+
357
+ # Fresh start
358
+ resume = False
359
+ load_from = None
360
+
361
+ # Default runtime settings (normally inherited from _base_)
362
+ default_scope = 'mmdet'
363
+ env_cfg = dict(
364
+ cudnn_benchmark=False,
365
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
366
+ dist_cfg=dict(backend='nccl'),
367
+ )
368
+
369
+ vis_backends = [dict(type='LocalVisBackend')]
370
+ visualizer = dict(
371
+ type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
372
+
373
+ log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
374
+ log_level = 'INFO'