Spaces:
Runtime error
Runtime error
File size: 4,695 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) OpenMMLab. All rights reserved.
# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0
from mmengine.config import read_base
with read_base():
from .rtmdet_l_8xb32_300e_coco import *
from mmcv.transforms.loading import LoadImageFromFile
from mmcv.transforms.processing import RandomResize
from mmengine.hooks.ema_hook import EMAHook
from torch.nn.modules.activation import SiLU
from mmdet.datasets.transforms.formatting import PackDetInputs
from mmdet.datasets.transforms.loading import (FilterAnnotations,
LoadAnnotations)
from mmdet.datasets.transforms.transforms import (CachedMixUp, CachedMosaic,
Pad, RandomCrop, RandomFlip,
Resize, YOLOXHSVRandomAug)
from mmdet.engine.hooks.pipeline_switch_hook import PipelineSwitchHook
from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsSepBNHead
from mmdet.models.layers.ema import ExpMomentumEMA
from mmdet.models.losses.dice_loss import DiceLoss
from mmdet.models.losses.gfocal_loss import QualityFocalLoss
from mmdet.models.losses.iou_loss import GIoULoss
from mmdet.models.task_modules.coders.distance_point_bbox_coder import \
DistancePointBBoxCoder
from mmdet.models.task_modules.prior_generators.point_generator import \
MlvlPointGenerator
model.merge(
dict(
bbox_head=dict(
_delete_=True,
type=RTMDetInsSepBNHead,
num_classes=80,
in_channels=256,
stacked_convs=2,
share_conv=True,
pred_kernel_size=1,
feat_channels=256,
act_cfg=dict(type=SiLU, inplace=True),
norm_cfg=dict(type='SyncBN', requires_grad=True),
anchor_generator=dict(
type=MlvlPointGenerator, offset=0, strides=[8, 16, 32]),
bbox_coder=dict(type=DistancePointBBoxCoder),
loss_cls=dict(
type=QualityFocalLoss,
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type=GIoULoss, loss_weight=2.0),
loss_mask=dict(
type=DiceLoss, loss_weight=2.0, eps=5e-6, reduction='mean')),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100,
mask_thr_binary=0.5),
))
train_pipeline = [
dict(type=LoadImageFromFile, backend_args=backend_args),
dict(
type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False),
dict(type=CachedMosaic, img_scale=(640, 640), pad_val=114.0),
dict(
type=RandomResize,
scale=(1280, 1280),
ratio_range=(0.1, 2.0),
resize_type=Resize,
keep_ratio=True),
dict(
type=RandomCrop,
crop_size=(640, 640),
recompute_bbox=True,
allow_negative_crop=True),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(
type=CachedMixUp,
img_scale=(640, 640),
ratio_range=(1.0, 1.0),
max_cached_images=20,
pad_val=(114, 114, 114)),
dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)),
dict(type=PackDetInputs)
]
train_dataloader.update(
dict(pin_memory=True, dataset=dict(pipeline=train_pipeline)))
train_pipeline_stage2 = [
dict(type=LoadImageFromFile, backend_args=backend_args),
dict(
type=LoadAnnotations, with_bbox=True, with_mask=True, poly2mask=False),
dict(
type=RandomResize,
scale=(640, 640),
ratio_range=(0.1, 2.0),
resize_type=Resize,
keep_ratio=True),
dict(
type=RandomCrop,
crop_size=(640, 640),
recompute_bbox=True,
allow_negative_crop=True),
dict(type=FilterAnnotations, min_gt_bbox_wh=(1, 1)),
dict(type=YOLOXHSVRandomAug),
dict(type=RandomFlip, prob=0.5),
dict(type=Pad, size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(type=PackDetInputs)
]
custom_hooks = [
dict(
type=EMAHook,
ema_type=ExpMomentumEMA,
momentum=0.0002,
update_buffers=True,
priority=49),
dict(
type=PipelineSwitchHook,
switch_epoch=280,
switch_pipeline=train_pipeline_stage2)
]
val_evaluator.update(dict(metric=['bbox', 'segm']))
test_evaluator = val_evaluator
|