File size: 4,695 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base

with read_base():
    from .rtmdet_l_8xb32_300e_coco import *

from mmcv.transforms.loading import LoadImageFromFile
from mmcv.transforms.processing import RandomResize
from mmengine.hooks.ema_hook import EMAHook
from torch.nn.modules.activation import SiLU

from mmdet.datasets.transforms.formatting import PackDetInputs
from mmdet.datasets.transforms.loading import (FilterAnnotations,
                                               LoadAnnotations)
from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic,
                                                  Pad, RandomCrop, RandomFlip,
                                                  Resize, YOLOXHSVRandomAug)
from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook
from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsSepBNHead
from mmdet.models.layers.ema import ExpMomentumEMA
from mmdet.models.losses.dice_loss import DiceLoss
from mmdet.models.losses.gfocal_loss import QualityFocalLoss
from mmdet.models.losses.iou_loss import GIoULoss
from mmdet.models.task_modules.coders.distance_point_bbox_coder import \
    DistancePointBBoxCoder
from mmdet.models.task_modules.prior_generators.point_generator import \
    MlvlPointGenerator

model.merge(
    dict(
        bbox_head=dict(
            _delete_=True,
            type=RTMDetInsSepBNHead,
            num_classes=80,
            in_channels=256,
            stacked_convs=2,
            share_conv=True,
            pred_kernel_size=1,
            feat_channels=256,
            act_cfg=dict(type=SiLU, inplace=True),
            norm_cfg=dict(type='SyncBN', requires_grad=True),
            anchor_generator=dict(
                type=MlvlPointGenerator, offset=0, strides=[8, 16, 32]),
            bbox_coder=dict(type=DistancePointBBoxCoder),
            loss_cls=dict(
                type=QualityFocalLoss,
                use_sigmoid=True,
                beta=2.0,
                loss_weight=1.0),
            loss_bbox=dict(type=GIoULoss, loss_weight=2.0),
            loss_mask=dict(
                type=DiceLoss, loss_weight=2.0, eps=5e-6, reduction='mean')),
        test_cfg=dict(
            nms_pre=1000,
            min_bbox_size=0,
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.6),
            max_per_img=100,
            mask_thr_binary=0.5),
    ))

train_pipeline = [
    dict(type=LoadImageFromFile, backend_args=backend_args),
    dict(
        type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False),
    dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0),
    dict(
        type=RandomResize,
        scale=(1280, 1280),
        ratio_range=(0.1, 2.0),
        resize_type=Resize,
        keep_ratio=True),
    dict(
        type=RandomCrop,
        crop_size=(640, 640),
        recompute_bbox=True,
        allow_negative_crop=True),
    dict(type=YOLOXHSVRandomAug),
    dict(type=RandomFlip, prob=0.5),
    dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
    dict(
        type=CachedMixUp,
        img_scale=(640, 640),
        ratio_range=(1.0, 1.0),
        max_cached_images=20,
        pad_val=(114, 114, 114)),
    dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)),
    dict(type=PackDetInputs)
]

train_dataloader.update(
    dict(pin_memory=True, dataset=dict(pipeline=train_pipeline)))

train_pipeline_stage2 = [
    dict(type=LoadImageFromFile, backend_args=backend_args),
    dict(
        type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False),
    dict(
        type=RandomResize,
        scale=(640, 640),
        ratio_range=(0.1, 2.0),
        resize_type=Resize,
        keep_ratio=True),
    dict(
        type=RandomCrop,
        crop_size=(640, 640),
        recompute_bbox=True,
        allow_negative_crop=True),
    dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)),
    dict(type=YOLOXHSVRandomAug),
    dict(type=RandomFlip, prob=0.5),
    dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
    dict(type=PackDetInputs)
]
custom_hooks = [
    dict(
        type=EMAHook,
        ema_type=ExpMomentumEMA,
        momentum=0.0002,
        update_buffers=True,
        priority=49),
    dict(
        type=PipelineSwitchHook,
        switch_epoch=280,
        switch_pipeline=train_pipeline_stage2)
]

val_evaluator.update(dict(metric=['bbox', 'segm']))
test_evaluator = val_evaluator