Spaces:
Sleeping
Sleeping
# cascade_rcnn_r50_fpn_meta.py - Enhanced config with Swin Transformer backbone | |
# | |
# PROGRESSIVE LOSS STRATEGY: | |
# - All 3 Cascade stages start with SmoothL1Loss for stable initial training | |
# - At epoch 5, Stage 3 (final stage) switches to GIoULoss via ProgressiveLossHook | |
# - Stage 1 & 2 remain SmoothL1Loss throughout training | |
# - This ensures model stability before introducing more complex IoU-based losses | |
# Custom imports - this registers our modules without polluting config namespace | |
custom_imports = dict( | |
imports=[ | |
'custom_models.custom_dataset', | |
'custom_models.register', | |
'custom_models.custom_hooks', | |
'custom_models.progressive_loss_hook', | |
], | |
allow_failed_imports=False | |
) | |
# Add to Python path | |
import sys | |
import os | |
# Use a simpler path approach that doesn't rely on __file__ | |
sys.path.insert(0, os.path.join(os.getcwd(), '..', '..')) | |
# Custom Cascade model with coordinate handling for chart data | |
model = dict( | |
type='CustomCascadeWithMeta', # Use custom model with coordinate handling | |
coordinate_standardization=dict( | |
enabled=True, | |
origin='bottom_left', # Match annotation creation coordinate system | |
normalize=True, | |
relative_to_plot=False, # Keep simple for now | |
scale_to_axis=False # Keep simple for now | |
), | |
data_preprocessor=dict( | |
type='DetDataPreprocessor', | |
mean=[123.675, 116.28, 103.53], | |
std=[58.395, 57.12, 57.375], | |
bgr_to_rgb=True, | |
pad_size_divisor=32), | |
# ----- Swin Transformer Base (22K) Backbone + FPN ----- | |
backbone=dict( | |
type='SwinTransformer', | |
embed_dims=128, # Swin Base embedding dimensions | |
depths=[2, 2, 18, 2], # Swin Base depths | |
num_heads=[4, 8, 16, 32], # Swin Base attention heads | |
window_size=7, | |
mlp_ratio=4, | |
qkv_bias=True, | |
qk_scale=None, | |
drop_rate=0.0, | |
attn_drop_rate=0.0, | |
drop_path_rate=0.3, # Slightly higher for more complex model | |
patch_norm=True, | |
out_indices=(0, 1, 2, 3), | |
with_cp=False, | |
convert_weights=True, | |
init_cfg=dict( | |
type='Pretrained', | |
checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth' | |
) | |
), | |
neck=dict( | |
type='FPN', | |
in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage) | |
out_channels=256, | |
num_outs=6, | |
start_level=0, | |
add_extra_convs='on_input' | |
), | |
# Enhanced RPN with smaller anchors for tiny objects + improved losses | |
rpn_head=dict( | |
type='RPNHead', | |
in_channels=256, | |
feat_channels=256, | |
anchor_generator=dict( | |
type='AnchorGenerator', | |
scales=[1, 2, 4, 8], # Even smaller scales for tiny objects | |
ratios=[0.5, 1.0, 2.0], # Multiple aspect ratios | |
strides=[4, 8, 16, 32, 64, 128]), # Extended FPN strides | |
bbox_coder=dict( | |
type='DeltaXYWHBBoxCoder', | |
target_means=[.0, .0, .0, .0], | |
target_stds=[1.0, 1.0, 1.0, 1.0]), | |
loss_cls=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=True, | |
loss_weight=1.0), | |
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), | |
# Progressive Loss Strategy: Start with SmoothL1 for all 3 stages | |
# Stage 3 (final stage) will switch to GIoU at epoch 5 via ProgressiveLossHook | |
roi_head=dict( | |
type='CascadeRoIHead', | |
num_stages=3, | |
stage_loss_weights=[1, 0.5, 0.25], | |
bbox_roi_extractor=dict( | |
type='SingleRoIExtractor', | |
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), | |
out_channels=256, | |
featmap_strides=[4, 8, 16, 32]), | |
bbox_head=[ | |
# Stage 1: Always SmoothL1Loss (coarse detection) | |
dict( | |
type='Shared2FCBBoxHead', | |
in_channels=256, | |
fc_out_channels=1024, | |
roi_feat_size=7, | |
num_classes=21, # 21 enhanced categories | |
bbox_coder=dict( | |
type='DeltaXYWHBBoxCoder', | |
target_means=[0., 0., 0., 0.], | |
target_stds=[0.05, 0.05, 0.1, 0.1]), | |
reg_class_agnostic=True, | |
loss_cls=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=False, | |
loss_weight=1.0), | |
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), | |
# Stage 2: Always SmoothL1Loss (intermediate refinement) | |
dict( | |
type='Shared2FCBBoxHead', | |
in_channels=256, | |
fc_out_channels=1024, | |
roi_feat_size=7, | |
num_classes=21, # 21 enhanced categories | |
bbox_coder=dict( | |
type='DeltaXYWHBBoxCoder', | |
target_means=[0., 0., 0., 0.], | |
target_stds=[0.033, 0.033, 0.067, 0.067]), | |
reg_class_agnostic=True, | |
loss_cls=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=False, | |
loss_weight=1.0), | |
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), | |
# Stage 3: SmoothL1 → GIoU at epoch 5 (progressive switching) | |
dict( | |
type='Shared2FCBBoxHead', | |
in_channels=256, | |
fc_out_channels=1024, | |
roi_feat_size=7, | |
num_classes=21, # 21 enhanced categories | |
bbox_coder=dict( | |
type='DeltaXYWHBBoxCoder', | |
target_means=[0., 0., 0., 0.], | |
target_stds=[0.02, 0.02, 0.05, 0.05]), | |
reg_class_agnostic=True, | |
loss_cls=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=False, | |
loss_weight=1.0), | |
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) | |
]), | |
train_cfg=dict( | |
rpn=dict( | |
assigner=dict( | |
type='MaxIoUAssigner', | |
pos_iou_thr=0.7, | |
neg_iou_thr=0.3, | |
min_pos_iou=0.3, | |
match_low_quality=True, | |
ignore_iof_thr=-1), | |
sampler=dict( | |
type='RandomSampler', | |
num=256, | |
pos_fraction=0.5, | |
neg_pos_ub=-1, | |
add_gt_as_proposals=False), | |
allowed_border=0, | |
pos_weight=-1, | |
debug=False), | |
rpn_proposal=dict( | |
nms_pre=2000, | |
max_per_img=2000, | |
nms=dict(type='nms', iou_threshold=0.8), | |
min_bbox_size=0), | |
rcnn=[ | |
dict( | |
assigner=dict( | |
type='MaxIoUAssigner', | |
pos_iou_thr=0.4, | |
neg_iou_thr=0.4, | |
min_pos_iou=0.4, | |
match_low_quality=False, | |
ignore_iof_thr=-1), | |
sampler=dict( | |
type='RandomSampler', | |
num=512, | |
pos_fraction=0.25, | |
neg_pos_ub=-1, | |
add_gt_as_proposals=True), | |
pos_weight=-1, | |
debug=False), | |
dict( | |
assigner=dict( | |
type='MaxIoUAssigner', | |
pos_iou_thr=0.6, | |
neg_iou_thr=0.6, | |
min_pos_iou=0.6, | |
match_low_quality=False, | |
ignore_iof_thr=-1), | |
sampler=dict( | |
type='RandomSampler', | |
num=512, | |
pos_fraction=0.25, | |
neg_pos_ub=-1, | |
add_gt_as_proposals=True), | |
pos_weight=-1, | |
debug=False), | |
dict( | |
assigner=dict( | |
type='MaxIoUAssigner', | |
pos_iou_thr=0.7, | |
neg_iou_thr=0.7, | |
min_pos_iou=0.7, | |
match_low_quality=False, | |
ignore_iof_thr=-1), | |
sampler=dict( | |
type='RandomSampler', | |
num=512, | |
pos_fraction=0.25, | |
neg_pos_ub=-1, | |
add_gt_as_proposals=True), | |
pos_weight=-1, | |
debug=False) | |
]), | |
# Enhanced test configuration with soft-NMS and multi-scale support | |
test_cfg=dict( | |
rpn=dict( | |
nms_pre=1000, | |
max_per_img=1000, | |
nms=dict(type='nms', iou_threshold=0.7), | |
min_bbox_size=0), | |
rcnn=dict( | |
score_thr=0.005, # Even lower threshold to catch more classes | |
nms=dict( | |
type='soft_nms', # Soft-NMS for better small object detection | |
iou_threshold=0.5, | |
min_score=0.005, | |
method='gaussian', | |
sigma=0.5), | |
max_per_img=500))) # Allow more detections | |
# Dataset settings - using cleaned annotations | |
dataset_type = 'ChartDataset' | |
data_root = '' # Remove data_root duplication | |
# Define the 21 chart element classes that match the annotations | |
CLASSES = ( | |
'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label', | |
'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item', | |
'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line', | |
'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area' | |
) | |
# Updated to use cleaned annotation files | |
train_dataloader = dict( | |
batch_size=2, # Increased back to 2 | |
num_workers=2, | |
persistent_workers=True, | |
sampler=dict(type='DefaultSampler', shuffle=True), | |
dataset=dict( | |
type=dataset_type, | |
data_root=data_root, | |
ann_file='legend_data/annotations_JSON_cleaned/train_enriched.json', # Full path | |
data_prefix=dict(img='legend_data/train/images/'), # Full path | |
metainfo=dict(classes=CLASSES), # Tell dataset what classes to expect | |
filter_cfg=dict(filter_empty_gt=True, min_size=0, class_specific_min_sizes={ | |
'data-point': 16, # Back to 16x16 from 32x32 | |
'data-bar': 16, # Back to 16x16 from 32x32 | |
'tick-label': 16, # Back to 16x16 from 32x32 | |
'x-tick-label': 16, # Back to 16x16 from 32x32 | |
'y-tick-label': 16 # Back to 16x16 from 32x32 | |
}), | |
pipeline=[ | |
dict(type='LoadImageFromFile'), | |
dict(type='LoadAnnotations', with_bbox=True), | |
dict(type='Resize', scale=(1600, 1000), keep_ratio=True), # Higher resolution for tiny objects | |
dict(type='RandomFlip', prob=0.5), | |
dict(type='ClampBBoxes'), # Ensure bboxes stay within image bounds | |
dict(type='PackDetInputs') | |
] | |
) | |
) | |
val_dataloader = dict( | |
batch_size=1, | |
num_workers=2, | |
persistent_workers=True, | |
drop_last=False, | |
sampler=dict(type='DefaultSampler', shuffle=False), | |
dataset=dict( | |
type=dataset_type, | |
data_root=data_root, | |
ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json', # Full path | |
data_prefix=dict(img='legend_data/train/images/'), # All images are in train/images | |
metainfo=dict(classes=CLASSES), # Tell dataset what classes to expect | |
test_mode=True, | |
pipeline=[ | |
dict(type='LoadImageFromFile'), | |
dict(type='Resize', scale=(1600, 1000), keep_ratio=True), # Base resolution for validation | |
dict(type='LoadAnnotations', with_bbox=True), | |
dict(type='ClampBBoxes'), # Ensure bboxes stay within image bounds | |
dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) | |
] | |
) | |
) | |
test_dataloader = val_dataloader | |
# Enhanced evaluators with debugging | |
val_evaluator = dict( | |
type='CocoMetric', | |
ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json', # Using cleaned annotations | |
metric='bbox', | |
format_only=False, | |
classwise=True, # Enable detailed per-class metrics table | |
proposal_nums=(100, 300, 1000)) # More detailed AR metrics | |
test_evaluator = val_evaluator | |
# Add custom hooks for debugging empty results | |
default_hooks = dict( | |
timer=dict(type='IterTimerHook'), | |
logger=dict(type='LoggerHook', interval=50), | |
param_scheduler=dict(type='ParamSchedulerHook'), | |
checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3), | |
sampler_seed=dict(type='DistSamplerSeedHook'), | |
visualization=dict(type='DetVisualizationHook')) | |
# Add NaN recovery hook for graceful handling like Faster R-CNN | |
custom_hooks = [ | |
dict(type='SkipBadSamplesHook', interval=1), # Skip samples with bad GT data | |
dict(type='ChartTypeDistributionHook', interval=500), # Monitor class distribution | |
dict(type='MissingImageReportHook', interval=1000), # Track missing images | |
dict(type='NanRecoveryHook', # For logging & monitoring | |
fallback_loss=1.0, | |
max_consecutive_nans=100, | |
log_interval=50), | |
dict(type='ProgressiveLossHook', # Progressive loss switching | |
switch_epoch=5, # Switch stage 3 to GIoU at epoch 5 | |
target_loss_type='GIoULoss', # Use GIoU for stage 3 (final stage) | |
loss_weight=1.0, # Keep same loss weight | |
warmup_epochs=2, # Monitor for 2 epochs after switch | |
monitor_stage_weights=True), # Log stage loss details | |
] | |
# Training configuration - extended to 40 epochs for Swin Base on small objects | |
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=40, val_interval=1) | |
val_cfg = dict(type='ValLoop') | |
test_cfg = dict(type='TestLoop') | |
# Optimizer with standard stable settings | |
optim_wrapper = dict( | |
type='OptimWrapper', | |
optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001), | |
clip_grad=dict(max_norm=35.0, norm_type=2) | |
) | |
# Extended learning rate schedule with cosine annealing for Swin Base | |
param_scheduler = [ | |
dict( | |
type='LinearLR', | |
start_factor=0.05, # 1e-4 / 2e-2 = 0.05 (warmup from 1e-4 to 2e-2) | |
by_epoch=False, | |
begin=0, | |
end=1000), # 1k iteration warmup | |
dict( | |
type='CosineAnnealingLR', | |
begin=0, | |
end=40, # Match max_epochs | |
by_epoch=True, | |
T_max=40, | |
eta_min=1e-6, # Minimum learning rate | |
convert_to_iter_based=True) | |
] | |
# Work directory | |
work_dir = './work_dirs/cascade_rcnn_swin_base_40ep_cosine_fpn_meta' | |
# Multi-scale test configuration (uncomment to enable) | |
# img_scales = [(800, 500), (1600, 1000), (2400, 1500)] # 0.5x, 1.0x, 1.5x scales | |
# tta_model = dict( | |
# type='DetTTAModel', | |
# tta_cfg=dict( | |
# nms=dict(type='nms', iou_threshold=0.5), | |
# max_per_img=100) | |
# ) | |
# Fresh start | |
resume = False | |
load_from = None | |