Spaces:
Sleeping
Sleeping
# mask_rcnn_swin_meta.py - Mask R-CNN with Swin Transformer for data point segmentation | |
# | |
# ADAPTED FROM CASCADE R-CNN CONFIG: | |
# - Uses same Swin Transformer Base backbone with optimizations | |
# - Maintains data-point class weighting (10x) and IoU strategies | |
# - Adds mask head for instance segmentation of data points | |
# - Uses enhanced annotation files with segmentation masks | |
# - Keeps custom hooks and progressive loss strategies | |
# | |
# MASK-SPECIFIC OPTIMIZATIONS: | |
# - RoI size 14x14 for mask extraction (matches data point size) | |
# - FCN mask head with 4 convolution layers | |
# - Mask loss weight balanced with bbox and classification losses | |
# - Enhanced test-time augmentation for better mask quality | |
# | |
# DATA POINT FOCUS: | |
# - Primary target: data-point class (ID 11) with 10x weight | |
# - Generates both bounding boxes AND instance masks | |
# - Optimized for 16x16 pixel data points in scientific charts | |
# Removed _base_ inheritance to avoid path issues - all configs are inlined below | |
# Custom imports - same as Cascade R-CNN setup | |
custom_imports = dict( | |
imports=[ | |
'custom_models.register', | |
'custom_models.custom_hooks', | |
'custom_models.progressive_loss_hook', | |
'custom_models.flexible_load_annotations', | |
], | |
allow_failed_imports=False | |
) | |
# Add to Python path | |
import sys | |
sys.path.insert(0, '.') | |
# Mask R-CNN model with Swin Transformer backbone | |
model = dict( | |
type='MaskRCNN', | |
data_preprocessor=dict( | |
type='DetDataPreprocessor', | |
mean=[123.675, 116.28, 103.53], | |
std=[58.395, 57.12, 57.375], | |
bgr_to_rgb=True, | |
pad_size_divisor=32, | |
pad_mask=True, # Important for mask training | |
mask_pad_value=0, | |
), | |
# Same Swin Transformer Base backbone as Cascade R-CNN | |
backbone=dict( | |
type='SwinTransformer', | |
embed_dims=128, # Swin Base embedding dimensions | |
depths=[2, 2, 18, 2], # Swin Base depths | |
num_heads=[4, 8, 16, 32], # Swin Base attention heads | |
window_size=7, | |
mlp_ratio=4, | |
qkv_bias=True, | |
qk_scale=None, | |
drop_rate=0.0, | |
attn_drop_rate=0.0, | |
drop_path_rate=0.3, # Same as Cascade config | |
patch_norm=True, | |
out_indices=(0, 1, 2, 3), | |
with_cp=False, | |
convert_weights=True, | |
init_cfg=dict( | |
type='Pretrained', | |
checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth' | |
) | |
), | |
# Same FPN as Cascade R-CNN | |
neck=dict( | |
type='FPN', | |
in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage) | |
out_channels=256, | |
num_outs=5, # Standard for Mask R-CNN (was 6 in Cascade) | |
start_level=0, | |
add_extra_convs='on_input' | |
), | |
# Same RPN configuration as Cascade R-CNN | |
rpn_head=dict( | |
type='RPNHead', | |
in_channels=256, | |
feat_channels=256, | |
anchor_generator=dict( | |
type='AnchorGenerator', | |
scales=[1, 2, 4, 8], # Same small scales for tiny objects | |
ratios=[0.5, 1.0, 2.0], | |
strides=[4, 8, 16, 32, 64]), # Standard FPN strides for Mask R-CNN | |
bbox_coder=dict( | |
type='DeltaXYWHBBoxCoder', | |
target_means=[.0, .0, .0, .0], | |
target_stds=[1.0, 1.0, 1.0, 1.0]), | |
loss_cls=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=True, | |
loss_weight=1.0), | |
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0) | |
), | |
# Mask R-CNN ROI head with bbox + mask branches | |
roi_head=dict( | |
type='StandardRoIHead', | |
# Bbox ROI extractor (same as Cascade R-CNN final stage) | |
bbox_roi_extractor=dict( | |
type='SingleRoIExtractor', | |
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), | |
out_channels=256, | |
featmap_strides=[4, 8, 16, 32] | |
), | |
# Bbox head with data-point class weighting | |
bbox_head=dict( | |
type='Shared2FCBBoxHead', | |
in_channels=256, | |
fc_out_channels=1024, | |
roi_feat_size=7, | |
num_classes=22, # 22 enhanced categories including boxplot | |
bbox_coder=dict( | |
type='DeltaXYWHBBoxCoder', | |
target_means=[0., 0., 0., 0.], | |
target_stds=[0.1, 0.1, 0.2, 0.2] | |
), | |
reg_class_agnostic=False, | |
loss_cls=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=False, | |
loss_weight=1.0, | |
class_weight=[1.0, # background class (index 0) | |
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |
10.0, # data-point at index 12 gets 10x weight (11+1 for background) | |
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] # Added boxplot class | |
), | |
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0) | |
), | |
# Mask ROI extractor (optimized for 16x16 data points) | |
mask_roi_extractor=dict( | |
type='SingleRoIExtractor', | |
roi_layer=dict(type='RoIAlign', output_size=(14, 14), sampling_ratio=0, aligned=True), # Force exact 14x14 with legacy alignment | |
out_channels=256, | |
featmap_strides=[4, 8, 16, 32] | |
), | |
# Mask head optimized for data points with square mask targets | |
mask_head=dict( | |
type='SquareFCNMaskHead', | |
num_convs=4, # 4 conv layers for good feature extraction | |
in_channels=256, | |
roi_feat_size=14, # Explicitly set ROI feature size | |
conv_out_channels=256, | |
num_classes=22, # 22 enhanced categories including boxplot | |
upsample_cfg=dict(type=None), # No upsampling - keep 14x14 | |
loss_mask=dict( | |
type='CrossEntropyLoss', | |
use_mask=True, | |
loss_weight=1.0 # Balanced with bbox loss | |
) | |
) | |
), | |
# Training configuration adapted from Cascade R-CNN | |
train_cfg=dict( | |
rpn=dict( | |
assigner=dict( | |
type='MaxIoUAssigner', | |
pos_iou_thr=0.7, | |
neg_iou_thr=0.3, | |
min_pos_iou=0.3, | |
match_low_quality=True, | |
ignore_iof_thr=-1), | |
sampler=dict( | |
type='RandomSampler', | |
num=256, | |
pos_fraction=0.5, | |
neg_pos_ub=-1, | |
add_gt_as_proposals=False), | |
allowed_border=0, | |
pos_weight=-1, | |
debug=False), | |
rpn_proposal=dict( | |
nms_pre=2000, | |
max_per_img=1000, | |
nms=dict(type='nms', iou_threshold=0.7), | |
min_bbox_size=0), | |
# RCNN training (using Cascade stage 2 settings - balanced for mask training) | |
rcnn=dict( | |
assigner=dict( | |
type='MaxIoUAssigner', | |
pos_iou_thr=0.5, # Balanced IoU for bbox + mask training | |
neg_iou_thr=0.5, | |
min_pos_iou=0.5, | |
match_low_quality=True, # Important for small data points | |
ignore_iof_thr=-1), | |
sampler=dict( | |
type='RandomSampler', | |
num=512, | |
pos_fraction=0.25, | |
neg_pos_ub=-1, | |
add_gt_as_proposals=True), | |
mask_size=(14, 14), # Force exact 14x14 size for data points | |
pos_weight=-1, | |
debug=False) | |
), | |
# Test configuration with soft NMS | |
test_cfg=dict( | |
rpn=dict( | |
nms_pre=1000, | |
max_per_img=1000, | |
nms=dict(type='nms', iou_threshold=0.7), | |
min_bbox_size=0), | |
rcnn=dict( | |
score_thr=0.005, # Low threshold to catch data points | |
nms=dict( | |
type='soft_nms', # Soft NMS for better small object detection | |
iou_threshold=0.3, # Low for data points | |
min_score=0.005, | |
method='gaussian', | |
sigma=0.5), | |
max_per_img=100, | |
mask_thr_binary=0.5 # Binary mask threshold | |
) | |
) | |
) | |
# Dataset settings - using standard COCO dataset for mask support | |
dataset_type = 'CocoDataset' | |
data_root = '' | |
# 22 enhanced categories including boxplot | |
CLASSES = ( | |
'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label', # 0-5 | |
'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item', # 6-10 | |
'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line', # 11-15 (data-point at index 11) | |
'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area', # 16-20 | |
'boxplot' # 21 | |
) | |
# Verify data-point class index | |
assert CLASSES[11] == 'data-point', f"Expected 'data-point' at index 11 in CLASSES tuple, got '{CLASSES[11]}'" | |
# Training dataloader with mask annotations | |
train_dataloader = dict( | |
batch_size=2, # Same as Cascade R-CNN | |
num_workers=2, | |
persistent_workers=True, | |
sampler=dict(type='DefaultSampler', shuffle=True), | |
dataset=dict( | |
type=dataset_type, | |
data_root=data_root, | |
ann_file='legend_match_swin/mask_generation/enhanced_datasets/train_filtered_with_masks_only.json', | |
data_prefix=dict(img='legend_data/train/images/'), | |
metainfo=dict(classes=CLASSES), | |
filter_cfg=dict(filter_empty_gt=False, min_size=12), # Don't filter out images with masks | |
# Disable any built-in filtering that might remove annotations | |
test_mode=False, | |
pipeline=[ | |
dict(type='LoadImageFromFile'), | |
dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True), | |
dict(type='Resize', scale=(1120, 672), keep_ratio=True), | |
dict(type='RandomFlip', prob=0.5), | |
dict(type='ClampBBoxes'), | |
dict(type='PackDetInputs') | |
] | |
) | |
) | |
# Validation dataloader with mask annotations | |
val_dataloader = dict( | |
batch_size=1, | |
num_workers=2, | |
persistent_workers=True, | |
drop_last=False, | |
sampler=dict(type='DefaultSampler', shuffle=False), | |
dataset=dict( | |
type=dataset_type, | |
data_root=data_root, | |
ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json', | |
data_prefix=dict(img='legend_data/train/images/'), | |
metainfo=dict(classes=CLASSES), | |
test_mode=True, | |
pipeline=[ | |
dict(type='LoadImageFromFile'), | |
dict(type='Resize', scale=(1120, 672), keep_ratio=True), | |
dict(type='FlexibleLoadAnnotations', with_bbox=True, with_mask=True), | |
dict(type='ClampBBoxes'), | |
dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) | |
] | |
) | |
) | |
test_dataloader = val_dataloader | |
# Enhanced evaluators for both bbox and mask metrics | |
val_evaluator = dict( | |
type='CocoMetric', | |
ann_file='legend_match_swin/mask_generation/enhanced_datasets/val_enriched_with_masks_only.json', | |
metric=['bbox', 'segm'], | |
format_only=False, | |
classwise=True, | |
proposal_nums=(100, 300, 1000) | |
) | |
test_evaluator = val_evaluator | |
# Same custom hooks as Cascade R-CNN | |
default_hooks = dict( | |
timer=dict(type='IterTimerHook'), | |
logger=dict(type='LoggerHook', interval=50), | |
param_scheduler=dict(type='ParamSchedulerHook'), | |
checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3), | |
sampler_seed=dict(type='DistSamplerSeedHook'), | |
visualization=dict(type='DetVisualizationHook') | |
) | |
# Same custom hooks as Cascade R-CNN (adapted for Mask R-CNN) | |
custom_hooks = [ | |
dict(type='SkipBadSamplesHook', interval=1), | |
dict(type='ChartTypeDistributionHook', interval=500), | |
dict(type='MissingImageReportHook', interval=1000), | |
dict(type='NanRecoveryHook', | |
fallback_loss=1.0, | |
max_consecutive_nans=50, | |
log_interval=25), | |
# Note: Progressive loss hook not used in standard Mask R-CNN | |
# but could be adapted if needed for bbox loss only | |
] | |
# Training configuration - reduced to 20 epochs | |
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=20, val_interval=1) | |
val_cfg = dict(type='ValLoop') | |
test_cfg = dict(type='TestLoop') | |
# Same optimizer settings as Cascade R-CNN | |
optim_wrapper = dict( | |
type='OptimWrapper', | |
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001), | |
clip_grad=dict(max_norm=10.0, norm_type=2) | |
) | |
# Same learning rate schedule as Cascade R-CNN | |
param_scheduler = [ | |
dict( | |
type='LinearLR', | |
start_factor=0.1, | |
by_epoch=False, | |
begin=0, | |
end=1000), | |
dict( | |
type='CosineAnnealingLR', | |
begin=0, | |
end=20, | |
by_epoch=True, | |
T_max=20, | |
eta_min=1e-5, | |
convert_to_iter_based=True) | |
] | |
# Work directory | |
work_dir = '/content/drive/MyDrive/Research Summer 2025/Dense Captioning Toolkit/CHART-DeMatch/work_dirs/mask_rcnn_swin_base_20ep_meta' | |
# Fresh start | |
resume = False | |
load_from = None | |
# Default runtime settings (normally inherited from _base_) | |
default_scope = 'mmdet' | |
env_cfg = dict( | |
cudnn_benchmark=False, | |
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | |
dist_cfg=dict(backend='nccl'), | |
) | |
vis_backends = [dict(type='LocalVisBackend')] | |
visualizer = dict( | |
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') | |
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True) | |
log_level = 'INFO' |